{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} #if __GLASGOW_HASKELL__ >= 805 {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE ExplicitNamespaces #-} #endif module Eigen.Matrix.Mutable ( -- * Types MMatrix(..) , MMatrixXf , MMatrixXd , MMatrixXcf , MMatrixXcd , IOMatrix , STMatrix -- * Construction , new , replicate -- * Indexing , read , write -- * Modification , set , copy -- * Modification with Pointers , unsafeWith -- * Conversion to Vectors , vals , fromVector ) where import Eigen.Internal ( Elem , C(..) , natToInt , Row(..) , Col(..) ) import Prelude hiding (replicate, read) import Control.Monad.Primitive (PrimMonad(..)) import Data.Complex (Complex) import Data.Kind (Type) import Foreign.C.Types (CInt) import Foreign.Ptr (Ptr) import GHC.Exts (RealWorld) import GHC.TypeLits (Nat, type (*), type (<=), KnownNat) import qualified Data.Vector.Storable.Mutable as VSM -- | A mutable matrix. See 'Eigen.Matrix.Matrix' for -- details about matrix layout. newtype MMatrix :: Nat -> Nat -> Type -> Type -> Type where MMatrix :: Vec (n * m) s a -> MMatrix n m s a -- | Used internally to track the size and corresponding C type of the matrix. newtype Vec :: Nat -> Type -> Type -> Type where Vec :: VSM.MVector s (C a) -> Vec n s a -- | Alias for single precision mutable matrix type MMatrixXf n m s = MMatrix n m s Float -- | Alias for double precision mutable matrix type MMatrixXd n m s = MMatrix n m s Double -- | Alias for single precision mutable matrix of complex numbers type MMatrixXcf n m s = MMatrix n m s (Complex Float) -- | Alias for double precision mutable matrix of complex numbers type MMatrixXcd n m s = MMatrix n m s (Complex Double) -- | A mutable matrix where the state token is specialised to 'RealWorld'. type IOMatrix n m a = MMatrix n m RealWorld a -- | This type does not differ from MSparseMatrix, but might be desirable for readability. type STMatrix n m s a = MMatrix n m s a -- | Create a mutable matrix of the given size and fill it with 0 as an initial value. new :: (PrimMonad p, Elem a, KnownNat n, KnownNat m) => p (MMatrix n m (PrimState p) a) {-# INLINE new #-} new = replicate 0 -- | Create a mutable matrix of the given size and fill it with an initial value. replicate :: forall n m p a. (PrimMonad p, Elem a, KnownNat n, KnownNat m) => a -> p (MMatrix n m (PrimState p) a) {-# INLINE replicate #-} replicate !val = do let !mm_rows = natToInt @n !mm_cols = natToInt @m !cval = toC val _vals <- VSM.replicate (mm_rows * mm_cols) cval pure (MMatrix $! Vec $! _vals) -- | Set all elements of the matrix to a given value. set :: (PrimMonad p, Elem a) => MMatrix n m (PrimState p) a -> a -> p () {-# INLINE set #-} set (MMatrix (Vec !vec)) !val = let !cval = toC val in VSM.set vec cval -- | Copy a matrix. copy :: (PrimMonad p, Elem a) => MMatrix n m (PrimState p) a -> MMatrix n m (PrimState p) a -> p () {-# INLINE copy #-} copy (MMatrix (Vec m1)) (MMatrix (Vec m2)) = VSM.unsafeCopy m1 m2 -- | Yield the element at the given position. read :: forall n m p a r c. (PrimMonad p, Elem a, KnownNat n, KnownNat r, KnownNat c, r <= n, c <= m) => Row r -> Col c -> MMatrix n m (PrimState p) a -> p a {-# INLINE read #-} read _ _ (MMatrix (Vec m)) = let !row = natToInt @r !col = natToInt @c !mm_rows = natToInt @n in VSM.unsafeRead m (col * mm_rows + row) >>= \ !val -> let !cval = fromC val in pure cval -- | Replace the element at the given position. write :: forall n m p a r c. (PrimMonad p, Elem a, KnownNat n, KnownNat r, KnownNat c, r <= n, c <= m) => Row r -> Col c -> MMatrix n m (PrimState p) a -> a -> p () {-# INLINE write #-} write _ _ (MMatrix (Vec m)) !val = let !row = natToInt @r !col = natToInt @c !mm_rows = natToInt @n !cval = toC val in VSM.unsafeWrite m (col * mm_rows + row) cval -- | Pass a pointer to the matrix's data to the IO action. -- Modifying dat through the pointer is unsafe if the matrix -- could have been frozen before the modification. unsafeWith :: forall n m a b. (KnownNat n, KnownNat m, Elem a) => IOMatrix n m a -> (Ptr (C a) -> CInt -> CInt -> IO b) -> IO b {-# INLINE unsafeWith #-} unsafeWith (MMatrix (Vec m)) f = let !cmm_rows = toC $! natToInt @n !cmm_cols = toC $! natToInt @m in VSM.unsafeWith m $ \p -> f p cmm_rows cmm_cols -- | Return a mutable storable 'VSM.MVector' of the corresponding C types to one's mutable matrix. vals :: MMatrix n m s a -> VSM.MVector s (C a) {-# INLINE vals #-} vals (MMatrix (Vec x)) = x -- | Create a mutable matrix from a mutable storable 'VSM.MVector'. fromVector :: Elem a => VSM.MVector s (C a) -> MMatrix n m s a {-# INLINE fromVector #-} fromVector x = MMatrix (Vec x)