{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} module Eigen.SparseMatrix.Mutable ( -- * Mutable SparseMatrix MSparseMatrix(..) , IOSparseMatrix , STSparseMatrix , new , reserve -- * SparseMatrix properties , rows , cols , innerSize , outerSize , nonZeros -- * SparseMatrix compression , compressed , compress , uncompress -- * Accessing SparseMatrix data , read , write , setZero , setIdentity ) where import Prelude hiding (read) import Control.Monad.Primitive (PrimMonad(..), unsafePrimToPrim) import Data.Kind (Type) import Foreign.C.String (CString) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Marshal.Alloc (alloca) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable(..)) import qualified Foreign.Concurrent as FC import GHC.Exts (RealWorld) import GHC.TypeLits (Nat, KnownNat, type (<=)) import Eigen.Internal ( Elem , Cast(..) , CSparseMatrix , CSparseMatrixPtr , natToInt , Row , Col ) import qualified Eigen.Internal as Internal -- | Mutable sparse matrix. See 'Eigen.SparseMatrix.SparseMatrix' for -- details about matrix layout. newtype MSparseMatrix :: Nat -> Nat -> Type -> Type -> Type where MSparseMatrix :: (ForeignPtr (CSparseMatrix a)) -> MSparseMatrix n m s a -- | A sparse matrix where the state token is specialised to 'ReadWorld'. type IOSparseMatrix n m a = MSparseMatrix n m RealWorld a -- | This type does not differ from 'MSparseMatrix', but might be desirable -- for readability. type STSparseMatrix n m s a = MSparseMatrix n m s a -- | Create a new sparse matrix with the given size @rows x cols@. new :: forall m n p a. (Elem a, KnownNat n, KnownNat m, PrimMonad p) => p (MSparseMatrix n m (PrimState p) a) new = unsafePrimToPrim $ alloca $ \pm -> do let !c_rs = toC $! natToInt @n !c_cs = toC $! natToInt @m Internal.call $ Internal.sparse_new c_rs c_cs pm m <- peek pm fm <- FC.newForeignPtr m $ Internal.call $ Internal.sparse_free m pure $! MSparseMatrix fm -- | Returns the number of rows of the matrix. rows :: forall n m s a. (Elem a, KnownNat n, KnownNat m) => MSparseMatrix n m s a -> Int rows _ = natToInt @n -- | Returns the number of columns of the matrix. cols :: forall n m s a. (Elem a, KnownNat n, KnownNat m) => MSparseMatrix n m s a -> Int cols _ = natToInt @m -- | Returns the number of rows (resp. columns) of the matrix if the storage order is column majour (resp. row majour) innerSize :: (Elem a, PrimMonad p) => MSparseMatrix n m (PrimState p) a -> p Int innerSize = _prop Internal.sparse_innerSize (pure . fromC) -- | Returns the number of columns (resp. rows) of the matrix if the storage order is column majour (resp. row majour) outerSize :: (Elem a, PrimMonad p) => MSparseMatrix n m (PrimState p) a -> p Int outerSize = _prop Internal.sparse_outerSize (pure . fromC) -- | Returns whether or not the matrix is in compressed form. compressed :: (Elem a, PrimMonad p) => MSparseMatrix n m (PrimState p) a -> p Bool compressed = _prop Internal.sparse_isCompressed (pure . (== 1)) -- | Turns the matrix into compressed format. compress :: (Elem a, PrimMonad p) => MSparseMatrix n m (PrimState p) a -> p () compress = _inplace Internal.sparse_compressInplace -- | Decompresses the matrix. uncompress :: (Elem a, PrimMonad p) => MSparseMatrix n m (PrimState p) a -> p () uncompress = _inplace Internal.sparse_uncompressInplace -- | Read the value of the matrix at position i,j. This function returns @Scalar(0)@ if the element is an explicit 0. read :: forall n m r c p a. (Elem a, PrimMonad p, KnownNat n, KnownNat m, KnownNat r, KnownNat c, r <= n, c <= m) => Row r -> Col c -> MSparseMatrix n m (PrimState p) a -> p a read _ _ (MSparseMatrix fm) = let !c_r = toC $! natToInt @r !c_c = toC $! natToInt @c in unsafePrimToPrim $ withForeignPtr fm $ \m -> alloca $ \px -> do Internal.call $ Internal.sparse_coeff m c_r c_c px fromC <$> peek px {- | Writes the value of the matrix at position @i@, @j@. This function turns the matrix into a non compressed form if that was not the case. This is a @O(log(nnz_j))@ operation (binary search) plus the cost of element insertion if the element does not already exist. Cost of element insertion is sorted insertion in O(1) if the elements of each inner vector are inserted in increasing inner index order, and in @O(nnz_j)@ for a random insertion. -} write :: forall n m r c p a. (Elem a, PrimMonad p, KnownNat n, KnownNat m, KnownNat r, KnownNat c, r <= n, c <= m) => MSparseMatrix n m (PrimState p) a -> Row r -> Col c -> a -> p () write (MSparseMatrix fm) _ _ x = let !c_r = toC $! natToInt @r !c_c = toC $! natToInt @c in unsafePrimToPrim $ withForeignPtr fm $ \m -> alloca $ \px -> do Internal.call $ Internal.sparse_coeffRef m c_r c_c px peek px >>= (`poke` toC x) -- | Sets the matrix to the identity matrix. setIdentity :: (Elem a, PrimMonad p) => MSparseMatrix n m (PrimState p) a -> p () setIdentity = _inplace Internal.sparse_setIdentity -- | Remove all non zeros, but keep allocated memory. setZero :: (Elem a, PrimMonad p) => MSparseMatrix n m (PrimState p) a -> p () setZero = _inplace Internal.sparse_setZero -- | Preallocates for non zeros. The matrix must be in compressed mode. reserve :: (Elem a, PrimMonad p) => MSparseMatrix n m (PrimState p) a -> Int -> p () reserve m s = _inplace (\p -> Internal.sparse_reserve p (toC s)) m -- | Returns the number of nonzero coefficients. nonZeros :: (Elem a, PrimMonad p) => MSparseMatrix n m (PrimState p) a -> p Int nonZeros = _prop Internal.sparse_nonZeros (pure . fromC) _inplace :: (Elem a, PrimMonad p) => (Ptr (CSparseMatrix a) -> IO CString) -> MSparseMatrix n m (PrimState p) a -> p () _inplace f (MSparseMatrix fm) = unsafePrimToPrim $ withForeignPtr fm $ \m -> Internal.call $ f m _prop :: (Storable b, PrimMonad p) => (CSparseMatrixPtr a -> Ptr b -> IO CString) -> (b -> IO c) -> MSparseMatrix n m (PrimState p) a -> p c _prop f g (MSparseMatrix fp) = unsafePrimToPrim $ withForeignPtr fp $ \p -> alloca $ \pq -> do Internal.call (f p pq) peek pq >>= g