-- While working on this module you are encouraged to remove it and fix
-- any warnings in the module. See
--     http://hackage.haskell.org/trac/ghc/wiki/WorkingConventions#Warnings
-- for details  

-----------------------------------------------------------------------------
-- |
-- Module      :  MatrixList
-- Copyright   :  (c) Lennart Schmitt
-- License     :  BSD-style (see the file libraries/base/LICENSE)
-- 
-- Maintainer  :  lennart...schmitt@<nospam>gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- This module implements a list of matrixes and a few functions to handle them.
-----------------------------------------------------------------------------

module Numeric.MatrixList (
         -- * Data Types
         MatrixList, 
         RawMatrixList, 
         -- * Functions  
         Numeric.MatrixList.toLists,
         countMatrixElements,
         countMatrixes,
         countMatrixesCols,
         averages,
         crossProduct,
         mapVectors,
         foldVectors,
         mapElements,
         transposeAll,
         countRows,
         countAllRows
         )where

import Numeric.Vector 
import Numeric.Matrix

-- | A list of matrixes
type MatrixList a = [Matrix a]

-- | A list of matrixes represented as a list of lists of lists
type RawMatrixList a = [[[a]]]

-- | Transforms a Matrixlist to a RawMatrixList.
toLists :: (Numeric.Matrix.Element a) => MatrixList a -> RawMatrixList a
toLists = (map (map toList) ) . (map toColumns)

-- | Counts the number of elements per matrix
-- @ unused
countMatrixElements :: RawMatrixList a -> RawMatrix Double
countMatrixElements = foldVectors count

-- | Counts the number of matrixes in a list of matrixes
countMatrixes :: RawMatrixList a -> Double
countMatrixes = count

-- | Counts the number of cols (based on the guess that all matrixes have the similare structure)
countMatrixesCols :: RawMatrixList a -> Double
countMatrixesCols = count . head . head

-- | Calculate every cols averages
--
-- >   averages [[[1,2],[2,1]],[[2,3],[3,4]]] == [[1.5,1.5],[2.5,3.5]]
averages :: RawMatrixList Double -> RawMatrix Double
averages = foldVectors average

-- | Calculates the cross-product of every matrix-cols
--
-- > crossProduct  [[[1,2],[2,1]],[[2,3],[3,4]]] == [[2.0,2.0],[6.0,12.0]]
crossProduct :: RawMatrixList Double -> RawMatrix Double
crossProduct = foldVectors product

-- | maps a function over every vector of a list of matrixes
--
-- > mapVectors (map (1+)) [[[1,2],[2,1]],[[2,3],[3,4]]] == [[[2.0,3.0],[3.0,2.0]],[[3.0,4.0],[4.0,5.0]]]
mapVectors :: ([a] -> [b]) -> RawMatrixList a -> RawMatrixList b
mapVectors f = map $ map $ f

-- | folds every vector of a list of matrixes
--
-- > foldVectors sum [[[1,2],[2,1]],[[2,3],[3,4]]] == [[[2.0,3.0],[3.0,2.0]],[[3.0,4.0],[4.0,5.0]]]
foldVectors :: ([a] -> b) -> RawMatrixList a -> RawMatrix b
foldVectors f = map $ map $ f

-- | maps a function over every element of a list of matrixes
--
-- > mapElements (1+) [[[1,2],[2,1]],[[2,3],[3,4]]] == [[[2.0,3.0],[3.0,2.0]],[[3.0,4.0],[4.0,5.0]]]
mapElements :: (a -> b) -> RawMatrixList a -> RawMatrixList b
mapElements f = mapVectors $ map f

-- | Transposes every matrix in a lit of matrixes
transposeAll :: RawMatrixList a -> RawMatrixList a
transposeAll = map transpose

-- | Counts the rows of every matrix in the list
--
-- > countRows [[[1,2],[2,1]],[[2,3],[3,4],[1,1]]] == [2.0,3.0]
countRows :: RawMatrixList a -> RawVector Double
countRows = map count

-- | Counts the sum of all matrixes-rows
-- 
-- > countAllRows [[[1,2],[2,1]],[[2,3],[3,4],[1,1]]] == 5.0
countAllRows :: RawMatrixList a -> Double
countAllRows = sum . countRows