module Data.Sparse.SpMatrix where
import Data.Sparse.Utils
import Data.Sparse.Types
import Numeric.Eps
import Numeric.LinearAlgebra.Class
import Data.Sparse.IntMap2.IntMap2
import qualified Data.IntMap as IM
import Data.Maybe
data SpMatrix a = SM {smDim :: (Rows, Cols),
smData :: IM.IntMap (IM.IntMap a)} deriving Eq
sizeStr :: SpMatrix a -> String
sizeStr sm =
unwords ["(",show (nrows sm),"rows,",show (ncols sm),"columns ) ,",show nz,"NZ ( sparsity",show sy,")"] where
(SMInfo nz sy) = infoSM sm
instance Show a => Show (SpMatrix a) where
show sm@(SM _ x) = "SM " ++ sizeStr sm ++ " "++ show (IM.toList x)
instance Functor SpMatrix where
fmap f (SM d md) = SM d ((fmap . fmap) f md)
instance Set SpMatrix where
liftU2 f2 (SM n1 x1) (SM n2 x2) = SM (maxTup n1 n2) ((liftU2.liftU2) f2 x1 x2)
liftI2 f2 (SM n1 x1) (SM n2 x2) = SM (minTup n1 n2) ((liftI2.liftI2) f2 x1 x2)
instance Additive SpMatrix where
zero = SM (0,0) IM.empty
(^+^) = liftU2 (+)
instance FiniteDim SpMatrix where
type FDSize SpMatrix = (Rows, Cols)
dim = smDim
instance HasData SpMatrix a where
type HDData SpMatrix a = IM.IntMap (IM.IntMap a)
dat = smData
instance Sparse SpMatrix a where
spy = spySM
zeroSM :: Rows -> Cols -> SpMatrix a
zeroSM m n = SM (m,n) IM.empty
mkDiagonal :: Int -> [a] -> SpMatrix a
mkDiagonal n = mkSubDiagonal n 0
eye :: Num a => Int -> SpMatrix a
eye n = mkDiagonal n (replicate n 1)
permutationSM :: Num a => Int -> [IxRow] -> SpMatrix a
permutationSM n iis = permutPairsSM n (zip [0 .. n1] iis)
permutPairsSM :: Num a => Int -> [(IxRow, IxRow)] -> SpMatrix a
permutPairsSM n iix = go iix (eye n) where
go ((i1, i2):iis) m = go iis (swapRows i1 i2 m)
go [] m = m
mkSubDiagonal :: Int -> Int -> [a] -> SpMatrix a
mkSubDiagonal n o xx | abs o < n = if o >= 0
then fz ii jj xx
else fz jj ii xx
| otherwise = error "mkSubDiagonal : offset > dimension" where
ii = [0 .. n1]
jj = [abs o .. n 1]
fz a b x = fromListSM (n,n) (zip3 a b x)
insertSpMatrix :: IxRow -> IxCol -> a -> SpMatrix a -> SpMatrix a
insertSpMatrix i j x s
| inBounds02 d (i,j) = SM d $ insertIM2 i j x smd
| otherwise = error "insertSpMatrix : index out of bounds" where
smd = immSM s
d = dim s
fromListSM' :: Foldable t => t (IxRow, IxCol, a) -> SpMatrix a -> SpMatrix a
fromListSM' iix sm = foldl ins sm iix where
ins t (i,j,x) = insertSpMatrix i j x t
fromListSM :: Foldable t => (Int, Int) -> t (IxRow, IxCol, a) -> SpMatrix a
fromListSM (m,n) iix = fromListSM' iix (zeroSM m n)
fromListDenseSM :: Int -> [a] -> SpMatrix a
fromListDenseSM m ll = fromListSM (m, n) $ denseIxArray2 m ll where
n = length ll `div` m
toDenseListSM :: Num t => SpMatrix t -> [(IxRow, IxCol, t)]
toDenseListSM m =
[(i, j, m @@ (i, j)) | i <- [0 .. nrows m 1], j <- [0 .. ncols m 1]]
lookupSM :: SpMatrix a -> IxRow -> IxCol -> Maybe a
lookupSM (SM _ im) i j = IM.lookup i im >>= IM.lookup j
lookupWD_SM, (@@!), (@@) :: Num a => SpMatrix a -> (IxRow, IxCol) -> a
lookupWD_SM sm (i,j) =
fromMaybe 0 (lookupSM sm i j)
lookupWD_IM :: Num a => IM.IntMap (IM.IntMap a) -> (IxRow, IxCol) -> a
lookupWD_IM im (i,j) = fromMaybe 0 (IM.lookup i im >>= IM.lookup j)
(@@!) = lookupWD_SM
m @@ d | isValidIxSM m d = m @@! d
| otherwise = error $ "@@ : incompatible indices : matrix size is " ++ show (dim m) ++ ", but user looked up " ++ show d
filterSM :: (IM.Key -> IM.Key -> a -> Bool) -> SpMatrix a -> SpMatrix a
filterSM f sm = SM (dim sm) $ ifilterIM2 f (dat sm)
extractDiag, extractSuperDiag, extractSubDiag :: SpMatrix a -> SpMatrix a
extractSubDiag = filterSM (\i j _ -> i > j)
extractSuperDiag = filterSM (\i j _ -> i < j)
extractDiag = filterSM (\i j _ -> i == j)
extractSubmatrixSM ::
(IM.Key -> IM.Key) ->
(IM.Key -> IM.Key) ->
SpMatrix a ->
(IxRow, IxRow) -> (IxCol, IxCol) ->
SpMatrix a
extractSubmatrixSM fi gi (SM (r, c) im) (i1, i2) (j1, j2)
| q = SM (m', n') imm'
| otherwise = error $ "extractSubmatrixSM : invalid index " ++ show (i1, i2) ++ ", " ++ show (j1, j2) where
imm' = mapKeysIM2 fi gi $
IM.filter (not . IM.null) $
ifilterIM2 ff im
ff i j _ = i1 <= i &&
i <= i2 &&
j1 <= j &&
j <= j2
(m', n') = (i2i1 + 1, j2j1 + 1)
q = inBounds0 r i1 &&
inBounds0 r i2 &&
inBounds0 c j1 &&
inBounds0 c j2 &&
i2 >= i1
extractSubmatrixRebalanceKeys ::
SpMatrix a -> (IxRow, IxRow) -> (IxCol, IxCol) -> SpMatrix a
extractSubmatrixRebalanceKeys mm (i1,i2) (j1,j2) =
extractSubmatrixSM (\i -> i i1) (\j -> j j1) mm (i1,i2) (j1,j2)
extractSubmatrix :: SpMatrix a -> (IxRow, IxRow) -> (IxCol, IxCol) -> SpMatrix a
extractSubmatrix = extractSubmatrixSM id id
extractRowSM :: SpMatrix a -> IxRow -> SpMatrix a
extractRowSM sm i = extractSubmatrix sm (i, i) (0, ncols sm 1)
extractSubRowSM :: SpMatrix a -> IxRow -> (IxCol, IxCol) -> SpMatrix a
extractSubRowSM sm i (j1, j2) = extractSubmatrix sm (i, i) (j1, j2)
extractSubRowSM_RK :: SpMatrix a -> IxRow -> (IxCol, IxCol) -> SpMatrix a
extractSubRowSM_RK sm i =
extractSubmatrixRebalanceKeys sm (i, i)
extractColSM :: SpMatrix a -> IxCol -> SpMatrix a
extractColSM sm j = extractSubmatrix sm (0, nrows sm 1) (j, j)
extractSubColSM :: SpMatrix a -> IxCol -> (IxRow, IxRow) -> SpMatrix a
extractSubColSM sm j (i1, i2) = extractSubmatrix sm (i1, i2) (j, j)
extractSubColSM_RK :: SpMatrix a -> IxCol -> (IxRow, IxRow) -> SpMatrix a
extractSubColSM_RK sm j (i1, i2) =
extractSubmatrixRebalanceKeys sm (i1, i2) (j, j)
isValidIxSM :: SpMatrix a -> (Int, Int) -> Bool
isValidIxSM mm = inBounds02 (dim mm)
isSquareSM :: SpMatrix a -> Bool
isSquareSM m = nrows m == ncols m
isDiagonalSM :: SpMatrix a -> Bool
isDiagonalSM m = IM.size d == nrows m where
d = IM.filterWithKey ff (immSM m)
ff irow row = IM.size row == 1 &&
IM.size (IM.filterWithKey (\j _ -> j == irow) row) == 1
isOrthogonalSM :: SpMatrix Double -> Bool
isOrthogonalSM sm@(SM (_,n) _) = rsm == eye n where
rsm = roundZeroOneSM $ transposeSM sm ## sm
immSM :: SpMatrix t -> IM.IntMap (IM.IntMap t)
immSM (SM _ imm) = imm
dimSM :: SpMatrix t -> (Rows, Cols)
dimSM (SM d _) = d
nelSM :: SpMatrix t -> Int
nelSM (SM (nr,nc) _) = nr*nc
nrows :: SpMatrix a -> Rows
nrows = fst . dim
ncols :: SpMatrix a -> Cols
ncols = snd . dim
data SMInfo = SMInfo { smNz :: Int,
smSpy :: Double} deriving (Eq, Show)
infoSM :: SpMatrix a -> SMInfo
infoSM s = SMInfo (nzSM s) (spySM s)
nzSM :: SpMatrix a -> Int
nzSM s = sum $ fmap IM.size (immSM s)
spySM :: Fractional b => SpMatrix a -> b
spySM s = fromIntegral (nzSM s) / fromIntegral (nelSM s)
nzRow :: SpMatrix a -> IM.Key -> Int
nzRow s i | inBounds0 (nrows s) i = nzRowU s i
| otherwise = error "nzRow : index out of bounds" where
nzRowU :: SpMatrix a -> IM.Key -> Int
nzRowU s i = maybe 0 IM.size (IM.lookup i $ immSM s)
bwMinSM :: SpMatrix a -> Int
bwMinSM = fst . bwBoundsSM
bwMaxSM :: SpMatrix a -> Int
bwMaxSM = snd . bwBoundsSM
bwBoundsSM :: SpMatrix a -> (Int, Int)
bwBoundsSM s =
(snd $ IM.findMin b,
snd $ IM.findMax b)
where
ss = immSM s
fmi = fst . IM.findMin
fma = fst . IM.findMax
b = fmap (\x -> fma x fmi x + 1:: Int) ss
vertStackSM, (-=-) :: SpMatrix a -> SpMatrix a -> SpMatrix a
vertStackSM mm1 mm2 = SM (m, n) $ IM.union u1 u2 where
nro1 = nrows mm1
m = nro1 + nrows mm2
n = max (ncols mm1) (ncols mm2)
u1 = immSM mm1
u2 = IM.mapKeys (+ nro1) (immSM mm2)
(-=-) = vertStackSM
horizStackSM, (-||-) :: SpMatrix a -> SpMatrix a -> SpMatrix a
horizStackSM mm1 mm2 = t (t mm1 -=- t mm2) where
t = transposeSM
(-||-) = horizStackSM
foldlSM :: (a -> b -> b) -> b -> SpMatrix a -> b
foldlSM f n (SM _ m)= foldlIM2 f n m
ifoldlSM :: (IM.Key -> IM.Key -> a -> b -> b) -> b -> SpMatrix a -> b
ifoldlSM f n (SM _ m) = ifoldlIM2' f n m
countSubdiagonalNZSM :: SpMatrix a -> Int
countSubdiagonalNZSM (SM _ im) = countSubdiagonalNZ im
subdiagIndicesSM :: SpMatrix a -> [(IxRow, IxCol)]
subdiagIndicesSM (SM _ im) = subdiagIndices im
sparsifyIM2 :: IM.IntMap (IM.IntMap Double) -> IM.IntMap (IM.IntMap Double)
sparsifyIM2 = ifilterIM2 (\_ _ x -> abs x >= eps)
sparsifySM :: SpMatrix Double -> SpMatrix Double
sparsifySM (SM d im) = SM d $ sparsifyIM2 im
roundZeroOneSM :: SpMatrix Double -> SpMatrix Double
roundZeroOneSM (SM d im) = sparsifySM $ SM d $ mapIM2 roundZeroOne im
swapRows :: IxRow -> IxRow -> SpMatrix a -> SpMatrix a
swapRows i1 i2 (SM d im) = SM d $ IM.insert i1 ro2 im' where
ro1 = im IM.! i1
ro2 = im IM.! i2
im' = IM.insert i2 ro1 im
swapRowsSafe :: IxRow -> IxRow -> SpMatrix a -> SpMatrix a
swapRowsSafe i1 i2 m
| inBounds02 (nro, nro) (i1, i2) = swapRows i1 i2 m
| otherwise =
error $ "swapRowsSafe : index out of bounds " ++ show (i1, i2)
where nro = nrows m
transposeSM, (#^) :: SpMatrix a -> SpMatrix a
transposeSM (SM (m, n) im) = SM (n, m) (transposeIM2 im)
(#^) = transposeSM
matScale :: Num a => a -> SpMatrix a -> SpMatrix a
matScale a = fmap (*a)
normFrobenius :: SpMatrix Double -> Double
normFrobenius m = sqrt $ foldlSM (+) 0 m' where
m' | nrows m > ncols m = transposeSM m ## m
| otherwise = m ## transposeSM m
matMat, (##) :: Num a => SpMatrix a -> SpMatrix a -> SpMatrix a
matMat m1 m2
| c1 == r2 = matMatU m1 m2
| otherwise = error $ "matMat : incompatible matrix sizes" ++ show (d1, d2) where
d1@(r1, c1) = dim m1
d2@(r2, c2) = dim m2
matMatU :: Num a => SpMatrix a -> SpMatrix a -> SpMatrix a
matMatU m1 m2 =
SM (nrows m1, ncols m2) im where
im = fmap (\vm1 -> (`dot` vm1) <$> transposeIM2 (immSM m2)) (immSM m1)
(##) = matMat
matMatSparsified, (#~#) :: SpMatrix Double -> SpMatrix Double -> SpMatrix Double
matMatSparsified m1 m2 = sparsifySM $ matMat m1 m2
(#~#) = matMatSparsified
(#^#) :: SpMatrix Double -> SpMatrix Double -> SpMatrix Double
a #^# b = transposeSM a #~# b
(##^) :: SpMatrix Double -> SpMatrix Double -> SpMatrix Double
a ##^ b = a #~# transposeSM b
contractSub :: Num a => SpMatrix a -> SpMatrix a -> IxRow -> IxCol -> Int -> a
contractSub a b i j n
| ncols a == nrows b &&
isValidIxSM a (i,j) &&
n <= ncols a = sum $ map (\i' -> a@@!(i,i')*b@@!(i',j)) [0 .. n]
| otherwise = error "contractSub : n must be <= i"