{-# LANGUAGE BangPatterns #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE ScopedTypeVariables #-} #if __GLASGOW_HASKELL__ >= 805 {-# LANGUAGE ExplicitNamespaces #-} {-# LANGUAGE NoStarIsType #-} #endif {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Eigen.Matrix ( -- * Types Matrix(..) , Vec(..) , MatrixXf , MatrixXd , MatrixXcf , MatrixXcd -- * Common API , Elem , C , natToInt , Row(..) , Col(..) -- * Encode/Decode a Matrix , encode , decode -- * Querying a Matrix , null , square , rows , cols , dims -- * Constructing a Matrix , empty , constant , zero , ones , identity , random , diagonal , (!) , coeff , generate , sum , prod , mean , trace , all , any , count , norm , squaredNorm , blueNorm , hypotNorm , determinant , add , sub , mul , map , imap , TriangularMode(..) , triangularView , filter , ifilter , length , foldl , foldl' , inverse , adjoint , transpose , conjugate , normalize , modify , block , unsafeFreeze , unsafeWith , fromList , toList ) where import Control.Monad (when) import Control.Monad.ST (ST) import Prelude hiding (map, null, filter, length, foldl, any, all, sum) import Control.Monad (forM_) import Control.Monad.Primitive (PrimMonad(..)) import Data.Binary (Binary(..)) import qualified Data.Binary as Binary import qualified Data.ByteString.Lazy as BSL import Data.Complex (Complex) import Data.Constraint.Nat import Eigen.Internal ( Elem , Cast(..) , natToInt , Row(..) , Col(..) ) import qualified Eigen.Internal as Internal import qualified Eigen.Matrix.Mutable as M import qualified Data.List as List import Data.Kind (Type) import GHC.TypeLits (Nat, type (*), type (<=), KnownNat) import Foreign.C.Types (CInt) import Foreign.C.String (CString) import Foreign.Marshal.Alloc (alloca) import Foreign.Ptr (Ptr) import Foreign.Storable (peek) import qualified Data.Vector.Storable as VS import qualified Data.Vector.Storable.Mutable as VSM -- | Matrix to be used in pure computations. -- -- * Uses column majour memory layout. -- -- * Has a copy-free FFI using the library. -- newtype Matrix :: Nat -> Nat -> Type -> Type where Matrix :: Vec (n * m) a -> Matrix n m a -- | Used internally to track the size and corresponding C type of the matrix. newtype Vec :: Nat -> Type -> Type where Vec :: VS.Vector (C a) -> Vec n a instance forall n m a. (Elem a, Show a, KnownNat n, KnownNat m) => Show (Matrix n m a) where show m = List.concat [ "Matrix ", show (rows m), "x", show (cols m) , "\n", List.intercalate "\n" $ List.map (List.intercalate "\t" . List.map show) $ toList m, "\n" ] instance forall n m a. (KnownNat n, KnownNat m, Elem a) => Binary (Matrix n m a) where put (Matrix (Vec vals)) = do put $ Internal.magicCode (undefined :: C a) put $ natToInt @n put $ natToInt @m put vals get = do get >>= (`when` fail "wrong matrix type") . (/= Internal.magicCode (undefined :: C a)) Matrix . Vec <$> get -- | Encode the sparse matrix as a lazy bytestring encode :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> BSL.ByteString encode = Binary.encode -- | Decode the sparse matrix from a lazy bytestring decode :: (Elem a, KnownNat n, KnownNat m) => BSL.ByteString -> Matrix n m a decode = Binary.decode -- | Alias for single precision matrix type MatrixXf n m = Matrix n m Float -- | Alias for double precision matrix type MatrixXd n m = Matrix n m Double -- | Alias for single precision matrix of complex numbers type MatrixXcf n m = Matrix n m (Complex Float) -- | Alias for double precision matrix of complex numbers type MatrixXcd n m = Matrix n m (Complex Double) -- | Construct an empty 0x0 matrix empty :: Elem a => Matrix 0 0 a {-# INLINE empty #-} empty = Matrix (Vec (VS.empty)) -- | Is matrix empty? null :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Bool {-# INLINE null #-} null m = cols m == 0 && rows m == 0 -- | Is matrix square? -- square :: forall n m a. (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Bool {-# INLINE square #-} square _ = natToInt @n == natToInt @m -- | Matrix where all coeffs are filled with the given value constant :: forall n m a. (Elem a, KnownNat n, KnownNat m) => a -> Matrix n m a {-# INLINE constant #-} constant !val = let !cval = toC val in withDims $ \rs cs -> VS.replicate (rs * cs) cval -- | Matrix where all coeffs are filled with 0 zero :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a {-# INLINE zero #-} zero = constant 0 -- | Matrix where all coeffs are filled with 1 ones :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a {-# INLINE ones #-} ones = constant 1 -- | The identity matrix (not necessarily square) identity :: forall n m a. (Elem a, KnownNat n, KnownNat m) => Matrix n m a identity = Internal.performIO $ do m :: M.IOMatrix n m a <- M.new Internal.call $ M.unsafeWith m Internal.identity unsafeFreeze m -- | The random matrix of a given size random :: forall n m a. (Elem a, KnownNat n, KnownNat m) => IO (Matrix n m a) random = do m :: M.IOMatrix n m a <- M.new Internal.call $ M.unsafeWith m Internal.random unsafeFreeze m withDims :: forall n m a. (Elem a, KnownNat n, KnownNat m) => (Int -> Int -> VS.Vector (C a)) -> Matrix n m a {-# INLINE withDims #-} withDims f = let !r = natToInt @n !c = natToInt @m in Matrix $ Vec $ f r c -- | The number of rows in the matrix rows :: forall n m a. KnownNat n => Matrix n m a -> Int {-# INLINE rows #-} rows _ = natToInt @n -- | The number of colums in the matrix cols :: forall n m a. KnownNat m => Matrix n m a -> Int {-# INLINE cols #-} cols _ = natToInt @m -- | Return Matrix size as a pair of (rows, cols) dims :: forall n m a. (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> (Int, Int) {-# INLINE dims #-} dims _ = (natToInt @n, natToInt @m) -- | Return the value at the given position. (!) :: forall n m a r c. (Elem a, KnownNat n, KnownNat r, KnownNat c, r <= n, c <= m) => Row r -> Col c -> Matrix n m a -> a {-# INLINE (!) #-} (!) = coeff -- | Return the value at the given position. coeff :: forall n m a r c. (Elem a, KnownNat n, KnownNat r, KnownNat c, r <= n, c <= m) => Row r -> Col c -> Matrix n m a -> a {-# INLINE coeff #-} coeff _ _ m@(Matrix (Vec vals)) = let !row = natToInt @r !col = natToInt @c in fromC $! VS.unsafeIndex vals $! col * rows m + row unsafeCoeff :: (Elem a, KnownNat n) => Int -> Int -> Matrix n m a -> a {-# INLINE unsafeCoeff #-} unsafeCoeff row col m@(Matrix (Vec vals)) = fromC $! VS.unsafeIndex vals $! col * rows m + row -- | Given a generation function `f :: Int -> Int -> a`, construct a Matrix of known size -- using points in the matrix as inputs. generate :: forall n m a. (Elem a, KnownNat n, KnownNat m) => (Int -> Int -> a) -> Matrix n m a generate f = withDims $ \rs cs -> VS.create $ do vals :: VSM.MVector s (C a) <- VSM.new (rs * cs) forM_ [0 .. pred rs] $ \r -> forM_ [0 .. pred cs] $ \c -> VSM.write vals (c * rs + r) (toC $! f r c) pure vals -- | The sum of all coefficients in the matrix sum :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> a sum = _prop Internal.sum -- | The product of all coefficients in the matrix prod :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> a prod = _prop Internal.prod -- | The arithmetic mean of all coefficients in the matrix mean :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> a mean = _prop Internal.mean -- | The trace of a matrix is the sum of the diagonal coefficients. -- -- 'trace' m == 'sum' ('diagonal' m) trace :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> a trace = _prop Internal.trace -- | Given a predicate p, determine if all values in the Matrix satisfy p. all :: (Elem a, KnownNat n, KnownNat m) => (a -> Bool) -> Matrix n m a -> Bool all f (Matrix (Vec vals)) = VS.all (f . fromC) vals -- | Given a predicate p, determine if any values in the Matrix satisfy p. any :: (Elem a, KnownNat n, KnownNat m) => (a -> Bool) -> Matrix n m a -> Bool any f (Matrix (Vec vals)) = VS.any (f . fromC) vals -- | Given a predicate p, determine how many values in the Matrix satisfy p. count :: (Elem a, KnownNat n, KnownNat m) => (a -> Bool) -> Matrix n m a -> Int count f (Matrix (Vec vals)) = VS.foldl' (\n x-> if f (fromC x) then (n + 1) else n) 0 vals norm, squaredNorm, blueNorm, hypotNorm :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> a {-| For vectors, the l2 norm, and for matrices the Frobenius norm. In both cases, it consists in the square root of the sum of the square of all the matrix entries. For vectors, this is also equals to the square root of the dot product of this with itself. -} norm = _prop Internal.norm -- | For vectors, the squared l2 norm, and for matrices the Frobenius norm. In both cases, it consists in the sum of the square of all the matrix entries. For vectors, this is also equals to the dot product of this with itself. squaredNorm = _prop Internal.squaredNorm -- | The l2 norm of the matrix using the Blue's algorithm. A Portable Fortran Program to Find the Euclidean Norm of a Vector, ACM TOMS, Vol 4, Issue 1, 1978. blueNorm = _prop Internal.blueNorm -- | The l2 norm of the matrix avoiding undeflow and overflow. This version use a concatenation of hypot calls, and it is very slow. hypotNorm = _prop Internal.hypotNorm -- | The determinant of the matrix determinant :: forall n a. (Elem a, KnownNat n) => Matrix n n a -> a determinant m = _prop Internal.determinant m -- | Add two matrices. add :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix n m a -> Matrix n m a add m1 m2 = _binop Internal.add m1 m2 -- | Subtract two matrices. sub :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix n m a -> Matrix n m a sub m1 m2 = _binop Internal.sub m1 m2 -- | Multiply two matrices. mul :: (Elem a, KnownNat p, KnownNat q, KnownNat r) => Matrix p q a -> Matrix q r a -> Matrix p r a mul m1 m2 = _binop Internal.mul m1 m2 {- | Apply a given function to each element of the matrix. Here is an example how to implement scalar matrix multiplication: >>> let a = fromList [[1,2],[3,4]] :: MatrixXf 2 2 >>> a Matrix 2x2 1.0 2.0 3.0 4.0 >>> map (*10) a Matrix 2x2 10.0 20.0 30.0 40.0 -} map :: Elem a => (a -> a) -> Matrix n m a -> Matrix n m a map f (Matrix (Vec vals)) = Matrix $ Vec $ VS.map (toC . f . fromC) vals {- | Apply a given function to each element of the matrix. Here is an example how upper triangular matrix can be implemented: >>> let a = fromList [[1,2,3],[4,5,6],[7,8,9]] :: MatrixXf >>> a Matrix 3x3 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 >>> imap (\row col val -> if row <= col then val else 0) a Matrix 3x3 1.0 2.0 3.0 0.0 5.0 6.0 0.0 0.0 9.0 -} imap :: (Elem a, KnownNat n, KnownNat m) => (Int -> Int -> a -> a) -> Matrix n m a -> Matrix n m a imap f (Matrix (Vec vals)) = withDims $ \rs _ -> VS.imap (\n -> let (c,r) = divMod n rs in toC . f r c . fromC) vals -- | Provide a view of the matrix for extraction of a subset. data TriangularMode -- | View matrix as a lower triangular matrix. = Lower -- | View matrix as an upper triangular matrix. | Upper -- | View matrix as a lower triangular matrix with zeros on the diagonal. | StrictlyLower -- | View matrix as an upper triangular matrix with zeros on the diagonal. | StrictlyUpper -- | View matrix as a lower triangular matrix with ones on the diagonal. | UnitLower -- | View matrix as an upper triangular matrix with ones on the diagonal. | UnitUpper deriving (Eq, Enum, Show, Read) -- | Triangular view extracted from the current matrix triangularView :: (Elem a, KnownNat n, KnownNat m) => TriangularMode -> Matrix n m a -> Matrix n m a triangularView = \case Lower -> imap $ \row col val -> case compare row col of { LT -> 0; _ -> val } Upper -> imap $ \row col val -> case compare row col of { GT -> 0; _ -> val } StrictlyLower -> imap $ \row col val -> case compare row col of { GT -> val; _ -> 0 } StrictlyUpper -> imap $ \row col val -> case compare row col of { LT -> val; _ -> 0 } UnitLower -> imap $ \row col val -> case compare row col of { GT -> val; LT -> 0; EQ -> 1 } UnitUpper -> imap $ \row col val -> case compare row col of { LT -> val; GT -> 0; EQ -> 1 } -- | Filter elements in the matrix. Filtered elements will be replaced by 0. filter :: Elem a => (a -> Bool) -> Matrix n m a -> Matrix n m a filter f = map (\x -> if f x then x else 0) -- | Filter elements in the matrix with an indexed predicate. Filtered elements will be replaces by 0. ifilter :: (Elem a, KnownNat n, KnownNat m) => (Int -> Int -> a -> Bool) -> Matrix n m a -> Matrix n m a ifilter f = imap (\r c x -> if f r c x then x else 0) -- | The length of the matrix. length :: forall n m a r. (Elem a, KnownNat n, KnownNat m, r ~ (n * m), KnownNat r) => Matrix n m a -> Int length _ = natToInt @r -- | Left fold of a matrix, where accumulation is lazy. foldl :: (Elem a, KnownNat n, KnownNat m) => (b -> a -> b) -> b -> Matrix n m a -> b foldl f b (Matrix (Vec vals)) = VS.foldl (\a x -> f a (fromC x)) b vals -- | Right fold of a matrix, where accumulation is strict. foldl' :: Elem a => (b -> a -> b) -> b -> Matrix n m a -> b foldl' f b (Matrix (Vec vals)) = VS.foldl' (\ !a x -> f a (fromC x)) b vals -- | Return the diagonal of a matrix. diagonal :: (Elem a, KnownNat n, KnownNat m, r ~ Min n m, KnownNat r) => Matrix n m a -> Matrix r 1 a diagonal = _unop Internal.diagonal {- | Inverse of the matrix For small fixed sizes up to 4x4, this method uses cofactors. In the general case, this method uses PartialPivLU decomposition -} inverse :: forall n a. (Elem a, KnownNat n) => Matrix n n a -> Matrix n n a inverse = _unop Internal.inverse -- | Adjoint of the matrix adjoint :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix m n a adjoint = _unop Internal.adjoint -- | Transpose of the matrix transpose :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix m n a transpose = _unop Internal.transpose -- | Conjugate of the matrix conjugate :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix n m a conjugate = _unop Internal.conjugate -- | Normalise the matrix by dividing it on its 'norm' normalize :: forall n m a. (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> Matrix n m a normalize (Matrix (Vec vals)) = Internal.performIO $ do vals' <- VS.thaw vals VSM.unsafeWith vals' $ \p -> let !rs = natToInt @n !cs = natToInt @m in Internal.call $ Internal.normalize p (toC rs) (toC cs) Matrix . Vec <$> VS.unsafeFreeze vals' -- | Apply a destructive operation to a matrix. The operation will be performed in-place, if it is safe -- to do so - otherwise, it will create a copy of the matrix. modify :: (Elem a, KnownNat n, KnownNat m) => (forall s. M.MMatrix n m s a -> ST s ()) -> Matrix n m a -> Matrix n m a modify f (Matrix (Vec vals)) = Matrix $ Vec $ VS.modify (f . M.fromVector ) vals -- | Extract rectangular block from matrix defined by startRow startCol blockRows blockCols block :: forall sr sc br bc n m a. (Elem a, KnownNat sr, KnownNat sc, KnownNat br, KnownNat bc, KnownNat n, KnownNat m) => (sr <= n, sc <= m, br <= n, bc <= m) => Row sr -- ^ starting row -> Col sc -- ^ starting col -> Row br -- ^ block of rows -> Col bc -- ^ block of cols -> Matrix n m a -- ^ extract from this -> Matrix br bc a -- ^ extraction block _ _ _ _ m = let !startRow = natToInt @sr !startCol = natToInt @sc in generate $ \row col -> unsafeCoeff (startRow + row) (startCol + col) m -- | Turn a mutable matrix into an immutable matrix without copying. -- The mutable matrix should not be modified after this conversion. unsafeFreeze :: (Elem a, KnownNat n, KnownNat m, PrimMonad p) => M.MMatrix n m (PrimState p) a -> p (Matrix n m a) unsafeFreeze m = VS.unsafeFreeze (M.vals m) >>= pure . Matrix . Vec -- | Pass a pointer to the matrix's data to the IO action. The data may not be modified through the pointer. unsafeWith :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> (Ptr (C a) -> CInt -> CInt -> IO b) -> IO b unsafeWith m@(Matrix (Vec (vals))) f = VS.unsafeWith vals $ \p -> let !rs = toC $! rows m !cs = toC $! cols m in f p rs cs _prop :: (Elem a, KnownNat n, KnownNat m) => (Ptr (C a) -> Ptr (C a) -> CInt -> CInt -> IO CString) -> Matrix n m a -> a {-# INLINE _prop #-} _prop f m = fromC $ Internal.performIO $ alloca $ \p -> do Internal.call $ unsafeWith m (f p) peek p _binop :: forall n m n1 m1 n2 m2 a. (Elem a, KnownNat n, KnownNat m, KnownNat n1, KnownNat m1, KnownNat n2, KnownNat m2) => (Ptr (C a) -> CInt -> CInt -> Ptr (C a) -> CInt -> CInt -> Ptr (C a) -> CInt -> CInt -> IO CString) -> Matrix n m a -> Matrix n1 m1 a -> Matrix n2 m2 a {-# INLINE _binop #-} _binop g m1 m2 = Internal.performIO $ do m0 :: M.IOMatrix n2 m2 a <- M.new M.unsafeWith m0 $ \vals0 rows0 cols0 -> unsafeWith m1 $ \vals1 rows1 cols1 -> unsafeWith m2 $ \vals2 rows2 cols2 -> Internal.call $ g vals0 rows0 cols0 vals1 rows1 cols1 vals2 rows2 cols2 unsafeFreeze m0 _unop :: forall n m n1 m1 a. (Elem a, KnownNat n, KnownNat m, KnownNat n1, KnownNat m1) => (Ptr (C a) -> CInt -> CInt -> Ptr (C a) -> CInt -> CInt -> IO CString) -> Matrix n m a -> Matrix n1 m1 a {-# INLINE _unop #-} _unop g m1 = Internal.performIO $ do m0 :: M.IOMatrix n1 m1 a <- M.new M.unsafeWith m0 $ \vals0 rows0 cols0 -> unsafeWith m1 $ \vals1 rows1 cols1 -> Internal.call $ g vals0 rows0 cols0 vals1 rows1 cols1 unsafeFreeze m0 -- | Convert a matrix to a list. toList :: (Elem a, KnownNat n, KnownNat m) => Matrix n m a -> [[a]] {-# INLINE toList #-} toList m@(Matrix (Vec vals)) | null m = [] | otherwise = [[fromC $ vals `VS.unsafeIndex` (col * _rows + row) | col <- [0..pred _cols]] | row <- [0..pred _rows]] where !_rows = rows m !_cols = cols m -- | Convert a list to a matrix. Returns 'Nothing' if the dimensions of the list do not match that -- of the matrix. fromList :: forall n m a. (Elem a, KnownNat n, KnownNat m) => [[a]] -> Maybe (Matrix n m a) fromList list = do let myRows = natToInt @n let myCols = natToInt @m let _rows = List.length list let _cols = List.foldl' max 0 (List.map List.length list) if ((myRows /= _rows) || (myCols /= _cols)) then Nothing else (Just . Matrix . Vec) $ VS.create $ do vm <- VSM.replicate (_rows * _cols) (toC (0 :: a)) forM_ (zip [0..] list) $ \(row,vals) -> forM_ (zip [0..] vals) $ \(col, val) -> VSM.write vm (col * _rows + row) (toC val) pure vm