{-# LANGUAGE FlexibleContexts #-} module Data.Eigen.Util ( -- | Basic operations on rows and columns on rows and columns on rows and -- columns on rows and columns rowAdd , rowsAdd , colAdd , colsAdd , scaleRow , scaleCol -- | Matrix creation from list , fromList' -- | stacking functions , hstack , vstack -- | Function to manipulate matrices , delRow , delRows , delCol , delCols -- | Kronecker product of two matrix , kronecker -- | Display matrix , pprint , pprintIO ) where import Data.Eigen.Matrix as E import Data.Maybe (fromJust) import Data.List as L import Data.Vector.Storable as V import Text.Printf (printf, PrintfArg) -- | to2DList turns a 1-d list to 2D list. to2DList _ [] = [] to2DList n es = row : to2DList n rest where ( row, rest ) = L.splitAt n es {- | Alternative implementation of fromList. It accepts a flatten list of -- elements and number of columns. -- No tests are performed to check if number of elements in list are sufficient. --} fromList' :: E.Elem a b => Int -> [ a ] -> E.Matrix a b fromList' n elems = E.fromList $ to2DList n elems -- | Stack matrices horizontallly hstack :: E.Elem a b => [ E.Matrix a b ] -> E.Matrix a b hstack mats = E.generate allrows allcols generateFunc where allcols = L.sum ncols allrows = E.rows $ L.head mats ncols = L.map E.cols mats whichMat c = fromJust $ L.findIndex (>c) $ L.scanl1 (+) ncols generateFunc i j = (mats!!matId) E.! (i, j - (L.sum $ L.take matId ncols) ) where matId = whichMat j {- | Stack given matrices vertically. It uses the following property -- vstack [a, b, c ..] = ( hstack [a',b',c'.. ] )' where M' is transpose of -- matrix M. -- -- TODO: This is computationally inefficient than implementing is directly like -- hstack. -} vstack :: E.Elem a b => [ E.Matrix a b ] -> E.Matrix a b vstack mats = E.transpose $ hstack $ L.map E.transpose mats -- | Kronecker matric multiplication. kronecker :: E.Elem a b => E.Matrix a b -> E.Matrix a b -> E.Matrix a b kronecker mat1 mat2 = E.imap (\i j v -> (mat1 E.! (div i x, div j y)) * ( mat2 E.! (rem i x, rem j y)) ) resMat where [ (p,q), (x,y) ] = [ E.dims mat1, E.dims mat2 ] (r, c) = ( p*x, q*y ) resMat = E.zero r c -- | rowAdd r1 = r1 + k * r2 rowAdd :: E.Elem a b => Int -> (a, Int) -> E.Matrix a b -> E.Matrix a b rowAdd r1 (k,r2) mat = E.imap ( \i j v -> if i == r1 then v + k * ( mat E.! (r2,j) ) else v ) mat -- | colAdd c1 = c1 + k * c2 colAdd :: E.Elem a b => Int -> (a, Int ) -> E.Matrix a b -> E.Matrix a b colAdd c1 (k, c2) mat = E.imap ( \i j v -> if j == c1 then v + k * ( mat E.! (i,c2) ) else v ) mat {- | Adds a list of given columns with a list weights to the first column in the list. - Note that first value in the list of weights is ignored -} colsAdd :: E.Elem a b => [ Int ] -> [ a ] -> E.Matrix a b -> E.Matrix a b colsAdd (c:[]) _ m = m colsAdd (c:c1:cols) (w:w1:ws) m = colsAdd (c:cols) (w:ws) $ colAdd c (w1,c1) m {- | Adds a list of given rows with a list weights to the first row in the list. - Note that first value in the list of weights is ignored -} rowsAdd :: E.Elem a b => [ Int ] -> [ a ] -> E.Matrix a b -> E.Matrix a b rowsAdd (r:[]) _ m = m rowsAdd (r:r1:rows) (w:w1:ws) m = rowsAdd (r:rows) (w:ws) $ rowAdd r (w1,r1) m -- | scale a row by a factor scaleRow :: E.Elem a b => Int -> a -> E.Matrix a b -> E.Matrix a b scaleRow row c mat = E.imap ( \i j v -> if i == row then c * v else v ) mat -- | scale a column by a factor scaleCol :: E.Elem a b => Int -> a -> E.Matrix a b -> E.Matrix a b scaleCol col c mat = E.imap ( \i j v -> if j == col then c * v else v ) mat -- Utility function to delete given element from the list deleteAt :: Int -> [a] -> [a] deleteAt n ls = let (ys,zs) = L.splitAt n ls in ys L.++ (L.tail zs) -- | delete a row delRow :: E.Elem a b => Int -> E.Matrix a b -> E.Matrix a b delRow r mat = E.fromList $ deleteAt r $ E.toList mat -- | delete a column delCol :: E.Elem a b => Int -> E.Matrix a b -> E.Matrix a b delCol c mat = E.fromList $ L.map (\row -> deleteAt c row) $ E.toList mat -- | delete list of given rows delRows :: E.Elem a b => [ Int ] -> E.Matrix a b -> E.Matrix a b delRows rows mat = delRows' (L.sort rows) mat where delRows' [] mat = mat delRows' (r:rs) mat = delRows' (L.map (\e->e-1) rs) (delRow r mat) -- | delete list of given columns delCols :: E.Elem a b => [ Int ] -> E.Matrix a b -> E.Matrix a b delCols cols mat = delCols' (L.sort cols) mat where delCols' [] mat = mat delCols' (c:cs) mat = delCols' (L.map (\e->e-1) cs) (delCol c mat) -- | Pretty print the matrix pprint :: (PrintfArg a, E.Elem a b) => E.Matrix a b -> String pprint mat = L.unlines $ -- construct each row L.map (\i -> L.unwords $ L.map (\e -> printf "%.5f" e) (E.row i mat)) $ -- all rows L.take (E.rows mat) [0,1..] -- | print matrix in IO monad pprintIO :: (PrintfArg a, E.Elem a b) => E.Matrix a b -> IO () pprintIO mat = do putStrLn $ pprint mat