{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
#if __GLASGOW_HASKELL__ >= 805
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE NoStarIsType #-}
#endif
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Eigen.Matrix
(
Matrix(..)
, Vec(..)
, MatrixXf
, MatrixXd
, MatrixXcf
, MatrixXcd
, Elem
, C
, natToInt
, Row(..)
, Col(..)
, encode
, decode
, null
, square
, rows
, cols
, dims
, empty
, constant
, zero
, ones
, identity
, random
, diagonal
, (!)
, coeff
, generate
, sum
, prod
, mean
, trace
, all
, any
, count
, norm
, squaredNorm
, blueNorm
, hypotNorm
, determinant
, add
, sub
, mul
, map
, imap
, TriangularMode(..)
, triangularView
, filter
, ifilter
, length
, foldl
, foldl'
, inverse
, adjoint
, transpose
, conjugate
, normalize
, modify
, block
, unsafeFreeze
, unsafeWith
, fromList
, toList
) where
import Control.Monad (when)
import Control.Monad.ST (ST)
import Prelude hiding
(map, null, filter, length, foldl, any, all, sum)
import Control.Monad (forM_)
import Control.Monad.Primitive (PrimMonad(..))
import Data.Binary (Binary(..))
import qualified Data.Binary as Binary
import qualified Data.ByteString.Lazy as BSL
import Data.Complex (Complex)
import Data.Constraint.Nat
import Eigen.Internal
( Elem
, Cast(..)
, natToInt
, Row(..)
, Col(..)
)
import qualified Eigen.Internal as Internal
import qualified Eigen.Matrix.Mutable as M
import qualified Data.List as List
import Data.Kind (Type)
import GHC.TypeLits (Nat, type (*), type (<=), KnownNat)
import Foreign.C.Types (CInt)
import Foreign.C.String (CString)
import Foreign.Marshal.Alloc (alloca)
import Foreign.Ptr (Ptr)
import Foreign.Storable (peek)
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
newtype Matrix :: Nat -> Nat -> Type -> Type where
Matrix :: Vec (n * m) a -> Matrix n m a
newtype Vec :: Nat -> Type -> Type where
Vec :: VS.Vector (C a) -> Vec n a
instance forall n m a. (Elem a, Show a, KnownNat n, KnownNat m) => Show (Matrix n m a) where
show m = List.concat
[ "Matrix ", show (rows m), "x", show (cols m)
, "\n", List.intercalate "\n" $ List.map (List.intercalate "\t" . List.map show) $ toList m, "\n"
]
instance forall n m a. (KnownNat n, KnownNat m, Elem a) => Binary (Matrix n m a) where
put (Matrix (Vec vals)) = do
put $ Internal.magicCode (undefined :: C a)
put $ natToInt @n
put $ natToInt @m
put vals
get = do
get >>= (`when` fail "wrong matrix type") . (/= Internal.magicCode (undefined :: C a))
Matrix . Vec <$> get
encode :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> BSL.ByteString
encode = Binary.encode
decode :: (Elem a, KnownNat n, KnownNat m) => BSL.ByteString -> Matrix n m a
decode = Binary.decode
type MatrixXf n m = Matrix n m Float
type MatrixXd n m = Matrix n m Double
type MatrixXcf n m = Matrix n m (Complex Float)
type MatrixXcd n m = Matrix n m (Complex Double)
empty :: Elem a => Matrix 0 0 a
{-# INLINE empty #-}
empty = Matrix (Vec (VS.empty))
null :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Bool
{-# INLINE null #-}
null m = cols m == 0 && rows m == 0
square :: forall n m a. (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Bool
{-# INLINE square #-}
square _ = natToInt @n == natToInt @m
constant :: forall n m a. (Elem a, KnownNat n, KnownNat m) => a -> Matrix n m a
{-# INLINE constant #-}
constant !val =
let !cval = toC val
in withDims $ \rs cs -> VS.replicate (rs * cs) cval
zero :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a
{-# INLINE zero #-}
zero = constant 0
ones :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a
{-# INLINE ones #-}
ones = constant 1
identity :: forall n m a. (Elem a, KnownNat n, KnownNat m) => Matrix n m a
identity =
Internal.performIO $ do
m :: M.IOMatrix n m a <- M.new
Internal.call $ M.unsafeWith m Internal.identity
unsafeFreeze m
random :: forall n m a. (Elem a, KnownNat n, KnownNat m) => IO (Matrix n m a)
random = do
m :: M.IOMatrix n m a <- M.new
Internal.call $ M.unsafeWith m Internal.random
unsafeFreeze m
withDims :: forall n m a. (Elem a, KnownNat n, KnownNat m) => (Int -> Int -> VS.Vector (C a)) -> Matrix n m a
{-# INLINE withDims #-}
withDims f =
let !r = natToInt @n
!c = natToInt @m
in Matrix $ Vec $ f r c
rows :: forall n m a. KnownNat n => Matrix n m a -> Int
{-# INLINE rows #-}
rows _ = natToInt @n
cols :: forall n m a. KnownNat m => Matrix n m a -> Int
{-# INLINE cols #-}
cols _ = natToInt @m
dims :: forall n m a. (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> (Int, Int)
{-# INLINE dims #-}
dims _ = (natToInt @n, natToInt @m)
(!) :: forall n m a r c. (Elem a, KnownNat n, KnownNat r, KnownNat c, r <= n, c <= m) => Row r -> Col c -> Matrix n m a -> a
{-# INLINE (!) #-}
(!) = coeff
coeff :: forall n m a r c. (Elem a, KnownNat n, KnownNat r, KnownNat c, r <= n, c <= m) => Row r -> Col c -> Matrix n m a -> a
{-# INLINE coeff #-}
coeff _ _ m@(Matrix (Vec vals)) =
let !row = natToInt @r
!col = natToInt @c
in fromC $! VS.unsafeIndex vals $! col * rows m + row
unsafeCoeff :: (Elem a, KnownNat n) => Int -> Int -> Matrix n m a -> a
{-# INLINE unsafeCoeff #-}
unsafeCoeff row col m@(Matrix (Vec vals)) = fromC $! VS.unsafeIndex vals $! col * rows m + row
generate :: forall n m a. (Elem a, KnownNat n, KnownNat m) => (Int -> Int -> a) -> Matrix n m a
generate f = withDims $ \rs cs -> VS.create $ do
vals :: VSM.MVector s (C a) <- VSM.new (rs * cs)
forM_ [0 .. pred rs] $ \r ->
forM_ [0 .. pred cs] $ \c ->
VSM.write vals (c * rs + r) (toC $! f r c)
pure vals
sum :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> a
sum = _prop Internal.sum
prod :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> a
prod = _prop Internal.prod
mean :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> a
mean = _prop Internal.mean
trace :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> a
trace = _prop Internal.trace
all :: (Elem a, KnownNat n, KnownNat m) => (a -> Bool) -> Matrix n m a -> Bool
all f (Matrix (Vec vals)) = VS.all (f . fromC) vals
any :: (Elem a, KnownNat n, KnownNat m) => (a -> Bool) -> Matrix n m a -> Bool
any f (Matrix (Vec vals)) = VS.any (f . fromC) vals
count :: (Elem a, KnownNat n, KnownNat m) => (a -> Bool) -> Matrix n m a -> Int
count f (Matrix (Vec vals)) = VS.foldl' (\n x-> if f (fromC x) then (n + 1) else n) 0 vals
norm, squaredNorm, blueNorm, hypotNorm :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> a
norm = _prop Internal.norm
squaredNorm = _prop Internal.squaredNorm
blueNorm = _prop Internal.blueNorm
hypotNorm = _prop Internal.hypotNorm
determinant :: forall n a. (Elem a, KnownNat n) => Matrix n n a -> a
determinant m = _prop Internal.determinant m
add :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix n m a -> Matrix n m a
add m1 m2 = _binop Internal.add m1 m2
sub :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix n m a -> Matrix n m a
sub m1 m2 = _binop Internal.sub m1 m2
mul :: (Elem a, KnownNat p, KnownNat q, KnownNat r) => Matrix p q a -> Matrix q r a -> Matrix p r a
mul m1 m2 = _binop Internal.mul m1 m2
map :: Elem a => (a -> a) -> Matrix n m a -> Matrix n m a
map f (Matrix (Vec vals)) = Matrix $ Vec $ VS.map (toC . f . fromC) vals
imap :: (Elem a, KnownNat n, KnownNat m) => (Int -> Int -> a -> a) -> Matrix n m a -> Matrix n m a
imap f (Matrix (Vec vals)) =
withDims $ \rs _ ->
VS.imap (\n ->
let (c,r) = divMod n rs
in toC . f r c . fromC) vals
data TriangularMode
= Lower
| Upper
| StrictlyLower
| StrictlyUpper
| UnitLower
| UnitUpper
deriving (Eq, Enum, Show, Read)
triangularView :: (Elem a, KnownNat n, KnownNat m) => TriangularMode -> Matrix n m a -> Matrix n m a
triangularView = \case
Lower -> imap $ \row col val -> case compare row col of { LT -> 0; _ -> val }
Upper -> imap $ \row col val -> case compare row col of { GT -> 0; _ -> val }
StrictlyLower -> imap $ \row col val -> case compare row col of { GT -> val; _ -> 0 }
StrictlyUpper -> imap $ \row col val -> case compare row col of { LT -> val; _ -> 0 }
UnitLower -> imap $ \row col val -> case compare row col of { GT -> val; LT -> 0; EQ -> 1 }
UnitUpper -> imap $ \row col val -> case compare row col of { LT -> val; GT -> 0; EQ -> 1 }
filter :: Elem a => (a -> Bool) -> Matrix n m a -> Matrix n m a
filter f = map (\x -> if f x then x else 0)
ifilter :: (Elem a, KnownNat n, KnownNat m) => (Int -> Int -> a -> Bool) -> Matrix n m a -> Matrix n m a
ifilter f = imap (\r c x -> if f r c x then x else 0)
length :: forall n m a r. (Elem a, KnownNat n, KnownNat m, r ~ (n * m), KnownNat r) => Matrix n m a -> Int
length _ = natToInt @r
foldl :: (Elem a, KnownNat n, KnownNat m) => (b -> a -> b) -> b -> Matrix n m a -> b
foldl f b (Matrix (Vec vals)) = VS.foldl (\a x -> f a (fromC x)) b vals
foldl' :: Elem a => (b -> a -> b) -> b -> Matrix n m a -> b
foldl' f b (Matrix (Vec vals)) = VS.foldl' (\ !a x -> f a (fromC x)) b vals
diagonal :: (Elem a, KnownNat n, KnownNat m, r ~ Min n m, KnownNat r) => Matrix n m a -> Matrix r 1 a
diagonal = _unop Internal.diagonal
inverse :: forall n a. (Elem a, KnownNat n) => Matrix n n a -> Matrix n n a
inverse = _unop Internal.inverse
adjoint :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix m n a
adjoint = _unop Internal.adjoint
transpose :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix m n a
transpose = _unop Internal.transpose
conjugate :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix n m a
conjugate = _unop Internal.conjugate
normalize :: forall n m a. (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix n m a
normalize (Matrix (Vec vals)) = Internal.performIO $ do
vals' <- VS.thaw vals
VSM.unsafeWith vals' $ \p ->
let !rs = natToInt @n
!cs = natToInt @m
in Internal.call $ Internal.normalize p (toC rs) (toC cs)
Matrix . Vec <$> VS.unsafeFreeze vals'
modify :: (Elem a, KnownNat n, KnownNat m) => (forall s. M.MMatrix n m s a -> ST s ()) -> Matrix n m a -> Matrix n m a
modify f (Matrix (Vec vals)) = Matrix $ Vec $ VS.modify (f . M.fromVector ) vals
block :: forall sr sc br bc n m a.
(Elem a, KnownNat sr, KnownNat sc, KnownNat br, KnownNat bc, KnownNat n, KnownNat m)
=> (sr <= n, sc <= m, br <= n, bc <= m)
=> Row sr
-> Col sc
-> Row br
-> Col bc
-> Matrix n m a
-> Matrix br bc a
block _ _ _ _ m =
let !startRow = natToInt @sr
!startCol = natToInt @sc
in generate $ \row col -> unsafeCoeff (startRow + row) (startCol + col) m
unsafeFreeze :: (Elem a, KnownNat n, KnownNat m, PrimMonad p) => M.MMatrix n m (PrimState p) a -> p (Matrix n m a)
unsafeFreeze m = VS.unsafeFreeze (M.vals m) >>= pure . Matrix . Vec
unsafeWith :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> (Ptr (C a) -> CInt -> CInt -> IO b) -> IO b
unsafeWith m@(Matrix (Vec (vals))) f =
VS.unsafeWith vals $ \p ->
let !rs = toC $! rows m
!cs = toC $! cols m
in f p rs cs
_prop :: (Elem a, KnownNat n, KnownNat m) => (Ptr (C a) -> Ptr (C a) -> CInt -> CInt -> IO CString) -> Matrix n m a -> a
{-# INLINE _prop #-}
_prop f m = fromC $ Internal.performIO $ alloca $ \p -> do
Internal.call $ unsafeWith m (f p)
peek p
_binop :: forall n m n1 m1 n2 m2 a. (Elem a, KnownNat n, KnownNat m, KnownNat n1, KnownNat m1, KnownNat n2, KnownNat m2)
=> (Ptr (C a) -> CInt -> CInt -> Ptr (C a) -> CInt -> CInt -> Ptr (C a) -> CInt -> CInt -> IO CString)
-> Matrix n m a
-> Matrix n1 m1 a
-> Matrix n2 m2 a
{-# INLINE _binop #-}
_binop g m1 m2 = Internal.performIO $ do
m0 :: M.IOMatrix n2 m2 a <- M.new
M.unsafeWith m0 $ \vals0 rows0 cols0 ->
unsafeWith m1 $ \vals1 rows1 cols1 ->
unsafeWith m2 $ \vals2 rows2 cols2 ->
Internal.call $ g
vals0 rows0 cols0
vals1 rows1 cols1
vals2 rows2 cols2
unsafeFreeze m0
_unop :: forall n m n1 m1 a. (Elem a, KnownNat n, KnownNat m, KnownNat n1, KnownNat m1)
=> (Ptr (C a) -> CInt -> CInt -> Ptr (C a) -> CInt -> CInt -> IO CString)
-> Matrix n m a
-> Matrix n1 m1 a
{-# INLINE _unop #-}
_unop g m1 = Internal.performIO $ do
m0 :: M.IOMatrix n1 m1 a <- M.new
M.unsafeWith m0 $ \vals0 rows0 cols0 ->
unsafeWith m1 $ \vals1 rows1 cols1 ->
Internal.call $ g
vals0 rows0 cols0
vals1 rows1 cols1
unsafeFreeze m0
toList :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> [[a]]
{-# INLINE toList #-}
toList m@(Matrix (Vec vals))
| null m = []
| otherwise = [[fromC $ vals `VS.unsafeIndex` (col * _rows + row) | col <- [0..pred _cols]] | row <- [0..pred _rows]]
where
!_rows = rows m
!_cols = cols m
fromList :: forall n m a. (Elem a, KnownNat n, KnownNat m) => [[a]] -> Maybe (Matrix n m a)
fromList list = do
let myRows = natToInt @n
let myCols = natToInt @m
let _rows = List.length list
let _cols = List.foldl' max 0 (List.map List.length list)
if ((myRows /= _rows) || (myCols /= _cols))
then Nothing
else (Just . Matrix . Vec) $ VS.create $ do
vm <- VSM.replicate (_rows * _cols) (toC (0 :: a))
forM_ (zip [0..] list) $ \(row,vals) ->
forM_ (zip [0..] vals) $ \(col, val) ->
VSM.write vm (col * _rows + row) (toC val)
pure vm