{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Matrix.Static.Dense
(
Matrix(..)
, C.dim
, C.rows
, C.cols
, (C.!)
, C.takeRow
, C.takeColumn
, C.takeDiag
, C.unsafeIndex
, C.unsafeTakeRow
, C.unsafeTakeColumn
, C.empty
, C.matrix
, C.fromVector
, C.fromList
, C.fromRows
, C.fromColumns
, C.unsafeFromVector
, replicate
, diag
, diagRect
, C.flatten
, C.toRows
, C.toColumns
, C.toList
, convert
, C.convertAny
, C.transpose
, C.map
, C.mapM
, C.imap
, C.imapM
, zip
, zip3
, zipWith
, zipWith3
, zipWith4
, zipWith5
, zipWith6
, izipWith
, izipWith3
, zipWithM
, zipWithM_
, unzip
, unzip3
, generate
, C.thaw
, C.unsafeThaw
, C.freeze
, C.unsafeFreeze
, C.create
, sum
, all
, any
) where
import Control.Monad (liftM)
import qualified Data.Vector.Generic as G
import Prelude hiding ( replicate, mapM, mapM_, zipWith, map
, sequence, sequence_, zip, unzip, zipWith3
, zip3, unzip3, sum, all, any )
import GHC.TypeLits (type (<=))
import Data.Singletons
import Data.Tuple (swap)
import qualified Data.List as L
import Text.Printf (printf)
import Data.Store (Store(..), Size(..))
import Foreign.Storable (sizeOf)
import Data.Matrix.Static.Dense.Mutable (MMatrix (..))
import qualified Data.Matrix.Static.Dense.Mutable as DM
import qualified Data.Matrix.Static.Generic as C
type instance C.Mutable Matrix = MMatrix
data Matrix :: C.MatrixKind where
Matrix :: (SingI r, SingI c) => v a -> Matrix r c v a
instance (G.Vector v a, Store (v a), SingI r, SingI c) =>
Store (Matrix r c v a) where
size = VarSize $ \(Matrix vec) -> case size of
VarSize f ->
2 * sizeOf (0 :: Int) + f vec
_ -> undefined
poke mat@(Matrix vec) = poke r >> poke c >> poke vec
where
(r,c) = C.dim mat
peek = do
r' <- peek
c' <- peek
if r' /= r || c' /= c
then error $ "Dimensions donot match: " <> show (r,c) <> " /= " <> show (r',c')
else Matrix <$> peek
where
r = fromIntegral $ fromSing (sing :: Sing r) :: Int
c = fromIntegral $ fromSing (sing :: Sing c) :: Int
instance (G.Vector v a, Show a) => Show (Matrix r c v a) where
show mat = printf "(%d x %d)\n%s" r c vals
where
(r,c) = C.dim mat
vals = unlines $ L.map (unwords . L.map show . G.toList) $ C.toRows mat
instance (G.Vector v a, Eq (v a)) => Eq (Matrix r c v a) where
(==) (Matrix v1) (Matrix v2) = v1 == v2
instance (SingI r, SingI c, G.Vector v a, Num a) =>
Num (Matrix r c v a) where
m1 + m2 = zipWith (+) m1 m2
m1 - m2 = zipWith (-) m1 m2
m1 * m2 = zipWith (*) m1 m2
negate = C.map negate
abs = C.map abs
signum = undefined
fromInteger = undefined
instance (SingI r, SingI c, G.Vector v a, Fractional a) =>
Fractional (Matrix r c v a) where
m1 / m2 = zipWith (/) m1 m2
recip = C.map recip
fromRational = undefined
instance G.Vector v a => C.Matrix Matrix v a where
dim :: forall r c. Matrix r c v a -> (Int, Int)
dim (Matrix _) = (r,c)
where
r = fromIntegral $ fromSing (sing :: Sing r)
c = fromIntegral $ fromSing (sing :: Sing c)
{-# INLINE dim #-}
unsafeIndex mat@(Matrix vec) (i,j) = vec `G.unsafeIndex` idx
where
idx = i + j * fst (C.dim mat)
{-# INLINE unsafeIndex #-}
unsafeFromVector = Matrix
{-# INLINE unsafeFromVector #-}
unsafeTakeColumn mat@(Matrix vec) i = G.slice (i*r) r vec
where
(r, _) = C.dim mat
{-# INLINE unsafeTakeColumn #-}
flatten (Matrix vec) = vec
{-# INLINE flatten #-}
transpose mat@(Matrix vec)
| r == 1 || c == 1 = Matrix vec
| otherwise = Matrix $ G.generate (r*c) $ \x ->
C.unsafeIndex mat $ x `divMod` c
where
(r, c) = C.dim mat
{-# INLINE transpose #-}
thaw (Matrix v) = MMatrix <$> G.thaw v
{-# INLINE thaw #-}
unsafeThaw (Matrix v) = MMatrix <$> G.unsafeThaw v
{-# INLINE unsafeThaw #-}
freeze (MMatrix v) = Matrix <$> G.freeze v
{-# INLINE freeze #-}
unsafeFreeze (MMatrix v) = Matrix <$> G.unsafeFreeze v
{-# INLINE unsafeFreeze #-}
map f (Matrix vec) = Matrix $ G.map f vec
{-# INLINE map #-}
imap f m@(Matrix vec) = Matrix $ G.imap g vec
where
g i = f (toIndex (C.rows m) i)
{-# INLINE imap #-}
imapM_ f m@(Matrix vec) = G.imapM_ g vec
where
g i = f (toIndex (C.rows m) i)
{-# INLINE imapM_ #-}
sequence (Matrix vec) = Matrix <$> G.sequence vec
{-# INLINE sequence #-}
sequence_ (Matrix vec) = G.sequence_ vec
{-# INLINE sequence_ #-}
replicate :: forall r c v a. (G.Vector v a, SingI r, SingI c)
=> a -> Matrix r c v a
replicate = C.unsafeFromVector . G.replicate (r*c)
where
r = fromIntegral $ fromSing (sing :: Sing r)
c = fromIntegral $ fromSing (sing :: Sing c)
{-# INLINE replicate #-}
diag :: (G.Vector v a, SingI n)
=> a
-> Matrix n 1 v a
-> Matrix n n v a
diag z0 d = C.create $ do
mat <- DM.replicate z0
C.imapM_ (DM.unsafeWrite mat) d
return mat
{-# INLINE diag #-}
diagRect :: (G.Vector v a, SingI r, SingI c, n <= r, n <= c)
=> a
-> Matrix n 1 v a
-> Matrix r c v a
diagRect z0 d = C.create $ do
mat <- DM.replicate z0
C.imapM_ (DM.unsafeWrite mat) d
return mat
{-# INLINE diagRect #-}
zipWith :: (G.Vector v a, G.Vector v b, G.Vector v c
, SingI n, SingI m )
=> (a -> b -> c)
-> Matrix n m v a -> Matrix n m v b -> Matrix n m v c
zipWith f m1 m2 = C.unsafeFromVector $ G.zipWith f (C.flatten m1) $ C.flatten m2
{-# INLINE zipWith #-}
zipWith3 :: (G.Vector v a, G.Vector v b, G.Vector v c, G.Vector v d
, SingI n, SingI m )
=> (a -> b -> c -> d)
-> Matrix n m v a -> Matrix n m v b -> Matrix n m v c
-> Matrix n m v d
zipWith3 f m1 m2 m3 = C.unsafeFromVector $
G.zipWith3 f (C.flatten m1) (C.flatten m2) $ C.flatten m3
{-# INLINE zipWith3 #-}
zipWith4 :: (G.Vector v a, G.Vector v b, G.Vector v c, G.Vector v d, G.Vector v e
, SingI n, SingI m )
=> (a -> b -> c -> d -> e)
-> Matrix n m v a
-> Matrix n m v b
-> Matrix n m v c
-> Matrix n m v d
-> Matrix n m v e
zipWith4 f m1 m2 m3 m4 = C.unsafeFromVector $
G.zipWith4 f (C.flatten m1) (C.flatten m2) (C.flatten m3) $ C.flatten m4
{-# INLINE zipWith4 #-}
zipWith5 :: ( G.Vector v a, G.Vector v b, G.Vector v c,G.Vector v d
, G.Vector v e, G.Vector v f
, SingI n, SingI m )
=> (a -> b -> c -> d -> e -> f)
-> Matrix n m v a
-> Matrix n m v b
-> Matrix n m v c
-> Matrix n m v d
-> Matrix n m v e
-> Matrix n m v f
zipWith5 f m1 m2 m3 m4 m5 = C.unsafeFromVector $
G.zipWith5 f (C.flatten m1) (C.flatten m2)
(C.flatten m3) (C.flatten m4) $ C.flatten m5
{-# INLINE zipWith5 #-}
zipWith6 :: ( G.Vector v a, G.Vector v b, G.Vector v c, G.Vector v d
, G.Vector v e, G.Vector v f, G.Vector v g
, SingI n, SingI m )
=> (a -> b -> c -> d -> e -> f -> g)
-> Matrix n m v a
-> Matrix n m v b
-> Matrix n m v c
-> Matrix n m v d
-> Matrix n m v e
-> Matrix n m v f
-> Matrix n m v g
zipWith6 f m1 m2 m3 m4 m5 m6 = C.unsafeFromVector $
G.zipWith6 f (C.flatten m1) (C.flatten m2) (C.flatten m3)
(C.flatten m4) (C.flatten m5) $ C.flatten m6
{-# INLINE zipWith6 #-}
izipWith :: ( G.Vector v a, G.Vector v b, G.Vector v c
, SingI n, SingI m )
=> ((Int, Int) -> a -> b -> c)
-> Matrix n m v a -> Matrix n m v b -> Matrix n m v c
izipWith f m1 m2 = C.unsafeFromVector $
G.izipWith g (C.flatten m1) $ C.flatten m2
where
g i = f (toIndex (C.rows m1) i)
{-# INLINE izipWith #-}
izipWith3 :: ( G.Vector v a, G.Vector v b, G.Vector v c, G.Vector v d
, SingI n, SingI m )
=> ((Int, Int) -> a -> b -> c -> d)
-> Matrix n m v a -> Matrix n m v b -> Matrix n m v c
-> Matrix n m v d
izipWith3 f m1 m2 m3 = C.unsafeFromVector $ G.izipWith3 g
(C.flatten m1) (C.flatten m2) $ C.flatten m3
where
g i = f (toIndex (C.rows m1) i)
{-# INLINE izipWith3 #-}
zip :: (SingI n, SingI m, G.Vector v a, G.Vector v b, G.Vector v (a,b))
=> Matrix n m v a -> Matrix n m v b -> Matrix n m v (a,b)
zip m1 m2 = C.unsafeFromVector $ G.zip (C.flatten m1) $ C.flatten m2
{-# INLINE zip #-}
zip3 :: (SingI n, SingI m, G.Vector v a, G.Vector v b, G.Vector v c, G.Vector v (a,b,c))
=> Matrix n m v a
-> Matrix n m v b
-> Matrix n m v c
-> Matrix n m v (a,b,c)
zip3 m1 m2 m3 = C.unsafeFromVector $
G.zip3 (C.flatten m1) (C.flatten m2) $ C.flatten m3
{-# INLINE zip3 #-}
zipWithM :: ( G.Vector v a, G.Vector v b, G.Vector v c
, Monad monad, SingI n, SingI m )
=> (a -> b -> monad c)
-> Matrix n m v a -> Matrix n m v b -> monad (Matrix n m v c)
zipWithM f m1 m2 = liftM C.unsafeFromVector $
G.zipWithM f (C.flatten m1) $ C.flatten m2
{-# INLINE zipWithM #-}
zipWithM_ :: (G.Vector v a, G.Vector v b, G.Vector v c, Monad monad)
=> (a -> b -> monad c)
-> Matrix n m v a -> Matrix n m v b -> monad ()
zipWithM_ f m1 m2 = G.zipWithM_ f (C.flatten m1) $ C.flatten m2
{-# INLINE zipWithM_ #-}
unzip :: ( G.Vector v a, G.Vector v b, G.Vector v (a,b)
, SingI n, SingI m )
=> Matrix n m v (a,b) -> (Matrix n m v a, Matrix n m v b )
unzip m = (C.unsafeFromVector v1, C.unsafeFromVector v2)
where
(v1, v2) = G.unzip $ C.flatten m
{-# INLINE unzip #-}
unzip3 :: ( G.Vector v a, G.Vector v b, G.Vector v c, G.Vector v (a,b,c)
, SingI n, SingI m )
=> Matrix n m v (a, b, c)
-> (Matrix n m v a, Matrix n m v b, Matrix n m v c)
unzip3 m = (C.unsafeFromVector v1, C.unsafeFromVector v2, C.unsafeFromVector v3)
where
(v1, v2, v3) = G.unzip3 $ C.flatten m
{-# INLINE unzip3 #-}
generate :: forall r c v a. (G.Vector v a, SingI r, SingI c)
=> ((Int, Int) -> a) -> Matrix r c v a
generate f = C.unsafeFromVector . G.generate (r*c) $ \i -> f (i `divMod` r)
where
r = fromIntegral $ fromSing (sing :: Sing r)
c = fromIntegral $ fromSing (sing :: Sing c)
{-# INLINE generate #-}
convert :: (G.Vector v a, G.Vector w a) => Matrix r c v a -> Matrix r c w a
convert (Matrix vec) = Matrix $ G.convert vec
{-# INLINE convert #-}
sum :: (Num a, G.Vector v a) => Matrix r c v a -> a
sum (Matrix vec) = G.sum vec
{-# INLINE sum #-}
all :: G.Vector v a => (a -> Bool) -> Matrix r c v a -> Bool
all f (Matrix vec) = G.all f vec
{-# INLINE all #-}
any :: G.Vector v a => (a -> Bool) -> Matrix r c v a -> Bool
any f (Matrix vec) = G.any f vec
{-# INLINE any #-}
toIndex :: Int
-> Int
-> (Int, Int)
toIndex r i = swap $ i `divMod` r
{-# INLINE toIndex #-}