module Data.Vector.Dense.Base
where
import Control.Monad
import Control.Monad.ST
import Data.AEq
import Foreign
import Unsafe.Coerce
import BLAS.Internal( checkBinaryOp, clearArray, inlinePerformIO,
checkedSubvector, checkedSubvectorWithStride, checkVecVecOp )
import BLAS.Types( ConjEnum(..) )
import Data.Elem.BLAS ( Complex, Elem, BLAS1, conjugate )
import qualified Data.Elem.BLAS.Level1 as BLAS
import Data.Tensor.Class
import Data.Tensor.Class.ITensor
import Data.Tensor.Class.MTensor
import Data.Vector.Dense.IOBase
infixl 7 <.>
newtype Vector n e = Vector (IOVector n e)
freezeIOVector :: (BLAS1 e) => IOVector n e -> IO (Vector n e)
freezeIOVector x = do
y <- newCopyIOVector x
return (Vector y)
thawIOVector :: (BLAS1 e) => Vector n e -> IO (IOVector n e)
thawIOVector (Vector x) =
newCopyIOVector x
unsafeFreezeIOVector :: IOVector n e -> IO (Vector n e)
unsafeFreezeIOVector = return . Vector
unsafeThawIOVector :: Vector n e -> IO (IOVector n e)
unsafeThawIOVector (Vector x) = return x
class (Shaped x Int, Elem e) => BaseVector x e where
dim :: x n e -> Int
stride :: x n e -> Int
isConj :: x n e -> Bool
isConj x = conjEnum x == Conj
conjEnum :: x n e -> ConjEnum
conj :: x n e -> x n e
coerceVector :: x n e -> x n' e
coerceVector = unsafeCoerce
unsafeSubvectorViewWithStride :: Int -> x n e -> Int -> Int -> x n' e
unsafeVectorToIOVector :: x n e -> IOVector n e
unsafeIOVectorToVector :: IOVector n e -> x n e
class (BaseVector x e, BLAS1 e, Monad m, ReadTensor x Int e m) => ReadVector x e m where
unsafePerformIOWithVector :: x n e -> (IOVector n e -> IO a) -> m a
freezeVector :: x n e -> m (Vector n e)
unsafeFreezeVector :: x n e -> m (Vector n e)
class (ReadVector x e m, WriteTensor x Int e m) => WriteVector x e m where
unsafeConvertIOVector :: IO (IOVector n e) -> m (x n e)
newVector_ :: Int -> m (x n e)
thawVector :: Vector n e -> m (x n e)
unsafeThawVector :: Vector n e -> m (x n e)
newVector :: (WriteVector x e m) => Int -> [(Int,e)] -> m (x n e)
newVector n ies = do
x <- newZeroVector n
unsafePerformIOWithVector x $ \x' ->
withIOVector x' $ \p -> do
forM_ ies $ \(i,e) -> do
when (i < 0 || i >= n) $ fail $
"Index `" ++ show i ++
"' is invalid for a vector with dimension `" ++
show n ++ "'"
pokeElemOff p i e
return x
unsafeNewVector :: (WriteVector x e m) => Int -> [(Int,e)] -> m (x n e)
unsafeNewVector n ies = do
x <- newZeroVector n
unsafePerformIOWithVector x $ \x' ->
withIOVector x' $ \p -> do
forM_ ies $ \(i,e) ->
pokeElemOff p i e
return x
newListVector :: (WriteVector x e m) => Int -> [e] -> m (x n e)
newListVector n es = do
x <- newVector_ n
unsafePerformIOWithVector x $ \x' ->
withIOVector x' $ \p -> do
pokeArray p $ take n $ es ++ (repeat 0)
return x
newZeroVector :: (WriteVector x e m) => Int -> m (x n e)
newZeroVector n = do
x <- newVector_ n
unsafePerformIOWithVector x $ \x' ->
withIOVector x' $ \p -> do
clearArray p n
return x
setZeroVector :: (WriteVector x e m) => x n e -> m ()
setZeroVector x =
unsafePerformIOWithVector x $ setZeroIOVector
newConstantVector :: (WriteVector x e m) => Int -> e -> m (x n e)
newConstantVector n e = do
x <- newVector_ n
unsafePerformIOWithVector x $ \x' ->
withIOVector x' $ \p -> do
pokeArray p (replicate n e)
return x
setConstantVector :: (WriteVector x e m) => e -> x n e -> m ()
setConstantVector e x =
unsafePerformIOWithVector x $ setConstantIOVector e
newBasisVector :: (WriteVector x e m) => Int -> Int -> m (x n e)
newBasisVector n i = do
x <- newZeroVector n
unsafePerformIOWithVector x $ \x' ->
withIOVector x' $ \p -> do
pokeElemOff p i 1
return x
setBasisVector :: (WriteVector x e m) => Int -> x n e -> m ()
setBasisVector i x
| i < 0 || i >= dim x =
fail $ "tried to set a vector of dimension `" ++ show (dim x) ++ "'"
++ " to basis vector `" ++ show i ++ "'"
| otherwise = do
setZeroVector x
unsafeWriteElem x i 1
newCopyVector :: (ReadVector x e m, WriteVector y e m) =>
x n e -> m (y n e)
newCopyVector x
| isConj x =
newCopyVector (conj x) >>= return . conj
| otherwise = do
y <- newVector_ n
unsafePerformIOWithVector y $ \y' ->
withIOVector (unsafeVectorToIOVector x) $ \pX ->
withIOVector y' $ \pY ->
BLAS.copy n pX incX pY 1
return y
where
n = dim x
incX = stride x
newCopyVector' :: (ReadVector x e m, WriteVector y e m) => x n e -> m (y n e)
newCopyVector' x | not (isConj x) = newCopyVector x
| otherwise = do
y <- newCopyVector (conj x)
unsafePerformIOWithVector y $ \y' ->
withIOVector y' $ \pY -> do
BLAS.vconj (dim x) pY 1
return y
copyVector :: (WriteVector y e m, ReadVector x e m) =>
y n e -> x n e -> m ()
copyVector y x = checkBinaryOp (shape x) (shape y) $ unsafeCopyVector y x
unsafeCopyVector :: (ReadVector y e m, ReadVector x e m) =>
y n e -> x n e -> m ()
unsafeCopyVector y x
| isConj x =
unsafeCopyVector (conj y) (conj x)
| isConj y = do
vectorCall2 BLAS.copy x y
vectorCall BLAS.vconj y
| otherwise =
vectorCall2 BLAS.copy x y
swapVector :: (WriteVector x e m, WriteVector y e m) =>
x n e -> y n e -> m ()
swapVector x y = checkBinaryOp (shape x) (shape y) $ unsafeSwapVector x y
unsafeSwapVector :: (WriteVector x e m, WriteVector y e m) =>
x n e -> y n e -> m ()
unsafeSwapVector x y
| isConj x =
unsafeSwapVector (conj x) (conj y)
| isConj y = do
vectorCall2 BLAS.swap x y
vectorCall BLAS.vconj x
vectorCall BLAS.vconj y
| otherwise =
vectorCall2 BLAS.swap x y
subvectorView :: (BaseVector x e) =>
x n e -> Int -> Int -> x n' e
subvectorView x = checkedSubvector (dim x) (unsafeSubvectorView x)
unsafeSubvectorView :: (BaseVector x e) =>
x n e -> Int -> Int -> x n' e
unsafeSubvectorView = unsafeSubvectorViewWithStride 1
subvectorViewWithStride :: (BaseVector x e) =>
Int -> x n e -> Int -> Int -> x n' e
subvectorViewWithStride s x =
checkedSubvectorWithStride s (dim x) (unsafeSubvectorViewWithStride s x)
getConjVector :: (ReadVector x e m, WriteVector y e m) =>
x n e -> m (y n e)
getConjVector = getUnaryVectorOp doConjVector
doConjVector :: (WriteVector y e m) => y n e -> m ()
doConjVector x =
unsafePerformIOWithVector x $ doConjIOVector
getScaledVector :: (ReadVector x e m, WriteVector y e m) =>
e -> x n e -> m (y n e)
getScaledVector e = getUnaryVectorOp (scaleByVector e)
scaleByVector :: (WriteVector y e m) => e -> y n e -> m ()
scaleByVector k x =
unsafePerformIOWithVector x $ scaleByIOVector k
getShiftedVector :: (ReadVector x e m, WriteVector y e m) =>
e -> x n e -> m (y n e)
getShiftedVector e = getUnaryVectorOp (shiftByVector e)
shiftByVector :: (WriteVector y e m) => e -> y n e -> m ()
shiftByVector k x =
unsafePerformIOWithVector x $ shiftByIOVector k
getAddVector ::
(ReadVector x e m, ReadVector y e m, WriteVector z e m) =>
x n e -> y n e -> m (z n e)
getAddVector = checkVectorOp2 unsafeGetAddVector
unsafeGetAddVector ::
(ReadVector x e m, ReadVector y e m, WriteVector z e m) =>
x n e -> y n e -> m (z n e)
unsafeGetAddVector = unsafeGetBinaryVectorOp unsafeAddVector
addVector :: (WriteVector y e m, ReadVector x e m) =>
y n e -> x n e -> m ()
addVector y x = checkBinaryOp (dim y) (dim x) $ unsafeAddVector y x
unsafeAddVector :: (WriteVector y e m, ReadVector x e m) =>
y n e -> x n e -> m ()
unsafeAddVector y x = unsafeAxpyVector 1 x y
getSubVector ::
(ReadVector x e m, ReadVector y e m, WriteVector z e m) =>
x n e -> y n e -> m (z n e)
getSubVector = checkVectorOp2 unsafeGetSubVector
unsafeGetSubVector ::
(ReadVector x e m, ReadVector y e m, WriteVector z e m) =>
x n e -> y n e -> m (z n e)
unsafeGetSubVector = unsafeGetBinaryVectorOp unsafeSubVector
subVector :: (WriteVector y e m, ReadVector x e m) =>
y n e -> x n e -> m ()
subVector y x = checkBinaryOp (dim y) (dim x) $ unsafeSubVector y x
unsafeSubVector :: (WriteVector y e m, ReadVector x e m) =>
y n e -> x n e -> m ()
unsafeSubVector y x = unsafeAxpyVector (1) x y
axpyVector :: (ReadVector x e m, WriteVector y e m) =>
e -> x n e -> y n e -> m ()
axpyVector alpha x y =
checkBinaryOp (shape x) (shape y) $ unsafeAxpyVector alpha x y
unsafeAxpyVector :: (ReadVector x e m, ReadVector y e m) =>
e -> x n e -> y n e -> m ()
unsafeAxpyVector alpha x y
| isConj y =
unsafeAxpyVector (conjugate alpha) (conj x) (conj y)
| isConj x =
vectorCall2 (flip BLAS.acxpy alpha) x y
| otherwise =
vectorCall2 (flip BLAS.axpy alpha) x y
getMulVector ::
(ReadVector x e m, ReadVector y e m, WriteVector z e m) =>
x n e -> y n e -> m (z n e)
getMulVector = checkVectorOp2 unsafeGetMulVector
unsafeGetMulVector ::
(ReadVector x e m, ReadVector y e m, WriteVector z e m) =>
x n e -> y n e -> m (z n e)
unsafeGetMulVector = unsafeGetBinaryVectorOp unsafeMulVector
mulVector :: (WriteVector y e m, ReadVector x e m) =>
y n e -> x n e -> m ()
mulVector y x =
checkBinaryOp (shape y) (shape x) $ unsafeMulVector y x
unsafeMulVector :: (WriteVector y e m, ReadVector x e m) =>
y n e -> x n e -> m ()
unsafeMulVector y x
| isConj y =
unsafeMulVector (conj y) (conj x)
| isConj x =
vectorCall2 BLAS.vcmul x y
| otherwise =
vectorCall2 BLAS.vmul x y
getDivVector ::
(ReadVector x e m, ReadVector y e m, WriteVector z e m) =>
x n e -> y n e -> m (z n e)
getDivVector = checkVectorOp2 unsafeGetDivVector
unsafeGetDivVector ::
(ReadVector x e m, ReadVector y e m, WriteVector z e m) =>
x n e -> y n e -> m (z n e)
unsafeGetDivVector = unsafeGetBinaryVectorOp unsafeDivVector
divVector :: (WriteVector y e m, ReadVector x e m) =>
y n e -> x n e -> m ()
divVector y x =
checkBinaryOp (shape y) (shape x) $ unsafeDivVector y x
unsafeDivVector :: (WriteVector y e m, ReadVector x e m) =>
y n e -> x n e -> m ()
unsafeDivVector y x
| isConj y =
unsafeDivVector (conj y) (conj x)
| isConj x =
vectorCall2 BLAS.vcdiv x y
| otherwise =
vectorCall2 BLAS.vdiv x y
getSumAbs :: (ReadVector x e m) => x n e -> m Double
getSumAbs = vectorCall BLAS.asum
getNorm2 :: (ReadVector x e m) => x n e -> m Double
getNorm2 = vectorCall BLAS.nrm2
getWhichMaxAbs :: (ReadVector x e m) => x n e -> m (Int, e)
getWhichMaxAbs x =
case (dim x) of
0 -> fail $ "getWhichMaxAbs of an empty vector"
_ -> do
i <- vectorCall BLAS.iamax x
e <- unsafeReadElem x i
return (i,e)
getDot :: (ReadVector x e m, ReadVector y e m) =>
x n e -> y n e -> m e
getDot x y = checkVecVecOp "getDot" (dim x) (dim y) $ unsafeGetDot x y
unsafeGetDot :: (ReadVector x e m, ReadVector y e m) =>
x n e -> y n e -> m e
unsafeGetDot x y =
vectorCall2 (BLAS.dot (conjEnum x) (conjEnum y)) x y
instance (Elem e) => BaseVector IOVector e where
dim = dimIOVector
stride = strideIOVector
conjEnum = conjEnumIOVector
conj = conjIOVector
unsafeSubvectorViewWithStride = unsafeSubvectorViewWithStrideIOVector
unsafeVectorToIOVector = id
unsafeIOVectorToVector = id
instance (BLAS1 e) => ReadVector IOVector e IO where
unsafePerformIOWithVector x f = f x
freezeVector = freezeIOVector
unsafeFreezeVector = unsafeFreezeIOVector
instance (BLAS1 e) => WriteVector IOVector e IO where
newVector_ = newIOVector_
unsafeConvertIOVector = id
thawVector = thawIOVector
unsafeThawVector = unsafeThawIOVector
vector :: (BLAS1 e) => Int -> [(Int, e)] -> Vector n e
vector n ies = unsafePerformIO $
unsafeFreezeIOVector =<< newVector n ies
unsafeVector :: (BLAS1 e) => Int -> [(Int, e)] -> Vector n e
unsafeVector n ies = unsafePerformIO $
unsafeFreezeIOVector =<< unsafeNewVector n ies
listVector :: (BLAS1 e) => Int -> [e] -> Vector n e
listVector n es = Vector $ unsafePerformIO $ newListVector n es
replaceVector :: (BLAS1 e) => Vector n e -> [(Int,e)] -> Vector n e
replaceVector (Vector x) ies =
unsafePerformIO $ do
y <- newCopyVector x
mapM_ (uncurry $ writeElem y) ies
return (Vector y)
unsafeReplaceVector :: (BLAS1 e) => Vector n e -> [(Int,e)] -> Vector n e
unsafeReplaceVector (Vector x) ies =
unsafePerformIO $ do
y <- newCopyVector x
mapM_ (uncurry $ unsafeWriteElem y) ies
return (Vector y)
zeroVector :: (BLAS1 e) => Int -> Vector n e
zeroVector n = unsafePerformIO $
unsafeFreezeIOVector =<< newZeroVector n
constantVector :: (BLAS1 e) => Int -> e -> Vector n e
constantVector n e = unsafePerformIO $
unsafeFreezeIOVector =<< newConstantVector n e
basisVector :: (BLAS1 e) => Int -> Int -> Vector n e
basisVector n i = unsafePerformIO $
unsafeFreezeIOVector =<< newBasisVector n i
subvector :: (BLAS1 e) => Vector n e -> Int -> Int -> Vector n' e
subvector = subvectorView
unsafeSubvector :: (BLAS1 e) => Vector n e -> Int -> Int -> Vector n' e
unsafeSubvector = unsafeSubvectorView
unsafeSubvectorWithStride :: (Elem e) =>
Int -> Vector n e -> Int -> Int -> Vector n' e
unsafeSubvectorWithStride = unsafeSubvectorViewWithStride
subvectorWithStride :: (BLAS1 e) =>
Int -> Vector n e -> Int -> Int -> Vector n' e
subvectorWithStride = subvectorViewWithStride
sizeVector :: Vector n e -> Int
sizeVector (Vector x) = sizeIOVector x
indicesVector :: Vector n e -> [Int]
indicesVector (Vector x) = indicesIOVector x
elemsVector :: (Elem e) => Vector n e -> [e]
elemsVector x | isConj x = (map conjugate . elemsVector . conj) x
| otherwise = case x of { (Vector (IOVector _ n f p incX)) ->
let end = p `advancePtr` (n*incX)
go p' | p' == end = inlinePerformIO $ do
io <- touchForeignPtr f
io `seq` return []
| otherwise = let e = inlinePerformIO (peek p')
es = go (p' `advancePtr` incX)
in e `seq` (e:es)
in go p }
assocsVector :: (Elem e) => Vector n e -> [(Int,e)]
assocsVector x = zip (indicesVector x) (elemsVector x)
unsafeAtVector :: (Elem e) => Vector n e -> Int -> e
unsafeAtVector x i | isConj x = conjugate $ unsafeAtVector (conj x) i
| otherwise = case x of { (Vector (IOVector _ _ f p inc)) ->
inlinePerformIO $ do
e <- peekElemOff p (i*inc)
io <- touchForeignPtr f
e `seq` io `seq` return e
}
tmapVector :: (BLAS1 e) => (e -> e) -> Vector n e -> Vector n e
tmapVector f x = listVector (dim x) (map f $ elemsVector x)
tzipWithVector :: (BLAS1 e) =>
(e -> e -> e) -> Vector n e -> Vector n e -> Vector n e
tzipWithVector f x y
| dim y /= n =
error ("tzipWith: vector lengths differ; first has length `" ++
show n ++ "' and second has length `" ++
show (dim y) ++ "'")
| otherwise =
listVector n (zipWith f (elems x) (elems y))
where
n = dim x
scaleVector :: (BLAS1 e) => e -> Vector n e -> Vector n e
scaleVector e (Vector x) =
unsafePerformIO $ unsafeFreezeIOVector =<< getScaledVector e x
shiftVector :: (BLAS1 e) => e -> Vector n e -> Vector n e
shiftVector e (Vector x) =
unsafePerformIO $ unsafeFreezeIOVector =<< getShiftedVector e x
sumAbs :: (BLAS1 e) => Vector n e -> Double
sumAbs (Vector x) = unsafePerformIO $ getSumAbs x
norm2 :: (BLAS1 e) => Vector n e -> Double
norm2 (Vector x) = unsafePerformIO $ getNorm2 x
whichMaxAbs :: (BLAS1 e) => Vector n e -> (Int, e)
whichMaxAbs (Vector x) = unsafePerformIO $ getWhichMaxAbs x
(<.>) :: (BLAS1 e) => Vector n e -> Vector n e -> e
(<.>) x y = unsafePerformIO $ getDot x y
unsafeDot :: (BLAS1 e) => Vector n e -> Vector n e -> e
unsafeDot x y = unsafePerformIO $ unsafeGetDot x y
instance Shaped Vector Int where
shape (Vector x) = shapeIOVector x
bounds (Vector x) = boundsIOVector x
instance (BLAS1 e) => ITensor Vector Int e where
(//) = replaceVector
unsafeReplace = unsafeReplaceVector
unsafeAt = unsafeAtVector
size = sizeVector
elems = elemsVector
indices = indicesVector
assocs = assocsVector
tmap = tmapVector
(*>) = scaleVector
shift = shiftVector
instance (BLAS1 e, Monad m) => ReadTensor Vector Int e m where
getSize = return . size
getAssocs = return . assocs
getIndices = return . indices
getElems = return . elems
getAssocs' = return . assocs
getIndices' = return . indices
getElems' = return . elems
unsafeReadElem x i = return $ unsafeAt x i
instance (Elem e) => BaseVector Vector e where
dim (Vector x) = dimIOVector x
stride (Vector x) = strideIOVector x
conjEnum (Vector x) = conjEnumIOVector x
conj (Vector x) = (Vector (conjIOVector x))
unsafeSubvectorViewWithStride s (Vector x) o n =
Vector (unsafeSubvectorViewWithStrideIOVector s x o n)
unsafeVectorToIOVector (Vector x) = x
unsafeIOVectorToVector = Vector
instance (BLAS1 e) => ReadVector Vector e IO where
unsafePerformIOWithVector (Vector x) f = f x
freezeVector (Vector x) = freezeIOVector x
unsafeFreezeVector = return
instance (BLAS1 e) => ReadVector Vector e (ST s) where
unsafePerformIOWithVector (Vector x) f = unsafeIOToST $ f x
freezeVector (Vector x) = unsafeIOToST $ freezeIOVector x
unsafeFreezeVector = return
instance (Elem e, Show e) => Show (Vector n e) where
show x
| isConj x = "conj (" ++ show (conj x) ++ ")"
| otherwise = "listVector " ++ show (dim x) ++ " " ++ show (elemsVector x)
instance (BLAS1 e) => Eq (Vector n e) where
(==) = compareVectorWith (==)
instance (BLAS1 e) => AEq (Vector n e) where
(===) = compareVectorWith (===)
(~==) = compareVectorWith (~==)
compareVectorWith :: (Elem e) =>
(e -> e -> Bool) ->
Vector n e -> Vector n e -> Bool
compareVectorWith cmp x y
| isConj x && isConj y =
compareVectorWith cmp (conj x) (conj y)
| otherwise =
(dim x == dim y) && (and $ zipWith cmp (elemsVector x) (elemsVector y))
instance (BLAS1 e) => Num (Vector n e) where
(+) x y = unsafePerformIO $ unsafeFreezeIOVector =<< getAddVector x y
() x y = unsafePerformIO $ unsafeFreezeIOVector =<< getSubVector x y
(*) x y = unsafePerformIO $ unsafeFreezeIOVector =<< getMulVector x y
negate = ((1) *>)
abs = tmap abs
signum = tmap signum
fromInteger n = listVector 1 [fromInteger n]
instance (BLAS1 e) => Fractional (Vector n e) where
(/) x y = unsafePerformIO $ unsafeFreezeIOVector =<< getDivVector x y
recip = tmap recip
fromRational q = listVector 1 [fromRational q]
instance (BLAS1 e, Floating e) => Floating (Vector n e) where
pi = listVector 1 [pi]
exp = tmap exp
sqrt = tmap sqrt
log = tmap log
(**) = tzipWithVector (**)
sin = tmap sin
cos = tmap cos
tan = tmap tan
asin = tmap asin
acos = tmap acos
atan = tmap atan
sinh = tmap sinh
cosh = tmap cosh
tanh = tmap tanh
asinh = tmap asinh
acosh = tmap acosh
atanh = tmap atanh
vectorCall :: (ReadVector x e m)
=> (Int -> Ptr e -> Int -> IO a)
-> x n e -> m a
vectorCall f x =
unsafePerformIOWithVector x $ \x' ->
let n = dimIOVector x'
incX = strideIOVector x'
in withIOVector x' $ \pX ->
f n pX incX
vectorCall2 :: (ReadVector x e m, ReadVector y f m)
=> (Int -> Ptr e -> Int -> Ptr f -> Int -> IO a)
-> x n e -> y n' f -> m a
vectorCall2 f x y =
unsafePerformIOWithVector x $ \x' ->
let y' = unsafeVectorToIOVector y
n = dimIOVector x'
incX = strideIOVector x'
incY = strideIOVector y'
in withIOVector x' $ \pX ->
withIOVector y' $ \pY ->
f n pX incX pY incY
checkVectorOp2 :: (BaseVector x e, BaseVector y f) =>
(x n e -> y n f -> a) ->
x n e -> y n f -> a
checkVectorOp2 f x y =
checkBinaryOp (dim x) (dim y) $ f x y
getUnaryVectorOp :: (ReadVector x e m, WriteVector y e m) =>
(y n e -> m ()) -> x n e -> m (y n e)
getUnaryVectorOp f x = do
y <- newCopyVector x
f y
return y
unsafeGetBinaryVectorOp ::
(WriteVector z e m, ReadVector x e m, ReadVector y e m) =>
(z n e -> y n e -> m ()) ->
x n e -> y n e -> m (z n e)
unsafeGetBinaryVectorOp f x y = do
z <- newCopyVector x
f z y
return z