{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Matrix.Static.Sparse
(
SparseMatrix(..)
, Zero(..)
, C.dim
, C.rows
, C.cols
, (C.!)
, C.takeDiag
, C.unsafeIndex
, C.unsafeTakeRow
, C.unsafeTakeColumn
, C.empty
, fromTriplet
, fromTripletC
, toTriplet
, C.fromVector
, C.fromList
, C.unsafeFromVector
, diag
, diagRect
, C.flatten
, C.toList
, C.convertAny
) where
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as GM
import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Mutable as SM
import Data.Singletons
import Control.Monad.ST (runST)
import Data.Bits (shiftR)
import Text.Printf (printf)
import Conduit
import Data.Conduit.Internal (zipSinks)
import Data.Tuple (swap)
import GHC.TypeLits (type (<=))
import Foreign.C.Types
import Data.Complex
import qualified Data.Matrix.Static.Dense as D
import qualified Data.Matrix.Static.Generic as C
import Data.Matrix.Static.Sparse.Mutable
type instance C.Mutable SparseMatrix = MSparseMatrix
class Eq a => Zero a where
zero :: a
instance Zero Int where
zero = 0
instance Zero Float where
zero = 0.0
instance Zero CFloat where
zero = 0.0
instance Zero Double where
zero = 0.0
instance Zero (Complex Float) where
zero = 0
instance Zero (Complex Double) where
zero = 0
instance Eq a => Zero ([] a) where
zero = []
data SparseMatrix :: C.MatrixKind where
SparseMatrix :: (SingI r, SingI c)
=> (v a)
-> (S.Vector CInt)
-> (S.Vector CInt)
-> SparseMatrix r c v a
instance (G.Vector v a, Eq (v a)) => Eq (SparseMatrix r c v a) where
(==) (SparseMatrix a b c) (SparseMatrix a' b' c') =
a == a' && b == b' && c == c'
instance (G.Vector v a, Zero a, Show a) => Show (SparseMatrix r c v a) where
show mat = printf "(%d x %d)\n%s" r c vals
where
(r,c) = C.dim mat
vals = unlines $ map (unwords . map show . G.toList) $ C.toRows mat
instance (G.Vector v a, Zero a) => C.Matrix SparseMatrix v a where
dim :: forall r c. SparseMatrix r c v a -> (Int, Int)
dim (SparseMatrix _ _ _) = (r,c)
where
r = fromIntegral $ fromSing (sing :: Sing r)
c = fromIntegral $ fromSing (sing :: Sing c)
{-# INLINE dim #-}
unsafeIndex (SparseMatrix vec inner outer) (i,j) =
case binarySearchByBounds inner (fromIntegral i) r0 r1 of
Nothing -> zero
Just k -> vec `G.unsafeIndex` k
where
r0 = fromIntegral $ outer `S.unsafeIndex` j
r1 = fromIntegral $ outer `S.unsafeIndex` (j+1) - 1
{-# INLINE unsafeIndex #-}
unsafeFromVector :: forall r c. (G.Vector v a, SingI r, SingI c)
=> v a -> SparseMatrix r c v a
unsafeFromVector vec = fromTriplet vec'
where
vec' = map (\((a,b),c) -> (a,b,c)) $ filter ((/=zero) . snd) $
zipWith (\i x -> (toIndex i, x)) [0..] $ G.toList vec
toIndex i = swap $ i `divMod` r
r = fromIntegral $ fromSing (sing :: Sing r)
{-# INLINE unsafeFromVector #-}
transpose mat = runIdentity $ fromTripletC source
where
source = toTriplet mat .| mapC (\(i,j,x) -> (j,i,x))
{-# INLINE transpose #-}
thaw = undefined
{-# INLINE thaw #-}
unsafeThaw = undefined
{-# INLINE unsafeThaw #-}
freeze = undefined
{-# INLINE freeze #-}
unsafeFreeze = undefined
{-# INLINE unsafeFreeze #-}
map f (SparseMatrix vec inner outer) = SparseMatrix (G.map f vec) inner outer
{-# INLINE map #-}
imap f mat@(SparseMatrix _ inner outer) = SparseMatrix vec' inner outer
where
vec' = runST $ runConduit $ toTriplet mat .| mapC g .| sinkVector
g (i,j,x) = f (i,j) x
{-# INLINE imap #-}
fromTriplet :: forall t r c v a. (Traversable t, G.Vector v a, SingI r, SingI c)
=> t (Int, Int, a) -> SparseMatrix r c v a
fromTriplet triplets = SparseMatrix val inner outer
where
outer = S.scanl (+) 0 $ S.create $ do
vec <- SM.replicate c 0
_ <- flip mapM triplets $ \(_, j, _) ->
SM.unsafeModify vec (+1) j
return vec
(val, inner) = runST $ do
outer' <- S.thaw outer
val' <- GM.new nnz
inner' <- SM.new nnz
_ <- flip mapM triplets $ \(i, j, v) -> do
idx <- fromIntegral <$> SM.unsafeRead outer' j
GM.unsafeWrite val' idx v
SM.unsafeWrite inner' idx $ fromIntegral i
SM.unsafeModify outer' (+1) j
(,) <$> G.unsafeFreeze val' <*> S.unsafeFreeze inner'
nnz = length triplets
c = fromIntegral $ fromSing (sing :: Sing c)
{-# INLINE fromTriplet #-}
fromTripletC :: forall m r c v a. (Monad m, G.Vector v a, SingI r, SingI c)
=> ConduitT () (Int, Int, a) m ()
-> m (SparseMatrix r c v a)
fromTripletC triplets = do
(nnz, outer) <- runConduit $ triplets .| zipSinks lengthC sinkOuter
(val, inner, _) <- runConduit $ triplets .| sinkValInner nnz (clone outer)
return $ SparseMatrix val inner outer
where
sinkOuter = S.scanl (+) 0 <$> foldlC f (S.replicate c 0)
where
f vec (_, j, _) = S.modify (\v -> SM.unsafeModify v (+1) j) vec
sinkValInner nnz outer0 = foldlC f (val0, inner0, outer0)
where
val0 = G.create $ GM.new nnz
inner0 = S.create $ SM.new nnz
f (val, inner, outer) (i, j, v) = (val', inner', outer')
where
idx = fromIntegral $ outer `S.unsafeIndex` j
val' = G.create $ do
vec <- G.unsafeThaw val
GM.unsafeWrite vec idx v
return vec
inner' = S.create $ do
vec <- S.unsafeThaw inner
SM.unsafeWrite vec idx $ fromIntegral i
return vec
outer' = S.create $ do
vec <- S.unsafeThaw outer
SM.unsafeModify vec (+1) j
return vec
c = fromIntegral $ fromSing (sing :: Sing c)
clone x = S.create $ S.thaw x
{-# INLINE fromTripletC #-}
toTriplet :: (Monad m, G.Vector v a, SingI r, SingI c)
=> SparseMatrix r c v a -> ConduitT i (Int, Int, a) m ()
toTriplet (SparseMatrix val inner outer) =
G.ifoldM_ go (fromIntegral $ G.head outer) outer
where
go start curC end = do
enumFromToC start (end'-1) .| mapC f
return end'
where
end' = fromIntegral end
f i = (fromIntegral $ inner `G.unsafeIndex` i, fromIntegral curC - 1, val `G.unsafeIndex` i)
{-# INLINE toTriplet #-}
diag :: (G.Vector v a, Zero a, SingI n)
=> D.Matrix n 1 v a
-> SparseMatrix n n v a
diag = diagRect
{-# INLINE diag #-}
diagRect :: (G.Vector v a, Zero a, SingI r, SingI c, n <= r, n <= c)
=> D.Matrix n 1 v a
-> SparseMatrix r c v a
diagRect d = SparseMatrix (C.flatten d) (S.enumFromN 0 n) (S.enumFromN 0 $ n + 1)
where
n = C.rows d
{-# INLINE diagRect #-}
binarySearchByBounds :: S.Vector CInt -> CInt -> Int -> Int -> Maybe Int
binarySearchByBounds vec x = loop
where
loop !l !u
| l > u = Nothing
| x == x' = Just k
| x < x' = loop l (k-1)
| otherwise = loop (k+1) u
where
k = (u+l) `shiftR` 1
x' = vec `S.unsafeIndex` k
{-# INLINE binarySearchByBounds #-}