module Data.Vector.Dense.STBase
where
import Control.Monad
import Control.Monad.ST
import Data.Elem.BLAS( Elem, BLAS1 )
import Data.Tensor.Class
import Data.Tensor.Class.MTensor
import Data.Vector.Dense.Base
import Data.Vector.Dense.IOBase
newtype STVector s n e = STVector (IOVector n e)
runSTVector :: (forall s . ST s (STVector s n e)) -> Vector n e
runSTVector mx =
runST $ mx >>= \(STVector x) -> return (Vector x)
instance Shaped (STVector s) Int where
shape (STVector x) = shapeIOVector x
bounds (STVector x) = boundsIOVector x
instance (Elem e) => ReadTensor (STVector s) Int e (ST s) where
getSize (STVector x) = unsafeIOToST $ getSizeIOVector x
unsafeReadElem (STVector x) i = unsafeIOToST $ unsafeReadElemIOVector x i
getIndices (STVector x) = unsafeIOToST $ getIndicesIOVector x
getIndices' (STVector x) = unsafeIOToST $ getIndicesIOVector' x
getElems (STVector x) = unsafeIOToST $ getElemsIOVector x
getElems' (STVector x) = unsafeIOToST $ getElemsIOVector' x
getAssocs (STVector x) = unsafeIOToST $ getAssocsIOVector x
getAssocs' (STVector x) = unsafeIOToST $ getAssocsIOVector' x
instance (BLAS1 e) => WriteTensor (STVector s) Int e (ST s) where
getMaxSize (STVector x) = unsafeIOToST $ getMaxSizeIOVector x
setZero (STVector x) = unsafeIOToST $ setZeroIOVector x
setConstant e (STVector x) = unsafeIOToST $ setConstantIOVector e x
canModifyElem (STVector x) i = unsafeIOToST $ canModifyElemIOVector x i
unsafeWriteElem (STVector x) i e= unsafeIOToST $ unsafeWriteElemIOVector x i e
unsafeModifyElem (STVector x) i f = unsafeIOToST $ unsafeModifyElemIOVector x i f
modifyWith f (STVector x) = unsafeIOToST $ modifyWithIOVector f x
doConj (STVector x) = unsafeIOToST $ doConjIOVector x
scaleBy k (STVector x) = unsafeIOToST $ scaleByIOVector k x
shiftBy k (STVector x) = unsafeIOToST $ shiftByIOVector k x
instance (Elem e) => BaseVector (STVector s) e where
dim (STVector x) = dimIOVector x
stride (STVector x) = strideIOVector x
conjEnum (STVector x) = conjEnumIOVector x
conj (STVector x) = STVector (conjIOVector x)
unsafeSubvectorViewWithStride s (STVector x) o n =
STVector (unsafeSubvectorViewWithStrideIOVector s x o n)
unsafeVectorToIOVector (STVector x) = x
unsafeIOVectorToVector = STVector
instance (BLAS1 e) => ReadVector (STVector s) e (ST s) where
unsafePerformIOWithVector (STVector x) f = unsafeIOToST $ f x
freezeVector (STVector x) = unsafeIOToST $ freezeIOVector x
unsafeFreezeVector (STVector x) = unsafeIOToST $ unsafeFreezeIOVector x
instance (BLAS1 e) => WriteVector (STVector s) e (ST s) where
newVector_ = liftM STVector . unsafeIOToST . newIOVector_
unsafeConvertIOVector = unsafeIOToST . liftM STVector
thawVector = liftM STVector . unsafeIOToST . thawIOVector
unsafeThawVector = liftM STVector . unsafeIOToST . unsafeThawIOVector