module Data.Vector.Dense.Internal (
Vector,
IOVector,
DVector(..),
module BLAS.Vector,
module BLAS.Tensor,
fromForeignPtr,
toForeignPtr,
newVector,
newVector_,
newListVector,
newBasis,
setBasis,
subvector,
subvectorWithStride,
coerceVector,
unsafeNewVector,
unsafeWithElemPtr,
unsafeSubvector,
unsafeSubvectorWithStride,
unsafeFreeze,
unsafeThaw,
) where
import Control.Monad
import Data.Ix
import Foreign
import System.IO.Unsafe
import Unsafe.Coerce
import Data.AEq
import BLAS.Access
import BLAS.Elem.Base ( Elem )
import qualified BLAS.Elem.Base as E
import BLAS.Vector hiding ( Vector )
import qualified BLAS.Vector as C
import BLAS.Tensor
import BLAS.Internal ( clearArray, inlinePerformIO, checkedSubvector,
checkedSubvectorWithStride )
import BLAS.C.Level1 ( BLAS1, copy )
data DVector t n e =
DV { storageOf :: !(ForeignPtr e)
, offsetOf :: !Int
, lengthOf :: !Int
, strideOf :: !Int
, isConj :: !Bool
}
type Vector n e = DVector Imm n e
type IOVector n e = DVector Mut n e
coerceVector :: DVector t n e -> DVector t m e
coerceVector = unsafeCoerce
fromForeignPtr :: ForeignPtr e -> Int -> Int -> Int -> Bool -> DVector t n e
fromForeignPtr = DV
toForeignPtr :: DVector t n e -> (ForeignPtr e, Int, Int, Int, Bool)
toForeignPtr (DV f o n s c) = (f, o, n, s, c)
subvector :: DVector t n e -> Int -> Int -> DVector t m e
subvector x = checkedSubvector (dim x) (unsafeSubvector x)
unsafeSubvector :: DVector t n e -> Int -> Int -> DVector t m e
unsafeSubvector = unsafeSubvectorWithStride 1
subvectorWithStride :: Int -> DVector t n e -> Int -> Int -> DVector t m e
subvectorWithStride s x =
checkedSubvectorWithStride s (dim x) (unsafeSubvectorWithStride s x)
unsafeSubvectorWithStride :: Int -> DVector t n e -> Int -> Int -> DVector t m e
unsafeSubvectorWithStride s x o n =
let f = storageOf x
o' = indexOf x o
n' = n
s' = s * (strideOf x)
c = isConj x
in
fromForeignPtr f o' n' s' c
newVector_ :: (Elem e) => Int -> IO (DVector t n e)
newVector_ n
| n < 0 =
ioError $ userError $
"Tried to create a vector with `" ++ show n ++ "' elements."
| otherwise = do
arr <- mallocForeignPtrArray n
return $ fromForeignPtr arr 0 n 1 False
newListVector :: (Elem e) => Int -> [e] -> IO (DVector t n e)
newListVector n es = do
x <- newVector_ n
withForeignPtr (storageOf x) $ flip pokeArray $ take n es
return x
listVector :: (Elem e) => Int -> [e] -> Vector n e
listVector n es = unsafeFreeze $ unsafePerformIO $ newListVector n es
newVector :: (BLAS1 e) => Int -> [(Int,e)] -> IO (DVector t n e)
newVector =
newVectorHelp writeElem
unsafeNewVector :: (BLAS1 e) => Int -> [(Int,e)] -> IO (DVector t n e)
unsafeNewVector =
newVectorHelp unsafeWriteElem
newVectorHelp :: (BLAS1 e) =>
(IOVector n e -> Int -> e -> IO ())
-> Int -> [(Int,e)] -> IO (DVector t n e)
newVectorHelp set n ies = do
x <- newZero n
mapM_ (uncurry $ set x) ies
return (unsafeCoerce x)
newBasis :: (BLAS1 e) => Int -> Int -> IO (IOVector n e)
newBasis n i = do
x <- newVector_ n
setBasis i x
return x
setBasis :: (BLAS1 e) => Int -> IOVector n e -> IO ()
setBasis i x
| i < 0 || i >= dim x =
ioError $ userError $
"tried to set a vector of dimension `" ++ show (dim x) ++ "'"
++ " to basis vector `" ++ show i ++ "'"
| otherwise = do
setZero x
unsafeWriteElem x i 1
indexOf :: DVector t n e -> Int -> Int
indexOf x i = offsetOf x + i * strideOf x
unsafeWithElemPtr :: (Elem e) => DVector t n e -> Int -> (Ptr e -> IO a) -> IO a
unsafeWithElemPtr x i f =
withForeignPtr (storageOf x) $ \ptr ->
let elemPtr = ptr `advancePtr` (indexOf x i)
in f elemPtr
unsafeFreeze :: DVector t n e -> Vector n e
unsafeFreeze = unsafeCoerce
unsafeThaw :: DVector t n e -> IOVector n e
unsafeThaw = unsafeCoerce
instance C.Vector (DVector t) where
dim = lengthOf
conj x = let c' = (not . isConj) x
in x { isConj=c' }
instance Tensor (DVector t n) Int e where
shape = dim
bounds x = (0, dim x 1)
instance (BLAS1 e) => ITensor (DVector Imm n) Int e where
size = dim
indices = range . bounds
elems = inlinePerformIO . getElems . unsafeThaw
assocs = inlinePerformIO . getAssocs . unsafeThaw
unsafeAt x = inlinePerformIO . unsafeReadElem (unsafeThaw x)
amap f x = listVector (dim x) (map f $ elems x)
(//) = replaceHelp writeElem
unsafeReplace = replaceHelp unsafeWriteElem
replaceHelp :: (BLAS1 e) =>
(IOVector n e -> Int -> e -> IO ())
-> Vector n e -> [(Int, e)] -> Vector n e
replaceHelp set x ies =
unsafeFreeze $ unsafePerformIO $ do
y <- newCopy (unsafeThaw x)
mapM_ (uncurry $ set y) ies
return y
instance (BLAS1 e) => IDTensor (DVector Imm n) Int e where
zero n = unsafeFreeze $ unsafePerformIO $ newZero n
constant n e = unsafeFreeze $ unsafePerformIO $ newConstant n e
azipWith f x y
| dim y /= n =
error ("amap2: vector lengths differ; first has length `" ++
show n ++ "' and second has length `" ++
show (dim y) ++ "'")
| otherwise =
listVector n (zipWith f (elems x) (elems y))
where
n = dim x
instance (BLAS1 e) => RTensor (DVector t n) Int e IO where
getSize = return . dim
newCopy x
| isConj x =
newCopy (conj x) >>= return . conj
| otherwise = do
y <- newVector_ (dim x)
unsafeWithElemPtr x 0 $ \pX ->
unsafeWithElemPtr y 0 $ \pY ->
let n = dim x
incX = strideOf x
incY = strideOf y
in copy n pX incX pY incY >>
return y
getIndices = return . indices . unsafeFreeze
unsafeReadElem x i
| isConj x =
unsafeReadElem (conj x) i >>= return . E.conj
| otherwise =
withForeignPtr (storageOf x) $ \ptr ->
peekElemOff ptr (indexOf x i)
getAssocs x
| isConj x =
getAssocs (conj x) >>= return . map (\(i,e) -> (i,E.conj e))
| otherwise =
let (f,o,n,incX,_) = toForeignPtr x
ptr = (unsafeForeignPtrToPtr f) `advancePtr` o
in return $ go n f incX ptr 0
where
go !n !f !incX !ptr !i
| i >= n =
inlinePerformIO $ do
touchForeignPtr f
return []
| otherwise =
let e = inlinePerformIO $ peek ptr
ptr' = ptr `advancePtr` incX
i' = i + 1
ies = go n f incX ptr' i'
in e `seq` ((i,e):ies)
instance (BLAS1 e) => RDTensor (DVector t n) Int e IO where
newZero n = newVector_ n >>= (\x -> setZero (unsafeThaw x) >> return x)
newConstant n e = newVector_ n >>= (\x -> setConstant e (unsafeThaw x) >> return x)
instance (BLAS1 e) => MTensor (DVector Mut n) Int e IO where
setZero x
| strideOf x == 1 = unsafeWithElemPtr x 0 $ flip clearArray (dim x)
| otherwise = setConstant 0 x
setConstant e x
| isConj x = setConstant (E.conj e) (conj x)
| otherwise = unsafeWithElemPtr x 0 $ go (dim x)
where
go !n !ptr | n <= 0 = return ()
| otherwise = let ptr' = ptr `advancePtr` (strideOf x)
n' = n 1
in poke ptr e >>
go n' ptr'
unsafeWriteElem x i e =
let e' = if isConj x then E.conj e else e
in withForeignPtr (storageOf x) $ \ptr ->
pokeElemOff ptr (indexOf x i) e'
canModifyElem x i = return $ inRange (bounds x) i
modifyWith f x
| isConj x = modifyWith (E.conj . f . E.conj) (conj x)
| otherwise = withForeignPtr (storageOf x) $
\ptr -> go (dim x) (ptr `advancePtr` offsetOf x)
where
go !n !ptr | n <= 0 = return ()
| otherwise = do
peek ptr >>= poke ptr . f
go (n1) (ptr `advancePtr` incX)
incX = strideOf x
compareHelp :: (BLAS1 e) =>
(e -> e -> Bool) -> Vector n e -> Vector n e -> Bool
compareHelp cmp x y
| isConj x && isConj y =
compareHelp cmp (conj x) (conj y)
| otherwise =
(dim x == dim y) && (and $ zipWith cmp (elems x) (elems y))
instance (BLAS1 e, Eq e) => Eq (DVector Imm n e) where
(==) = compareHelp (==)
instance (BLAS1 e, AEq e) => AEq (DVector Imm n e) where
(===) = compareHelp (===)
(~==) = compareHelp (~==)
instance (BLAS1 e, Show e) => Show (DVector Imm n e) where
show x
| isConj x = "conj (" ++ show (conj x) ++ ")"
| otherwise = "listVector " ++ show (dim x) ++ " " ++ show (elems x)