{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Vector.Dense.Internal -- Copyright : Copyright (c) 2008, Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Vector.Dense.Internal ( -- * The Vector type Vector(..), -- * Vector shape dim, coerceVector, module BLAS.Tensor.Base, -- * Conjugating vectors module BLAS.Conj, -- * Creating new vectors vector, listVector, unsafeVector, -- * Reading vector elements module BLAS.Tensor.Immutable, -- * Special vectors zeroVector, constantVector, basisVector, -- * Vector views subvector, subvectorWithStride, unsafeSubvector, unsafeSubvectorWithStride, -- * Vector properties sumAbs, norm2, whichMaxAbs, (<.>), -- * Low-level vector properties stride, isConj, ) where import Data.AEq import System.IO.Unsafe import BLAS.Conj import BLAS.Tensor.Base import BLAS.Tensor.Immutable import BLAS.Elem ( BLAS1 ) import BLAS.Internal ( inlinePerformIO ) import BLAS.UnsafeIOToM import Data.Vector.Dense.IO import Data.Vector.Dense.Class infixl 7 <.> newtype Vector n e = V (IOVector n e) unsafeFreezeIOVector :: IOVector n e -> Vector n e unsafeFreezeIOVector = V unsafeThawIOVector :: Vector n e -> IOVector n e unsafeThawIOVector (V x) = x liftVector :: (IOVector n e -> a) -> Vector n e -> a liftVector f (V x) = f x {-# INLINE liftVector #-} liftVector2 :: (IOVector n e -> IOVector n e -> a) -> Vector n e -> Vector n e -> a liftVector2 f x = liftVector (liftVector f x) {-# INLINE liftVector2 #-} unsafeLiftVector :: (IOVector n e -> IO a) -> Vector n e -> a unsafeLiftVector f = unsafePerformIO . liftVector f {-# NOINLINE unsafeLiftVector #-} unsafeLiftVector2 :: (IOVector n e -> IOVector n e -> IO a) -> Vector n e -> Vector n e -> a unsafeLiftVector2 f x y = unsafePerformIO $ liftVector2 f x y {-# NOINLINE unsafeLiftVector2 #-} inlineLiftVector :: (IOVector n e -> IO a) -> Vector n e -> a inlineLiftVector f = inlinePerformIO . liftVector f {-# INLINE inlineLiftVector #-} -- | Create a vector with the given dimension and elements. The elements -- given in the association list must all have unique indices, otherwise -- the result is undefined. vector :: (BLAS1 e) => Int -> [(Int, e)] -> Vector n e vector n ies = unsafeFreezeIOVector $ unsafePerformIO $ newVector n ies {-# NOINLINE vector #-} -- | Same as 'vector', but does not range-check the indices. unsafeVector :: (BLAS1 e) => Int -> [(Int, e)] -> Vector n e unsafeVector n ies = unsafeFreezeIOVector $ unsafePerformIO $ unsafeNewVector n ies {-# NOINLINE unsafeVector #-} -- | Create a vector of the given dimension with elements initialized -- to the values from the list. @listVector n es@ is equivalent to -- @vector n (zip [0..(n-1)] es)@, except that the result is undefined -- if @length es@ is less than @n@. listVector :: (BLAS1 e) => Int -> [e] -> Vector n e listVector n es = unsafeFreezeIOVector $ unsafePerformIO $ newListVector n es {-# NOINLINE listVector #-} -- | @zeroVector n@ creates a vector of dimension @n@ with all values -- set to zero. zeroVector :: (BLAS1 e) => Int -> Vector n e zeroVector n = unsafeFreezeIOVector $ unsafePerformIO $ newZeroVector n {-# NOINLINE zeroVector #-} -- | @constantVector n e@ creates a vector of dimension @n@ with all values -- set to @e@. constantVector :: (BLAS1 e) => Int -> e -> Vector n e constantVector n e = unsafeFreezeIOVector $ unsafePerformIO $ newConstantVector n e {-# NOINLINE constantVector #-} -- | @basisVector n i@ creates a vector of dimension @n@ with zeros -- everywhere but position @i@, where there is a one. basisVector :: (BLAS1 e) => Int -> Int -> Vector n e basisVector n i = unsafeFreezeIOVector $ unsafePerformIO $ newBasisVector n i {-# NOINLINE basisVector #-} -- | Compute the sum of absolute values of entries in the vector. sumAbs :: (BLAS1 e) => Vector n e -> Double sumAbs = unsafeLiftVector getSumAbs {-# NOINLINE sumAbs #-} -- | Compute the 2-norm of a vector. norm2 :: (BLAS1 e) => Vector n e -> Double norm2 = unsafeLiftVector getNorm2 {-# NOINLINE norm2 #-} -- | Get the index and norm of the element with absulte value. Not valid -- if any of the vector entries are @NaN@. Raises an exception if the -- vector has length @0@. whichMaxAbs :: (BLAS1 e) => Vector n e -> (Int, e) whichMaxAbs = unsafeLiftVector getWhichMaxAbs {-# NOINLINE whichMaxAbs #-} -- | Compute the dot product of two vectors. (<.>) :: (BLAS1 e) => Vector n e -> Vector n e -> e (<.>) = unsafeLiftVector2 getDot {-# NOINLINE (<.>) #-} instance BaseTensor Vector Int where shape = liftVector shape bounds = liftVector bounds instance ITensor Vector Int where (//) = replaceHelp writeElem unsafeReplace = replaceHelp unsafeWriteElem unsafeAt x i = inlineLiftVector (flip unsafeReadElem i) x {-# INLINE unsafeAt #-} size = inlineLiftVector getSize elems = inlineLiftVector getElems indices = inlineLiftVector getIndices assocs = inlineLiftVector getAssocs tmap f x = listVector (dim x) (map f $ elems x) (*>) k x = unsafeFreezeIOVector $ unsafeLiftVector (getScaledVector k) x {-# NOINLINE (*>) #-} shift k x = unsafeFreezeIOVector $ unsafeLiftVector (getShiftedVector k) x {-# NOINLINE shift #-} replaceHelp :: (BLAS1 e) => (IOVector n e -> Int -> e -> IO ()) -> Vector n e -> [(Int, e)] -> Vector n e replaceHelp set x ies = unsafePerformIO $ do y <- newCopyVector (unsafeThawIOVector x) mapM_ (uncurry $ set y) ies return (unsafeFreezeIOVector y) {-# NOINLINE replaceHelp #-} instance (Monad m) => ReadTensor Vector Int m where getSize = return . size getAssocs = return . assocs getIndices = return . indices getElems = return . elems getAssocs' = return . assocs getIndices' = return . indices getElems' = return . elems unsafeReadElem x i = return $ unsafeAt x i instance BaseVector Vector where vectorViewArray f o n s c = V $ vectorViewArray f o n s c arrayFromVector = liftVector arrayFromVector instance (UnsafeIOToM m) => ReadVector Vector m where instance (BLAS1 e) => Num (Vector n e) where (+) x y = unsafeFreezeIOVector $ unsafeLiftVector2 getAddVector x y (-) x y = unsafeFreezeIOVector $ unsafeLiftVector2 getSubVector x y (*) x y = unsafeFreezeIOVector $ unsafeLiftVector2 getMulVector x y negate = ((-1) *>) abs = tmap abs signum = tmap signum fromInteger = (constantVector 1) . fromInteger instance (BLAS1 e) => Fractional (Vector n e) where (/) x y = unsafeFreezeIOVector $ unsafeLiftVector2 getDivVector x y recip = tmap recip fromRational = (constantVector 1) . fromRational instance (BLAS1 e, Floating e) => Floating (Vector n e) where pi = constantVector 1 pi exp = tmap exp sqrt = tmap sqrt log = tmap log (**) = tzipWith (**) sin = tmap sin cos = tmap cos tan = tmap tan asin = tmap asin acos = tmap acos atan = tmap atan sinh = tmap sinh cosh = tmap cosh tanh = tmap tanh asinh = tmap asinh acosh = tmap acosh atanh = tmap atanh tzipWith :: (BLAS1 e) => (e -> e -> e) -> Vector n e -> Vector n e -> Vector n e tzipWith f x y | dim y /= n = error ("tzipWith: vector lengths differ; first has length `" ++ show n ++ "' and second has length `" ++ show (dim y) ++ "'") | otherwise = listVector n (zipWith f (elems x) (elems y)) where n = dim x instance (BLAS1 e, Show e) => Show (Vector n e) where show x | isConj x = "conj (" ++ show (conj x) ++ ")" | otherwise = "listVector " ++ show (dim x) ++ " " ++ show (elems x) instance (BLAS1 e, Eq e) => Eq (Vector n e) where (==) = compareHelp (==) instance (BLAS1 e, AEq e) => AEq (Vector n e) where (===) = compareHelp (===) (~==) = compareHelp (~==) compareHelp :: (BLAS1 e) => (e -> e -> Bool) -> Vector n e -> Vector n e -> Bool compareHelp cmp x y | isConj x && isConj y = compareHelp cmp (conj x) (conj y) | otherwise = (dim x == dim y) && (and $ zipWith cmp (elems x) (elems y))