module Data.Vector.Dense.Internal (
Vector(..),
dim,
coerceVector,
module BLAS.Tensor.Base,
module BLAS.Conj,
vector,
listVector,
unsafeVector,
module BLAS.Tensor.Immutable,
zeroVector,
constantVector,
basisVector,
subvector,
subvectorWithStride,
unsafeSubvector,
unsafeSubvectorWithStride,
sumAbs,
norm2,
whichMaxAbs,
(<.>),
stride,
isConj,
) where
import Data.AEq
import System.IO.Unsafe
import BLAS.Conj
import BLAS.Tensor.Base
import BLAS.Tensor.Immutable
import BLAS.Elem ( BLAS1 )
import BLAS.Internal ( inlinePerformIO )
import BLAS.UnsafeIOToM
import Data.Vector.Dense.IO
import Data.Vector.Dense.Class
infixl 7 <.>
newtype Vector n e = V (IOVector n e)
unsafeFreezeIOVector :: IOVector n e -> Vector n e
unsafeFreezeIOVector = V
unsafeThawIOVector :: Vector n e -> IOVector n e
unsafeThawIOVector (V x) = x
liftVector :: (IOVector n e -> a) -> Vector n e -> a
liftVector f (V x) = f x
liftVector2 ::
(IOVector n e -> IOVector n e -> a) ->
Vector n e -> Vector n e -> a
liftVector2 f x = liftVector (liftVector f x)
unsafeLiftVector :: (IOVector n e -> IO a) -> Vector n e -> a
unsafeLiftVector f = unsafePerformIO . liftVector f
unsafeLiftVector2 ::
(IOVector n e -> IOVector n e -> IO a) ->
Vector n e -> Vector n e -> a
unsafeLiftVector2 f x y = unsafePerformIO $ liftVector2 f x y
inlineLiftVector :: (IOVector n e -> IO a) -> Vector n e -> a
inlineLiftVector f = inlinePerformIO . liftVector f
vector :: (BLAS1 e) => Int -> [(Int, e)] -> Vector n e
vector n ies = unsafeFreezeIOVector $ unsafePerformIO $ newVector n ies
unsafeVector :: (BLAS1 e) => Int -> [(Int, e)] -> Vector n e
unsafeVector n ies = unsafeFreezeIOVector $ unsafePerformIO $ unsafeNewVector n ies
listVector :: (BLAS1 e) => Int -> [e] -> Vector n e
listVector n es = unsafeFreezeIOVector $ unsafePerformIO $ newListVector n es
zeroVector :: (BLAS1 e) => Int -> Vector n e
zeroVector n = unsafeFreezeIOVector $ unsafePerformIO $ newZeroVector n
constantVector :: (BLAS1 e) => Int -> e -> Vector n e
constantVector n e = unsafeFreezeIOVector $ unsafePerformIO $ newConstantVector n e
basisVector :: (BLAS1 e) => Int -> Int -> Vector n e
basisVector n i = unsafeFreezeIOVector $ unsafePerformIO $ newBasisVector n i
sumAbs :: (BLAS1 e) => Vector n e -> Double
sumAbs = unsafeLiftVector getSumAbs
norm2 :: (BLAS1 e) => Vector n e -> Double
norm2 = unsafeLiftVector getNorm2
whichMaxAbs :: (BLAS1 e) => Vector n e -> (Int, e)
whichMaxAbs = unsafeLiftVector getWhichMaxAbs
(<.>) :: (BLAS1 e) => Vector n e -> Vector n e -> e
(<.>) = unsafeLiftVector2 getDot
instance BaseTensor Vector Int where
shape = liftVector shape
bounds = liftVector bounds
instance ITensor Vector Int where
(//) = replaceHelp writeElem
unsafeReplace = replaceHelp unsafeWriteElem
unsafeAt x i = inlineLiftVector (flip unsafeReadElem i) x
size = inlineLiftVector getSize
elems = inlineLiftVector getElems
indices = inlineLiftVector getIndices
assocs = inlineLiftVector getAssocs
tmap f x = listVector (dim x) (map f $ elems x)
(*>) k x = unsafeFreezeIOVector $ unsafeLiftVector (getScaledVector k) x
shift k x = unsafeFreezeIOVector $ unsafeLiftVector (getShiftedVector k) x
replaceHelp :: (BLAS1 e) =>
(IOVector n e -> Int -> e -> IO ()) ->
Vector n e -> [(Int, e)] -> Vector n e
replaceHelp set x ies =
unsafePerformIO $ do
y <- newCopyVector (unsafeThawIOVector x)
mapM_ (uncurry $ set y) ies
return (unsafeFreezeIOVector y)
instance (Monad m) => ReadTensor Vector Int m where
getSize = return . size
getAssocs = return . assocs
getIndices = return . indices
getElems = return . elems
getAssocs' = return . assocs
getIndices' = return . indices
getElems' = return . elems
unsafeReadElem x i = return $ unsafeAt x i
instance BaseVector Vector where
vectorViewArray f o n s c = V $ vectorViewArray f o n s c
arrayFromVector = liftVector arrayFromVector
instance (UnsafeIOToM m) => ReadVector Vector m where
instance (BLAS1 e) => Num (Vector n e) where
(+) x y = unsafeFreezeIOVector $ unsafeLiftVector2 getAddVector x y
() x y = unsafeFreezeIOVector $ unsafeLiftVector2 getSubVector x y
(*) x y = unsafeFreezeIOVector $ unsafeLiftVector2 getMulVector x y
negate = ((1) *>)
abs = tmap abs
signum = tmap signum
fromInteger = (constantVector 1) . fromInteger
instance (BLAS1 e) => Fractional (Vector n e) where
(/) x y = unsafeFreezeIOVector $ unsafeLiftVector2 getDivVector x y
recip = tmap recip
fromRational = (constantVector 1) . fromRational
instance (BLAS1 e, Floating e) => Floating (Vector n e) where
pi = constantVector 1 pi
exp = tmap exp
sqrt = tmap sqrt
log = tmap log
(**) = tzipWith (**)
sin = tmap sin
cos = tmap cos
tan = tmap tan
asin = tmap asin
acos = tmap acos
atan = tmap atan
sinh = tmap sinh
cosh = tmap cosh
tanh = tmap tanh
asinh = tmap asinh
acosh = tmap acosh
atanh = tmap atanh
tzipWith :: (BLAS1 e) =>
(e -> e -> e) -> Vector n e -> Vector n e -> Vector n e
tzipWith f x y
| dim y /= n =
error ("tzipWith: vector lengths differ; first has length `" ++
show n ++ "' and second has length `" ++
show (dim y) ++ "'")
| otherwise =
listVector n (zipWith f (elems x) (elems y))
where
n = dim x
instance (BLAS1 e, Show e) => Show (Vector n e) where
show x
| isConj x = "conj (" ++ show (conj x) ++ ")"
| otherwise = "listVector " ++ show (dim x) ++ " " ++ show (elems x)
instance (BLAS1 e, Eq e) => Eq (Vector n e) where
(==) = compareHelp (==)
instance (BLAS1 e, AEq e) => AEq (Vector n e) where
(===) = compareHelp (===)
(~==) = compareHelp (~==)
compareHelp :: (BLAS1 e) =>
(e -> e -> Bool) ->
Vector n e -> Vector n e -> Bool
compareHelp cmp x y
| isConj x && isConj y =
compareHelp cmp (conj x) (conj y)
| otherwise =
(dim x == dim y) && (and $ zipWith cmp (elems x) (elems y))