module Data.Vector.Dense.Operations (
copyVector,
swapVectors,
sumAbs,
norm2,
whichMaxAbs,
(<.>),
getSumAbs,
getNorm2,
getWhichMaxAbs,
getDot,
shift,
scale,
invScale,
add,
plus,
minus,
times,
divide,
getShifted,
getScaled,
getInvScaled,
getSum,
getDiff,
getProduct,
getRatio,
doConj,
scaleBy,
shiftBy,
invScaleBy,
(+=),
(-=),
(*=),
(//=),
unsafeCopyVector,
unsafeSwapVectors,
axpy,
) where
import Control.Monad ( forM_ )
import Data.Vector.Dense.Internal
import BLAS.Tensor
import BLAS.Elem.Base ( Elem )
import qualified BLAS.Elem.Base as E
import Foreign ( Ptr )
import System.IO.Unsafe
import Unsafe.Coerce
import BLAS.Internal ( inlinePerformIO, checkVecVecOp )
import BLAS.C hiding ( copy, swap, iamax, conj, axpy, acxpy )
import qualified BLAS.C as BLAS
import qualified BLAS.C.Types as T
infixl 7 <.>, `times`, `divide`, `scale`, `invScale`
infixl 6 `plus`, `minus`, `shift`
infixl 1 +=, -=, *=, //=
copyVector :: (BLAS1 e) => IOVector n e -> DVector t n e -> IO ()
copyVector x y = checkVecVecOp "copyVector" (dim x) (dim y) >> unsafeCopyVector x y
unsafeCopyVector :: (BLAS1 e) => IOVector n e -> DVector t n e -> IO ()
unsafeCopyVector (C x) (C y) =
unsafeCopyVector x y
unsafeCopyVector x@(DV _ _ _ _) y@(DV _ _ _ _) =
call2 BLAS.copy y x
unsafeCopyVector x y = do
forM_ [0..(dim x 1)] $ \i -> do
unsafeReadElem y i >>= unsafeWriteElem x i
swapVectors :: (BLAS1 e) => IOVector n e -> IOVector n e -> IO ()
swapVectors x y = checkVecVecOp "swapVectors" (dim x) (dim y) >> unsafeSwapVectors x y
unsafeSwapVectors :: (BLAS1 e) => IOVector n e -> IOVector n e -> IO ()
unsafeSwapVectors (C x) (C y) =
unsafeSwapVectors x y
unsafeSwapVectors x@(DV _ _ _ _) y@(DV _ _ _ _) =
call2 BLAS.swap x y
unsafeSwapVectors x y = do
forM_ [0..(dim x 1)] $ \i -> do
tmp <- unsafeReadElem x i
unsafeReadElem y i >>= unsafeWriteElem x i
unsafeWriteElem y i tmp
getSumAbs :: (BLAS1 e) => DVector t n e -> IO Double
getSumAbs = call asum
getNorm2 :: (BLAS1 e) => DVector t n e -> IO Double
getNorm2 = call nrm2
getWhichMaxAbs :: (BLAS1 e) => DVector t n e -> IO (Int, e)
getWhichMaxAbs x =
case (dim x) of
0 -> ioError $ userError $ "getWhichMaxAbs of an empty vector"
_ -> do
i <- call BLAS.iamax x
e <- unsafeReadElem x i
return (i,e)
getDot :: (BLAS1 e) => DVector s n e -> DVector t n e -> IO e
getDot x y = checkVecVecOp "dot" (dim x) (dim y) >> unsafeGetDot x y
unsafeGetDot :: (BLAS1 e) => DVector s n e -> DVector t n e -> IO e
unsafeGetDot x@(DV _ _ _ _) y@(DV _ _ _ _) =
call2 dotc x y
unsafeGetDot (C x@(DV _ _ _ _)) (y@(DV _ _ _ _)) =
call2 dotu x y
unsafeGetDot (x@(DV _ _ _ _)) (C y@(DV _ _ _ _)) =
call2 dotu x y >>= return . E.conj
unsafeGetDot x@(DV _ _ _ _) (C (C y)) =
unsafeGetDot x y
unsafeGetDot (C x) y =
unsafeGetDot x (conj y) >>= return . E.conj
getShifted :: (BLAS1 e) => e -> DVector t n e -> IO (DVector r n e)
getShifted k x = do
y <- newCopy x
shiftBy k (unsafeThaw y)
return (unsafeCoerce y)
getScaled :: (BLAS1 e) => e -> DVector t n e -> IO (DVector r n e)
getScaled k x = do
y <- newCopy x
scaleBy k (unsafeThaw y)
return (unsafeCoerce y)
getInvScaled :: (BLAS1 e) => e -> DVector t n e -> IO (DVector r n e)
getInvScaled k x = do
y <- newCopy x
invScaleBy k (unsafeThaw y)
return (unsafeCoerce y)
getSum :: (BLAS1 e) => e -> DVector s n e -> e -> DVector t n e -> IO (DVector r n e)
getSum alpha x beta y = checkVecVecOp "getSum" (dim x) (dim y) >> unsafeGetSum alpha x beta y
unsafeGetSum :: (BLAS1 e) => e -> DVector s n e -> e -> DVector t n e -> IO (DVector r n e)
unsafeGetSum 1 x beta y
| beta /= 1 = unsafeGetSum beta y 1 x
unsafeGetSum alpha (C x) beta y = do
s <- unsafeGetSum (E.conj alpha) x (E.conj beta) (conj y)
return (conj s)
unsafeGetSum alpha x@(DV _ _ _ _) beta y = do
s <- newCopy y
scaleBy beta (unsafeThaw s)
axpy alpha x (unsafeThaw s)
return (unsafeCoerce s)
getDiff :: (BLAS1 e) => DVector s n e -> DVector t n e -> IO (DVector r n e)
getDiff x y = checkVecVecOp "getDiff" (dim x) (dim y) >> unsafeGetSum 1 x (1) y
getProduct :: (BLAS2 e) => DVector s n e -> DVector t n e -> IO (DVector r n e)
getProduct = binaryOp "getProduct" (*=)
getRatio :: (BLAS2 e) => DVector s n e -> DVector t n e -> IO (DVector r n e)
getRatio = binaryOp "getRatio" (//=)
doConj :: (BLAS1 e) => IOVector n e -> IO ()
doConj = call BLAS.conj
shiftBy :: (BLAS1 e) => e -> IOVector n e -> IO ()
shiftBy alpha (C x) = shiftBy (E.conj alpha) x
shiftBy alpha x = modifyWith (alpha+) x
scaleBy :: (BLAS1 e) => e -> IOVector n e -> IO ()
scaleBy 1 _ = return ()
scaleBy k (C x) = do
scaleBy (E.conj k) x
scaleBy k x@(DV _ _ _ _) =
call (flip scal k) x
invScaleBy :: (BLAS1 e) => e -> IOVector n e -> IO ()
invScaleBy k (C x) = invScaleBy (E.conj k) x
invScaleBy k x = modifyWith (/k) x
(+=) :: (BLAS1 e) => IOVector n e -> DVector t n e -> IO ()
(+=) y x = checkVecVecOp "(+=)" (dim y) (dim x) >> axpy 1 x y
axpy :: (BLAS1 e) => e -> DVector t n e -> IOVector n e -> IO ()
axpy alpha x@(DV _ _ _ _) y@(DV _ _ _ _) =
call2 (flip BLAS.axpy alpha) x y
axpy alpha (C x@(DV _ _ _ _)) y@(DV _ _ _ _) =
call2 (flip BLAS.acxpy alpha) x y
axpy alpha (C (C x)) y =
axpy alpha x y
axpy alpha x (C y) =
axpy (E.conj alpha) (conj x) y
(-=) :: (BLAS1 e) => IOVector n e -> DVector t n e -> IO ()
(-=) y x = checkVecVecOp "(-=)" (dim y) (dim x) >> axpy (1) x y
(*=) :: (BLAS2 e) => IOVector n e -> DVector t n e -> IO ()
(*=) y x = checkVecVecOp "(*=)" (dim y) (dim x) >> timesEquals y x
timesEquals :: (BLAS2 e) => IOVector n e -> DVector t n e -> IO ()
timesEquals y@(DV _ _ _ _) x@(DV _ _ _ _) =
call2 (flip (tbmv T.colMajor T.upper T.noTrans T.nonUnit) 0) x y
timesEquals y@(DV _ _ _ _) (C x@(DV _ _ _ _)) =
call2 (flip (tbmv T.colMajor T.upper T.conjTrans T.nonUnit) 0) x y
timesEquals y@(DV _ _ _ _) (C (C x)) =
timesEquals y x
timesEquals (C y) x =
timesEquals y (conj x)
(//=) :: (BLAS2 e) => IOVector n e -> DVector t n e -> IO ()
(//=) y x = checkVecVecOp "(//=)" (dim y) (dim x) >> divideEquals y x
divideEquals :: (BLAS2 e) => IOVector n e -> DVector t n e -> IO ()
divideEquals y@(DV _ _ _ _) x@(DV _ _ _ _) =
call2 (flip (tbsv T.colMajor T.upper T.noTrans T.nonUnit) 0) x y
divideEquals y@(DV _ _ _ _) (C x@(DV _ _ _ _)) =
call2 (flip (tbsv T.colMajor T.upper T.conjTrans T.nonUnit) 0) x y
divideEquals y@(DV _ _ _ _) (C (C x)) =
divideEquals y x
divideEquals (C y) x =
divideEquals y (conj x)
call :: (Elem e) => (Int -> Ptr e -> Int -> IO a) -> DVector t n e -> IO a
call f x =
let n = dim x
incX = strideOf x
in unsafeWithElemPtr x 0 $ \pX -> f n pX incX
call2 :: (Elem e) =>
(Int -> Ptr e -> Int -> Ptr e -> Int -> IO a)
-> DVector s n e -> DVector t m e -> IO a
call2 f x y =
let n = dim x
incX = strideOf x
incY = strideOf y
in unsafeWithElemPtr x 0 $ \pX ->
unsafeWithElemPtr y 0 $ \pY ->
f n pX incX pY incY
binaryOp :: (BLAS1 e) => String -> (IOVector n e -> DVector t n e -> IO ())
-> DVector s n e -> DVector t n e -> IO (DVector r n e)
binaryOp name f x y =
checkVecVecOp name (dim x) (dim y) >> do
x' <- newCopy x >>= return . unsafeThaw
f x' y
return $! (unsafeCoerce x')
sumAbs :: (BLAS1 e) => Vector n e -> Double
sumAbs x = inlinePerformIO $ getSumAbs x
norm2 :: (BLAS1 e) => Vector n e -> Double
norm2 x = inlinePerformIO $ getNorm2 x
whichMaxAbs :: (BLAS1 e) => Vector n e -> (Int, e)
whichMaxAbs x = inlinePerformIO $ getWhichMaxAbs x
(<.>) :: (BLAS1 e) => Vector n e -> Vector n e -> e
(<.>) x y = inlinePerformIO $ getDot x y
shift :: (BLAS1 e) => e -> Vector n e -> Vector n e
shift k x = unsafePerformIO $ getShifted k x
scale :: (BLAS1 e) => e -> Vector n e -> Vector n e
scale k x = unsafePerformIO $ getScaled k x
invScale :: (BLAS1 e) => e -> Vector n e -> Vector n e
invScale k x = unsafePerformIO $ getInvScaled k x
add :: (BLAS1 e) => e -> Vector n e -> e -> Vector n e -> Vector n e
add alpha x beta y = unsafePerformIO $ getSum alpha x beta y
plus :: (BLAS1 e) => Vector n e -> Vector n e -> Vector n e
plus x y = add 1 x 1 y
minus :: (BLAS1 e) => Vector n e -> Vector n e -> Vector n e
minus x y = unsafePerformIO $ getDiff x y
times :: (BLAS2 e) => Vector n e -> Vector n e -> Vector n e
times x y = unsafePerformIO $ getProduct x y
divide :: (BLAS2 e) => Vector n e -> Vector n e -> Vector n e
divide x y = unsafePerformIO $ getRatio x y