{-# LANGUAGE BangPatterns, FlexibleInstances, MultiParamTypeClasses #-} {-# OPTIONS_GHC -fglasgow-exts #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Vector.Dense.Internal -- Copyright : Copyright (c) 2008, Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Vector.Dense.Internal ( -- * Vector data types Vector, IOVector, DVector(..), module BLAS.Vector, module BLAS.Tensor, -- * Conversion to and from @ForeignPtr@s. fromForeignPtr, toForeignPtr, -- * Creating new vectors newVector, newVector_, newListVector, -- * Special vectors newBasis, setBasis, -- * Vector views subvector, subvectorWithStride, -- * Casting vectors coerceVector, -- * Unsafe operations unsafeNewVector, unsafeWithElemPtr, unsafeSubvector, unsafeSubvectorWithStride, unsafeFreeze, unsafeThaw, ) where import Control.Monad import Data.Ix import Foreign import System.IO.Unsafe import Unsafe.Coerce import Data.AEq import BLAS.Access import BLAS.Elem.Base ( Elem ) import qualified BLAS.Elem.Base as E import BLAS.Vector hiding ( Vector ) import qualified BLAS.Vector as C import BLAS.Tensor import BLAS.Internal ( clearArray, inlinePerformIO, checkedSubvector, checkedSubvectorWithStride ) import BLAS.C.Level1 ( BLAS1, copy ) -- | A dense vector. @t@ is a type that will usually be @Imm@ or @Mut@. -- @n@ is a phantom type for the dimension of the vector, and @e@ is the -- element type. A @DVector@ @x@ stores @dim x@ elements. Indices into -- the vector are @0@-based. data DVector t n e = DV { storageOf :: {-# UNPACK #-} !(ForeignPtr e) -- ^ a pointer to the storage region , offsetOf :: {-# UNPACK #-} !Int -- ^ an offset (in elements, not bytes) to the first element in the vector. , lengthOf :: {-# UNPACK #-} !Int -- ^ the length of the vector , strideOf :: {-# UNPACK #-} !Int -- ^ the stride (in elements, not bytes) between elements. , isConj :: {-# UNPACK #-} !Bool -- ^ indicates whether or not the vector is conjugated } type Vector n e = DVector Imm n e type IOVector n e = DVector Mut n e -- | Cast the phantom length type. coerceVector :: DVector t n e -> DVector t m e coerceVector = unsafeCoerce {-# INLINE coerceVector #-} -- | @fromForeignPtr fptr offset n inc c@ creates a vector view of a -- region in memory starting at the given offset and having dimension @n@, -- with a stride of @inc@, and with @isConj@ set to @c@. fromForeignPtr :: ForeignPtr e -> Int -> Int -> Int -> Bool -> DVector t n e fromForeignPtr = DV {-# INLINE fromForeignPtr #-} -- | Gets the tuple @(fptr,offset,n,inc,c)@, where @n@ is the dimension and -- @inc@ is the stride of the vector, and @c@ indicates whether or not the -- vector is conjugated. toForeignPtr :: DVector t n e -> (ForeignPtr e, Int, Int, Int, Bool) toForeignPtr (DV f o n s c) = (f, o, n, s, c) {-# INLINE toForeignPtr #-} -- | @subvector x o n@ creates a subvector view of @x@ starting at index @o@ -- and having length @n@. subvector :: DVector t n e -> Int -> Int -> DVector t m e subvector x = checkedSubvector (dim x) (unsafeSubvector x) -- | Same as 'subvector' but arguments are not range-checked. unsafeSubvector :: DVector t n e -> Int -> Int -> DVector t m e unsafeSubvector = unsafeSubvectorWithStride 1 -- | @subvectorWithStride s x o n@ creates a subvector view of @x@ starting -- at index @o@, having length @n@ and stride @s@. subvectorWithStride :: Int -> DVector t n e -> Int -> Int -> DVector t m e subvectorWithStride s x = checkedSubvectorWithStride s (dim x) (unsafeSubvectorWithStride s x) -- | Same as 'subvectorWithStride' but arguments are not range-checked. unsafeSubvectorWithStride :: Int -> DVector t n e -> Int -> Int -> DVector t m e unsafeSubvectorWithStride s x o n = let f = storageOf x o' = indexOf x o n' = n s' = s * (strideOf x) c = isConj x in fromForeignPtr f o' n' s' c -- | Creates a new vector of the given length. The elements will be -- uninitialized. newVector_ :: (Elem e) => Int -> IO (DVector t n e) newVector_ n | n < 0 = ioError $ userError $ "Tried to create a vector with `" ++ show n ++ "' elements." | otherwise = do arr <- mallocForeignPtrArray n return $ fromForeignPtr arr 0 n 1 False -- | Creates a new vector of the given dimension with the given elements. -- If the list has length less than the passed-in dimenson, the tail of -- the vector will be uninitialized. newListVector :: (Elem e) => Int -> [e] -> IO (DVector t n e) newListVector n es = do x <- newVector_ n withForeignPtr (storageOf x) $ flip pokeArray $ take n es return x -- | @listVector n es@ is equivalent to @vector n (zip [0..(n-1)] es)@, except -- that the result is undefined if @length es@ is less than @n@. listVector :: (Elem e) => Int -> [e] -> Vector n e listVector n es = unsafeFreeze $ unsafePerformIO $ newListVector n es {-# NOINLINE listVector #-} -- | Creates a new vector with the given association list. Unspecified -- indices will get initialized to zero. newVector :: (BLAS1 e) => Int -> [(Int,e)] -> IO (DVector t n e) newVector = newVectorHelp writeElem -- | Same as 'newVector' but indices are not range-checked. unsafeNewVector :: (BLAS1 e) => Int -> [(Int,e)] -> IO (DVector t n e) unsafeNewVector = newVectorHelp unsafeWriteElem newVectorHelp :: (BLAS1 e) => (IOVector n e -> Int -> e -> IO ()) -> Int -> [(Int,e)] -> IO (DVector t n e) newVectorHelp set n ies = do x <- newZero n mapM_ (uncurry $ set x) ies return (unsafeCoerce x) -- | @newBasis n i@ creates a vector of length @n@ that is all zero except for -- at position @i@, where it equal to one. newBasis :: (BLAS1 e) => Int -> Int -> IO (IOVector n e) newBasis n i = do x <- newVector_ n setBasis i x return x -- | @setBasis x i@ sets the @i@th coordinate of @x@ to @1@, and all other -- coordinates to @0@. If the vector has been scaled, it is possible that -- @readVector x i@ will not return exactly @1@. See 'setElem'. setBasis :: (BLAS1 e) => Int -> IOVector n e -> IO () setBasis i x | i < 0 || i >= dim x = ioError $ userError $ "tried to set a vector of dimension `" ++ show (dim x) ++ "'" ++ " to basis vector `" ++ show i ++ "'" | otherwise = do setZero x unsafeWriteElem x i 1 indexOf :: DVector t n e -> Int -> Int indexOf x i = offsetOf x + i * strideOf x {-# INLINE indexOf #-} -- | Evaluate a function with a pointer to the value stored at the given -- index. Note that the value may need to conjugated before using it. See -- 'isConj'. unsafeWithElemPtr :: (Elem e) => DVector t n e -> Int -> (Ptr e -> IO a) -> IO a unsafeWithElemPtr x i f = withForeignPtr (storageOf x) $ \ptr -> let elemPtr = ptr `advancePtr` (indexOf x i) in f elemPtr {-# INLINE unsafeWithElemPtr #-} -- | Cast the access type to @Imm@. unsafeFreeze :: DVector t n e -> Vector n e unsafeFreeze = unsafeCoerce -- | Cast the access type to @Mut@. unsafeThaw :: DVector t n e -> IOVector n e unsafeThaw = unsafeCoerce instance C.Vector (DVector t) where dim = lengthOf {-# INLINE dim #-} conj x = let c' = (not . isConj) x in x { isConj=c' } {-# INLINE conj #-} instance Tensor (DVector t n) Int e where shape = dim {-# INLINE shape #-} bounds x = (0, dim x - 1) {-# INLINE bounds #-} instance (BLAS1 e) => ITensor (DVector Imm n) Int e where size = dim indices = range . bounds {-# INLINE indices #-} elems = inlinePerformIO . getElems . unsafeThaw assocs = inlinePerformIO . getAssocs . unsafeThaw unsafeAt x = inlinePerformIO . unsafeReadElem (unsafeThaw x) {-# INLINE unsafeAt #-} amap f x = listVector (dim x) (map f $ elems x) (//) = replaceHelp writeElem unsafeReplace = replaceHelp unsafeWriteElem replaceHelp :: (BLAS1 e) => (IOVector n e -> Int -> e -> IO ()) -> Vector n e -> [(Int, e)] -> Vector n e replaceHelp set x ies = unsafeFreeze $ unsafePerformIO $ do y <- newCopy (unsafeThaw x) mapM_ (uncurry $ set y) ies return y {-# NOINLINE replaceHelp #-} instance (BLAS1 e) => IDTensor (DVector Imm n) Int e where zero n = unsafeFreeze $ unsafePerformIO $ newZero n {-# NOINLINE zero #-} constant n e = unsafeFreeze $ unsafePerformIO $ newConstant n e {-# NOINLINE constant #-} azipWith f x y | dim y /= n = error ("amap2: vector lengths differ; first has length `" ++ show n ++ "' and second has length `" ++ show (dim y) ++ "'") | otherwise = listVector n (zipWith f (elems x) (elems y)) where n = dim x instance (BLAS1 e) => RTensor (DVector t n) Int e IO where getSize = return . dim newCopy x | isConj x = newCopy (conj x) >>= return . conj | otherwise = do y <- newVector_ (dim x) unsafeWithElemPtr x 0 $ \pX -> unsafeWithElemPtr y 0 $ \pY -> let n = dim x incX = strideOf x incY = strideOf y in copy n pX incX pY incY >> return y getIndices = return . indices . unsafeFreeze {-# INLINE getIndices #-} unsafeReadElem x i | isConj x = unsafeReadElem (conj x) i >>= return . E.conj | otherwise = withForeignPtr (storageOf x) $ \ptr -> peekElemOff ptr (indexOf x i) getAssocs x | isConj x = getAssocs (conj x) >>= return . map (\(i,e) -> (i,E.conj e)) | otherwise = let (f,o,n,incX,_) = toForeignPtr x ptr = (unsafeForeignPtrToPtr f) `advancePtr` o in return $ go n f incX ptr 0 where go !n !f !incX !ptr !i | i >= n = -- This is very important since we are doing unsafe IO. -- Otherwise, the DVector might get discared and the -- memory freed before all of the elements are read inlinePerformIO $ do touchForeignPtr f return [] | otherwise = let e = inlinePerformIO $ peek ptr ptr' = ptr `advancePtr` incX i' = i + 1 ies = go n f incX ptr' i' in e `seq` ((i,e):ies) {-# NOINLINE getAssocs #-} instance (BLAS1 e) => RDTensor (DVector t n) Int e IO where newZero n = newVector_ n >>= (\x -> setZero (unsafeThaw x) >> return x) newConstant n e = newVector_ n >>= (\x -> setConstant e (unsafeThaw x) >> return x) instance (BLAS1 e) => MTensor (DVector Mut n) Int e IO where setZero x | strideOf x == 1 = unsafeWithElemPtr x 0 $ flip clearArray (dim x) | otherwise = setConstant 0 x setConstant e x | isConj x = setConstant (E.conj e) (conj x) | otherwise = unsafeWithElemPtr x 0 $ go (dim x) where go !n !ptr | n <= 0 = return () | otherwise = let ptr' = ptr `advancePtr` (strideOf x) n' = n - 1 in poke ptr e >> go n' ptr' unsafeWriteElem x i e = let e' = if isConj x then E.conj e else e in withForeignPtr (storageOf x) $ \ptr -> pokeElemOff ptr (indexOf x i) e' canModifyElem x i = return $ inRange (bounds x) i {-# INLINE canModifyElem #-} modifyWith f x | isConj x = modifyWith (E.conj . f . E.conj) (conj x) | otherwise = withForeignPtr (storageOf x) $ \ptr -> go (dim x) (ptr `advancePtr` offsetOf x) where go !n !ptr | n <= 0 = return () | otherwise = do peek ptr >>= poke ptr . f go (n-1) (ptr `advancePtr` incX) incX = strideOf x compareHelp :: (BLAS1 e) => (e -> e -> Bool) -> Vector n e -> Vector n e -> Bool compareHelp cmp x y | isConj x && isConj y = compareHelp cmp (conj x) (conj y) | otherwise = (dim x == dim y) && (and $ zipWith cmp (elems x) (elems y)) instance (BLAS1 e, Eq e) => Eq (DVector Imm n e) where (==) = compareHelp (==) instance (BLAS1 e, AEq e) => AEq (DVector Imm n e) where (===) = compareHelp (===) (~==) = compareHelp (~==) instance (BLAS1 e, Show e) => Show (DVector Imm n e) where show x | isConj x = "conj (" ++ show (conj x) ++ ")" | otherwise = "listVector " ++ show (dim x) ++ " " ++ show (elems x)