module BLAS.Matrix.Mutable (
getRow,
getCol,
getRows,
getCols,
getRows',
getCols',
getApply,
getSApply,
getApplyMat,
getSApplyMat,
doApply,
doSApplyAdd,
doApply_,
doSApply_,
doApplyMat,
doSApplyAddMat,
doApplyMat_,
doSApplyMat_,
MMatrix(..),
unsafeGetApply,
unsafeDoApply,
unsafeDoApply_,
unsafeGetApplyMat,
unsafeDoApplyMat,
unsafeDoApplyMat_,
) where
import Control.Monad( liftM )
import Control.Monad.ST( ST )
import BLAS.Elem
import BLAS.Internal( checkSquare, checkMatVecMult, checkMatVecMultAdd,
checkMatMatMult, checkMatMatMultAdd, checkedRow, checkedCol )
import BLAS.UnsafeIOToM
import BLAS.Matrix.Base
import Data.Vector.Dense.Class
import Data.Matrix.Dense.Internal( Matrix )
import Data.Matrix.Dense.Class.Internal hiding ( BaseMatrix )
class (BaseMatrix a, BLAS1 e, Monad m) => MMatrix a e m where
unsafeGetSApply :: (ReadVector x m, WriteVector y m) =>
e -> a (k,l) e -> x l e -> m (y k e)
unsafeGetSApply alpha a x = do
y <- newVector_ (numRows a)
unsafeDoSApplyAdd alpha a x 0 y
return y
unsafeGetSApplyMat :: (ReadMatrix b x m, WriteMatrix c y m) =>
e -> a (r,s) e -> b (s,t) e -> m (c (r,t) e)
unsafeGetSApplyMat alpha a b = do
c <- newMatrix_ (numRows a, numCols b)
unsafeDoSApplyAddMat alpha a b 0 c
return c
unsafeDoSApplyAdd :: (ReadVector x m, WriteVector y m) =>
e -> a (k,l) e -> x l e -> e -> y k e -> m ()
unsafeDoSApplyAdd alpha a x beta y = do
y' <- unsafeGetSApply alpha a x
scaleBy beta y
unsafeAxpyVector 1 y' y
unsafeDoSApplyAddMat :: (ReadMatrix b x m, WriteMatrix c y m) =>
e -> a (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m ()
unsafeDoSApplyAddMat alpha a b beta c = do
c' <- unsafeGetSApplyMat alpha a b
scaleBy beta c
unsafeAxpyMatrix 1 c' c
unsafeDoSApply_ :: (WriteVector y m) =>
e -> a (n,n) e -> y n e -> m ()
unsafeDoSApply_ alpha a x = do
y <- newVector_ (dim x)
unsafeDoSApplyAdd alpha a x 0 y
unsafeCopyVector x y
unsafeDoSApplyMat_ :: (WriteMatrix b y m) =>
e -> a (k,k) e -> b (k,l) e -> m ()
unsafeDoSApplyMat_ alpha a b = do
c <- newMatrix_ (shape b)
unsafeDoSApplyAddMat alpha a b 0 c
unsafeCopyMatrix b c
unsafeGetRow :: (WriteVector x m) => a (k,l) e -> Int -> m (x l e)
unsafeGetRow a i = do
e <- newBasisVector (numRows a) i
liftM conj $ unsafeGetApply (herm a) e
unsafeGetCol :: (WriteVector x m) => a (k,l) e -> Int -> m (x k e)
unsafeGetCol a j = do
e <- newBasisVector (numCols a) j
unsafeGetApply a e
getRow :: (MMatrix a e m, WriteVector x m) => a (k,l) e -> Int -> m (x l e)
getRow a = checkedRow (shape a) (unsafeGetRow a)
getCol :: (MMatrix a e m, WriteVector x m) => a (k,l) e -> Int -> m (x k e)
getCol a = checkedCol (shape a) (unsafeGetCol a)
getRows :: (MMatrix a e m, WriteVector x m) =>
a (k,l) e -> m [x l e]
getRows = unsafeInterleaveM . getRows'
getCols :: (MMatrix a e m, WriteVector x m) =>
a (k,l) e -> m [x k e]
getCols = unsafeInterleaveM . getCols'
getRows' :: (MMatrix a e m, WriteVector x m) => a (k,l) e -> m [x l e]
getRows' a = mapM (unsafeGetRow a) [0..numRows a 1]
getCols' :: (MMatrix a e m, WriteVector x m) => a (k,l) e -> m [x k e]
getCols' a = mapM (unsafeGetCol a) [0..numCols a 1]
getSApply :: (MMatrix a e m, ReadVector x m, WriteVector y m) =>
e -> a (k,l) e -> x l e -> m (y k e)
getSApply k a x =
checkMatVecMult (shape a) (dim x) $
unsafeGetSApply k a x
getSApplyMat :: (MMatrix a e m, ReadMatrix b x m, WriteMatrix c y m) =>
e -> a (r,s) e -> b (s,t) e -> m (c (r,t) e)
getSApplyMat k a b =
checkMatMatMult (shape a) (shape b) $
unsafeGetSApplyMat k a b
doSApplyAdd :: (MMatrix a e m, ReadVector x m, WriteVector y m) =>
e -> a (k,l) e -> x l e -> e -> y k e -> m ()
doSApplyAdd alpha a x beta y =
checkMatVecMultAdd (shape a) (dim x) (dim y) $
unsafeDoSApplyAdd alpha a x beta y
doSApplyAddMat :: (MMatrix a e m, ReadMatrix b x m, WriteMatrix c y m) =>
e -> a (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m ()
doSApplyAddMat alpha a b beta c =
checkMatMatMultAdd (shape a) (shape b) (shape c)
unsafeDoSApplyAddMat alpha a b beta c
getApply :: (MMatrix a e m, ReadVector x m, WriteVector y m) =>
a (k,l) e -> x l e -> m (y k e)
getApply a x =
checkMatVecMult (shape a) (dim x) $ do
unsafeGetApply a x
getApplyMat :: (MMatrix a e m, ReadMatrix b x m, WriteMatrix c y m) =>
a (r,s) e -> b (s,t) e -> m (c (r,t) e)
getApplyMat a b =
checkMatMatMult (shape a) (shape b) $
unsafeGetApplyMat a b
doSApply_ :: (MMatrix a e m, WriteVector y m) =>
e -> a (n,n) e -> y n e -> m ()
doSApply_ alpha a x =
checkSquare (shape a) $
checkMatVecMult (shape a) (dim x) $
unsafeDoSApply_ alpha a x
doSApplyMat_ :: (MMatrix a e m, WriteMatrix b y m) =>
e -> a (s,s) e -> b (s,t) e -> m ()
doSApplyMat_ alpha a b =
checkSquare (shape a) $
checkMatMatMult (shape a) (shape b) $
unsafeDoSApplyMat_ alpha a b
unsafeGetApply :: (MMatrix a e m, ReadVector x m, WriteVector y m) =>
a (k,l) e -> x l e -> m (y k e)
unsafeGetApply = unsafeGetSApply 1
unsafeGetApplyMat :: (MMatrix a e m, ReadMatrix b x m, WriteMatrix c y m) =>
a (r,s) e -> b (s,t) e -> m (c (r,t) e)
unsafeGetApplyMat = unsafeGetSApplyMat 1
doApply :: (MMatrix a e m, ReadVector x m, WriteVector y m) =>
a (k,l) e -> x l e -> y k e -> m ()
doApply a x y =
checkMatVecMultAdd (numRows a, numCols a) (dim x) (dim y) $
unsafeDoApply a x y
doApplyMat :: (MMatrix a e m, ReadMatrix b x m, WriteMatrix c y m) =>
a (r,s) e -> b (s,t) e -> c (r,t) e -> m ()
doApplyMat a b c =
checkMatMatMultAdd (shape a) (shape b) (shape c) $
unsafeDoApplyMat a b c
unsafeDoApply :: (MMatrix a e m, ReadVector x m, WriteVector y m) =>
a (k,l) e -> x l e -> y k e -> m ()
unsafeDoApply a x y = unsafeDoSApplyAdd 1 a x 0 y
unsafeDoApplyMat :: (MMatrix a e m, ReadMatrix b x m, WriteMatrix c y m) =>
a (r,s) e -> b (s,t) e -> c (r,t) e -> m ()
unsafeDoApplyMat a b c = unsafeDoSApplyAddMat 1 a b 0 c
doApply_ :: (MMatrix a e m, WriteVector y m) =>
a (n,n) e -> y n e -> m ()
doApply_ a x =
checkSquare (shape a) $
checkMatVecMult (shape a) (dim x) $
unsafeDoApply_ a x
doApplyMat_ :: (MMatrix a e m, WriteMatrix b y m) =>
a (s,s) e -> b (s,t) e -> m ()
doApplyMat_ a b =
checkSquare (shape a) $
checkMatMatMult (shape a) (shape b) $
unsafeDoApplyMat_ a b
unsafeDoApply_ :: (MMatrix a e m, WriteVector y m) =>
a (n,n) e -> y n e -> m ()
unsafeDoApply_ a x =
unsafeDoSApply_ 1 a x
unsafeDoApplyMat_ :: (MMatrix a e m, WriteMatrix b y m) =>
a (s,s) e -> b (s,t) e -> m ()
unsafeDoApplyMat_ a b =
unsafeDoSApplyMat_ 1 a b
instance (BLAS3 e) => MMatrix IOMatrix e IO where
unsafeDoSApplyAdd = gemv
unsafeDoSApplyAddMat = gemm
unsafeGetRow = unsafeGetRowMatrix
unsafeGetCol = unsafeGetColMatrix
instance (BLAS3 e) => MMatrix (STMatrix s) e (ST s) where
unsafeDoSApplyAdd = gemv
unsafeDoSApplyAddMat = gemm
unsafeGetRow = unsafeGetRowMatrix
unsafeGetCol = unsafeGetColMatrix
instance (BLAS3 e, UnsafeIOToM m) => MMatrix Matrix e m where
unsafeDoSApplyAdd = gemv
unsafeDoSApplyAddMat = gemm
unsafeGetRow = unsafeGetRowMatrix
unsafeGetCol = unsafeGetColMatrix