{-# LANGUAGE BangPatterns, MultiParamTypeClasses, FunctionalDependencies, FlexibleContexts, FlexibleInstances #-} {-# OPTIONS_HADDOCK hide, prune #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Vector.Dense.Class.Internal -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Vector.Dense.Class.Internal ( -- * Vector types IOVector, STVector, unsafeIOVectorToSTVector, unsafeSTVectorToIOVector, -- * Vector type classes BaseVector(..), ReadVector, WriteVector, -- * Basic vector propreties dim, stride, isConj, -- * Coercing the vector shape coerceVector, -- * BaseTensor functions shapeVector, boundsVector, -- * ReadTensor functions getSizeVector, getAssocsVector, getIndicesVector, getElemsVector, getAssocsVector', getIndicesVector', getElemsVector', unsafeReadElemVector, -- * WriteTensor functions newVector_, newZeroVector, setZeroVector, newConstantVector, setConstantVector, canModifyElemVector, unsafeWriteElemVector, modifyWithVector, -- * Numeric functions doConjVector, scaleByVector, shiftByVector, -- * CopyTensor functions newCopyVector, unsafeCopyVector, -- * SwapTensor functions unsafeSwapVector, -- * Numeric2 functions unsafeAxpyVector, unsafeMulVector, unsafeDivVector, -- * Utility functions withVectorPtr, indexOfVector, indicesVector, vectorCall, vectorCall2, ) where import Control.Monad( forM_ ) import Control.Monad.ST import Foreign import Unsafe.Coerce import BLAS.Internal ( clearArray, inlinePerformIO ) import BLAS.Elem import qualified BLAS.C as BLAS import BLAS.Tensor import BLAS.UnsafeIOToM import Data.Vector.Dense.Class.Internal.Base class (BaseVector x, UnsafeIOToM m, ReadTensor x Int m) => ReadVector x m class (ReadVector x m, WriteTensor x Int m) => WriteVector x m | x -> m, m -> x -- | Cast the shape type of the vector. coerceVector :: (BaseVector x) => x n e -> x n' e coerceVector = unsafeCoerce {-# INLINE coerceVector #-} -------------------------- BaseTensor functions ----------------------------- shapeVector :: (BaseVector x) => x n e -> Int shapeVector = dim {-# INLINE shapeVector #-} boundsVector :: (BaseVector x) => x n e -> (Int,Int) boundsVector x = (0, dim x - 1) {-# INLINE boundsVector #-} -------------------------- ReadTensor functions ----------------------------- getSizeVector :: (ReadVector x m) => x n e -> m Int getSizeVector = return . dim {-# INLINE getSizeVector #-} getIndicesVector :: (ReadVector x m) => x n e -> m [Int] getIndicesVector = return . indicesVector {-# INLINE getIndicesVector #-} getIndicesVector' :: (ReadVector x m) => x n e -> m [Int] getIndicesVector' = getIndicesVector {-# INLINE getIndicesVector' #-} getElemsVector :: (ReadVector x m, Elem e) => x n e -> m [e] getElemsVector x = do ies <- getAssocsVector x return $ (snd . unzip) ies getElemsVector' :: (ReadVector x m, Elem e) => x n e -> m [e] getElemsVector' x = do ies <- getAssocsVector' x return $ (snd . unzip) ies getAssocsVector :: (ReadVector x m, Elem e) => x n e -> m [(Int,e)] getAssocsVector x | isConj x = getAssocsVector (conj x) >>= return . map (\(i,e) -> (i,conj e)) | otherwise = let (f,p,n,incX,_) = arrayFromVector x in return $ go n f incX p 0 where go !n !f !incX !ptr !i | i >= n = -- This is very important since we are doing unsafe IO. -- Otherwise, the DVector might get discared and the -- memory freed before all of the elements are read inlinePerformIO $ do touchForeignPtr f return [] | otherwise = let e = inlinePerformIO $ peek ptr ptr' = ptr `advancePtr` incX i' = i + 1 ies = go n f incX ptr' i' in e `seq` ((i,e):ies) {-# INLINE getAssocsVector #-} getAssocsVector' :: (ReadVector x m, Elem e) => x n e -> m [(Int,e)] getAssocsVector' x | isConj x = getAssocsVector' (conj x) >>= return . map (\(i,e) -> (i,conj e)) | otherwise = unsafeIOToM $ withVectorPtr x $ \ptr -> go (dim x) (stride x) ptr 0 where go !n !incX !ptr !i | i >= n = return [] | otherwise = do e <- peek ptr let ptr' = ptr `advancePtr` incX i' = i + 1 ies <- go n incX ptr' i' return $ (i,e):ies unsafeReadElemVector :: (ReadVector x m, Elem e) => x n e -> Int -> m e unsafeReadElemVector x i | isConj x = unsafeReadElemVector (conj x) i >>= return . conj | otherwise = unsafeIOToM $ withVectorPtr x $ \ptr -> peekElemOff ptr (indexOfVector x i) ------------------------- WriteTensor functions ----------------------------- -- | Creates a new vector of the given length. The elements will be -- uninitialized. newVector_ :: (WriteVector x m, Elem e) => Int -> m (x n e) newVector_ n | n < 0 = fail $ "Tried to create a vector with `" ++ show n ++ "' elements." | otherwise = unsafeIOToM $ do arr <- mallocForeignPtrArray n return $ vectorViewArray arr (unsafeForeignPtrToPtr arr) n 1 False -- | Create a zero vector of the specified length. newZeroVector :: (WriteVector y m, Elem e) => Int -> m (y n e) newZeroVector n = do x <- newVector_ n setZeroVector x return x -- | Set every element in the vector to zero. setZeroVector :: (WriteVector y m, Elem e) => y n e -> m () setZeroVector x | stride x == 1 = unsafeIOToM $ withVectorPtr x $ flip clearArray (dim x) | otherwise = setConstantVector 0 x -- | Create a constant vector of the specified length. newConstantVector :: (WriteVector y m, Elem e) => Int -> e -> m (y n e) newConstantVector n e = do x <- newVector_ n setConstantVector e x return x -- | Set every element in the vector to a constant. setConstantVector :: (WriteVector y m, Elem e) => e -> y n e -> m () setConstantVector e x | isConj x = setConstantVector (conj e) (conj x) | otherwise = unsafeIOToM $ withVectorPtr x $ go (dim x) where go !n !ptr | n <= 0 = return () | otherwise = let ptr' = ptr `advancePtr` (stride x) n' = n - 1 in poke ptr e >> go n' ptr' canModifyElemVector :: (WriteVector y m) => y n e -> Int -> m Bool canModifyElemVector _ _ = return True {-# INLINE canModifyElemVector #-} unsafeWriteElemVector :: (WriteVector y m, Elem e) => y n e -> Int -> e -> m () unsafeWriteElemVector x i e = let e' = if isConj x then conj e else e in unsafeIOToM $ withVectorPtr x $ \ptr -> pokeElemOff ptr (indexOfVector x i) e' modifyWithVector :: (WriteVector y m, Elem e) => (e -> e) -> y n e -> m () modifyWithVector f x | isConj x = modifyWithVector (conj . f . conj) (conj x) | otherwise = unsafeIOToM $ withVectorPtr x $ \ptr -> go (dim x) ptr where go !n !ptr | n <= 0 = return () | otherwise = do peek ptr >>= poke ptr . f go (n-1) (ptr `advancePtr` incX) incX = stride x ------------------------- CopyTensor functions ------------------------------ -- | Creats a new vector by copying another one. newCopyVector :: (BLAS1 e, ReadVector x m, WriteVector y m) => x n e -> m (y n e) newCopyVector x | isConj x = newCopyVector (conj x) >>= return . conj | otherwise = do y <- newVector_ (dim x) unsafeCopyVector y x return y -- | Same as 'copyVector' but the sizes of the arguments are not checked. unsafeCopyVector :: (BLAS1 e, WriteVector y m, ReadVector x m) => y n e -> x n e -> m () unsafeCopyVector y x | isConj x && isConj y = unsafeCopyVector (conj y) (conj x) | isConj x || isConj y = forM_ [0..(dim x - 1)] $ \i -> do unsafeReadElem x i >>= unsafeWriteElem y i | otherwise = vectorCall2 BLAS.copy x y ------------------------- SwapTensor functions ------------------------------ -- | Same as 'swapVector' but the sizes of the arguments are not checked. unsafeSwapVector :: (BLAS1 e, WriteVector y m) => y n e -> y n e -> m () unsafeSwapVector x y | isConj x && isConj y = unsafeSwapVector (conj x) (conj y) | isConj x || isConj y = forM_ [0..(dim x - 1)] $ \i -> do tmp <- unsafeReadElem x i unsafeReadElem y i >>= unsafeWriteElem x i unsafeWriteElem y i tmp | otherwise = vectorCall2 BLAS.swap x y --------------------------- Numeric functions ------------------------------- doConjVector :: (WriteVector y m, BLAS1 e) => y n e -> m () doConjVector = vectorCall BLAS.conj scaleByVector :: (WriteVector y m, BLAS1 e) => e -> y n e -> m () scaleByVector 1 _ = return () scaleByVector k x | isConj x = scaleByVector (conj k) (conj x) | otherwise = vectorCall (flip BLAS.scal k) x shiftByVector :: (WriteVector y m, Elem e) => e -> y n e -> m () shiftByVector k x | isConj x = shiftByVector (conj k) (conj x) | otherwise = modifyWithVector (k+) x -------------------------- Numeric2 functions ------------------------------- unsafeAxpyVector :: (ReadVector x m, WriteVector y m, BLAS1 e) => e -> x n e -> y n e -> m () unsafeAxpyVector alpha x y | isConj y = unsafeAxpyVector (conj alpha) (conj x) (conj y) | isConj x = vectorCall2 (flip BLAS.acxpy alpha) x y | otherwise = vectorCall2 (flip BLAS.axpy alpha) x y unsafeMulVector :: (WriteVector y m, ReadVector x m, BLAS1 e) => y n e -> x n e -> m () unsafeMulVector y x | isConj y = unsafeMulVector (conj y) (conj x) | isConj x = vectorCall2 BLAS.cmul x y | otherwise = vectorCall2 BLAS.mul x y unsafeDivVector :: (WriteVector y m, ReadVector x m, BLAS1 e) => y n e -> x n e -> m () unsafeDivVector y x | isConj y = unsafeDivVector (conj y) (conj x) | isConj x = vectorCall2 BLAS.cdiv x y | otherwise = vectorCall2 BLAS.div x y --------------------------- Utility functions ------------------------------- indexOfVector :: (BaseVector x) => x n e -> Int -> Int indexOfVector x i = i * stride x {-# INLINE indexOfVector #-} indicesVector :: (BaseVector x) => x n e -> [Int] indicesVector x = [0..(n-1)] where n = dim x {-# INLINE indicesVector #-} vectorCall :: (ReadVector x m) => (Int -> Ptr e -> Int -> IO a) -> x n e -> m a vectorCall f x = let n = dim x incX = stride x in unsafeIOToM $ withVectorPtr x $ \pX -> f n pX incX {-# INLINE vectorCall #-} vectorCall2 :: (ReadVector x m, ReadVector y m) => (Int -> Ptr e -> Int -> Ptr f -> Int -> IO a) -> x n e -> y n' f -> m a vectorCall2 f x y = let n = dim x incX = stride x incY = stride y in unsafeIOToM $ withVectorPtr x $ \pX -> withVectorPtr y $ \pY -> f n pX incX pY incY {-# INLINE vectorCall2 #-} ------------------------------------ Instances ------------------------------ data IOVector n e = DV {-# UNPACK #-} !(ForeignPtr e) -- memory owner {-# UNPACK #-} !(Ptr e) -- pointer to the first element {-# UNPACK #-} !Int -- the length of the vector {-# UNPACK #-} !Int -- the stride (in elements, not bytes) between elements. {-# UNPACK #-} !Bool -- indicates whether or not the vector is conjugated newtype STVector s n e = ST (IOVector n e) unsafeIOVectorToSTVector :: IOVector n e -> STVector s n e unsafeIOVectorToSTVector = ST unsafeSTVectorToIOVector :: STVector s n e -> IOVector n e unsafeSTVectorToIOVector (ST x) = x instance BaseVector IOVector where vectorViewArray = DV {-# INLINE vectorViewArray #-} arrayFromVector (DV f p n s c) = (f,p,n,s,c) {-# INLINE arrayFromVector #-} instance BaseVector (STVector s) where vectorViewArray f p n s c = ST $ DV f p n s c {-# INLINE vectorViewArray #-} arrayFromVector (ST x) = arrayFromVector x {-# INLINE arrayFromVector #-} instance BaseTensor IOVector Int where bounds = boundsVector shape = shapeVector instance BaseTensor (STVector s) Int where bounds = boundsVector shape = shapeVector instance ReadTensor IOVector Int IO where getSize = getSizeVector getAssocs = getAssocsVector getIndices = getIndicesVector getElems = getElemsVector getAssocs' = getAssocsVector' getIndices' = getIndicesVector' getElems' = getElemsVector' unsafeReadElem = unsafeReadElemVector instance ReadTensor (STVector s) Int (ST s) where getSize = getSizeVector getAssocs = getAssocsVector getIndices = getIndicesVector getElems = getElemsVector getAssocs' = getAssocsVector' getIndices' = getIndicesVector' getElems' = getElemsVector' unsafeReadElem = unsafeReadElemVector instance ReadVector IOVector IO where instance ReadVector (STVector s) (ST s) where instance WriteTensor IOVector Int IO where setConstant = setConstantVector setZero = setZeroVector canModifyElem = canModifyElemVector unsafeWriteElem = unsafeWriteElemVector modifyWith = modifyWithVector doConj = doConjVector scaleBy = scaleByVector shiftBy = shiftByVector instance WriteTensor (STVector s) Int (ST s) where setConstant = setConstantVector setZero = setZeroVector canModifyElem = canModifyElemVector unsafeWriteElem = unsafeWriteElemVector modifyWith = modifyWithVector doConj = doConjVector scaleBy = scaleByVector shiftBy = shiftByVector instance WriteVector IOVector IO where instance WriteVector (STVector s) (ST s) where