{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE InstanceSigs #-}
module Data.Matrix.Static.LinearAlgebra
( module Data.Matrix.Static.LinearAlgebra.Types
, Arithmetic(..)
, Factorization(..)
, LinearAlgebra(..)
, zeros
, ones
, inverse
, eig
, svd
, cond
) where
import qualified Data.Vector.Storable as VS
import Data.Vector.Storable (Vector)
import System.IO.Unsafe (unsafePerformIO)
import Data.Complex (Complex)
import Data.Singletons.Prelude hiding ((@@), type (==), type (-), type (<=))
import Data.Type.Bool (If)
import Data.Type.Equality
import GHC.TypeLits
import qualified Data.Matrix.Static.Dense as D
import qualified Data.Matrix.Static.Sparse as S
import qualified Data.Matrix.Static.Generic.Mutable as CM
import qualified Data.Matrix.Static.Generic as C
import qualified Data.Matrix.Static.Internal as Internal
import Data.Matrix.Static.LinearAlgebra.Types
import Data.Matrix.Static.LinearAlgebra.Internal
class Arithmetic (mat1 :: C.MatrixKind) (mat2 :: C.MatrixKind) where
(@@) :: ( Numeric a, SingI n, SingI m
, If (mat1 == mat2) mat1 D.Matrix ~ mat3 )
=> mat1 n p Vector a
-> mat2 p m Vector a
-> mat3 n m Vector a
infixr 8 @@
(%+%) :: ( Numeric a, SingI n, SingI m
, If (mat1 == mat2) mat1 D.Matrix ~ mat3 )
=> mat1 n m Vector a
-> mat2 n m Vector a
-> mat3 n m Vector a
infixr 8 %+%
(%-%) :: ( Numeric a, SingI n, SingI m
, If (mat1 == mat2) mat1 D.Matrix ~ mat3 )
=> mat1 n m Vector a
-> mat2 n m Vector a
-> mat3 n m Vector a
infixr 8 %-%
(%*%) :: ( Numeric a, SingI n, SingI m
, If (mat1 == mat2) mat1 S.SparseMatrix ~ mat3 )
=> mat1 n m Vector a
-> mat2 n m Vector a
-> mat3 n m Vector a
infixr 8 %*%
instance Arithmetic D.Matrix D.Matrix where
(@@) = withFun2 Internal.c_dd_mul
(%+%) = (+)
(%-%) = (-)
(%*%) = (*)
instance Arithmetic D.Matrix S.SparseMatrix where
(@@) = withDS Internal.c_ds_mul
(%+%) = flip (%+%)
(%-%) a b = a %+% C.map negate b
(%*%) = undefined
instance Arithmetic S.SparseMatrix D.Matrix where
(@@) = withSD Internal.c_sd_mul
(%+%) = withSD Internal.c_sd_plus
(%-%) a b = a %+% C.map negate b
(%*%) = undefined
instance Arithmetic S.SparseMatrix S.SparseMatrix where
(@@) = withSS Internal.c_ss_mul
(%+%) = withSS Internal.c_ss_plus
(%-%) a b = a %+% C.map negate b
(%*%) = withSS Internal.c_ss_cmul
class LinearAlgebra (mat :: C.MatrixKind) where
ident :: (Numeric a, SingI n) => mat n n Vector a
colSum :: (Numeric a, SingI n, C.Matrix mat Vector a)
=> mat m n Vector a
-> Matrix 1 n a
colSum mat = D.create $ do
m <- CM.replicate 0
flip C.imapM_ mat $ \(_,j) v -> CM.unsafeModify m (+v) (0, j)
return m
{-# INLINE colSum #-}
rowSum :: (Numeric a, SingI m, C.Matrix mat Vector a)
=> mat m n Vector a
-> Matrix m 1 a
rowSum mat = D.create $ do
m <- CM.replicate 0
flip C.imapM_ mat $ \(i,_) x -> CM.unsafeModify m (+x) (i, 0)
return m
{-# INLINE rowSum #-}
instance LinearAlgebra D.Matrix where
ident = D.diag 0 $ D.replicate 1
instance LinearAlgebra S.SparseMatrix where
ident = S.diag $ D.replicate 1
class Factorization mat where
eigS :: (SingI k, SingI n, k <= n - 2)
=> Sing k
-> mat n n Vector Double
-> (Matrix k 1 (Complex Double), Matrix n k (Complex Double))
eigSH :: (SingI k, SingI n, k <= n - 1)
=> Sing k
-> mat n n Vector Double
-> (Matrix k 1 Double, Matrix n k Double)
cholesky :: (Numeric a, SingI n) => mat n n Vector a -> mat n n Vector a
instance Factorization D.Matrix where
eigS s mat
| D.all (==0) mat = ( D.replicate 0, D.replicate 1)
| otherwise = unsafePerformIO $ do
m1 <- CM.new
m2 <- CM.new
_ <- unsafeWith' m1 $ \v1 _ _ -> unsafeWith' m2 $ \v2 _ _ ->
unsafeWith mat $ \v n _ -> Internal.c_eigs k v1 v2 v n
m1' <- C.unsafeFreeze m1
m2' <- C.unsafeFreeze m2
return (m1', m2')
where
k = fromIntegral $ fromSing s
{-# INLINE eigS #-}
eigSH s mat
| D.all (==0) mat = (D.replicate 0, D.replicate 1)
| otherwise = unsafePerformIO $ do
m1 <- CM.new
m2 <- CM.new
_ <- unsafeWith' m1 $ \v1 _ _ -> unsafeWith' m2 $ \v2 _ _ ->
unsafeWith mat $ \v n _ -> Internal.c_eigsh k v1 v2 v n
m1' <- C.unsafeFreeze m1
m2' <- C.unsafeFreeze m2
return (m1', m2')
where
k = fromIntegral $ fromSing s
{-# INLINE eigSH #-}
cholesky mat = flip withFun1 mat $
\code p1 c1 _ p2 _ _ -> Internal.c_cholesky code p1 p2 c1
{-# INLINE cholesky #-}
instance Factorization S.SparseMatrix where
eigS s mat = unsafePerformIO $ do
m1 <- CM.new
m2 <- CM.new
_ <- unsafeWith' m1 $ \v1 _ _ -> unsafeWith' m2 $ \v2 _ _ ->
unsafeWithS mat $ \pv pin po n _ size ->
Internal.c_seigs k v1 v2 pv po pin n size
m1' <- C.unsafeFreeze m1
m2' <- C.unsafeFreeze m2
return (m1', m2')
where
k = fromIntegral $ fromSing s
{-# INLINE eigS #-}
eigSH s mat = unsafePerformIO $ do
m1 <- CM.new
m2 <- CM.new
_ <- unsafeWith' m1 $ \v1 _ _ -> unsafeWith' m2 $ \v2 _ _ ->
unsafeWithS mat $ \pv pin po n _ size ->
Internal.c_seigsh k v1 v2 pv po pin n size
m1' <- C.unsafeFreeze m1
m2' <- C.unsafeFreeze m2
return (m1', m2')
where
k = fromIntegral $ fromSing s
{-# INLINE eigSH #-}
cholesky = undefined
type family R a where
R Float = Float
R Double = Double
R (Complex Double) = Double
R (Complex Float) = Float
zeros :: (SingI m, SingI n) => Matrix m n Double
zeros = D.replicate 0
{-# INLINE zeros #-}
ones :: (SingI m, SingI n) => Matrix m n Double
ones = D.replicate 1
{-# INLINE ones #-}
inverse :: (SingI n, Numeric a) => Matrix n n a -> Matrix n n a
inverse = withFun1 Internal.c_inverse
{-# INLINE inverse #-}
eig :: forall n . SingI n
=> Matrix n n Double
-> (Matrix n 1 (Complex Double), Matrix n n (Complex Double))
eig mat = unsafePerformIO $ do
m1 <- CM.new
m2 <- CM.new
_ <- unsafeWith' m1 $ \v1 _ _ -> unsafeWith' m2 $ \v2 _ _ ->
unsafeWith mat $ \v n _ -> Internal.c_eig v1 v2 v n
m1' <- C.unsafeFreeze m1
m2' <- C.unsafeFreeze m2
return (m1', m2')
{-# INLINE eig #-}
svd :: forall n p a m. (Numeric (R a), Numeric a, SingI n, SingI p, SingI m, m ~ Min n p)
=> Matrix n p a
-> (Matrix n m a, Matrix m 1 (R a), Matrix p m a)
svd mat = unsafePerformIO $ do
mu <- CM.new
ms <- CM.new
mv <- CM.new
checkResult $ unsafeWith' mu $ \pu _ _ -> unsafeWith' ms $ \ps _ _ ->
unsafeWith' mv $ \pv _ _ -> unsafeWith mat $ \px r c ->
Internal.c_bdcsvd (foreignType (undefined :: a))
pu ps pv px r c
u <- C.unsafeFreeze mu
s <- C.unsafeFreeze ms
v <- C.unsafeFreeze mv
return (u, s, v)
{-# INLINE svd #-}
cond :: ( Numeric a, Numeric (R a), Ord (R a), Fractional (R a)
, SingI n, SingI m, SingI (Min n m))
=> Matrix n m a -> R a
cond mat = VS.maximum val / VS.minimum val
where
val = VS.filter (/=0) $ D.flatten s
(_,s,_) = svd mat
{-# INLINE cond #-}