{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, TypeFamilies, Rank2Types #-} {-# OPTIONS_HADDOCK hide #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Matrix.Dense.STBase -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Matrix.Dense.STBase where import Control.Monad import Control.Monad.ST import Data.Elem.BLAS( Elem, BLAS1, BLAS3 ) import Data.Tensor.Class import Data.Tensor.Class.MTensor import Data.Matrix.Class import Data.Matrix.Dense.Base import Data.Matrix.Dense.IOBase import Data.Matrix.Herm import Data.Matrix.TriBase import Data.Vector.Dense.STBase -- | Dense matrix in the 'ST' monad. The type arguments are as follows: -- -- * @s@: the state variable argument for the 'ST' type -- -- * @np@: a phantom type for the shape of the matrix. Most functions -- will demand that this be specified as a pair. When writing a function -- signature, you should always prefer @STMatrix s (n,p) e@ to -- @STMatrix s np e@. -- -- * @e@: the element type of the matrix. Only certain element types -- are supported. -- newtype STMatrix s np e = STMatrix (IOMatrix np e) -- | A safe way to create and work with a mutable matrix before returning -- an immutable matrix for later perusal. This function avoids copying -- the matrix before returning it - it uses unsafeFreezeMatrix internally, -- but this wrapper is a safe interface to that function. runSTMatrix :: (forall s . ST s (STMatrix s n e)) -> Matrix n e runSTMatrix mx = runST $ mx >>= \(STMatrix x) -> return (Matrix x) instance HasVectorView (STMatrix s) where type VectorView (STMatrix s) = STVector s instance Shaped (STMatrix s) (Int,Int) where shape (STMatrix a) = shapeIOMatrix a {-# INLINE shape #-} bounds (STMatrix a) = boundsIOMatrix a {-# INLINE bounds #-} instance (Elem e) => ReadTensor (STMatrix s) (Int,Int) e (ST s) where getSize (STMatrix a) = unsafeIOToST $ getSizeIOMatrix a {-# INLINE getSize #-} unsafeReadElem (STMatrix a) i = unsafeIOToST $ unsafeReadElemIOMatrix a i {-# INLINE unsafeReadElem #-} getIndices (STMatrix a) = unsafeIOToST $ getIndicesIOMatrix a {-# INLINE getIndices #-} getIndices' (STMatrix a) = unsafeIOToST $ getIndicesIOMatrix' a {-# INLINE getIndices' #-} getElems (STMatrix a) = unsafeIOToST $ getElemsIOMatrix a {-# INLINE getElems #-} getElems' (STMatrix a) = unsafeIOToST $ getElemsIOMatrix' a {-# INLINE getElems' #-} getAssocs (STMatrix a) = unsafeIOToST $ getAssocsIOMatrix a {-# INLINE getAssocs #-} getAssocs' (STMatrix a) = unsafeIOToST $ getAssocsIOMatrix' a {-# INLINE getAssocs' #-} instance (BLAS1 e) => WriteTensor (STMatrix s) (Int,Int) e (ST s) where getMaxSize (STMatrix a) = unsafeIOToST $ getMaxSizeIOMatrix a {-# INLINE getMaxSize #-} setZero (STMatrix a) = unsafeIOToST $ setZeroIOMatrix a {-# INLINE setZero #-} setConstant e (STMatrix a) = unsafeIOToST $ setConstantIOMatrix e a {-# INLINE setConstant #-} canModifyElem (STMatrix a) i = unsafeIOToST $ canModifyElemIOMatrix a i {-# INLINE canModifyElem #-} unsafeWriteElem (STMatrix a) i e = unsafeIOToST $ unsafeWriteElemIOMatrix a i e {-# INLINE unsafeWriteElem #-} unsafeModifyElem (STMatrix a) i f = unsafeIOToST $ unsafeModifyElemIOMatrix a i f {-# INLINE unsafeModifyElem #-} modifyWith f (STMatrix a) = unsafeIOToST $ modifyWithIOMatrix f a {-# INLINE modifyWith #-} doConj (STMatrix a) = unsafeIOToST $ doConjIOMatrix a {-# INLINE doConj #-} scaleBy k (STMatrix a) = unsafeIOToST $ scaleByIOMatrix k a {-# INLINE scaleBy #-} shiftBy k (STMatrix a) = unsafeIOToST $ shiftByIOMatrix k a {-# INLINE shiftBy #-} instance MatrixShaped (STMatrix s) where herm (STMatrix a) = STMatrix (herm a) {-# INLINE herm #-} instance (BLAS3 e) => MMatrix (STMatrix s) e (ST s) where unsafeDoSApplyAdd = gemv {-# INLINE unsafeDoSApplyAdd #-} unsafeDoSApplyAddMat = gemm {-# INLINE unsafeDoSApplyAddMat #-} unsafeGetRow = unsafeGetRowMatrix {-# INLINE unsafeGetRow #-} unsafeGetCol = unsafeGetColMatrix {-# INLINE unsafeGetCol #-} getRows = getRowsST {-# INLINE getRows #-} getCols = getColsST {-# INLINE getCols #-} instance (BLAS3 e) => MMatrix (Herm (STMatrix s)) e (ST s) where unsafeDoSApplyAdd = hemv' {-# INLINE unsafeDoSApplyAdd #-} unsafeDoSApplyAddMat = hemm' {-# INLINE unsafeDoSApplyAddMat #-} getRows = getRowsST {-# INLINE getRows #-} getCols = getColsST {-# INLINE getCols #-} instance (BLAS3 e) => MMatrix (Tri (STMatrix s)) e (ST s) where unsafeDoSApplyAdd = unsafeDoSApplyAddTriMatrix {-# INLINE unsafeDoSApplyAdd #-} unsafeDoSApplyAddMat = unsafeDoSApplyAddMatTriMatrix {-# INLINE unsafeDoSApplyAddMat #-} unsafeDoSApply_ = trmv {-# INLINE unsafeDoSApply_ #-} unsafeDoSApplyMat_ = trmm {-# INLINE unsafeDoSApplyMat_ #-} getRows = getRowsST {-# INLINE getRows #-} getCols = getColsST {-# INLINE getCols #-} instance (BLAS3 e) => MSolve (Tri (STMatrix s)) e (ST s) where unsafeDoSSolve = unsafeDoSSolveTriMatrix {-# INLINE unsafeDoSSolve #-} unsafeDoSSolveMat = unsafeDoSSolveMatTriMatrix {-# INLINE unsafeDoSSolveMat #-} unsafeDoSSolve_ = trsv {-# INLINE unsafeDoSSolve_ #-} unsafeDoSSolveMat_ = trsm {-# INLINE unsafeDoSSolveMat_ #-} instance (Elem e) => BaseMatrix (STMatrix s) e where ldaMatrix (STMatrix a) = ldaMatrixIOMatrix a {-# INLINE ldaMatrix #-} isHermMatrix (STMatrix a) = isHermMatrix a {-# INLINE isHermMatrix #-} unsafeSubmatrixView (STMatrix a) ij mn = STMatrix (unsafeSubmatrixViewIOMatrix a ij mn) {-# INLINE unsafeSubmatrixView #-} unsafeDiagView (STMatrix a) i = STVector (unsafeDiagViewIOMatrix a i) {-# INLINE unsafeDiagView #-} unsafeRowView (STMatrix a) i = STVector (unsafeRowViewIOMatrix a i) {-# INLINE unsafeRowView #-} unsafeColView (STMatrix a) i = STVector (unsafeColViewIOMatrix a i) {-# INLINE unsafeColView #-} maybeViewMatrixAsVector (STMatrix a) = liftM STVector (maybeViewMatrixAsVector a) {-# INLINE maybeViewMatrixAsVector #-} maybeViewVectorAsMatrix mn (STVector x) = liftM STMatrix $ maybeViewVectorAsIOMatrix mn x {-# INLINE maybeViewVectorAsMatrix #-} maybeViewVectorAsRow (STVector x) = liftM STMatrix (maybeViewVectorAsRow x) {-# INLINE maybeViewVectorAsRow #-} maybeViewVectorAsCol (STVector x) = liftM STMatrix (maybeViewVectorAsCol x) {-# INLINE maybeViewVectorAsCol #-} unsafeIOMatrixToMatrix = STMatrix {-# INLINE unsafeIOMatrixToMatrix #-} unsafeMatrixToIOMatrix (STMatrix a) = a {-# INLINE unsafeMatrixToIOMatrix #-} instance (BLAS3 e) => ReadMatrix (STMatrix s) e (ST s) where unsafePerformIOWithMatrix (STMatrix a) f = unsafeIOToST $ f a {-# INLINE unsafePerformIOWithMatrix #-} freezeMatrix (STMatrix a) = unsafeIOToST $ freezeIOMatrix a {-# INLINE freezeMatrix #-} unsafeFreezeMatrix (STMatrix a) = unsafeIOToST $ unsafeFreezeIOMatrix a {-# INLINE unsafeFreezeMatrix #-} instance (BLAS3 e) => WriteMatrix (STMatrix s) e (ST s) where newMatrix_ = unsafeIOToST . liftM STMatrix . newIOMatrix_ {-# INLINE newMatrix_ #-} unsafeConvertIOMatrix = unsafeIOToST . liftM STMatrix {-# INLINE unsafeConvertIOMatrix #-} thawMatrix = unsafeIOToST . liftM STMatrix . thawIOMatrix {-# INLINE thawMatrix #-} unsafeThawMatrix = unsafeIOToST . liftM STMatrix . unsafeThawIOMatrix {-# INLINE unsafeThawMatrix #-}