module Data.Vector.Dense.IOBase
where
import Control.Monad
import Foreign
import System.IO.Unsafe
import BLAS.Internal ( clearArray )
import BLAS.Types( ConjEnum(..), flipConj )
import Data.Elem.BLAS ( Complex, Elem, BLAS1, conjugate )
import qualified Data.Elem.BLAS.Level1 as BLAS
import Data.Tensor.Class
import Data.Tensor.Class.MTensor
data IOVector n e =
IOVector !ConjEnum
!Int
!(ForeignPtr e)
!(Ptr e)
!Int
vectorViewArray :: (Elem e)
=> ForeignPtr e
-> Int
-> Int
-> IOVector n e
vectorViewArray = vectorViewArrayWithStride 1
vectorViewArrayWithStride :: (Elem e)
=> Int
-> ForeignPtr e
-> Int
-> Int
-> IOVector n e
vectorViewArrayWithStride s f o n =
let p = unsafeForeignPtrToPtr f `advancePtr` o
in IOVector NoConj n f p s
dimIOVector :: IOVector n e -> Int
dimIOVector (IOVector _ n _ _ _) = n
strideIOVector :: IOVector n e -> Int
strideIOVector (IOVector _ _ _ _ incX) = incX
conjEnumIOVector :: IOVector n e -> ConjEnum
conjEnumIOVector (IOVector c _ _ _ _) = c
isConjIOVector :: IOVector n e -> Bool
isConjIOVector x = conjEnumIOVector x == Conj
conjIOVector :: IOVector n e -> IOVector n e
conjIOVector (IOVector c n f p incX) = (IOVector (flipConj c) n f p incX)
unsafeSubvectorViewWithStrideIOVector :: (Elem e) =>
Int -> IOVector n e -> Int -> Int -> IOVector n' e
unsafeSubvectorViewWithStrideIOVector s' (IOVector c _ f p inc) o' n' =
IOVector c n' f (p `advancePtr` (inc*o')) (inc*s')
withIOVector :: IOVector n e -> (Ptr e -> IO a) -> IO a
withIOVector (IOVector _ _ f p _) g = do
a <- g p
touchForeignPtr f
return a
newIOVector_ :: (Elem e) => Int -> IO (IOVector n e)
newIOVector_ n
| n < 0 =
fail $ "Tried to create a vector with `" ++ show n ++ "' elements."
| otherwise = do
arr <- mallocForeignPtrArray n
return $ IOVector NoConj n arr (unsafeForeignPtrToPtr arr) 1
newCopyIOVector :: (BLAS1 e) => IOVector n e -> IO (IOVector n e)
newCopyIOVector (IOVector c n f p incX) = do
(IOVector _ _ f' p' _) <- newIOVector_ n
BLAS.copy n p incX p' 1
touchForeignPtr f
touchForeignPtr f'
return (IOVector c n f' p' 1)
shapeIOVector :: IOVector n e -> Int
shapeIOVector = dimIOVector
boundsIOVector :: IOVector n e -> (Int,Int)
boundsIOVector x = (0, dimIOVector x 1)
sizeIOVector :: IOVector n e -> Int
sizeIOVector = dimIOVector
getSizeIOVector :: IOVector n e -> IO Int
getSizeIOVector = return . sizeIOVector
getMaxSizeIOVector :: IOVector n e -> IO Int
getMaxSizeIOVector = getSizeIOVector
indicesIOVector :: IOVector n e -> [Int]
indicesIOVector x = [ 0..n1 ] where n = dimIOVector x
getIndicesIOVector :: IOVector n e -> IO [Int]
getIndicesIOVector = return . indicesIOVector
getIndicesIOVector' :: IOVector n e -> IO [Int]
getIndicesIOVector' = getIndicesIOVector
getElemsIOVector :: (Elem e) => IOVector n e -> IO [e]
getElemsIOVector (IOVector Conj n f p incX) = do
es <- getElemsIOVector (IOVector NoConj n f p incX)
return $ map conjugate es
getElemsIOVector (IOVector NoConj n f p incX) =
let end = p `advancePtr` (n*incX)
go p' | p' == end = do
touchForeignPtr f
return []
| otherwise = unsafeInterleaveIO $ do
e <- peek p'
es <- go (p' `advancePtr` incX)
return (e:es)
in go p
getElemsIOVector' :: (Elem e) => IOVector n e -> IO [e]
getElemsIOVector' (IOVector Conj n f p incX) = do
es <- getElemsIOVector' (IOVector NoConj n f p incX)
return $ map conjugate es
getElemsIOVector' (IOVector NoConj n f p incX) =
let end = p `advancePtr` (incX)
go p' es | p' == end = do
touchForeignPtr f
return es
| otherwise = do
e <- peek p'
go (p' `advancePtr` (incX)) (e:es)
in go (p `advancePtr` ((n1)*incX)) []
getAssocsIOVector :: (Elem e) => IOVector n e -> IO [(Int,e)]
getAssocsIOVector x = liftM2 zip (getIndicesIOVector x) (getElemsIOVector x)
getAssocsIOVector' :: (Elem e) => IOVector n e -> IO [(Int,e)]
getAssocsIOVector' x = liftM2 zip (getIndicesIOVector' x) (getElemsIOVector' x)
unsafeReadElemIOVector :: (Elem e) => IOVector n e -> Int -> IO e
unsafeReadElemIOVector (IOVector Conj n f p incX) i =
liftM conjugate $ unsafeReadElemIOVector (IOVector NoConj n f p incX) i
unsafeReadElemIOVector (IOVector NoConj _ f p incX) i = do
e <- peekElemOff p (i*incX)
touchForeignPtr f
return e
canModifyElemIOVector :: IOVector n e -> Int -> IO Bool
canModifyElemIOVector _ _ = return True
unsafeWriteElemIOVector :: (Elem e) => IOVector n e -> Int -> e -> IO ()
unsafeWriteElemIOVector (IOVector c _ f p incX) i e =
let e' = if c == Conj then conjugate e else e
in do
pokeElemOff p (i*incX) e'
touchForeignPtr f
unsafeModifyElemIOVector :: (Elem e) => IOVector n e -> Int -> (e -> e) -> IO ()
unsafeModifyElemIOVector (IOVector c _ f p incX) i g =
let g' = if c == Conj then conjugate . g . conjugate else g
p' = p `advancePtr` (i*incX)
in do
e <- peek p'
poke p' (g' e)
touchForeignPtr f
unsafeSwapElemsIOVector :: (Elem e) => IOVector n e -> Int -> Int -> IO ()
unsafeSwapElemsIOVector (IOVector _ _ f p incX) i1 i2 =
let p1 = p `advancePtr` (i1*incX)
p2 = p `advancePtr` (i2*incX)
in do
e1 <- peek p1
e2 <- peek p2
poke p2 e1
poke p1 e2
touchForeignPtr f
modifyWithIOVector :: (Elem e) => (e -> e) -> IOVector n e -> IO ()
modifyWithIOVector g (IOVector c n f p incX) =
let g' = if c == Conj then (conjugate . g . conjugate) else g
end = p `advancePtr` (n*incX)
go p' | p' == end = touchForeignPtr f
| otherwise = do
e <- peek p'
poke p' (g' e)
go (p' `advancePtr` incX)
in go p
setZeroIOVector :: (Elem e) => IOVector n e -> IO ()
setZeroIOVector x@(IOVector _ n f p incX)
| incX == 1 = clearArray p n >> touchForeignPtr f
| otherwise = setConstantIOVector 0 x
setConstantIOVector :: (Elem e) => e -> IOVector n e -> IO ()
setConstantIOVector 0 x | strideIOVector x == 1 = setZeroIOVector x
setConstantIOVector e (IOVector c n f p incX) =
let e' = if c == Conj then conjugate e else e
end = p `advancePtr` (n*incX)
go p' | p' == end = touchForeignPtr f
| otherwise = do
poke p' e'
go (p' `advancePtr` incX)
in go p
doConjIOVector :: (BLAS1 e) => IOVector n e -> IO ()
doConjIOVector (IOVector _ n f p incX) =
BLAS.vconj n p incX >> touchForeignPtr f
scaleByIOVector :: (BLAS1 e) => e -> IOVector n e -> IO ()
scaleByIOVector 1 _ = return ()
scaleByIOVector k (IOVector c n f p incX) =
let k' = if c == Conj then conjugate k else k
in BLAS.scal n k' p incX >> touchForeignPtr f
shiftByIOVector :: (Elem e) => e -> IOVector n e -> IO ()
shiftByIOVector k x | isConjIOVector x =
shiftByIOVector (conjugate k) (conjIOVector x)
| otherwise =
modifyWithIOVector (k+) x
instance Shaped IOVector Int where
shape = shapeIOVector
bounds = boundsIOVector
instance (Elem e) => ReadTensor IOVector Int e IO where
getSize = getSizeIOVector
unsafeReadElem = unsafeReadElemIOVector
getIndices = getIndicesIOVector
getIndices' = getIndicesIOVector'
getElems = getElemsIOVector
getElems' = getElemsIOVector'
getAssocs = getAssocsIOVector
getAssocs' = getAssocsIOVector'
instance (BLAS1 e) => WriteTensor IOVector Int e IO where
getMaxSize = getMaxSizeIOVector
setZero = setZeroIOVector
setConstant = setConstantIOVector
canModifyElem = canModifyElemIOVector
unsafeWriteElem = unsafeWriteElemIOVector
unsafeModifyElem = unsafeModifyElemIOVector
modifyWith = modifyWithIOVector
doConj = doConjIOVector
scaleBy = scaleByIOVector
shiftBy = shiftByIOVector