{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Matrix.Static.Sparse
(
SparseMatrix(..)
, Zero(..)
, C.dim
, C.rows
, C.cols
, (C.!)
, C.takeDiag
, C.unsafeIndex
, C.unsafeTakeRow
, C.unsafeTakeColumn
, C.empty
, C.fromVector
, C.fromList
, C.unsafeFromVector
, diag
, diagRect
, C.flatten
, C.toList
, C.convertAny
) where
import Control.DeepSeq
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Mutable as SM
import Data.Singletons
import Control.Monad
import Data.Bits (shiftR)
import Text.Printf (printf)
import GHC.TypeLits (type (<=))
import Foreign.C.Types
import Data.Complex
import qualified Data.Matrix.Static.Dense as D
import qualified Data.Matrix.Static.Generic as C
import Data.Matrix.Static.Sparse.Mutable
type instance C.Mutable SparseMatrix = MSparseMatrix
class Eq a => Zero a where
zero :: a
instance Zero Int where
zero = 0
instance Zero Float where
zero = 0.0
instance Zero Double where
zero = 0.0
instance Zero (Complex Float) where
zero = 0
instance Zero (Complex Double) where
zero = 0
instance Eq a => Zero ([] a) where
zero = []
data SparseMatrix :: C.MatrixKind where
SparseMatrix :: (SingI r, SingI c)
=> (v a)
-> (S.Vector CInt)
-> (S.Vector CInt)
-> SparseMatrix r c v a
instance (G.Vector v a, Zero a, Show a) => Show (SparseMatrix r c v a) where
show mat = printf "(%d x %d)\n%s" r c vals
where
(r,c) = C.dim mat
vals = unlines $ map (unwords . map show . G.toList) $ C.toRows mat
instance (NFData (v a)) => NFData (SparseMatrix r c v a) where
rnf (SparseMatrix vec inner outer) = rnf vec
instance (G.Vector v a, Zero a) => C.Matrix SparseMatrix v a where
dim :: forall r c. SparseMatrix r c v a -> (Int, Int)
dim (SparseMatrix _ _ _) = (r,c)
where
r = fromIntegral $ fromSing (sing :: Sing r)
c = fromIntegral $ fromSing (sing :: Sing c)
{-# INLINE dim #-}
unsafeIndex (SparseMatrix vec inner outer) (i,j) =
case binarySearchByBounds inner (fromIntegral i) r0 r1 of
Nothing -> zero
Just k -> vec `G.unsafeIndex` k
where
r0 = fromIntegral $ outer `S.unsafeIndex` j
r1 = fromIntegral $ outer `S.unsafeIndex` (j+1) - 1
{-# INLINE unsafeIndex #-}
unsafeFromVector :: forall r c. (G.Vector v a, SingI r, SingI c)
=> v a -> SparseMatrix r c v a
unsafeFromVector vec = SparseMatrix
(G.generate n (G.unsafeIndex vec . S.unsafeIndex nz))
inner outer
where
inner = S.map fromIntegral $ S.map (`mod` c) nz
outer = S.create $ do
v <- SM.replicate (c+1) 0
S.forM_ nz $ \x -> do
let i = x `div` r
SM.unsafeModify v succ (i+1)
forM_ [1..c] $ \i -> do
x <- SM.unsafeRead v (i-1)
SM.unsafeModify v (+x) i
return v
nz = S.filter (\i -> vec `G.unsafeIndex` i /= zero) $ S.enumFromN 0 (r*c)
n = S.length nz
r = fromIntegral $ fromSing (sing :: Sing r)
c = fromIntegral $ fromSing (sing :: Sing c)
{-# INLINE unsafeFromVector #-}
thaw = undefined
{-# INLINE thaw #-}
unsafeThaw = undefined
{-# INLINE unsafeThaw #-}
freeze = undefined
{-# INLINE freeze #-}
unsafeFreeze = undefined
{-# INLINE unsafeFreeze #-}
map f (SparseMatrix vec inner outer) = SparseMatrix (G.map f vec) inner outer
imap = undefined
{-# INLINE map #-}
diag :: (G.Vector v a, Zero a, SingI n)
=> D.Matrix n 1 v a
-> SparseMatrix n n v a
diag = diagRect
{-# INLINE diag #-}
diagRect :: (G.Vector v a, Zero a, SingI r, SingI c, n <= r, n <= c)
=> D.Matrix n 1 v a
-> SparseMatrix r c v a
diagRect d = SparseMatrix (C.flatten d) (S.enumFromN 0 n) (S.enumFromN 0 $ n + 1)
where
n = C.rows d
{-# INLINE diagRect #-}
binarySearchByBounds :: S.Vector CInt -> CInt -> Int -> Int -> Maybe Int
binarySearchByBounds vec x = loop
where
loop !l !u
| l > u = Nothing
| x == x' = Just k
| x < x' = loop l (k-1)
| otherwise = loop (k+1) u
where
k = (u+l) `shiftR` 1
x' = vec `S.unsafeIndex` k
{-# INLINE binarySearchByBounds #-}