module Data.Matrix.Dense.Operations (
copyMatrix,
swapMatrices,
apply,
applyMat,
sapply,
sapplyMat,
getApply,
getApplyMat,
getSApply,
getSApplyMat,
shift,
scale,
invScale,
add,
plus,
minus,
times,
divide,
getShifted,
getScaled,
getInvScaled,
getSum,
getDiff,
getProduct,
getRatio,
doConj,
shiftBy,
scaleBy,
invScaleBy,
axpy,
(+=),
(-=),
(*=),
(//=),
gemv,
gemm,
) where
import System.IO.Unsafe
import Unsafe.Coerce
import BLAS.Internal ( checkMatMatOp, checkMatVecMult, checkMatMatMult )
import Data.Matrix.Dense.Internal
import Data.Vector.Dense.Internal hiding ( unsafeWithElemPtr, unsafeThaw,
unsafeFreeze )
import qualified Data.Vector.Dense.Operations as V
import qualified Data.Vector.Dense.Internal as V
import BLAS.C ( CBLASTrans, colMajor, noTrans, conjTrans )
import qualified BLAS.C as BLAS
import BLAS.Elem ( BLAS1, BLAS2, BLAS3 )
import qualified BLAS.Elem as E
infixl 7 `apply`, `applyMat`, `scale`, `invScale`
infixl 6 `shift`
infixl 1 +=, -=, *=, //=
copyMatrix :: (BLAS1 e) => IOMatrix (m,n) e -> DMatrix t (m,n) e -> IO ()
copyMatrix a b = checkMatMatOp "copyMatrix" (shape a) (shape b) >> unsafeCopyMatrix a b
unsafeCopyMatrix :: (BLAS1 e) => IOMatrix (m,n) e -> DMatrix t (m,n) e -> IO ()
unsafeCopyMatrix = liftV2 (V.unsafeCopyVector)
swapMatrices :: (BLAS1 e) => IOMatrix (m,n) e -> IOMatrix (m,n) e -> IO ()
swapMatrices a b = checkMatMatOp "swapMatrices" (shape a) (shape b) >> unsafeSwapMatrices a b
unsafeSwapMatrices :: (BLAS1 e) => IOMatrix (m,n) e -> IOMatrix (m,n) e -> IO ()
unsafeSwapMatrices = liftV2 (V.unsafeSwapVectors)
getApply :: (BLAS2 e) => DMatrix s (m,n) e -> DVector t n e -> IO (DVector r m e)
getApply = getSApply 1
getSApply :: (BLAS2 e) => e -> DMatrix s (m,n) e -> DVector t n e -> IO (DVector r m e)
getSApply alpha a x = checkMatVecMult (shape a) (V.dim x) >> unsafeGetSApply alpha a x
unsafeGetSApply :: (BLAS2 e) => e -> DMatrix s (m,n) e -> DVector t n e -> IO (DVector r m e)
unsafeGetSApply alpha a x = do
y <- V.newZero (numRows a)
gemv alpha a x 0 y
return (unsafeCoerce y)
getApplyMat :: (BLAS3 e) => DMatrix s (m,k) e -> DMatrix t (k,n) e -> IO (DMatrix r (m,n) e)
getApplyMat = getSApplyMat 1
getSApplyMat :: (BLAS3 e) => e -> DMatrix s (m,k) e -> DMatrix t (k,n) e -> IO (DMatrix r (m,n) e)
getSApplyMat alpha a b = checkMatMatMult (shape a) (shape b) >> unsafeGetSApplyMat alpha a b
unsafeGetSApplyMat :: (BLAS3 e) => e -> DMatrix s (m,k) e -> DMatrix t (k,n) e -> IO (DMatrix r (m,n) e)
unsafeGetSApplyMat alpha a b = do
c <- newZero (numRows a, numCols b)
gemm alpha a b 0 c
return (unsafeCoerce c)
getShifted :: (BLAS1 e) => e -> DMatrix t (m,n) e -> IO (DMatrix r (m,n) e)
getShifted k = unaryOp (shiftBy k)
getScaled :: (BLAS1 e) => e -> DMatrix t (m,n) e -> IO (DMatrix r (m,n) e)
getScaled k = unaryOp (scaleBy k)
getInvScaled :: (BLAS1 e) => e -> DMatrix t (m,n) e -> IO (DMatrix r (m,n) e)
getInvScaled k = unaryOp (invScaleBy k)
getSum :: (BLAS1 e) => e -> DMatrix s (m,n) e -> e -> DMatrix t (m,n) e -> IO (DMatrix r (m,n) e)
getSum alpha a beta b = checkMatMatOp "getSum" (shape a) (shape b)
>> unsafeGetSum alpha a beta b
unsafeGetSum :: (BLAS1 e) => e -> DMatrix s (m,n) e -> e -> DMatrix t (m,n) e -> IO (DMatrix r (m,n) e)
unsafeGetSum alpha a@(H _) beta b = do
s <- unsafeGetSum (E.conj alpha) (herm a) (E.conj beta) (herm b)
return (herm s)
unsafeGetSum alpha a@(DM _ _ _ _ _) beta b = do
s <- getScaled alpha a
axpy beta b (unsafeThaw s)
return (unsafeCoerce s)
getDiff :: (BLAS1 e) => DMatrix s (m,n) e -> DMatrix t (m,n) e -> IO (DMatrix r (m,n) e)
getDiff a b = checkMatMatOp "minus" (shape a) (shape b) >> unsafeGetSum 1 a (1) b
getProduct :: (BLAS2 e) => DMatrix s (m,n) e -> DMatrix t (m,n) e -> IO (DMatrix r (m,n) e)
getProduct = binaryOp "times" (*=)
getRatio :: (BLAS2 e) => DMatrix s (m,n) e -> DMatrix t (m,n) e -> IO (DMatrix r (m,n) e)
getRatio = binaryOp "getRatio" (//=)
doConj :: (BLAS1 e) => IOMatrix (m,n) e -> IO (IOMatrix (m,n) e)
doConj (H a) = do
a' <- doConj a
return (H a')
doConj a@(DM _ _ _ _ _) = do
liftV (\x -> V.doConj x >> return ()) a
return a
scaleBy :: (BLAS1 e) => e -> IOMatrix (m,n) e -> IO ()
scaleBy k = liftV (\x -> V.scaleBy k x >> return ())
shiftBy :: (BLAS1 e) => e -> IOMatrix (m,n) e -> IO ()
shiftBy k = liftV (V.shiftBy k)
invScaleBy :: (BLAS1 e) => e -> IOMatrix (m,n) e -> IO ()
invScaleBy k = liftV (V.invScaleBy k)
axpy :: (BLAS1 e) => e -> DMatrix t (m,n) e -> IOMatrix (m,n) e -> IO ()
axpy alpha = liftV2 (V.axpy alpha)
(+=) :: (BLAS1 e) => IOMatrix (m,n) e -> DMatrix t (m,n) e -> IO ()
(+=) = liftV2 (V.+=)
(-=) :: (BLAS1 e) => IOMatrix (m,n) e -> DMatrix t (m,n) e -> IO ()
(-=) = liftV2 (V.-=)
(*=) :: (BLAS2 e) => IOMatrix (m,n) e -> DMatrix t (m,n) e -> IO ()
(*=) = liftV2 (V.*=)
(//=) :: (BLAS2 e) => IOMatrix (m,n) e -> DMatrix t (m,n) e -> IO ()
(//=) = liftV2 (V.//=)
blasTransOf :: DMatrix t (m,n) e -> CBLASTrans
blasTransOf a =
case (isHerm a) of
False -> noTrans
True -> conjTrans
flipShape :: (Int,Int) -> (Int,Int)
flipShape (m,n) = (n,m)
gemv :: (BLAS2 e) => e -> DMatrix s (m,n) e -> DVector t n e -> e -> IOVector m e -> IO ()
gemv alpha a x beta y
| numRows a == 0 || numCols a == 0 =
return ()
| V.isConj y = do
V.doConj y
gemv alpha a x beta (V.conj y)
V.doConj y
| V.isConj x = do
x' <- V.newCopy (V.unsafeThaw x)
V.doConj x'
gemv alpha a (V.conj x') beta y
| otherwise =
let order = colMajor
transA = blasTransOf a
(m,n) = case (isHerm a) of
False -> shape a
True -> (flipShape . shape) a
ldA = ldaOf a
incX = V.strideOf x
incY = V.strideOf y
in unsafeWithElemPtr a (0,0) $ \pA ->
V.unsafeWithElemPtr x 0 $ \pX ->
V.unsafeWithElemPtr y 0 $ \pY -> do
BLAS.gemv order transA m n alpha pA ldA pX incX beta pY incY
gemm :: (BLAS3 e) => e -> DMatrix s (m,k) e -> DMatrix t (k,n) e -> e -> IOMatrix (m,n) e -> IO ()
gemm alpha a b beta c
| numRows a == 0 || numCols a == 0 || numCols b == 0 = return ()
| isHerm c = gemm (E.conj alpha) (herm b) (herm a) (E.conj beta) (herm c)
| otherwise =
let order = colMajor
transA = blasTransOf a
transB = blasTransOf b
(m,n) = shape c
k = numCols a
ldA = ldaOf a
ldB = ldaOf b
ldC = ldaOf c
in unsafeWithElemPtr a (0,0) $ \pA ->
unsafeWithElemPtr b (0,0) $ \pB ->
unsafeWithElemPtr c (0,0) $ \pC ->
BLAS.gemm order transA transB m n k alpha pA ldA pB ldB beta pC ldC
unaryOp :: (BLAS1 e) => (IOMatrix (m,n) e -> IO ())
-> DMatrix t (m,n) e -> IO (DMatrix r (m,n) e)
unaryOp f a = do
a' <- newCopy a
f (unsafeThaw a')
return (unsafeCoerce a')
binaryOp :: (BLAS1 e) => String -> (IOMatrix (m,n) e -> DMatrix t (m,n) e -> IO ())
-> DMatrix s (m,n) e -> DMatrix t (m,n) e -> IO (DMatrix r (m,n) e)
binaryOp name f a b =
checkMatMatOp name (shape a) (shape b) >> do
a' <- newCopy a
f (unsafeThaw a') b
return (unsafeCoerce a')
apply :: (BLAS2 e) => Matrix (m,n) e -> Vector n e -> Vector m e
apply = sapply 1
sapply :: (BLAS2 e) => e -> Matrix (m,n) e -> Vector n e -> Vector m e
sapply alpha a x = unsafePerformIO $ getSApply alpha a x
sapplyMat :: (BLAS3 e) => e -> Matrix (m,k) e -> Matrix (k,n) e -> Matrix (m,n) e
sapplyMat alpha a b = unsafePerformIO $ getSApplyMat alpha a b
applyMat :: (BLAS3 e) => Matrix (m,k) e -> Matrix (k,n) e -> Matrix (m,n) e
applyMat = sapplyMat 1
scale :: (BLAS1 e) => e -> Matrix (m,n) e -> Matrix (m,n) e
scale k a = unsafePerformIO $ getScaled k a
shift :: (BLAS1 e) => e -> Matrix (m,n) e -> Matrix (m,n) e
shift k a = unsafePerformIO $ getShifted k a
invScale :: (BLAS1 e) => e -> Matrix (m,n) e -> Matrix (m,n) e
invScale k a = unsafePerformIO $ getInvScaled k a
add :: (BLAS1 e) => e -> Matrix (m,n) e -> e -> Matrix (m,n) e -> Matrix (m,n) e
add alpha a beta b = unsafePerformIO $ getSum alpha a beta b
plus :: (BLAS1 e) => Matrix (m,n) e -> Matrix (m,n) e -> Matrix (m,n) e
plus a b = add 1 a 1 b
minus :: (BLAS1 e) => Matrix (m,n) e -> Matrix (m,n) e -> Matrix (m,n) e
minus a b = unsafePerformIO $ getDiff a b
times :: (BLAS2 e) => Matrix (m,n) e -> Matrix (m,n) e -> Matrix (m,n) e
times a b = unsafePerformIO $ getProduct a b
divide :: (BLAS2 e) => Matrix (m,n) e -> Matrix (m,n) e -> Matrix (m,n) e
divide a b = unsafePerformIO $ getRatio a b