module Data.Matrix.Banded.STBase
where
import Control.Monad
import Control.Monad.ST
import Data.Elem.BLAS( Elem, BLAS3 )
import Data.Matrix.Class
import Data.Matrix.Class.MMatrixBase
import Data.Matrix.Class.MSolveBase
import Data.Matrix.Herm
import Data.Matrix.Tri
import Data.Tensor.Class
import Data.Tensor.Class.MTensor
import Data.Matrix.Dense.STBase( STMatrix(..) )
import Data.Vector.Dense.STBase( STVector(..) )
import Data.Matrix.Banded.Base
import Data.Matrix.Banded.IOBase( IOBanded )
import qualified Data.Matrix.Banded.IOBase as IO
newtype STBanded s np e = STBanded (IOBanded np e)
runSTBanded :: (forall s . ST s (STBanded s n e)) -> Banded n e
runSTBanded mx =
runST $ mx >>= \(STBanded x) -> return (Banded x)
instance HasVectorView (STBanded s) where
type VectorView (STBanded s) = STVector s
instance HasMatrixStorage (STBanded s) where
type MatrixStorage (STBanded s) = (STMatrix s)
instance Shaped (STBanded s) (Int,Int) where
shape (STBanded a) = IO.shapeIOBanded a
bounds (STBanded a) = IO.boundsIOBanded a
instance MatrixShaped (STBanded s) where
herm (STBanded a) = STBanded $ IO.hermIOBanded a
instance (BLAS3 e) => ReadTensor (STBanded s) (Int,Int) e (ST s) where
getSize (STBanded a) = unsafeIOToST $ IO.getSizeIOBanded a
getAssocs (STBanded a) = unsafeIOToST $ IO.getAssocsIOBanded a
getIndices (STBanded a) = unsafeIOToST $ IO.getIndicesIOBanded a
getElems (STBanded a) = unsafeIOToST $ IO.getElemsIOBanded a
getAssocs' (STBanded a) = unsafeIOToST $ IO.getAssocsIOBanded' a
getIndices' (STBanded a) = unsafeIOToST $ IO.getIndicesIOBanded' a
getElems' (STBanded a) = unsafeIOToST $ IO.getElemsIOBanded' a
unsafeReadElem (STBanded a) i = unsafeIOToST $ IO.unsafeReadElemIOBanded a i
instance (BLAS3 e) => WriteTensor (STBanded s) (Int,Int) e (ST s) where
setConstant k (STBanded a) = unsafeIOToST $ IO.setConstantIOBanded k a
setZero (STBanded a) = unsafeIOToST $ IO.setZeroIOBanded a
modifyWith f (STBanded a) = unsafeIOToST $ IO.modifyWithIOBanded f a
unsafeWriteElem (STBanded a) i e = unsafeIOToST $ IO.unsafeWriteElemIOBanded a i e
canModifyElem (STBanded a) i = unsafeIOToST $ IO.canModifyElemIOBanded a i
instance (Elem e) => BaseBanded (STBanded s) e where
numLower (STBanded a) = IO.numLowerIOBanded a
numUpper (STBanded a) = IO.numUpperIOBanded a
bandwidths (STBanded a) = IO.bandwidthsIOBanded a
ldaBanded (STBanded a) = IO.ldaIOBanded a
isHermBanded (STBanded a) = IO.isHermIOBanded a
maybeMatrixStorageFromBanded (STBanded a) =
liftM STMatrix $ IO.maybeMatrixStorageFromIOBanded a
maybeBandedFromMatrixStorage mn kl (STMatrix a) =
liftM STBanded $ IO.maybeIOBandedFromMatrixStorage mn kl a
viewVectorAsBanded mn (STVector x) = STBanded $ IO.viewVectorAsIOBanded mn x
maybeViewBandedAsVector (STBanded a) =
liftM STVector $ IO.maybeViewIOBandedAsVector a
unsafeDiagViewBanded (STBanded a) i =
STVector $ IO.unsafeDiagViewIOBanded a i
unsafeRowViewBanded (STBanded a) i =
case IO.unsafeRowViewIOBanded a i of (nb,x,na) -> (nb, STVector x, na)
unsafeColViewBanded (STBanded a) j =
case IO.unsafeColViewIOBanded a j of (nb,x,na) -> (nb, STVector x, na)
unsafeIOBandedToBanded = STBanded
unsafeBandedToIOBanded (STBanded a) = a
instance (BLAS3 e) => ReadBanded (STBanded s) e (ST s) where
unsafePerformIOWithBanded (STBanded a) f = unsafeIOToST $ f a
freezeBanded (STBanded a) = unsafeIOToST $ freezeIOBanded a
unsafeFreezeBanded (STBanded a) = unsafeIOToST $ unsafeFreezeIOBanded a
instance (BLAS3 e) => MMatrix (STBanded s) e (ST s) where
unsafeDoSApplyAdd = gbmv
unsafeDoSApplyAddMat = gbmm
unsafeGetRow = unsafeGetRowBanded
unsafeGetCol = unsafeGetColBanded
getRows = getRowsST
getCols = getColsST
instance (BLAS3 e) => MMatrix (Herm (STBanded s)) e (ST s) where
unsafeDoSApplyAdd = hbmv
unsafeDoSApplyAddMat = hbmm
getRows = getRowsST
getCols = getColsST
instance (BLAS3 e) => MMatrix (Tri (STBanded s)) e (ST s) where
unsafeDoSApply_ = tbmv
unsafeDoSApplyMat_ = tbmm
unsafeDoSApplyAdd = tbmv'
unsafeDoSApplyAddMat = tbmm'
getRows = getRowsST
getCols = getColsST
instance (BLAS3 e) => MSolve (Tri (STBanded s)) e (ST s) where
unsafeDoSSolve_ = tbsv
unsafeDoSSolveMat_ = tbsm
unsafeDoSSolve = tbsv'
unsafeDoSSolveMat = tbsm'