module Data.Matrix.Dense.STBase
where
import Control.Monad
import Control.Monad.ST
import Data.Elem.BLAS( Elem, BLAS1, BLAS3 )
import Data.Tensor.Class
import Data.Tensor.Class.MTensor
import Data.Matrix.Class
import Data.Matrix.Dense.Base
import Data.Matrix.Dense.IOBase
import Data.Matrix.Herm
import Data.Matrix.TriBase
import Data.Vector.Dense.STBase
newtype STMatrix s np e = STMatrix (IOMatrix np e)
runSTMatrix :: (forall s . ST s (STMatrix s n e)) -> Matrix n e
runSTMatrix mx =
runST $ mx >>= \(STMatrix x) -> return (Matrix x)
instance HasVectorView (STMatrix s) where
type VectorView (STMatrix s) = STVector s
instance Shaped (STMatrix s) (Int,Int) where
shape (STMatrix a) = shapeIOMatrix a
bounds (STMatrix a) = boundsIOMatrix a
instance (Elem e) => ReadTensor (STMatrix s) (Int,Int) e (ST s) where
getSize (STMatrix a) = unsafeIOToST $ getSizeIOMatrix a
unsafeReadElem (STMatrix a) i = unsafeIOToST $ unsafeReadElemIOMatrix a i
getIndices (STMatrix a) = unsafeIOToST $ getIndicesIOMatrix a
getIndices' (STMatrix a) = unsafeIOToST $ getIndicesIOMatrix' a
getElems (STMatrix a) = unsafeIOToST $ getElemsIOMatrix a
getElems' (STMatrix a) = unsafeIOToST $ getElemsIOMatrix' a
getAssocs (STMatrix a) = unsafeIOToST $ getAssocsIOMatrix a
getAssocs' (STMatrix a) = unsafeIOToST $ getAssocsIOMatrix' a
instance (BLAS1 e) => WriteTensor (STMatrix s) (Int,Int) e (ST s) where
getMaxSize (STMatrix a) = unsafeIOToST $ getMaxSizeIOMatrix a
setZero (STMatrix a) = unsafeIOToST $ setZeroIOMatrix a
setConstant e (STMatrix a) = unsafeIOToST $ setConstantIOMatrix e a
canModifyElem (STMatrix a) i = unsafeIOToST $ canModifyElemIOMatrix a i
unsafeWriteElem (STMatrix a) i e = unsafeIOToST $ unsafeWriteElemIOMatrix a i e
unsafeModifyElem (STMatrix a) i f = unsafeIOToST $ unsafeModifyElemIOMatrix a i f
modifyWith f (STMatrix a) = unsafeIOToST $ modifyWithIOMatrix f a
doConj (STMatrix a) = unsafeIOToST $ doConjIOMatrix a
scaleBy k (STMatrix a) = unsafeIOToST $ scaleByIOMatrix k a
shiftBy k (STMatrix a) = unsafeIOToST $ shiftByIOMatrix k a
instance MatrixShaped (STMatrix s) where
herm (STMatrix a) = STMatrix (herm a)
instance (BLAS3 e) => MMatrix (STMatrix s) e (ST s) where
unsafeDoSApplyAdd = gemv
unsafeDoSApplyAddMat = gemm
unsafeGetRow = unsafeGetRowMatrix
unsafeGetCol = unsafeGetColMatrix
getRows = getRowsST
getCols = getColsST
instance (BLAS3 e) => MMatrix (Herm (STMatrix s)) e (ST s) where
unsafeDoSApplyAdd = hemv'
unsafeDoSApplyAddMat = hemm'
getRows = getRowsST
getCols = getColsST
instance (BLAS3 e) => MMatrix (Tri (STMatrix s)) e (ST s) where
unsafeDoSApplyAdd = unsafeDoSApplyAddTriMatrix
unsafeDoSApplyAddMat = unsafeDoSApplyAddMatTriMatrix
unsafeDoSApply_ = trmv
unsafeDoSApplyMat_ = trmm
getRows = getRowsST
getCols = getColsST
instance (BLAS3 e) => MSolve (Tri (STMatrix s)) e (ST s) where
unsafeDoSSolve = unsafeDoSSolveTriMatrix
unsafeDoSSolveMat = unsafeDoSSolveMatTriMatrix
unsafeDoSSolve_ = trsv
unsafeDoSSolveMat_ = trsm
instance (Elem e) => BaseMatrix (STMatrix s) e where
ldaMatrix (STMatrix a) = ldaMatrixIOMatrix a
isHermMatrix (STMatrix a) = isHermMatrix a
unsafeSubmatrixView (STMatrix a) ij mn =
STMatrix (unsafeSubmatrixViewIOMatrix a ij mn)
unsafeDiagView (STMatrix a) i = STVector (unsafeDiagViewIOMatrix a i)
unsafeRowView (STMatrix a) i = STVector (unsafeRowViewIOMatrix a i)
unsafeColView (STMatrix a) i = STVector (unsafeColViewIOMatrix a i)
maybeViewMatrixAsVector (STMatrix a) = liftM STVector (maybeViewMatrixAsVector a)
maybeViewVectorAsMatrix mn (STVector x) =
liftM STMatrix $ maybeViewVectorAsIOMatrix mn x
maybeViewVectorAsRow (STVector x) = liftM STMatrix (maybeViewVectorAsRow x)
maybeViewVectorAsCol (STVector x) = liftM STMatrix (maybeViewVectorAsCol x)
unsafeIOMatrixToMatrix = STMatrix
unsafeMatrixToIOMatrix (STMatrix a) = a
instance (BLAS3 e) => ReadMatrix (STMatrix s) e (ST s) where
unsafePerformIOWithMatrix (STMatrix a) f = unsafeIOToST $ f a
freezeMatrix (STMatrix a) = unsafeIOToST $ freezeIOMatrix a
unsafeFreezeMatrix (STMatrix a) = unsafeIOToST $ unsafeFreezeIOMatrix a
instance (BLAS3 e) => WriteMatrix (STMatrix s) e (ST s) where
newMatrix_ = unsafeIOToST . liftM STMatrix . newIOMatrix_
unsafeConvertIOMatrix = unsafeIOToST . liftM STMatrix
thawMatrix = unsafeIOToST . liftM STMatrix . thawIOMatrix
unsafeThawMatrix = unsafeIOToST . liftM STMatrix . unsafeThawIOMatrix