module Math.LinearAlgebra.Sparse.Matrix
(
SMx, SparseMatrix (..),
height, width, setSize, emptyMx, zeroMx, isZeroMx, isNotZeroMx , idMx,
(//), hconcat, vconcat, sizedBlockMx, sizedBlockSMx, blockMx, blockSMx,
addRow, addCol, addZeroRow, addZeroCol, delRow, delCol, delRowCol, separateMx,
(#), row, col, updRow, eraseRow, erase, ins, findRowIndices, findRowIndicesR, popRow, (|>), (<|), replaceRow, exchangeRows, mapOnRows,
diagonalMx, mainDiag, fromRows, toAssocList, fromAssocListWithSize, fromAssocList, fillMx, sparseMx,
trans,
mulMV, (×·) , mulVM, (·×) , mul, (×),
)
where
import Math.LinearAlgebra.Sparse.IntMapUtilities
import Data.Functor
import Data.Foldable as F
import Data.List as L
import Data.IntMap as M hiding ((!))
import Data.Monoid
import Math.LinearAlgebra.Sparse.Vector
type SMx α = SVec (SVec α)
data SparseMatrix α = SM
{ dims :: (Int,Int)
, mx :: SMx α
} deriving Eq
instance Functor SparseMatrix where
fmap f m = m {mx = fmap (fmap f) (mx m)}
instance (Eq α, Num α) => Num (SparseMatrix α) where
(SM (h1,w1) m) + (SM (h2,w2) n)
= SM (max h1 h2, max w1 w2) $ M.filter (not . M.null)
(unionWith (unionWith (+)) m n)
(SM (h1,w1) m) * (SM (h2,w2) n)
= SM (max h1 h2, max w1 w2) $ M.filter (not . M.null)
(intersectionWith (intersectionWith (*)) m n)
negate = fmap negate
fromInteger 0 = emptyMx
fromInteger x = diagonalMx [fromInteger x]
abs = fmap abs
signum = fmap signum
instance Monoid (SparseMatrix α) where
mempty = emptyMx
(SM (h1,w1) m) `mappend` (SM (h2,w2) n)
= SM (max h1 h2, w1 + w2) (M.unionWith M.union m (M.map (shiftKeys w1) n))
instance (Show α, Eq α, Num α) => Show (SparseMatrix α) where
show = showSparseMatrix . fillMx
showSparseMatrix :: (Show α, Eq α, Num α) => [[α]] -> String
showSparseMatrix [] = "(0,0):\n[]\n"
showSparseMatrix m = show (length m, length (head m))++": \n"++
(unlines $ L.map (("["++) . (++"]") . intercalate "|")
$ transpose $ L.map column $ transpose m)
column :: (Show α, Eq α, Num α) => [α] -> [String]
column c = let c' = L.map showNonZero c
width = L.maximum $ L.map length c'
offset x = replicate (width (length x)) ' ' ++ x
in L.map offset c'
showNonZero x = if x == 0 then " " else show x
height, width :: SparseMatrix α -> Int
height = fst . dims
width = snd . dims
setSize :: (Num α) => (Int,Int) -> SparseMatrix α -> SparseMatrix α
setSize s m = m { dims = s }
emptyMx :: SparseMatrix α
emptyMx = SM (0,0) M.empty
zeroMx :: Num α => (Int, Int) -> SparseMatrix α
zeroMx (h,w) = setSize (h,w) emptyMx
isZeroMx, isNotZeroMx :: SparseMatrix α -> Bool
isZeroMx = M.null . mx
isNotZeroMx = not . isZeroMx
idMx :: (Num α, Eq α) => Int -> SparseMatrix α
idMx n = diagonalMx (L.replicate n 1)
(//) :: SparseMatrix α -> SparseMatrix α -> SparseMatrix α
(SM (h1,w1) m) // (SM (h2,w2) n) =
SM (h1 + h2, max w1 w2) (m `M.union` (shiftKeys h1 n))
hconcat, vconcat :: [SparseMatrix α] -> SparseMatrix α
hconcat = L.foldl' (<>) emptyMx
vconcat = L.foldl' (//) emptyMx
sizedBlockMx :: Num α => (Int, Int) -> [[SparseMatrix α]] -> SparseMatrix α
sizedBlockMx s = blockMx . fmap (fmap (setSize s))
sizedBlockSMx :: (Eq α, Num α) =>(Int, Int) -> SparseMatrix (SparseMatrix α) -> SparseMatrix α
sizedBlockSMx s = sizedBlockMx s . fillMx
blockMx :: [[SparseMatrix α]] -> SparseMatrix α
blockMx = vconcat . fmap hconcat
blockSMx :: (Eq α, Num α) => SparseMatrix (SparseMatrix α) -> SparseMatrix α
blockSMx = blockMx . fillMx
addRow :: (Num α) => SparseVector α -> Index -> SparseMatrix α -> SparseMatrix α
addRow v i m = SM (height m + 1, max (width m) (dim v))
(addElem mbv i (mx m))
where mbv = if isZeroVec v then Nothing else Just (vec v)
addCol :: (Num α) => SparseVector α -> Index -> SparseMatrix α -> SparseMatrix α
addCol v j m = SM (max (height m) (dim v), width m + 1)
(M.mapWithKey insCol (mx m))
where insCol i row = addElem (M.lookup i (vec v)) j row
addZeroRow :: Num α => Index -> SparseMatrix α -> SparseMatrix α
addZeroRow i m = addRow (zeroVec (width m)) i m
addZeroCol :: Num α => Index -> SparseMatrix α -> SparseMatrix α
addZeroCol i m = addCol (zeroVec (height m)) i m
delRow :: (Num α) => Index -> SparseMatrix α -> SparseMatrix α
delRow i m | isZeroMx m = setSize (height m 1, width m) m
| otherwise = SM (height m 1, width m)
(delElem i (mx m))
delCol :: (Num α) => Index -> SparseMatrix α -> SparseMatrix α
delCol j m | isZeroMx m = setSize (height m, width m 1) m
| otherwise = SM (height m, width m 1)
(M.map (delElem j) (mx m))
delRowCol :: Num α => Index -> Index -> SparseMatrix α -> SparseMatrix α
delRowCol i j m = delCol j (delRow i m)
partitionMx :: (Num α) => (SparseVector α -> Bool) -> SparseMatrix α -> (SparseMatrix α, SparseMatrix α)
partitionMx p (SM (h,w) m) = (SM (st,w) t, SM (hst,w) f)
where (t,f) = partitionMap (p . SV w) m
st = size t
separateMx :: (Num α) => (SparseVector α -> Bool) -> SparseMatrix α -> (SparseMatrix α, SparseMatrix α)
separateMx p (SM (h,w) m) = (SM (h,w) t, SM (h,w) f)
where (t,f) = M.partition (p . SV w) m
st = size t
(#) :: (Num α) => SparseMatrix α -> (Index,Index) -> α
m # (i,j) = maybe 0 (findWithDefault 0 j) (M.lookup i (mx m))
row :: (Num α) => SparseMatrix α -> Index -> SparseVector α
m `row` i = SV (width m) (findWithDefault M.empty i (mx m))
col :: (Num α, Eq α) => SparseMatrix α -> Index -> SparseVector α
m `col` j = M.foldlWithKey' addElem (zeroVec (height m)) (mx m)
where addElem acc i row = maybe acc (\x -> acc `vecIns` (i,x)) (M.lookup j row)
updRow :: (Num α) => (SparseVector α -> SparseVector α) -> Index -> SparseMatrix α -> SparseMatrix α
updRow f i m = m { mx = M.adjust f' i (mx m) }
where f' = vec . f . SV (width m)
eraseRow :: (Num α) => Index -> SparseMatrix α -> SparseMatrix α
eraseRow i m = m { mx = M.delete i (mx m) }
erase :: (Num α) => SparseMatrix α -> (Index,Index) -> SparseMatrix α
m `erase` (i,j) = if isZeroVec (m' `row` i)
then eraseRow i m'
else m'
where m' = updRow (`eraseInVec` j) i m
ins :: (Num α, Eq α) => SparseMatrix α -> ((Index,Index), α) -> SparseMatrix α
m `ins` ((i,j),0) = m `erase` (i,j)
m `ins` ((i,j),x) = m { mx = newMx }
where newMx = M.insertWith' M.union i (M.singleton j x) (mx m)
findRowIndices :: (SparseVector α -> Bool) -> SparseMatrix α -> [Int]
findRowIndices p m = fst $ M.mapAccumRWithKey (\acc i x -> (if p (SV (width m) x) then i:acc else acc,x)) [] (mx m)
findRowIndicesR :: (SparseVector α -> Bool) -> SparseMatrix α -> [Int]
findRowIndicesR p m = fst $ M.mapAccumWithKey (\acc i x -> (if p (SV (width m) x) then i:acc else acc,x)) [] (mx m)
popRow :: Num α =>Index -> SparseMatrix α -> (SparseVector α, SparseMatrix α)
popRow i m = (m `row` i, delRow i m)
(|>) :: Num α => SparseVector α -> SparseMatrix α -> SparseMatrix α
r |> m = addRow r 1 m
(<|) :: Num α => SparseMatrix α -> SparseVector α -> SparseMatrix α
m <| r = addRow r (height m + 1) m
replaceRow :: Num α => SparseVector α -> Index -> SparseMatrix α -> SparseMatrix α
replaceRow r i m | isZeroVec r = eraseRow i m
| otherwise = m { mx = M.insert i (vec r) (mx m) }
exchangeRows :: Num α => Index -> Index -> SparseMatrix α -> SparseMatrix α
exchangeRows i j m | i == j = m
| otherwise = replaceRow (m `row` i) j
$ replaceRow (m `row` j) i m
mapOnRows :: (SparseVector α -> SparseVector β)-> SparseMatrix α -> SparseMatrix β
mapOnRows f m = m { mx = M.map (vec . f . (SV (width m))) (mx m) }
diagonalMx :: (Num α, Eq α) => [α] -> SparseMatrix α
diagonalMx = L.foldl add emptyMx
where add m x = let i = height m + 1
in setSize (i,i) (m `ins` ((i,i),x))
mainDiag :: (Eq α, Num α) => SparseMatrix α -> SparseVector α
mainDiag m = sparseList [ m#(i,i) | i <- [1 .. l] ]
where l = min (height m) (width m)
fromRows :: (Num α) => [SparseVector α] -> SparseMatrix α
fromRows = L.foldl (<|) emptyMx
toAssocList :: (Num α, Eq α) => SparseMatrix α -> [ ((Index,Index), α) ]
toAssocList (SM s m) = (s, 0) :
[ ((i,j), x) | (i,row) <- M.toAscList m, (j,x) <- M.toAscList row, x /= 0 ]
fromAssocListWithSize :: (Num α, Eq α) => (Int,Int) -> [ ((Index,Index), α) ] -> SparseMatrix α
fromAssocListWithSize s l = L.foldl' ins (zeroMx s) l
fromAssocList :: (Num α, Eq α) => [ ((Index,Index), α) ] -> SparseMatrix α
fromAssocList l = fromAssocListWithSize (L.maximum $ fmap fst l) l
fillMx :: (Num α) => SparseMatrix α -> [[α]]
fillMx m = [ [ m # (i,j) | j <- [1 .. width m] ]
| i <- [1 .. height m] ]
sparseMx :: (Num α, Eq α) => [[α]] -> SparseMatrix α
sparseMx [] = emptyMx
sparseMx m@(r:_) = SM (length m, length r) $ M.fromList
[ (i,row) | (i,row) <- zipWith pair [1..] m, not (M.null row) ]
where pair i r = (i, vec (sparseList r))
trans :: (Num α, Eq α) => SparseMatrix α -> SparseMatrix α
trans m = let mt = M.foldlWithKey' accRow emptyMx (mx m)
accRow acc i row = M.foldlWithKey' (accElem i) acc row
accElem i acc j x = acc `ins` ((j,i),x)
in setSize (width m, height m) mt
mulMV :: (Num α, Eq α) => SparseMatrix α -> SparseVector α -> SparseVector α
mulMV = (×·)
(×·) :: (Num α, Eq α) => SparseMatrix α -> SparseVector α -> SparseVector α
(SM (h,_) m) ×· (SV _ v) = SV h (M.filter (0/=) (M.map (v··) m))
mulVM :: (Num α, Eq α) => SparseVector α -> SparseMatrix α -> SparseVector α
mulVM = (·×)
(·×) :: (Num α, Eq α) => SparseVector α -> SparseMatrix α -> SparseVector α
v ·× m = (trans m) ×· v
mul :: (Num α, Eq α) => SparseMatrix α -> SparseMatrix α -> SparseMatrix α
mul = (×)
(×) :: (Num α, Eq α) => SparseMatrix α -> SparseMatrix α -> SparseMatrix α
a × b = let d = (height a, width b)
bt = mx (trans b)
m = M.filter (not . M.null)
$ M.map (\aRow -> M.filter (0/=) (M.map (aRow··) bt)) (mx a)
in SM d m