module Data.Vector.Dense.Operations (
copyVector,
swapVectors,
sumAbs,
norm2,
whichMaxAbs,
(<.>),
getSumAbs,
getNorm2,
getWhichMaxAbs,
getDot,
shift,
scale,
invScale,
add,
plus,
minus,
times,
divide,
getConj,
getShifted,
getScaled,
getInvScaled,
getSum,
getDiff,
getProduct,
getRatio,
doConj,
scaleBy,
shiftBy,
invScaleBy,
(+=),
(-=),
(*=),
(//=),
axpy,
unsafeCopyVector,
unsafeSwapVectors,
unsafeGetDot,
unsafeAxpy,
unsafePlusEquals,
unsafeMinusEquals,
unsafeTimesEquals,
unsafeDivideEquals,
) where
import Control.Monad ( forM_ )
import Data.Vector.Dense.Internal
import BLAS.Tensor
import BLAS.Elem.Base ( Elem )
import qualified BLAS.Elem.Base as E
import Foreign ( Ptr )
import System.IO.Unsafe
import Unsafe.Coerce
import BLAS.Internal ( inlinePerformIO, checkVecVecOp )
import BLAS.C hiding ( copy, swap, iamax, conj, axpy, acxpy )
import qualified BLAS.C as BLAS
import qualified BLAS.C.Types as T
infixl 7 <.>, `times`, `divide`, `scale`, `invScale`
infixl 6 `plus`, `minus`, `shift`
infixl 1 +=, -=, *=, //=
copyVector :: (BLAS1 e) => IOVector n e -> DVector t n e -> IO ()
copyVector x y =
checkVecVecOp "copyVector" (dim x) (dim y) $ unsafeCopyVector x y
unsafeCopyVector :: (BLAS1 e) => IOVector n e -> DVector t n e -> IO ()
unsafeCopyVector x y
| isConj x && isConj y =
unsafeCopyVector (conj x) (conj y)
| isConj x || isConj y =
forM_ [0..(dim x 1)] $ \i -> do
unsafeReadElem y i >>= unsafeWriteElem x i
| otherwise =
call2 BLAS.copy y x
swapVectors :: (BLAS1 e) => IOVector n e -> IOVector n e -> IO ()
swapVectors x y =
checkVecVecOp "swapVectors" (dim x) (dim y) $ unsafeSwapVectors x y
unsafeSwapVectors :: (BLAS1 e) => IOVector n e -> IOVector n e -> IO ()
unsafeSwapVectors x y
| isConj x && isConj y =
unsafeSwapVectors (conj x) (conj y)
| isConj x || isConj y =
forM_ [0..(dim x 1)] $ \i -> do
tmp <- unsafeReadElem x i
unsafeReadElem y i >>= unsafeWriteElem x i
unsafeWriteElem y i tmp
| otherwise =
call2 BLAS.swap x y
getSumAbs :: (BLAS1 e) => DVector t n e -> IO Double
getSumAbs = call asum
getNorm2 :: (BLAS1 e) => DVector t n e -> IO Double
getNorm2 = call nrm2
getWhichMaxAbs :: (BLAS1 e) => DVector t n e -> IO (Int, e)
getWhichMaxAbs x =
case (dim x) of
0 -> ioError $ userError $ "getWhichMaxAbs of an empty vector"
_ -> do
i <- call BLAS.iamax x
e <- unsafeReadElem x i
return (i,e)
getDot :: (BLAS1 e) => DVector s n e -> DVector t n e -> IO e
getDot x y = checkVecVecOp "dot" (dim x) (dim y) $ unsafeGetDot x y
unsafeGetDot :: (BLAS1 e) => DVector s n e -> DVector t n e -> IO e
unsafeGetDot x y =
case (isConj x, isConj y) of
(False, False) -> call2 dotc x y
(True , False) -> call2 dotu x y
(False, True ) -> call2 dotu x y >>= return . E.conj
(True , True) -> call2 dotc x y >>= return . E.conj
unsafeGetDotDouble :: DVector s n Double -> DVector t n Double -> IO Double
unsafeGetDotDouble x y = call2 dotc x y
getConj :: (BLAS1 e) => DVector t n e -> IO (DVector r n e)
getConj x
| isConj x = do
y <- newCopy (conj x)
return (unsafeCoerce y)
| otherwise = do
y <- newCopy x
doConj (unsafeThaw y)
return (unsafeCoerce y)
getShifted :: (BLAS1 e) => e -> DVector t n e -> IO (DVector r n e)
getShifted k x = do
y <- newCopy x
shiftBy k (unsafeThaw y)
return (unsafeCoerce y)
getScaled :: (BLAS1 e) => e -> DVector t n e -> IO (DVector r n e)
getScaled k x = do
y <- newCopy x
scaleBy k (unsafeThaw y)
return (unsafeCoerce y)
getInvScaled :: (BLAS1 e) => e -> DVector t n e -> IO (DVector r n e)
getInvScaled k x = do
y <- newCopy x
invScaleBy k (unsafeThaw y)
return (unsafeCoerce y)
getSum :: (BLAS1 e) => e -> DVector s n e -> e -> DVector t n e -> IO (DVector r n e)
getSum alpha x beta y = checkVecVecOp "getSum" (dim x) (dim y) $ unsafeGetSum alpha x beta y
unsafeGetSum :: (BLAS1 e) => e -> DVector s n e -> e -> DVector t n e -> IO (DVector r n e)
unsafeGetSum 1 x beta y
| beta /= 1 = unsafeGetSum beta y 1 x
unsafeGetSum alpha x beta y
| isConj x = do
s <- unsafeGetSum (E.conj alpha) (conj x) (E.conj beta) (conj y)
return (conj s)
| otherwise = do
s <- newCopy y
scaleBy beta (unsafeThaw s)
axpy alpha x (unsafeThaw s)
return (unsafeCoerce s)
getDiff :: (BLAS1 e) => DVector s n e -> DVector t n e -> IO (DVector r n e)
getDiff x y = checkVecVecOp "getDiff" (dim x) (dim y) $ unsafeGetSum 1 x (1) y
getProduct :: (BLAS2 e) => DVector s n e -> DVector t n e -> IO (DVector r n e)
getProduct = binaryOp "getProduct" (*=)
getRatio :: (BLAS2 e) => DVector s n e -> DVector t n e -> IO (DVector r n e)
getRatio = binaryOp "getRatio" (//=)
doConj :: (BLAS1 e) => IOVector n e -> IO ()
doConj = call BLAS.conj
shiftBy :: (BLAS1 e) => e -> IOVector n e -> IO ()
shiftBy alpha x | isConj x = shiftBy (E.conj alpha) (conj x)
| otherwise = modifyWith (alpha+) x
scaleBy :: (BLAS1 e) => e -> IOVector n e -> IO ()
scaleBy 1 _ = return ()
scaleBy k x | isConj x = scaleBy (E.conj k) (conj x)
| otherwise = call (flip scal k) x
invScaleBy :: (BLAS1 e) => e -> IOVector n e -> IO ()
invScaleBy k x | isConj x = invScaleBy (E.conj k) (conj x)
| otherwise = modifyWith (/k) x
axpy :: (BLAS1 e) => e -> DVector t n e -> IOVector n e -> IO ()
axpy alpha x y = checkVecVecOp "axpy" (dim x) (dim y) $ unsafeAxpy alpha x y
unsafeAxpy :: (BLAS1 e) => e -> DVector t n e -> IOVector n e -> IO ()
unsafeAxpy alpha x y
| isConj y =
axpy (E.conj alpha) (conj x) (conj y)
| isConj x =
call2 (flip BLAS.acxpy alpha) x y
| otherwise =
call2 (flip BLAS.axpy alpha) x y
(+=) :: (BLAS1 e) => IOVector n e -> DVector t n e -> IO ()
(+=) y x = checkVecVecOp "(+=)" (dim y) (dim x) $ unsafePlusEquals y x
unsafePlusEquals :: (BLAS1 e) => IOVector n e -> DVector t n e -> IO ()
unsafePlusEquals y x = axpy 1 x y
(-=) :: (BLAS1 e) => IOVector n e -> DVector t n e -> IO ()
(-=) y x = checkVecVecOp "(-=)" (dim y) (dim x) $ unsafeMinusEquals y x
unsafeMinusEquals :: (BLAS1 e) => IOVector n e -> DVector t n e -> IO ()
unsafeMinusEquals y x = unsafeAxpy (1) x y
(*=) :: (BLAS2 e) => IOVector n e -> DVector t n e -> IO ()
(*=) y x = checkVecVecOp "(*=)" (dim y) (dim x) $ unsafeTimesEquals y x
unsafeTimesEquals :: (BLAS2 e) => IOVector n e -> DVector t n e -> IO ()
unsafeTimesEquals y x
| isConj y =
unsafeTimesEquals (conj y) (conj x)
| isConj x =
call2 (flip (tbmv T.colMajor T.upper T.conjTrans T.nonUnit) 0) x y
| otherwise =
call2 (flip (tbmv T.colMajor T.upper T.noTrans T.nonUnit) 0) x y
(//=) :: (BLAS2 e) => IOVector n e -> DVector t n e -> IO ()
(//=) y x = checkVecVecOp "(//=)" (dim y) (dim x) $ unsafeDivideEquals y x
unsafeDivideEquals :: (BLAS2 e) => IOVector n e -> DVector t n e -> IO ()
unsafeDivideEquals y x
| isConj y =
unsafeDivideEquals (conj y) (conj x)
| isConj x =
call2 (flip (tbsv T.colMajor T.upper T.conjTrans T.nonUnit) 0) x y
| otherwise =
call2 (flip (tbsv T.colMajor T.upper T.noTrans T.nonUnit) 0) x y
call :: (Elem e) => (Int -> Ptr e -> Int -> IO a) -> DVector t n e -> IO a
call f x =
let n = dim x
incX = strideOf x
in unsafeWithElemPtr x 0 $ \pX -> f n pX incX
call2 :: (Elem e) =>
(Int -> Ptr e -> Int -> Ptr e -> Int -> IO a)
-> DVector s n e -> DVector t m e -> IO a
call2 f x y =
let n = dim x
incX = strideOf x
incY = strideOf y
in unsafeWithElemPtr x 0 $ \pX ->
unsafeWithElemPtr y 0 $ \pY ->
f n pX incX pY incY
binaryOp :: (BLAS1 e) => String -> (IOVector n e -> DVector t n e -> IO ())
-> DVector s n e -> DVector t n e -> IO (DVector r n e)
binaryOp name f x y =
checkVecVecOp name (dim x) (dim y) $ do
x' <- newCopy x >>= return . unsafeThaw
f x' y
return $! (unsafeCoerce x')
sumAbs :: (BLAS1 e) => Vector n e -> Double
sumAbs x = inlinePerformIO $ getSumAbs x
norm2 :: (BLAS1 e) => Vector n e -> Double
norm2 x = inlinePerformIO $ getNorm2 x
whichMaxAbs :: (BLAS1 e) => Vector n e -> (Int, e)
whichMaxAbs x = inlinePerformIO $ getWhichMaxAbs x
(<.>) :: (BLAS1 e) => Vector n e -> Vector n e -> e
(<.>) x y = inlinePerformIO $ getDot x y
shift :: (BLAS1 e) => e -> Vector n e -> Vector n e
shift k x = unsafePerformIO $ getShifted k x
scale :: (BLAS1 e) => e -> Vector n e -> Vector n e
scale k x = unsafePerformIO $ getScaled k x
invScale :: (BLAS1 e) => e -> Vector n e -> Vector n e
invScale k x = unsafePerformIO $ getInvScaled k x
add :: (BLAS1 e) => e -> Vector n e -> e -> Vector n e -> Vector n e
add alpha x beta y = unsafePerformIO $ getSum alpha x beta y
plus :: (BLAS1 e) => Vector n e -> Vector n e -> Vector n e
plus x y = add 1 x 1 y
minus :: (BLAS1 e) => Vector n e -> Vector n e -> Vector n e
minus x y = unsafePerformIO $ getDiff x y
times :: (BLAS2 e) => Vector n e -> Vector n e -> Vector n e
times x y = unsafePerformIO $ getProduct x y
divide :: (BLAS2 e) => Vector n e -> Vector n e -> Vector n e
divide x y = unsafePerformIO $ getRatio x y