module Data.Eigen.Util (
rowAdd
, rowsAdd
, colAdd
, colsAdd
, scaleRow
, scaleCol
, fromList'
, hstack
, vstack
, delRow
, delRows
, delCol
, delCols
, kronecker
, pprint
, pprintIO
) where
import Data.Eigen.Matrix as E
import Data.Maybe (fromJust)
import Data.List as L
import Data.Vector.Storable as V
import Text.Printf (printf, PrintfArg)
to2DList _ [] = []
to2DList n es = row : to2DList n rest
where
( row, rest ) = L.splitAt n es
fromList' :: E.Elem a b => Int -> [ a ] -> E.Matrix a b
fromList' n elems = E.fromList $ to2DList n elems
hstack :: E.Elem a b => [ E.Matrix a b ] -> E.Matrix a b
hstack mats = E.generate allrows allcols generateFunc
where
allcols = L.sum ncols
allrows = E.rows $ L.head mats
ncols = L.map E.cols mats
whichMat c = fromJust $ L.findIndex (>c) $ L.scanl1 (+) ncols
generateFunc i j = (mats!!matId) E.! (i, j (L.sum $ L.take matId ncols) )
where
matId = whichMat j
vstack :: E.Elem a b => [ E.Matrix a b ] -> E.Matrix a b
vstack mats = E.transpose $ hstack $ L.map E.transpose mats
kronecker :: E.Elem a b => E.Matrix a b -> E.Matrix a b -> E.Matrix a b
kronecker mat1 mat2 = E.imap (\i j v ->
(mat1 E.! (div i x, div j y)) * ( mat2 E.! (rem i x, rem j y))
) resMat
where
[ (p,q), (x,y) ] = [ E.dims mat1, E.dims mat2 ]
(r, c) = ( p*x, q*y )
resMat = E.zero r c
rowAdd :: E.Elem a b => Int -> (a, Int) -> E.Matrix a b -> E.Matrix a b
rowAdd r1 (k,r2) mat = E.imap (
\i j v -> if i == r1 then v + k * ( mat E.! (r2,j) ) else v ) mat
colAdd :: E.Elem a b => Int -> (a, Int ) -> E.Matrix a b -> E.Matrix a b
colAdd c1 (k, c2) mat = E.imap (
\i j v -> if j == c1 then v + k * ( mat E.! (i,c2) ) else v ) mat
colsAdd :: E.Elem a b => [ Int ] -> [ a ] -> E.Matrix a b -> E.Matrix a b
colsAdd (c:[]) _ m = m
colsAdd (c:c1:cols) (w:w1:ws) m = colsAdd (c:cols) (w:ws) $ colAdd c (w1,c1) m
rowsAdd :: E.Elem a b => [ Int ] -> [ a ] -> E.Matrix a b -> E.Matrix a b
rowsAdd (r:[]) _ m = m
rowsAdd (r:r1:rows) (w:w1:ws) m = rowsAdd (r:rows) (w:ws) $ rowAdd r (w1,r1) m
scaleRow :: E.Elem a b => Int -> a -> E.Matrix a b -> E.Matrix a b
scaleRow row c mat = E.imap ( \i j v -> if i == row then c * v else v ) mat
scaleCol :: E.Elem a b => Int -> a -> E.Matrix a b -> E.Matrix a b
scaleCol col c mat = E.imap ( \i j v -> if j == col then c * v else v ) mat
deleteAt :: Int -> [a] -> [a]
deleteAt n ls = let (ys,zs) = L.splitAt n ls in ys L.++ (L.tail zs)
delRow :: E.Elem a b => Int -> E.Matrix a b -> E.Matrix a b
delRow r mat = E.fromList $ deleteAt r $ E.toList mat
delCol :: E.Elem a b => Int -> E.Matrix a b -> E.Matrix a b
delCol c mat = E.fromList $ L.map (\row -> deleteAt c row) $ E.toList mat
delRows :: E.Elem a b => [ Int ] -> E.Matrix a b -> E.Matrix a b
delRows rows mat = delRows' (L.sort rows) mat
where
delRows' [] mat = mat
delRows' (r:rs) mat = delRows' (L.map (\e->e1) rs) (delRow r mat)
delCols :: E.Elem a b => [ Int ] -> E.Matrix a b -> E.Matrix a b
delCols cols mat = delCols' (L.sort cols) mat
where
delCols' [] mat = mat
delCols' (c:cs) mat = delCols' (L.map (\e->e1) cs) (delCol c mat)
pprint :: (PrintfArg a, E.Elem a b) => E.Matrix a b -> String
pprint mat = L.unlines $
L.map (\i -> L.unwords $ L.map (\e -> printf "%.5f" e) (E.row i mat)) $
L.take (E.rows mat) [0,1..]
pprintIO :: (PrintfArg a, E.Elem a b) => E.Matrix a b -> IO ()
pprintIO mat = do
putStrLn $ pprint mat