{-# LANGUAGE MultiParamTypeClasses, FlexibleContexts, FlexibleInstances #-} {-# OPTIONS_HADDOCK hide #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Vector.Dense.Base -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Vector.Dense.Base where import Control.Monad import Control.Monad.ST import Data.AEq import Foreign import Unsafe.Coerce import BLAS.Internal( checkBinaryOp, clearArray, inlinePerformIO, checkedSubvector, checkedSubvectorWithStride, checkVecVecOp ) import BLAS.Types( ConjEnum(..) ) import Data.Elem.BLAS ( Complex, Elem, BLAS1, conjugate ) import qualified Data.Elem.BLAS.Level1 as BLAS import Data.Tensor.Class import Data.Tensor.Class.ITensor import Data.Tensor.Class.MTensor import Data.Vector.Dense.IOBase infixl 7 <.> -- | Immutable dense vectors. The type arguments are as follows: -- -- * @n@: a phantom type for the dimension of the vector -- -- * @e@: the element type of the vector. Only certain element types -- are supported. -- newtype Vector n e = Vector (IOVector n e) freezeIOVector :: (BLAS1 e) => IOVector n e -> IO (Vector n e) freezeIOVector x = do y <- newCopyIOVector x return (Vector y) thawIOVector :: (BLAS1 e) => Vector n e -> IO (IOVector n e) thawIOVector (Vector x) = newCopyIOVector x unsafeFreezeIOVector :: IOVector n e -> IO (Vector n e) unsafeFreezeIOVector = return . Vector unsafeThawIOVector :: Vector n e -> IO (IOVector n e) unsafeThawIOVector (Vector x) = return x -- | Common functionality for all vector types. class (Shaped x Int, Elem e) => BaseVector x e where -- | Get the dimension (length) of the vector. dim :: x n e -> Int -- | Get the memory stride (in elements) between consecutive elements. stride :: x n e -> Int -- | Indicate whether or not internally the vector stores the complex -- conjugates of its elements. isConj :: x n e -> Bool isConj x = conjEnum x == Conj {-# INLINE isConj #-} -- | Get the storage type. conjEnum :: x n e -> ConjEnum -- | Get a view into the complex conjugate of a vector. conj :: x n e -> x n e -- | Cast the shape type of the vector. coerceVector :: x n e -> x n' e coerceVector = unsafeCoerce {-# INLINE coerceVector #-} unsafeSubvectorViewWithStride :: Int -> x n e -> Int -> Int -> x n' e -- | Unsafe cast from a vector to an 'IOVector'. unsafeVectorToIOVector :: x n e -> IOVector n e unsafeIOVectorToVector :: IOVector n e -> x n e -- | Vectors that can be read in a monad. class (BaseVector x e, BLAS1 e, Monad m, ReadTensor x Int e m) => ReadVector x e m where -- | Cast the vector to an 'IOVector', perform an @IO@ action, and -- convert the @IO@ action to an action in the monad @m@. This -- operation is /very/ unsafe. unsafePerformIOWithVector :: x n e -> (IOVector n e -> IO a) -> m a -- | Convert a mutable vector to an immutable one by taking a complete -- copy of it. freezeVector :: x n e -> m (Vector n e) unsafeFreezeVector :: x n e -> m (Vector n e) -- | Vectors that can be created or modified in a monad. class (ReadVector x e m, WriteTensor x Int e m) => WriteVector x e m where -- | Unsafely convert an 'IO' action that creates an 'IOVector' into -- an action in @m@ that creates a vector. unsafeConvertIOVector :: IO (IOVector n e) -> m (x n e) -- | Creates a new vector of the given length. The elements will be -- uninitialized. newVector_ :: Int -> m (x n e) -- | Convert an immutable vector to a mutable one by taking a complete -- copy of it. thawVector :: Vector n e -> m (x n e) unsafeThawVector :: Vector n e -> m (x n e) -- | Creates a new vector with the given association list. Unspecified -- indices will get initialized to zero. newVector :: (WriteVector x e m) => Int -> [(Int,e)] -> m (x n e) newVector n ies = do x <- newZeroVector n unsafePerformIOWithVector x $ \x' -> withIOVector x' $ \p -> do forM_ ies $ \(i,e) -> do when (i < 0 || i >= n) $ fail $ "Index `" ++ show i ++ "' is invalid for a vector with dimension `" ++ show n ++ "'" pokeElemOff p i e return x {-# INLINE newVector #-} unsafeNewVector :: (WriteVector x e m) => Int -> [(Int,e)] -> m (x n e) unsafeNewVector n ies = do x <- newZeroVector n unsafePerformIOWithVector x $ \x' -> withIOVector x' $ \p -> do forM_ ies $ \(i,e) -> pokeElemOff p i e return x {-# INLINE unsafeNewVector #-} -- | Creates a new vector of the given dimension with the given elements. -- If the list has length less than the passed-in dimenson, the tail of -- the vector will be uninitialized. newListVector :: (WriteVector x e m) => Int -> [e] -> m (x n e) newListVector n es = do x <- newVector_ n unsafePerformIOWithVector x $ \x' -> withIOVector x' $ \p -> do pokeArray p $ take n $ es ++ (repeat 0) return x {-# INLINE newListVector #-} -- | Create a zero vector of the specified length. newZeroVector :: (WriteVector x e m) => Int -> m (x n e) newZeroVector n = do x <- newVector_ n unsafePerformIOWithVector x $ \x' -> withIOVector x' $ \p -> do clearArray p n return x {-# INLINE newZeroVector #-} -- | Set every element in the vector to zero. setZeroVector :: (WriteVector x e m) => x n e -> m () setZeroVector x = unsafePerformIOWithVector x $ setZeroIOVector {-# INLINE setZeroVector #-} -- | Create a vector with every element initialized to the same value. newConstantVector :: (WriteVector x e m) => Int -> e -> m (x n e) newConstantVector n e = do x <- newVector_ n unsafePerformIOWithVector x $ \x' -> withIOVector x' $ \p -> do pokeArray p (replicate n e) return x {-# INLINE newConstantVector #-} -- | Set every element in the vector to a constant. setConstantVector :: (WriteVector x e m) => e -> x n e -> m () setConstantVector e x = unsafePerformIOWithVector x $ setConstantIOVector e {-# INLINE setConstantVector #-} -- | @newBasisVector n i@ creates a vector of length @n@ that is all zero -- except for at position @i@, where it equal to one. newBasisVector :: (WriteVector x e m) => Int -> Int -> m (x n e) newBasisVector n i = do x <- newZeroVector n unsafePerformIOWithVector x $ \x' -> withIOVector x' $ \p -> do pokeElemOff p i 1 return x {-# INLINE newBasisVector #-} -- | @setBasis x i@ sets the @i@th coordinate of @x@ to @1@, and all other -- coordinates to @0@. If the vector has been scaled, it is possible that -- @readVector x i@ will not return exactly @1@. See 'setElem'. setBasisVector :: (WriteVector x e m) => Int -> x n e -> m () setBasisVector i x | i < 0 || i >= dim x = fail $ "tried to set a vector of dimension `" ++ show (dim x) ++ "'" ++ " to basis vector `" ++ show i ++ "'" | otherwise = do setZeroVector x unsafeWriteElem x i 1 {-# INLINE setBasisVector #-} -- | Creates a new vector by copying another one. newCopyVector :: (ReadVector x e m, WriteVector y e m) => x n e -> m (y n e) newCopyVector x | isConj x = newCopyVector (conj x) >>= return . conj | otherwise = do y <- newVector_ n unsafePerformIOWithVector y $ \y' -> withIOVector (unsafeVectorToIOVector x) $ \pX -> withIOVector y' $ \pY -> BLAS.copy n pX incX pY 1 return y where n = dim x incX = stride x {-# INLINE newCopyVector #-} -- | Creates a new vector by copying another one. The returned vector -- is gauranteed not to be a view into another vector. That is, the -- returned vector will have @isConj@ to be @False@. newCopyVector' :: (ReadVector x e m, WriteVector y e m) => x n e -> m (y n e) newCopyVector' x | not (isConj x) = newCopyVector x | otherwise = do y <- newCopyVector (conj x) unsafePerformIOWithVector y $ \y' -> withIOVector y' $ \pY -> do BLAS.vconj (dim x) pY 1 return y {-# INLINE newCopyVector' #-} -- | @copyVector dst src@ replaces the values in @dst@ with those in -- source. The operands must be the same shape. copyVector :: (WriteVector y e m, ReadVector x e m) => y n e -> x n e -> m () copyVector y x = checkBinaryOp (shape x) (shape y) $ unsafeCopyVector y x {-# INLINE copyVector #-} unsafeCopyVector :: (ReadVector y e m, ReadVector x e m) => y n e -> x n e -> m () unsafeCopyVector y x | isConj x = unsafeCopyVector (conj y) (conj x) | isConj y = do vectorCall2 BLAS.copy x y vectorCall BLAS.vconj y | otherwise = vectorCall2 BLAS.copy x y {-# INLINE unsafeCopyVector #-} -- | Swap the values stored in two vectors. swapVector :: (WriteVector x e m, WriteVector y e m) => x n e -> y n e -> m () swapVector x y = checkBinaryOp (shape x) (shape y) $ unsafeSwapVector x y {-# INLINE swapVector #-} unsafeSwapVector :: (WriteVector x e m, WriteVector y e m) => x n e -> y n e -> m () unsafeSwapVector x y | isConj x = unsafeSwapVector (conj x) (conj y) | isConj y = do vectorCall2 BLAS.swap x y vectorCall BLAS.vconj x vectorCall BLAS.vconj y | otherwise = vectorCall2 BLAS.swap x y {-# INLINE unsafeSwapVector #-} -- | @subvectorView x o n@ creates a subvector view of @x@ starting at index @o@ -- and having length @n@. subvectorView :: (BaseVector x e) => x n e -> Int -> Int -> x n' e subvectorView x = checkedSubvector (dim x) (unsafeSubvectorView x) {-# INLINE subvectorView #-} unsafeSubvectorView :: (BaseVector x e) => x n e -> Int -> Int -> x n' e unsafeSubvectorView = unsafeSubvectorViewWithStride 1 {-# INLINE unsafeSubvectorView #-} -- | @subvectorViewWithStride s x o n@ creates a subvector view of @x@ starting -- at index @o@, having length @n@ and stride @s@. subvectorViewWithStride :: (BaseVector x e) => Int -> x n e -> Int -> Int -> x n' e subvectorViewWithStride s x = checkedSubvectorWithStride s (dim x) (unsafeSubvectorViewWithStride s x) {-# INLINE subvectorViewWithStride #-} -- | Get a new vector with elements with the conjugates of the elements -- of the given vector getConjVector :: (ReadVector x e m, WriteVector y e m) => x n e -> m (y n e) getConjVector = getUnaryVectorOp doConjVector {-# INLINE getConjVector #-} -- | Conjugate every element of the vector. doConjVector :: (WriteVector y e m) => y n e -> m () doConjVector x = unsafePerformIOWithVector x $ doConjIOVector {-# INLINE doConjVector #-} -- | Get a new vector by scaling the elements of another vector -- by a given value. getScaledVector :: (ReadVector x e m, WriteVector y e m) => e -> x n e -> m (y n e) getScaledVector e = getUnaryVectorOp (scaleByVector e) {-# INLINE getScaledVector #-} -- | Scale every element by the given value. scaleByVector :: (WriteVector y e m) => e -> y n e -> m () scaleByVector k x = unsafePerformIOWithVector x $ scaleByIOVector k {-# INLINE scaleByVector #-} -- | Get a new vector by shifting the elements of another vector -- by a given value. getShiftedVector :: (ReadVector x e m, WriteVector y e m) => e -> x n e -> m (y n e) getShiftedVector e = getUnaryVectorOp (shiftByVector e) {-# INLINE getShiftedVector #-} -- | Add the given value to every element. shiftByVector :: (WriteVector y e m) => e -> y n e -> m () shiftByVector k x = unsafePerformIOWithVector x $ shiftByIOVector k {-# INLINE shiftByVector #-} -- | @getAddVector x y@ creates a new vector equal to the sum @x+y@. The -- operands must have the same dimension. getAddVector :: (ReadVector x e m, ReadVector y e m, WriteVector z e m) => x n e -> y n e -> m (z n e) getAddVector = checkVectorOp2 unsafeGetAddVector {-# INLINE getAddVector #-} unsafeGetAddVector :: (ReadVector x e m, ReadVector y e m, WriteVector z e m) => x n e -> y n e -> m (z n e) unsafeGetAddVector = unsafeGetBinaryVectorOp unsafeAddVector {-# INLINE unsafeGetAddVector #-} -- | @addVector y x@ replaces @y@ with @y+x@. addVector :: (WriteVector y e m, ReadVector x e m) => y n e -> x n e -> m () addVector y x = checkBinaryOp (dim y) (dim x) $ unsafeAddVector y x {-# INLINE addVector #-} unsafeAddVector :: (WriteVector y e m, ReadVector x e m) => y n e -> x n e -> m () unsafeAddVector y x = unsafeAxpyVector 1 x y {-# INLINE unsafeAddVector #-} -- | @getSubVector x y@ creates a new tensor equal to the difference @x-y@. -- The operands must have the same dimension. getSubVector :: (ReadVector x e m, ReadVector y e m, WriteVector z e m) => x n e -> y n e -> m (z n e) getSubVector = checkVectorOp2 unsafeGetSubVector {-# INLINE getSubVector #-} unsafeGetSubVector :: (ReadVector x e m, ReadVector y e m, WriteVector z e m) => x n e -> y n e -> m (z n e) unsafeGetSubVector = unsafeGetBinaryVectorOp unsafeSubVector {-# INLINE unsafeGetSubVector #-} -- | @subVector y x@ replaces @y@ with @y-x@. subVector :: (WriteVector y e m, ReadVector x e m) => y n e -> x n e -> m () subVector y x = checkBinaryOp (dim y) (dim x) $ unsafeSubVector y x {-# INLINE subVector #-} unsafeSubVector :: (WriteVector y e m, ReadVector x e m) => y n e -> x n e -> m () unsafeSubVector y x = unsafeAxpyVector (-1) x y {-# INLINE unsafeSubVector #-} -- | @axpyVector alpha x y@ replaces @y@ with @alpha * x + y@. axpyVector :: (ReadVector x e m, WriteVector y e m) => e -> x n e -> y n e -> m () axpyVector alpha x y = checkBinaryOp (shape x) (shape y) $ unsafeAxpyVector alpha x y {-# INLINE axpyVector #-} unsafeAxpyVector :: (ReadVector x e m, ReadVector y e m) => e -> x n e -> y n e -> m () unsafeAxpyVector alpha x y | isConj y = unsafeAxpyVector (conjugate alpha) (conj x) (conj y) | isConj x = vectorCall2 (flip BLAS.acxpy alpha) x y | otherwise = vectorCall2 (flip BLAS.axpy alpha) x y {-# INLINE unsafeAxpyVector #-} -- | @getMulVector x y@ creates a new vector equal to the elementwise product -- @x*y@. The operands must have the same dimensino getMulVector :: (ReadVector x e m, ReadVector y e m, WriteVector z e m) => x n e -> y n e -> m (z n e) getMulVector = checkVectorOp2 unsafeGetMulVector {-# INLINE getMulVector #-} unsafeGetMulVector :: (ReadVector x e m, ReadVector y e m, WriteVector z e m) => x n e -> y n e -> m (z n e) unsafeGetMulVector = unsafeGetBinaryVectorOp unsafeMulVector {-# INLINE unsafeGetMulVector #-} -- | @mulVector y x@ replaces @y@ with @y*x@. mulVector :: (WriteVector y e m, ReadVector x e m) => y n e -> x n e -> m () mulVector y x = checkBinaryOp (shape y) (shape x) $ unsafeMulVector y x {-# INLINE mulVector #-} unsafeMulVector :: (WriteVector y e m, ReadVector x e m) => y n e -> x n e -> m () unsafeMulVector y x | isConj y = unsafeMulVector (conj y) (conj x) | isConj x = vectorCall2 BLAS.vcmul x y | otherwise = vectorCall2 BLAS.vmul x y {-# INLINE unsafeMulVector #-} -- | @getDivVector x y@ creates a new vector equal to the elementwise -- ratio @x/y@. The operands must have the same shape. getDivVector :: (ReadVector x e m, ReadVector y e m, WriteVector z e m) => x n e -> y n e -> m (z n e) getDivVector = checkVectorOp2 unsafeGetDivVector {-# INLINE getDivVector #-} unsafeGetDivVector :: (ReadVector x e m, ReadVector y e m, WriteVector z e m) => x n e -> y n e -> m (z n e) unsafeGetDivVector = unsafeGetBinaryVectorOp unsafeDivVector {-# INLINE unsafeGetDivVector #-} -- | @divVector y x@ replaces @y@ with @y/x@. divVector :: (WriteVector y e m, ReadVector x e m) => y n e -> x n e -> m () divVector y x = checkBinaryOp (shape y) (shape x) $ unsafeDivVector y x {-# INLINE divVector #-} unsafeDivVector :: (WriteVector y e m, ReadVector x e m) => y n e -> x n e -> m () unsafeDivVector y x | isConj y = unsafeDivVector (conj y) (conj x) | isConj x = vectorCall2 BLAS.vcdiv x y | otherwise = vectorCall2 BLAS.vdiv x y {-# INLINE unsafeDivVector #-} -- | Gets the sum of the absolute values of the vector entries. getSumAbs :: (ReadVector x e m) => x n e -> m Double getSumAbs = vectorCall BLAS.asum {-# INLINE getSumAbs #-} -- | Gets the 2-norm of a vector. getNorm2 :: (ReadVector x e m) => x n e -> m Double getNorm2 = vectorCall BLAS.nrm2 {-# INLINE getNorm2 #-} -- | Gets the index and norm of the element with maximum magnitude. This is -- undefined if any of the elements are @NaN@. It will throw an exception if -- the dimension of the vector is 0. getWhichMaxAbs :: (ReadVector x e m) => x n e -> m (Int, e) getWhichMaxAbs x = case (dim x) of 0 -> fail $ "getWhichMaxAbs of an empty vector" _ -> do i <- vectorCall BLAS.iamax x e <- unsafeReadElem x i return (i,e) {-# INLINE getWhichMaxAbs #-} -- | Computes the dot product of two vectors. getDot :: (ReadVector x e m, ReadVector y e m) => x n e -> y n e -> m e getDot x y = checkVecVecOp "getDot" (dim x) (dim y) $ unsafeGetDot x y {-# INLINE getDot #-} unsafeGetDot :: (ReadVector x e m, ReadVector y e m) => x n e -> y n e -> m e unsafeGetDot x y = vectorCall2 (BLAS.dot (conjEnum x) (conjEnum y)) x y {-# INLINE unsafeGetDot #-} instance (Elem e) => BaseVector IOVector e where dim = dimIOVector {-# INLINE dim #-} stride = strideIOVector {-# INLINE stride #-} conjEnum = conjEnumIOVector {-# INLINE conjEnum #-} conj = conjIOVector {-# INLINE conj #-} unsafeSubvectorViewWithStride = unsafeSubvectorViewWithStrideIOVector {-# INLINE unsafeSubvectorViewWithStride #-} unsafeVectorToIOVector = id {-# INLINE unsafeVectorToIOVector #-} unsafeIOVectorToVector = id {-# INLINE unsafeIOVectorToVector #-} instance (BLAS1 e) => ReadVector IOVector e IO where unsafePerformIOWithVector x f = f x {-# INLINE unsafePerformIOWithVector #-} freezeVector = freezeIOVector {-# INLINE freezeVector #-} unsafeFreezeVector = unsafeFreezeIOVector {-# INLINE unsafeFreezeVector #-} instance (BLAS1 e) => WriteVector IOVector e IO where newVector_ = newIOVector_ {-# INLINE newVector_ #-} unsafeConvertIOVector = id {-# NOINLINE unsafeConvertIOVector #-} thawVector = thawIOVector {-# INLINE thawVector #-} unsafeThawVector = unsafeThawIOVector {-# INLINE unsafeThawVector #-} -- | Create a vector with the given dimension and elements. The elements -- given in the association list must all have unique indices, otherwise -- the result is undefined. vector :: (BLAS1 e) => Int -> [(Int, e)] -> Vector n e vector n ies = unsafePerformIO $ unsafeFreezeIOVector =<< newVector n ies {-# NOINLINE vector #-} -- Same as 'vector', but does not range-check the indices. unsafeVector :: (BLAS1 e) => Int -> [(Int, e)] -> Vector n e unsafeVector n ies = unsafePerformIO $ unsafeFreezeIOVector =<< unsafeNewVector n ies {-# NOINLINE unsafeVector #-} -- | Create a vector of the given dimension with elements initialized -- to the values from the list. @listVector n es@ is equivalent to -- @vector n (zip [0..(n-1)] es)@, except that the result is undefined -- if @length es@ is less than @n@. listVector :: (BLAS1 e) => Int -> [e] -> Vector n e listVector n es = Vector $ unsafePerformIO $ newListVector n es {-# NOINLINE listVector #-} replaceVector :: (BLAS1 e) => Vector n e -> [(Int,e)] -> Vector n e replaceVector (Vector x) ies = unsafePerformIO $ do y <- newCopyVector x mapM_ (uncurry $ writeElem y) ies return (Vector y) {-# NOINLINE replaceVector #-} unsafeReplaceVector :: (BLAS1 e) => Vector n e -> [(Int,e)] -> Vector n e unsafeReplaceVector (Vector x) ies = unsafePerformIO $ do y <- newCopyVector x mapM_ (uncurry $ unsafeWriteElem y) ies return (Vector y) {-# NOINLINE unsafeReplaceVector #-} -- | @zeroVector n@ creates a vector of dimension @n@ with all values -- set to zero. zeroVector :: (BLAS1 e) => Int -> Vector n e zeroVector n = unsafePerformIO $ unsafeFreezeIOVector =<< newZeroVector n {-# NOINLINE zeroVector #-} -- | @constantVector n e@ creates a vector of dimension @n@ with all values -- set to @e@. constantVector :: (BLAS1 e) => Int -> e -> Vector n e constantVector n e = unsafePerformIO $ unsafeFreezeIOVector =<< newConstantVector n e {-# NOINLINE constantVector #-} -- | @basisVector n i@ creates a vector of dimension @n@ with zeros -- everywhere but position @i@, where there is a one. basisVector :: (BLAS1 e) => Int -> Int -> Vector n e basisVector n i = unsafePerformIO $ unsafeFreezeIOVector =<< newBasisVector n i {-# NOINLINE basisVector #-} -- | @subvector x o n@ creates a subvector of @x@ starting at index @o@ -- and having length @n@. subvector :: (BLAS1 e) => Vector n e -> Int -> Int -> Vector n' e subvector = subvectorView {-# INLINE subvector #-} unsafeSubvector :: (BLAS1 e) => Vector n e -> Int -> Int -> Vector n' e unsafeSubvector = unsafeSubvectorView {-# INLINE unsafeSubvector #-} unsafeSubvectorWithStride :: (Elem e) => Int -> Vector n e -> Int -> Int -> Vector n' e unsafeSubvectorWithStride = unsafeSubvectorViewWithStride {-# INLINE unsafeSubvectorWithStride #-} -- | @subvectorWithStride s x o n@ creates a subvector of @x@ starting -- at index @o@, having length @n@ and stride @s@. subvectorWithStride :: (BLAS1 e) => Int -> Vector n e -> Int -> Int -> Vector n' e subvectorWithStride = subvectorViewWithStride {-# INLINE subvectorWithStride #-} sizeVector :: Vector n e -> Int sizeVector (Vector x) = sizeIOVector x {-# INLINE sizeVector #-} indicesVector :: Vector n e -> [Int] indicesVector (Vector x) = indicesIOVector x {-# INLINE indicesVector #-} elemsVector :: (Elem e) => Vector n e -> [e] elemsVector x | isConj x = (map conjugate . elemsVector . conj) x | otherwise = case x of { (Vector (IOVector _ n f p incX)) -> let end = p `advancePtr` (n*incX) go p' | p' == end = inlinePerformIO $ do io <- touchForeignPtr f io `seq` return [] | otherwise = let e = inlinePerformIO (peek p') es = go (p' `advancePtr` incX) in e `seq` (e:es) in go p } {-# SPECIALIZE INLINE elemsVector :: Vector n Double -> [Double] #-} {-# SPECIALIZE INLINE elemsVector :: Vector n (Complex Double) -> [Complex Double] #-} assocsVector :: (Elem e) => Vector n e -> [(Int,e)] assocsVector x = zip (indicesVector x) (elemsVector x) {-# INLINE assocsVector #-} unsafeAtVector :: (Elem e) => Vector n e -> Int -> e unsafeAtVector x i | isConj x = conjugate $ unsafeAtVector (conj x) i | otherwise = case x of { (Vector (IOVector _ _ f p inc)) -> inlinePerformIO $ do e <- peekElemOff p (i*inc) io <- touchForeignPtr f e `seq` io `seq` return e } {-# INLINE unsafeAtVector #-} tmapVector :: (BLAS1 e) => (e -> e) -> Vector n e -> Vector n e tmapVector f x = listVector (dim x) (map f $ elemsVector x) {-# INLINE tmapVector #-} tzipWithVector :: (BLAS1 e) => (e -> e -> e) -> Vector n e -> Vector n e -> Vector n e tzipWithVector f x y | dim y /= n = error ("tzipWith: vector lengths differ; first has length `" ++ show n ++ "' and second has length `" ++ show (dim y) ++ "'") | otherwise = listVector n (zipWith f (elems x) (elems y)) where n = dim x {-# INLINE tzipWithVector #-} scaleVector :: (BLAS1 e) => e -> Vector n e -> Vector n e scaleVector e (Vector x) = unsafePerformIO $ unsafeFreezeIOVector =<< getScaledVector e x {-# NOINLINE scaleVector #-} shiftVector :: (BLAS1 e) => e -> Vector n e -> Vector n e shiftVector e (Vector x) = unsafePerformIO $ unsafeFreezeIOVector =<< getShiftedVector e x {-# NOINLINE shiftVector #-} -- | Compute the sum of absolute values of entries in the vector. sumAbs :: (BLAS1 e) => Vector n e -> Double sumAbs (Vector x) = unsafePerformIO $ getSumAbs x {-# NOINLINE sumAbs #-} -- | Compute the 2-norm of a vector. norm2 :: (BLAS1 e) => Vector n e -> Double norm2 (Vector x) = unsafePerformIO $ getNorm2 x {-# NOINLINE norm2 #-} -- | Get the index and norm of the element with absulte value. Not valid -- if any of the vector entries are @NaN@. Raises an exception if the -- vector has length @0@. whichMaxAbs :: (BLAS1 e) => Vector n e -> (Int, e) whichMaxAbs (Vector x) = unsafePerformIO $ getWhichMaxAbs x {-# NOINLINE whichMaxAbs #-} -- | Compute the dot product of two vectors. (<.>) :: (BLAS1 e) => Vector n e -> Vector n e -> e (<.>) x y = unsafePerformIO $ getDot x y {-# NOINLINE (<.>) #-} unsafeDot :: (BLAS1 e) => Vector n e -> Vector n e -> e unsafeDot x y = unsafePerformIO $ unsafeGetDot x y {-# NOINLINE unsafeDot #-} instance Shaped Vector Int where shape (Vector x) = shapeIOVector x {-# INLINE shape #-} bounds (Vector x) = boundsIOVector x {-# INLINE bounds #-} instance (BLAS1 e) => ITensor Vector Int e where (//) = replaceVector {-# INLINE (//) #-} unsafeReplace = unsafeReplaceVector {-# INLINE unsafeReplace #-} unsafeAt = unsafeAtVector {-# INLINE unsafeAt #-} size = sizeVector {-# INLINE size #-} elems = elemsVector {-# INLINE elems #-} indices = indicesVector {-# INLINE indices #-} assocs = assocsVector {-# INLINE assocs #-} tmap = tmapVector {-# INLINE tmap #-} (*>) = scaleVector {-# INLINE (*>) #-} shift = shiftVector {-# INLINE shift #-} instance (BLAS1 e, Monad m) => ReadTensor Vector Int e m where getSize = return . size {-# INLINE getSize #-} getAssocs = return . assocs {-# INLINE getAssocs #-} getIndices = return . indices {-# INLINE getIndices #-} getElems = return . elems {-# INLINE getElems #-} getAssocs' = return . assocs {-# INLINE getAssocs' #-} getIndices' = return . indices {-# INLINE getIndices' #-} getElems' = return . elems {-# INLINE getElems' #-} unsafeReadElem x i = return $ unsafeAt x i {-# INLINE unsafeReadElem #-} instance (Elem e) => BaseVector Vector e where dim (Vector x) = dimIOVector x {-# INLINE dim #-} stride (Vector x) = strideIOVector x {-# INLINE stride #-} conjEnum (Vector x) = conjEnumIOVector x {-# INLINE conjEnum #-} conj (Vector x) = (Vector (conjIOVector x)) {-# INLINE conj #-} unsafeSubvectorViewWithStride s (Vector x) o n = Vector (unsafeSubvectorViewWithStrideIOVector s x o n) {-# INLINE unsafeSubvectorViewWithStride #-} unsafeVectorToIOVector (Vector x) = x {-# INLINE unsafeVectorToIOVector #-} unsafeIOVectorToVector = Vector {-# INLINE unsafeIOVectorToVector #-} instance (BLAS1 e) => ReadVector Vector e IO where unsafePerformIOWithVector (Vector x) f = f x {-# INLINE unsafePerformIOWithVector #-} freezeVector (Vector x) = freezeIOVector x {-# INLINE freezeVector #-} unsafeFreezeVector = return {-# INLINE unsafeFreezeVector #-} instance (BLAS1 e) => ReadVector Vector e (ST s) where unsafePerformIOWithVector (Vector x) f = unsafeIOToST $ f x {-# INLINE unsafePerformIOWithVector #-} freezeVector (Vector x) = unsafeIOToST $ freezeIOVector x {-# INLINE freezeVector #-} unsafeFreezeVector = return {-# INLINE unsafeFreezeVector #-} instance (Elem e, Show e) => Show (Vector n e) where show x | isConj x = "conj (" ++ show (conj x) ++ ")" | otherwise = "listVector " ++ show (dim x) ++ " " ++ show (elemsVector x) instance (BLAS1 e) => Eq (Vector n e) where (==) = compareVectorWith (==) instance (BLAS1 e) => AEq (Vector n e) where (===) = compareVectorWith (===) (~==) = compareVectorWith (~==) compareVectorWith :: (Elem e) => (e -> e -> Bool) -> Vector n e -> Vector n e -> Bool compareVectorWith cmp x y | isConj x && isConj y = compareVectorWith cmp (conj x) (conj y) | otherwise = (dim x == dim y) && (and $ zipWith cmp (elemsVector x) (elemsVector y)) instance (BLAS1 e) => Num (Vector n e) where (+) x y = unsafePerformIO $ unsafeFreezeIOVector =<< getAddVector x y {-# NOINLINE (+) #-} (-) x y = unsafePerformIO $ unsafeFreezeIOVector =<< getSubVector x y {-# NOINLINE (-) #-} (*) x y = unsafePerformIO $ unsafeFreezeIOVector =<< getMulVector x y {-# NOINLINE (*) #-} negate = ((-1) *>) {-# INLINE negate #-} abs = tmap abs signum = tmap signum fromInteger n = listVector 1 [fromInteger n] instance (BLAS1 e) => Fractional (Vector n e) where (/) x y = unsafePerformIO $ unsafeFreezeIOVector =<< getDivVector x y {-# NOINLINE (/) #-} recip = tmap recip fromRational q = listVector 1 [fromRational q] instance (BLAS1 e, Floating e) => Floating (Vector n e) where pi = listVector 1 [pi] exp = tmap exp sqrt = tmap sqrt log = tmap log (**) = tzipWithVector (**) sin = tmap sin cos = tmap cos tan = tmap tan asin = tmap asin acos = tmap acos atan = tmap atan sinh = tmap sinh cosh = tmap cosh tanh = tmap tanh asinh = tmap asinh acosh = tmap acosh atanh = tmap atanh vectorCall :: (ReadVector x e m) => (Int -> Ptr e -> Int -> IO a) -> x n e -> m a vectorCall f x = unsafePerformIOWithVector x $ \x' -> let n = dimIOVector x' incX = strideIOVector x' in withIOVector x' $ \pX -> f n pX incX {-# INLINE vectorCall #-} vectorCall2 :: (ReadVector x e m, ReadVector y f m) => (Int -> Ptr e -> Int -> Ptr f -> Int -> IO a) -> x n e -> y n' f -> m a vectorCall2 f x y = unsafePerformIOWithVector x $ \x' -> let y' = unsafeVectorToIOVector y n = dimIOVector x' incX = strideIOVector x' incY = strideIOVector y' in withIOVector x' $ \pX -> withIOVector y' $ \pY -> f n pX incX pY incY {-# INLINE vectorCall2 #-} checkVectorOp2 :: (BaseVector x e, BaseVector y f) => (x n e -> y n f -> a) -> x n e -> y n f -> a checkVectorOp2 f x y = checkBinaryOp (dim x) (dim y) $ f x y {-# INLINE checkVectorOp2 #-} getUnaryVectorOp :: (ReadVector x e m, WriteVector y e m) => (y n e -> m ()) -> x n e -> m (y n e) getUnaryVectorOp f x = do y <- newCopyVector x f y return y {-# INLINE getUnaryVectorOp #-} unsafeGetBinaryVectorOp :: (WriteVector z e m, ReadVector x e m, ReadVector y e m) => (z n e -> y n e -> m ()) -> x n e -> y n e -> m (z n e) unsafeGetBinaryVectorOp f x y = do z <- newCopyVector x f z y return z {-# INLINE unsafeGetBinaryVectorOp #-}