module Data.Vector.Dense.Class.Internal (
IOVector,
STVector,
unsafeIOVectorToSTVector,
unsafeSTVectorToIOVector,
BaseVector(..),
ReadVector,
WriteVector,
dim,
stride,
isConj,
coerceVector,
shapeVector,
boundsVector,
getSizeVector,
getAssocsVector,
getIndicesVector,
getElemsVector,
getAssocsVector',
getIndicesVector',
getElemsVector',
unsafeReadElemVector,
newVector_,
newZeroVector,
setZeroVector,
newConstantVector,
setConstantVector,
canModifyElemVector,
unsafeWriteElemVector,
modifyWithVector,
doConjVector,
scaleByVector,
shiftByVector,
newCopyVector,
unsafeCopyVector,
unsafeSwapVector,
unsafeAxpyVector,
unsafeMulVector,
unsafeDivVector,
withVectorPtr,
indexOfVector,
indicesVector,
vectorCall,
vectorCall2,
) where
import Control.Monad( forM_ )
import Control.Monad.ST
import Foreign
import Unsafe.Coerce
import BLAS.Internal ( clearArray, inlinePerformIO )
import BLAS.Elem
import qualified BLAS.C as BLAS
import BLAS.Tensor
import BLAS.UnsafeIOToM
import Data.Vector.Dense.Class.Internal.Base
class (BaseVector x, UnsafeIOToM m, ReadTensor x Int m) => ReadVector x m
class (ReadVector x m, WriteTensor x Int m) => WriteVector x m | x -> m, m -> x
coerceVector :: (BaseVector x) => x n e -> x n' e
coerceVector = unsafeCoerce
shapeVector :: (BaseVector x) => x n e -> Int
shapeVector = dim
boundsVector :: (BaseVector x) => x n e -> (Int,Int)
boundsVector x = (0, dim x 1)
getSizeVector :: (ReadVector x m) => x n e -> m Int
getSizeVector = return . dim
getIndicesVector :: (ReadVector x m) => x n e -> m [Int]
getIndicesVector = return . indicesVector
getIndicesVector' :: (ReadVector x m) => x n e -> m [Int]
getIndicesVector' = getIndicesVector
getElemsVector :: (ReadVector x m, Elem e) => x n e -> m [e]
getElemsVector x = do
ies <- getAssocsVector x
return $ (snd . unzip) ies
getElemsVector' :: (ReadVector x m, Elem e) => x n e -> m [e]
getElemsVector' x = do
ies <- getAssocsVector' x
return $ (snd . unzip) ies
getAssocsVector :: (ReadVector x m, Elem e) => x n e -> m [(Int,e)]
getAssocsVector x
| isConj x =
getAssocsVector (conj x)
>>= return . map (\(i,e) -> (i,conj e))
| otherwise =
let (f,p,n,incX,_) = arrayFromVector x
in return $ go n f incX p 0
where
go !n !f !incX !ptr !i
| i >= n =
inlinePerformIO $ do
touchForeignPtr f
return []
| otherwise =
let e = inlinePerformIO $ peek ptr
ptr' = ptr `advancePtr` incX
i' = i + 1
ies = go n f incX ptr' i'
in e `seq` ((i,e):ies)
getAssocsVector' :: (ReadVector x m, Elem e) => x n e -> m [(Int,e)]
getAssocsVector' x
| isConj x =
getAssocsVector' (conj x)
>>= return . map (\(i,e) -> (i,conj e))
| otherwise =
unsafeIOToM $
withVectorPtr x $ \ptr ->
go (dim x) (stride x) ptr 0
where
go !n !incX !ptr !i
| i >= n =
return []
| otherwise = do
e <- peek ptr
let ptr' = ptr `advancePtr` incX
i' = i + 1
ies <- go n incX ptr' i'
return $ (i,e):ies
unsafeReadElemVector :: (ReadVector x m, Elem e) => x n e -> Int -> m e
unsafeReadElemVector x i
| isConj x =
unsafeReadElemVector (conj x) i >>= return . conj
| otherwise =
unsafeIOToM $
withVectorPtr x $ \ptr ->
peekElemOff ptr (indexOfVector x i)
newVector_ :: (WriteVector x m, Elem e) => Int -> m (x n e)
newVector_ n
| n < 0 =
fail $ "Tried to create a vector with `" ++ show n ++ "' elements."
| otherwise = unsafeIOToM $ do
arr <- mallocForeignPtrArray n
return $ vectorViewArray arr (unsafeForeignPtrToPtr arr) n 1 False
newZeroVector :: (WriteVector y m, Elem e) => Int -> m (y n e)
newZeroVector n = do
x <- newVector_ n
setZeroVector x
return x
setZeroVector :: (WriteVector y m, Elem e) => y n e -> m ()
setZeroVector x
| stride x == 1 = unsafeIOToM $
withVectorPtr x $
flip clearArray (dim x)
| otherwise = setConstantVector 0 x
newConstantVector :: (WriteVector y m, Elem e) => Int -> e -> m (y n e)
newConstantVector n e = do
x <- newVector_ n
setConstantVector e x
return x
setConstantVector :: (WriteVector y m, Elem e) => e -> y n e -> m ()
setConstantVector e x
| isConj x = setConstantVector (conj e) (conj x)
| otherwise = unsafeIOToM $ withVectorPtr x $ go (dim x)
where
go !n !ptr | n <= 0 = return ()
| otherwise = let ptr' = ptr `advancePtr` (stride x)
n' = n 1
in poke ptr e >>
go n' ptr'
canModifyElemVector :: (WriteVector y m) => y n e -> Int -> m Bool
canModifyElemVector _ _ = return True
unsafeWriteElemVector :: (WriteVector y m, Elem e) => y n e -> Int -> e -> m ()
unsafeWriteElemVector x i e =
let e' = if isConj x then conj e else e
in unsafeIOToM $ withVectorPtr x $ \ptr ->
pokeElemOff ptr (indexOfVector x i) e'
modifyWithVector :: (WriteVector y m, Elem e) => (e -> e) -> y n e -> m ()
modifyWithVector f x
| isConj x = modifyWithVector (conj . f . conj) (conj x)
| otherwise = unsafeIOToM $
withVectorPtr x $ \ptr ->
go (dim x) ptr
where
go !n !ptr | n <= 0 = return ()
| otherwise = do
peek ptr >>= poke ptr . f
go (n1) (ptr `advancePtr` incX)
incX = stride x
newCopyVector :: (BLAS1 e, ReadVector x m, WriteVector y m) =>
x n e -> m (y n e)
newCopyVector x
| isConj x =
newCopyVector (conj x) >>= return . conj
| otherwise = do
y <- newVector_ (dim x)
unsafeCopyVector y x
return y
unsafeCopyVector :: (BLAS1 e, WriteVector y m, ReadVector x m) =>
y n e -> x n e -> m ()
unsafeCopyVector y x
| isConj x && isConj y =
unsafeCopyVector (conj y) (conj x)
| isConj x || isConj y =
forM_ [0..(dim x 1)] $ \i -> do
unsafeReadElem x i >>= unsafeWriteElem y i
| otherwise =
vectorCall2 BLAS.copy x y
unsafeSwapVector :: (BLAS1 e, WriteVector y m) =>
y n e -> y n e -> m ()
unsafeSwapVector x y
| isConj x && isConj y =
unsafeSwapVector (conj x) (conj y)
| isConj x || isConj y =
forM_ [0..(dim x 1)] $ \i -> do
tmp <- unsafeReadElem x i
unsafeReadElem y i >>= unsafeWriteElem x i
unsafeWriteElem y i tmp
| otherwise =
vectorCall2 BLAS.swap x y
doConjVector :: (WriteVector y m, BLAS1 e) => y n e -> m ()
doConjVector = vectorCall BLAS.conj
scaleByVector :: (WriteVector y m, BLAS1 e) => e -> y n e -> m ()
scaleByVector 1 _ = return ()
scaleByVector k x | isConj x = scaleByVector (conj k) (conj x)
| otherwise = vectorCall (flip BLAS.scal k) x
shiftByVector :: (WriteVector y m, Elem e) => e -> y n e -> m ()
shiftByVector k x | isConj x = shiftByVector (conj k) (conj x)
| otherwise = modifyWithVector (k+) x
unsafeAxpyVector :: (ReadVector x m, WriteVector y m, BLAS1 e) =>
e -> x n e -> y n e -> m ()
unsafeAxpyVector alpha x y
| isConj y =
unsafeAxpyVector (conj alpha) (conj x) (conj y)
| isConj x =
vectorCall2 (flip BLAS.acxpy alpha) x y
| otherwise =
vectorCall2 (flip BLAS.axpy alpha) x y
unsafeMulVector :: (WriteVector y m, ReadVector x m, BLAS1 e) =>
y n e -> x n e -> m ()
unsafeMulVector y x
| isConj y =
unsafeMulVector (conj y) (conj x)
| isConj x =
vectorCall2 BLAS.cmul x y
| otherwise =
vectorCall2 BLAS.mul x y
unsafeDivVector :: (WriteVector y m, ReadVector x m, BLAS1 e) =>
y n e -> x n e -> m ()
unsafeDivVector y x
| isConj y =
unsafeDivVector (conj y) (conj x)
| isConj x =
vectorCall2 BLAS.cdiv x y
| otherwise =
vectorCall2 BLAS.div x y
indexOfVector :: (BaseVector x) => x n e -> Int -> Int
indexOfVector x i = i * stride x
indicesVector :: (BaseVector x) => x n e -> [Int]
indicesVector x = [0..(n1)] where n = dim x
vectorCall :: (ReadVector x m) =>
(Int -> Ptr e -> Int -> IO a)
-> x n e -> m a
vectorCall f x =
let n = dim x
incX = stride x
in unsafeIOToM $
withVectorPtr x $ \pX ->
f n pX incX
vectorCall2 :: (ReadVector x m, ReadVector y m) =>
(Int -> Ptr e -> Int -> Ptr f -> Int -> IO a)
-> x n e -> y n' f -> m a
vectorCall2 f x y =
let n = dim x
incX = stride x
incY = stride y
in unsafeIOToM $
withVectorPtr x $ \pX ->
withVectorPtr y $ \pY ->
f n pX incX pY incY
data IOVector n e =
DV !(ForeignPtr e)
!(Ptr e)
!Int
!Int
!Bool
newtype STVector s n e = ST (IOVector n e)
unsafeIOVectorToSTVector :: IOVector n e -> STVector s n e
unsafeIOVectorToSTVector = ST
unsafeSTVectorToIOVector :: STVector s n e -> IOVector n e
unsafeSTVectorToIOVector (ST x) = x
instance BaseVector IOVector where
vectorViewArray = DV
arrayFromVector (DV f p n s c) = (f,p,n,s,c)
instance BaseVector (STVector s) where
vectorViewArray f p n s c = ST $ DV f p n s c
arrayFromVector (ST x) = arrayFromVector x
instance BaseTensor IOVector Int where
bounds = boundsVector
shape = shapeVector
instance BaseTensor (STVector s) Int where
bounds = boundsVector
shape = shapeVector
instance ReadTensor IOVector Int IO where
getSize = getSizeVector
getAssocs = getAssocsVector
getIndices = getIndicesVector
getElems = getElemsVector
getAssocs' = getAssocsVector'
getIndices' = getIndicesVector'
getElems' = getElemsVector'
unsafeReadElem = unsafeReadElemVector
instance ReadTensor (STVector s) Int (ST s) where
getSize = getSizeVector
getAssocs = getAssocsVector
getIndices = getIndicesVector
getElems = getElemsVector
getAssocs' = getAssocsVector'
getIndices' = getIndicesVector'
getElems' = getElemsVector'
unsafeReadElem = unsafeReadElemVector
instance ReadVector IOVector IO where
instance ReadVector (STVector s) (ST s) where
instance WriteTensor IOVector Int IO where
setConstant = setConstantVector
setZero = setZeroVector
canModifyElem = canModifyElemVector
unsafeWriteElem = unsafeWriteElemVector
modifyWith = modifyWithVector
doConj = doConjVector
scaleBy = scaleByVector
shiftBy = shiftByVector
instance WriteTensor (STVector s) Int (ST s) where
setConstant = setConstantVector
setZero = setZeroVector
canModifyElem = canModifyElemVector
unsafeWriteElem = unsafeWriteElemVector
modifyWith = modifyWithVector
doConj = doConjVector
scaleBy = scaleByVector
shiftBy = shiftByVector
instance WriteVector IOVector IO where
instance WriteVector (STVector s) (ST s) where