module Math.LinearAlgebra.Sparse.Matrix
where
import Data.Functor
import Data.Foldable as F
import Data.List as L
import Data.IntMap as M hiding ((!))
import Data.Monoid
import Math.LinearAlgebra.Sparse.Vector
type SMx α = SVec (SVec α)
data SparseMatrix α = SM
{ dims :: (Int,Int)
, mx :: SMx α
} deriving Eq
instance Functor SparseMatrix where
fmap f m = m {mx = fmap (fmap f) (mx m)}
instance (Eq α, Num α) => Num (SparseMatrix α) where
(SM (h1,w1) m) + (SM (h2,w2) n)
= SM (max h1 h2, max w1 w2) $ M.filter (not . M.null)
(unionWith (unionWith (+)) m n)
(SM (h1,w1) m) * (SM (h2,w2) n)
= SM (max h1 h2, max w1 w2) $ M.filter (not . M.null)
(intersectionWith (intersectionWith (*)) m n)
negate = fmap negate
fromInteger x = diagonalMx [fromInteger x]
abs = fmap abs
signum = fmap signum
height, width :: SparseMatrix α -> Int
height = fst . dims
width = snd . dims
setSize :: (Num α) => (Int,Int) -> SparseMatrix α -> SparseMatrix α
setSize s m = m { dims = s }
emptyMx :: SparseMatrix α
emptyMx = SM (0,0) M.empty
zeroMx :: Num α => (Int, Int) -> SparseMatrix α
zeroMx (h,w) = setSize (h,w) emptyMx
isZeroVec, isNotZeroVec :: SparseVector α -> Bool
isZeroVec = M.null . vec
isNotZeroVec = not . isZeroVec
isZeroMx, isNotZeroMx :: SparseMatrix α -> Bool
isZeroMx = M.null . mx
isNotZeroMx = not . isZeroMx
idMx :: (Num α, Eq α) => Int -> SparseMatrix α
idMx n = diagonalMx (L.replicate n 1)
addRow :: (Num α) => SparseVector α -> Index -> SparseMatrix α -> SparseMatrix α
addRow v i m
| isZeroMx m = SM (1, dim v)
(singleton 1 (vec v))
| otherwise = SM (height m + 1, max (width m) (dim v))
(addElem mbv i (mx m))
where mbv = if isZeroVec v then Nothing else Just (vec v)
addCol :: (Num α) => SparseVector α -> Index -> SparseMatrix α -> SparseMatrix α
addCol v j m
| isZeroMx m = SM (dim v, 1)
(M.map (singleton 1) (vec v))
| otherwise = SM (max (height m) (dim v), width m + 1)
(M.mapWithKey insCol (mx m))
where insCol i row = addElem (M.lookup i (vec v)) j row
addZeroRow :: Num α => Index -> SparseMatrix α -> SparseMatrix α
addZeroRow i m = addRow (zeroVec (width m)) i m
addZeroCol :: Num α => Index -> SparseMatrix α -> SparseMatrix α
addZeroCol i m = addCol (zeroVec (height m)) i m
delRow :: (Num α) => Index -> SparseMatrix α -> SparseMatrix α
delRow i m | isZeroMx m = m
| otherwise = SM (height m 1, width m)
(delElem i (mx m))
delCol :: (Num α) => Index -> SparseMatrix α -> SparseMatrix α
delCol j m | isZeroMx m = m
| otherwise = SM (height m, width m 1)
(M.map (delElem j) (mx m))
delRowCol :: Num α => Index -> Index -> SparseMatrix α -> SparseMatrix α
delRowCol i j m = delCol j (delRow i m)
partitionMx :: (Num α) => (SparseVector α -> Bool) -> SparseMatrix α -> (SparseMatrix α, SparseMatrix α)
partitionMx p (SM (h,w) m) = (SM (st,w) t, SM (hst,w) f)
where (t,f) = partitionMap (p . SV w) m
st = size t
separateMx :: (Num α) => (SparseVector α -> Bool) -> SparseMatrix α -> (SparseMatrix α, SparseMatrix α)
separateMx p (SM (h,w) m) = (SM (h,w) t, SM (h,w) f)
where (t,f) = M.partition (p . SV w) m
st = size t
(#) :: (Num α) => SparseMatrix α -> (Index,Index) -> α
m # (i,j) = maybe 0 (findWithDefault 0 j) (M.lookup i (mx m))
row :: (Num α) => SparseMatrix α -> Index -> SparseVector α
m `row` i = SV (width m) (findWithDefault M.empty i (mx m))
col :: (Num α, Eq α) => SparseMatrix α -> Index -> SparseVector α
m `col` i = (trans m) `row` i
updRow :: (Num α) => SparseMatrix α -> (SparseVector α -> SparseVector α) -> Index -> SparseMatrix α
updRow m f i = m { mx = M.adjust f' i (mx m) }
where f' = vec . f . SV (width m)
eraseRow :: (Num α) => SparseMatrix α -> Index -> SparseMatrix α
m `eraseRow` i = m { mx = M.delete i (mx m) }
erase :: (Num α) => SparseMatrix α -> (Index,Index) -> SparseMatrix α
m `erase` (i,j) = if isZeroVec (m' `row` i)
then m' `eraseRow` i
else m'
where m' = updRow m (`eraseInVec` j) i
ins :: (Num α, Eq α) => SparseMatrix α -> ((Index,Index), α) -> SparseMatrix α
m `ins` ((i,j),0) = m `erase` (i,j)
m `ins` ((i,j),x) = m { mx = newMx }
where newMx = M.insertWith' M.union i (M.singleton j x) (mx m)
findRowIndices :: (SparseVector α -> Bool) -> SparseMatrix α -> [Key]
findRowIndices p m = fst $ M.mapAccumRWithKey (\acc i x -> (if p (SV (width m) x) then i:acc else acc,x)) [] (mx m)
findRowIndicesR :: (SparseVector α -> Bool) -> SparseMatrix α -> [Key]
findRowIndicesR p m = fst $ M.mapAccumWithKey (\acc i x -> (if p (SV (width m) x) then i:acc else acc,x)) [] (mx m)
diagonalMx :: (Num α, Eq α) => [α] -> SparseMatrix α
diagonalMx = L.foldl add emptyMx
where add m x = let i = height m + 1
in setSize (i,i) (m `ins` ((i,i),x))
fromRows :: (Num α) => [SparseVector α] -> SparseMatrix α
fromRows = L.foldl (\m r -> addRow r (height m + 1) m) emptyMx
toAssocList :: (Num α, Eq α) => SparseMatrix α -> [ ((Index,Index), α) ]
toAssocList (SM s m) = (s, 0) :
[ ((i,j), x) | (i,row) <- M.toAscList m, (j,x) <- M.toAscList row, x /= 0 ]
fromAssocList :: (Num α, Eq α) => [ ((Index,Index), α) ] -> SparseMatrix α
fromAssocList l = let size = L.maximum $ fmap fst l
m = L.foldl ins emptyMx l
in m { dims = size }
fillMx :: (Num α) => SparseMatrix α -> [[α]]
fillMx m = [ [ m # (i,j) | j <- [1 .. width m] ]
| i <- [1 .. height m] ]
sparseMx :: (Num α, Eq α) => [[α]] -> SparseMatrix α
sparseMx [] = emptyMx
sparseMx m@(r:_) = SM (length m, length r) $ M.fromList
[ (i,row) | (i,row) <- zipWith pair [1..] m, not (M.null row) ]
where pair i r = (i, vec (sparseList r))
instance (Show α, Eq α, Num α) => Show (SparseMatrix α) where
show = showSparseMatrix . fillMx
showSparseMatrix :: (Show α, Eq α, Num α) => [[α]] -> String
showSparseMatrix [] = "[]"
showSparseMatrix m = show (length m, length (head m))++": \n"++
(unlines $ L.map (("["++) . (++"]") . intercalate "|")
$ transpose $ L.map column $ transpose m)
column :: (Show α, Eq α, Num α) => [α] -> [String]
column c = let c' = L.map showNonZero c
width = L.maximum $ L.map length c'
offset x = replicate (width (length x)) ' ' ++ x
in L.map offset c'
trans :: (Num α, Eq α) => SparseMatrix α -> SparseMatrix α
trans m = let indexes = [ (i,j) | i <- [1 .. height m], j <- [1 .. width m] ]
add acc (i,j) = acc `ins` ((j,i), m # (i,j))
mt = F.foldl' add emptyMx indexes
in mt { dims = (width m, height m) }
mulMV :: (Num α, Eq α) => SparseMatrix α -> SparseVector α -> SparseVector α
mulMV = (×·)
(×·) :: (Num α, Eq α) => SparseMatrix α -> SparseVector α -> SparseVector α
(SM (h,_) m) ×· (SV _ v) = SV h (M.filter (0/=) (M.map (v··) m))
mulVM :: (Num α, Eq α) => SparseVector α -> SparseMatrix α -> SparseVector α
mulVM = (·×)
(·×) :: (Num α, Eq α) => SparseVector α -> SparseMatrix α -> SparseVector α
v ·× m = (trans m) ×· v
mul :: (Num α, Eq α) => SparseMatrix α -> SparseMatrix α -> SparseMatrix α
mul = (×)
(×) :: (Num α, Eq α) => SparseMatrix α -> SparseMatrix α -> SparseMatrix α
a × b = let d = (height a, width b)
bt = mx (trans b)
m = M.filter (not . M.null)
$ M.map (\aRow -> M.filter (0/=) (M.map (aRow··) bt)) (mx a)
in SM d m