module Data.Matrix.Dense.Base
where
import Control.Monad
import Control.Monad.ST
import Data.AEq
import Foreign
import System.IO.Unsafe
import Unsafe.Coerce
import BLAS.Internal( checkBinaryOp, checkedSubmatrix, checkedDiag,
checkedRow, checkedCol, inlinePerformIO )
import Data.Elem.BLAS( Elem, BLAS1, BLAS3, conjugate )
import qualified Data.Elem.BLAS.Level1 as BLAS
import qualified Data.Elem.BLAS.Level2 as BLAS
import qualified Data.Elem.BLAS.Level3 as BLAS
import Data.Tensor.Class
import Data.Tensor.Class.ITensor
import Data.Tensor.Class.MTensor
import Data.Matrix.Class
import Data.Matrix.Herm
import Data.Matrix.TriBase
import Data.Vector.Dense.IOBase
import Data.Vector.Dense.Base
import Data.Matrix.Dense.IOBase
newtype Matrix np e = Matrix (IOMatrix np e)
freezeIOMatrix :: (BLAS1 e) => IOMatrix np e -> IO (Matrix np e)
freezeIOMatrix x = do
y <- newCopyIOMatrix x
return (Matrix y)
thawIOMatrix :: (BLAS1 e) => Matrix np e -> IO (IOMatrix np e)
thawIOMatrix (Matrix x) =
newCopyIOMatrix x
unsafeFreezeIOMatrix :: IOMatrix np e -> IO (Matrix np e)
unsafeFreezeIOMatrix = return . Matrix
unsafeThawIOMatrix :: Matrix np e -> IO (IOMatrix np e)
unsafeThawIOMatrix (Matrix x) = return x
class (HasVectorView a, Elem e, MatrixShaped a
, BaseVector (VectorView a) e) => BaseMatrix a e where
ldaMatrix :: a (n,p) e -> Int
isHermMatrix :: a (n,p) e -> Bool
coerceMatrix :: a np e -> a np' e
coerceMatrix = unsafeCoerce
unsafeSubmatrixView :: a (n,p) e -> (Int,Int) -> (Int,Int) -> a (n',p') e
unsafeDiagView :: a (n,p) e -> Int -> VectorView a k e
unsafeRowView :: a (n,p) e -> Int -> VectorView a p e
unsafeColView :: a (n,p) e -> Int -> VectorView a n e
maybeViewMatrixAsVector :: a (n,p) e -> Maybe (VectorView a np e)
maybeViewVectorAsRow :: VectorView a p e -> Maybe (a (one,p) e)
maybeViewVectorAsCol :: VectorView a n e -> Maybe (a (n,one) e)
maybeViewVectorAsMatrix :: (Int,Int)
-> VectorView a np e
-> Maybe (a (n,p) e)
unsafeMatrixToIOMatrix :: a (n,p) e -> IOMatrix (n,p) e
unsafeIOMatrixToMatrix :: IOMatrix (n,p) e -> a (n,p) e
class (BaseMatrix a e, BLAS3 e, ReadTensor a (Int,Int) e m
, MMatrix a e m, MMatrix (Herm a) e m, MMatrix (Tri a) e m
, MSolve (Tri a) e m
, ReadVector (VectorView a) e m) => ReadMatrix a e m where
unsafePerformIOWithMatrix :: a (n,p) e -> (IOMatrix (n,p) e -> IO r) -> m r
freezeMatrix :: a (n,p) e -> m (Matrix (n,p) e)
unsafeFreezeMatrix :: a (n,p) e -> m (Matrix (n,p) e)
class (ReadMatrix a e m, WriteTensor a (Int,Int) e m
, WriteVector (VectorView a) e m) =>
WriteMatrix a e m where
newMatrix_ :: (Int,Int) -> m (a (n,p) e)
unsafeConvertIOMatrix :: IO (IOMatrix (n,p) e) -> m (a (n,p) e)
thawMatrix :: Matrix (n,p) e -> m (a (n,p) e)
unsafeThawMatrix :: Matrix (n,p) e -> m (a (n,p) e)
newMatrix :: (WriteMatrix a e m) =>
(Int,Int) -> [((Int,Int), e)] -> m (a (n,p) e)
newMatrix (m,n) ies = unsafeConvertIOMatrix $ do
a <- newZeroMatrix (m,n)
withIOMatrix a $ \p ->
forM_ ies $ \((i,j),e) -> do
when (i < 0 || i >= m || j < 0 || j >= n) $ fail $
"Index `" ++ show (i,j) ++
"' is invalid for a matrix with shape `" ++ show (m,n) ++
"'"
pokeElemOff p (i+j*m) e
return a
unsafeNewMatrix :: (WriteMatrix a e m) =>
(Int,Int) -> [((Int,Int), e)] -> m (a (n,p) e)
unsafeNewMatrix (m,n) ies = unsafeConvertIOMatrix $ do
a <- newZeroMatrix (m,n)
withIOMatrix a $ \p ->
forM_ ies $ \((i,j),e) -> do
pokeElemOff p (i+j*m) e
return a
newListMatrix :: (WriteMatrix a e m) => (Int,Int) -> [e] -> m (a (n,p) e)
newListMatrix (m,n) es = do
a <- newZeroMatrix (m,n)
unsafePerformIOWithMatrix a $ \a' -> do
withIOMatrix a' $ flip pokeArray (take (m*n) es)
return a
newColsMatrix :: (ReadVector x e m, WriteMatrix a e m) =>
(Int,Int) -> [x n e] -> m (a (n,p) e)
newColsMatrix (m,n) cs = do
a <- newZeroMatrix (m,n)
forM_ (zip [0..(n1)] cs) $ \(j,c) ->
unsafeCopyVector (unsafeColView a j) c
return a
newRowsMatrix :: (ReadVector x e m, WriteMatrix a e m) =>
(Int,Int) -> [x p e] -> m (a (n,p) e)
newRowsMatrix (m,n) rs = do
a <- newZeroMatrix (m,n)
forM_ (zip [0..(m1)] rs) $ \(i,r) ->
unsafeCopyVector (unsafeRowView a i) r
return a
newColMatrix :: (ReadVector x e m, WriteMatrix a e m) =>
x n e -> m (a (n,one) e)
newColMatrix x = newColsMatrix (dim x,1) [x]
newRowMatrix :: (ReadVector x e m, WriteMatrix a e m) =>
x p e -> m (a (one,p) e)
newRowMatrix x = newRowsMatrix (1,dim x) [x]
newZeroMatrix :: (WriteMatrix a e m) => (Int,Int) -> m (a (n,p) e)
newZeroMatrix mn = do
a <- newMatrix_ mn
setZeroMatrix a
return a
setZeroMatrix :: (WriteMatrix a e m) => a (n,p) e -> m ()
setZeroMatrix a =
unsafePerformIOWithMatrix a $ setZeroIOMatrix
newConstantMatrix :: (WriteMatrix a e m) => (Int,Int) -> e -> m (a (n,p) e)
newConstantMatrix mn e = do
a <- newMatrix_ mn
setConstantMatrix e a
return a
setConstantMatrix :: (WriteMatrix a e m) => e -> a (n,p) e -> m ()
setConstantMatrix e a =
unsafePerformIOWithMatrix a $ setConstantIOMatrix e
newIdentityMatrix :: (WriteMatrix a e m) => (Int,Int) -> m (a (n,p) e)
newIdentityMatrix mn = do
a <- newMatrix_ mn
setIdentityMatrix a
return a
setIdentityMatrix :: (WriteMatrix a e m) => a (n,p) e -> m ()
setIdentityMatrix a = do
setZeroMatrix a
setConstantVector 1 (unsafeDiagView a 0)
newCopyMatrix :: (ReadMatrix a e m, WriteMatrix b e m) =>
a (n,p) e -> m (b (n,p) e)
newCopyMatrix a | isHermMatrix a = liftM herm $ newCopyMatrix (herm a)
| otherwise = do
b <- newMatrix_ (shape a)
unsafeCopyMatrix b a
return b
newCopyMatrix' :: (ReadMatrix a e m, WriteMatrix b e m) =>
a (n,p) e -> m (b (n,p) e)
newCopyMatrix' a = do
b <- newMatrix_ (shape a)
unsafeCopyMatrix b a
return b
copyMatrix :: (WriteMatrix b e m, ReadMatrix a e m) =>
b (n,p) e -> a (n,p) e -> m ()
copyMatrix b a = checkBinaryOp (shape b) (shape a) $ unsafeCopyMatrix b a
unsafeCopyMatrix :: (WriteMatrix b e m, ReadMatrix a e m) =>
b (n,p) e -> a (n,p) e -> m ()
unsafeCopyMatrix = liftMatrix2 unsafeCopyVector
swapMatrix :: (WriteMatrix a e m, WriteMatrix b e m) =>
a (n,p) e -> b (n,p) e -> m ()
swapMatrix a b = checkBinaryOp (shape b) (shape a) $ unsafeSwapMatrix a b
unsafeSwapMatrix :: (WriteMatrix a e m, WriteMatrix b e m) =>
a (n,p) e -> b (n,p) e -> m ()
unsafeSwapMatrix = liftMatrix2 unsafeSwapVector
swapRows :: (WriteMatrix a e m) => a (n,p) e -> Int -> Int -> m ()
swapRows a i j =
when (i /= j) $ unsafeSwapVector (rowView a i) (rowView a j)
swapCols :: (WriteMatrix a e m) => a (n,p) e -> Int -> Int -> m ()
swapCols a i j =
when (i /= j) $ unsafeSwapVector (colView a i) (colView a j)
unsafeSwapRows :: (WriteMatrix a e m) => a (n,p) e -> Int -> Int -> m ()
unsafeSwapRows a i j =
when (i /= j) $ unsafeSwapVector (unsafeRowView a i) (unsafeRowView a j)
unsafeSwapCols :: (WriteMatrix a e m) => a (n,p) e -> Int -> Int -> m ()
unsafeSwapCols a i j =
when (i /= j) $ unsafeSwapVector (unsafeColView a i) (unsafeColView a j)
submatrixView :: (BaseMatrix a e) => a (n,p) e -> (Int,Int) -> (Int,Int) -> a (n',p') e
submatrixView a = checkedSubmatrix (shape a) (unsafeSubmatrixView a)
splitRowsAt :: (BaseMatrix a e) =>
Int -> a (n,p) e -> (a (n1,p) e, a (n2,p) e)
splitRowsAt m1 a = ( submatrixView a (0,0) (m1,n)
, submatrixView a (m1,0) (m2,n)
)
where
(m,n) = shape a
m2 = m m1
unsafeSplitRowsAt :: (BaseMatrix a e) =>
Int -> a (n,p) e -> (a (n1,p) e, a (n2,p) e)
unsafeSplitRowsAt m1 a = ( unsafeSubmatrixView a (0,0) (m1,n)
, unsafeSubmatrixView a (m1,0) (m2,n)
)
where
(m,n) = shape a
m2 = m m1
splitColsAt :: (BaseMatrix a e) =>
Int -> a (n,p) e -> (a (n,p1) e, a (n,p2) e)
splitColsAt n1 a = ( submatrixView a (0,0) (m,n1)
, submatrixView a (0,n1) (m,n2)
)
where
(m,n) = shape a
n2 = n n1
unsafeSplitColsAt :: (BaseMatrix a e) =>
Int -> a (n,p) e -> (a (n,p1) e, a (n,p2) e)
unsafeSplitColsAt n1 a = ( unsafeSubmatrixView a (0,0) (m,n1)
, unsafeSubmatrixView a (0,n1) (m,n2)
)
where
(m,n) = shape a
n2 = n n1
rowViews :: (BaseMatrix a e) => a (n,p) e -> [VectorView a p e]
rowViews a = [ unsafeRowView a i | i <- [0..numRows a 1] ]
colViews :: (BaseMatrix a e) => a (n,p) e -> [VectorView a n e]
colViews a = [ unsafeColView a j | j <- [0..numCols a 1] ]
rowView :: (BaseMatrix a e) => a (n,p) e -> Int -> VectorView a p e
rowView a = checkedRow (shape a) (unsafeRowView a)
unsafeGetRowMatrix :: (ReadMatrix a e m, WriteVector y e m) =>
a (n,p) e -> Int -> m (y p e)
unsafeGetRowMatrix a i = newCopyVector (unsafeRowView a i)
colView :: (BaseMatrix a e) => a (n,p) e -> Int -> VectorView a n e
colView a = checkedCol (shape a) (unsafeColView a)
unsafeGetColMatrix :: (ReadMatrix a e m, WriteVector y e m) =>
a (n,p) e -> Int -> m (y n e)
unsafeGetColMatrix a j = newCopyVector (unsafeColView a j)
diagView :: (BaseMatrix a e) => a (n,p) e -> Int -> VectorView a k e
diagView a = checkedDiag (shape a) (unsafeDiagView a)
getDiag :: (ReadMatrix a e m, WriteVector y e m) =>
a (n,p) e -> Int -> m (y k e)
getDiag a = checkedDiag (shape a) (unsafeGetDiag a)
unsafeGetDiag :: (ReadMatrix a e m, WriteVector y e m) =>
a (n,p) e -> Int -> m (y k e)
unsafeGetDiag a i = newCopyVector (unsafeDiagView a i)
doConjMatrix :: (WriteMatrix a e m) => a (n,p) e -> m ()
doConjMatrix = liftMatrix doConjVector
getConjMatrix :: (ReadMatrix a e m, WriteMatrix b e m) =>
a (n,p) e -> m (b (n,p) e)
getConjMatrix = getUnaryMatrixOp doConjMatrix
scaleByMatrix :: (WriteMatrix a e m) => e -> a (n,p) e -> m ()
scaleByMatrix k = liftMatrix (scaleByVector k)
getScaledMatrix :: (ReadMatrix a e m, WriteMatrix b e m) =>
e -> a (n,p) e -> m (b (n,p) e)
getScaledMatrix e = getUnaryMatrixOp (scaleByMatrix e)
shiftByMatrix :: (WriteMatrix a e m) => e -> a (n,p) e -> m ()
shiftByMatrix k = liftMatrix (shiftByVector k)
getShiftedMatrix :: (ReadMatrix a e m, WriteMatrix b e m) =>
e -> a (n,p) e -> m (b (n,p) e)
getShiftedMatrix e = getUnaryMatrixOp (shiftByMatrix e)
addMatrix :: (WriteMatrix b e m, ReadMatrix a e m) =>
b (n,p) e -> a (n,p) e -> m ()
addMatrix b a =
checkBinaryOp (shape b) (shape a) $ unsafeAddMatrix b a
unsafeAddMatrix :: (WriteMatrix b e m, ReadMatrix a e m) =>
b (n,p) e -> a (n,p) e -> m ()
unsafeAddMatrix b a = unsafeAxpyMatrix 1 a b
getAddMatrix ::
(ReadMatrix a e m, ReadMatrix b e m, WriteMatrix c e m) =>
a (n,p) e -> b (n,p) e -> m (c (n,p) e)
getAddMatrix = checkMatrixOp2 unsafeGetAddMatrix
unsafeGetAddMatrix ::
(ReadMatrix a e m, ReadMatrix b e m, WriteMatrix c e m) =>
a (n,p) e -> b (n,p) e -> m (c (n,p) e)
unsafeGetAddMatrix = unsafeGetBinaryMatrixOp unsafeAddMatrix
subMatrix :: (WriteMatrix b e m, ReadMatrix a e m) =>
b (n,p) e -> a (n,p) e -> m ()
subMatrix b a =
checkBinaryOp (shape b) (shape a) $ unsafeSubMatrix b a
unsafeSubMatrix :: (WriteMatrix b e m, ReadMatrix a e m) =>
b (n,p) e -> a (n,p) e -> m ()
unsafeSubMatrix b a = unsafeAxpyMatrix (1) a b
getSubMatrix ::
(ReadMatrix a e m, ReadMatrix b e m, WriteMatrix c e m) =>
a (n,p) e -> b (n,p) e -> m (c (n,p) e)
getSubMatrix = checkMatrixOp2 unsafeGetSubMatrix
unsafeGetSubMatrix ::
(ReadMatrix a e m, ReadMatrix b e m, WriteMatrix c e m) =>
a (n,p) e -> b (n,p) e -> m (c (n,p) e)
unsafeGetSubMatrix = unsafeGetBinaryMatrixOp unsafeSubMatrix
axpyMatrix :: (ReadMatrix a e m, WriteMatrix b e m) =>
e -> a (n,p) e -> b (n,p) e -> m ()
axpyMatrix alpha x y =
checkBinaryOp (shape x) (shape y) $ unsafeAxpyMatrix alpha x y
unsafeAxpyMatrix :: (ReadMatrix a e m, WriteMatrix b e m) =>
e -> a (n,p) e -> b (n,p) e -> m ()
unsafeAxpyMatrix alpha = liftMatrix2 (unsafeAxpyVector alpha)
mulMatrix :: (WriteMatrix b e m, ReadMatrix a e m) =>
b (n,p) e -> a (n,p) e -> m ()
mulMatrix b a =
checkBinaryOp (shape b) (shape a) $ unsafeMulMatrix b a
unsafeMulMatrix :: (WriteMatrix b e m, ReadMatrix a e m) =>
b (n,p) e -> a (n,p) e -> m ()
unsafeMulMatrix = liftMatrix2 unsafeMulVector
getMulMatrix ::
(ReadMatrix a e m, ReadMatrix b e m, WriteMatrix c e m) =>
a (n,p) e -> b (n,p) e -> m (c (n,p) e)
getMulMatrix = checkMatrixOp2 unsafeGetMulMatrix
unsafeGetMulMatrix ::
(ReadMatrix a e m, ReadMatrix b e m, WriteMatrix c e m) =>
a (n,p) e -> b (n,p) e -> m (c (n,p) e)
unsafeGetMulMatrix = unsafeGetBinaryMatrixOp unsafeMulMatrix
divMatrix :: (WriteMatrix b e m, ReadMatrix a e m) =>
b (n,p) e -> a (n,p) e -> m ()
divMatrix b a =
checkBinaryOp (shape b) (shape a) $ unsafeDivMatrix b a
unsafeDivMatrix :: (WriteMatrix b e m, ReadMatrix a e m) =>
b (n,p) e -> a (n,p) e -> m ()
unsafeDivMatrix = liftMatrix2 unsafeDivVector
getDivMatrix ::
(ReadMatrix a e m, ReadMatrix b e m, WriteMatrix c e m) =>
a (n,p) e -> b (n,p) e -> m (c (n,p) e)
getDivMatrix = checkMatrixOp2 unsafeGetDivMatrix
unsafeGetDivMatrix ::
(ReadMatrix a e m, ReadMatrix b e m, WriteMatrix c e m) =>
a (n,p) e -> b (n,p) e -> m (c (n,p) e)
unsafeGetDivMatrix = unsafeGetBinaryMatrixOp unsafeDivMatrix
class (MatrixShaped a, BLAS3 e, Monad m) => MMatrix a e m where
unsafeGetSApply :: (ReadVector x e m, WriteVector y e m) =>
e -> a (k,l) e -> x l e -> m (y k e)
unsafeGetSApply alpha a x = do
y <- newVector_ (numRows a)
unsafeDoSApplyAdd alpha a x 0 y
return y
unsafeGetSApplyMat :: (ReadMatrix b e m, WriteMatrix c e m) =>
e -> a (r,s) e -> b (s,t) e -> m (c (r,t) e)
unsafeGetSApplyMat alpha a b = do
c <- newMatrix_ (numRows a, numCols b)
unsafeDoSApplyAddMat alpha a b 0 c
return c
unsafeDoSApplyAdd :: (ReadVector x e m, WriteVector y e m) =>
e -> a (k,l) e -> x l e -> e -> y k e -> m ()
unsafeDoSApplyAdd alpha a x beta (y :: y k e) = do
(y' :: y k e) <- unsafeGetSApply alpha a x
scaleByVector beta y
unsafeAxpyVector 1 y' y
unsafeDoSApplyAddMat :: (ReadMatrix b e m, WriteMatrix c e m) =>
e -> a (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m ()
unsafeDoSApplyAddMat alpha a b beta (c :: c (r,t) e) = do
(c' :: c (r,t) e) <- unsafeGetSApplyMat alpha a b
scaleByMatrix beta c
unsafeAxpyMatrix 1 c' c
unsafeDoSApply_ :: (WriteVector y e m) =>
e -> a (n,n) e -> y n e -> m ()
unsafeDoSApply_ alpha a (x :: y n e) = do
(y :: y n e) <- newVector_ (dim x)
unsafeDoSApplyAdd alpha a x 0 y
unsafeCopyVector x y
unsafeDoSApplyMat_ :: (WriteMatrix b e m) =>
e -> a (k,k) e -> b (k,l) e -> m ()
unsafeDoSApplyMat_ alpha a (b :: b (k,l) e) = do
(c :: b (k,l) e) <- newMatrix_ (shape b)
unsafeDoSApplyAddMat alpha a b 0 c
unsafeCopyMatrix b c
unsafeGetRow :: (WriteVector x e m) => a (k,l) e -> Int -> m (x l e)
unsafeGetRow (a :: a (k,l) e) i =
let unsafeGetRowHelp :: (WriteVector x e m) => x l e -> m (x l e)
unsafeGetRowHelp (_ :: x l e) = do
(e :: x k e) <- newBasisVector (numRows a) i
liftM conj $ unsafeGetSApply 1 (herm a) e
in unsafeGetRowHelp undefined
unsafeGetCol :: (WriteVector x e m) => a (k,l) e -> Int -> m (x k e)
unsafeGetCol (a :: a (k,l) e) j =
let unsafeGetColHelp :: (WriteVector x e m) => x k e -> m (x k e)
unsafeGetColHelp (_ :: x k e) = do
(e :: x l e) <- newBasisVector (numCols a) j
unsafeGetSApply 1 a e
in unsafeGetColHelp undefined
getRows :: (WriteVector x e m) =>
a (k,l) e -> m [x l e]
getCols :: (WriteVector x e m) =>
a (k,l) e -> m [x k e]
getColsM :: (MMatrix a e m, WriteVector x e m)
=> (forall b. m b -> m b)
-> a (k,l) e -> m [x k e]
getColsM unsafeInterleaveM a =
let n = numCols a
go j | j == n = return []
| otherwise = unsafeInterleaveM $ do
c <- unsafeGetCol a j
cs <- go (j+1)
return (c:cs)
in go 0
getColsIO :: (MMatrix a e IO, WriteVector x e IO)
=> a (k,l) e -> IO [x k e]
getColsIO = getColsM unsafeInterleaveIO
getColsST :: (MMatrix a e (ST s), WriteVector x e (ST s))
=> a (k,l) e -> ST s [x k e]
getColsST = getColsM unsafeInterleaveST
getRowsM :: (MMatrix a e m, WriteVector x e m)
=> (forall b. m b -> m b)
-> a (k,l) e -> m [x l e]
getRowsM unsafeInterleaveM a =
let m = numRows a
go i | i == m = return []
| otherwise = unsafeInterleaveM $ do
r <- unsafeGetRow a i
rs <- go (i+1)
return (r:rs)
in go 0
getRowsIO :: (MMatrix a e IO, WriteVector x e IO)
=> a (k,l) e -> IO [x l e]
getRowsIO = getRowsM unsafeInterleaveIO
getRowsST :: (MMatrix a e (ST s), WriteVector x e (ST s))
=> a (k,l) e -> ST s [x l e]
getRowsST = getRowsM unsafeInterleaveST
gemv :: (ReadMatrix a e m, ReadVector x e m, WriteVector y e m) =>
e -> a (k,l) e -> x l e -> e -> y k e -> m ()
gemv alpha a x beta y
| numRows a == 0 || numCols a == 0 =
scaleByVector beta y
| isConj y && (isConj x || stride x == 1) =
let transA = if isConj x then NoTrans else ConjTrans
transB = transMatrix (herm a)
m = 1
n = dim y
k = dim x
ldA = stride x
ldB = ldaMatrix a
ldC = stride y
alpha' = conjugate alpha
beta' = conjugate beta
x' = unsafeVectorToIOVector x
y' = unsafeVectorToIOVector y
in
withMatrixPtr a $ \pB ->
withIOVector x' $ \pA ->
withIOVector y' $ \pC ->
BLAS.gemm transA transB m n k alpha' pA ldA pB ldB beta' pC ldC
| (isConj y && otherwise) || isConj x = do
doConjVector y
gemv alpha a x beta (conj y)
doConjVector y
| otherwise =
let transA = transMatrix a
(m,n) = case (isHermMatrix a) of
False -> shape a
True -> (flipShape . shape) a
ldA = ldaMatrix a
incX = stride x
incY = stride y
x' = unsafeVectorToIOVector x
y' = unsafeVectorToIOVector y
in
withMatrixPtr a $ \pA ->
withIOVector x' $ \pX ->
withIOVector y' $ \pY -> do
BLAS.gemv transA m n alpha pA ldA pX incX beta pY incY
where
withMatrixPtr d f = unsafePerformIOWithMatrix d $ flip withIOMatrix f
gemm :: (ReadMatrix a e m, ReadMatrix b e m, WriteMatrix c e m) =>
e -> a (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m ()
gemm alpha a b beta c
| numRows a == 0 || numCols a == 0 || numCols b == 0 =
scaleByMatrix beta c
| isHermMatrix c = gemm (conjugate alpha) (herm b) (herm a) (conjugate beta) (herm c)
| otherwise =
let transA = transMatrix a
transB = transMatrix b
(m,n) = shape c
k = numCols a
ldA = ldaMatrix a
ldB = ldaMatrix b
ldC = ldaMatrix c
in
withMatrixPtr a $ \pA ->
withIOMatrix (unsafeMatrixToIOMatrix b) $ \pB ->
withIOMatrix (unsafeMatrixToIOMatrix c) $ \pC ->
BLAS.gemm transA transB m n k alpha pA ldA pB ldB beta pC ldC
where
withMatrixPtr d f = unsafePerformIOWithMatrix d $ flip withIOMatrix f
hemv :: (ReadMatrix a e m, ReadVector x e m, WriteVector y e m) =>
e -> Herm a (k,k) e -> x k e -> e -> y k e -> m ()
hemv alpha h (x :: x k e) beta (y :: y k e)
| numRows h == 0 =
return ()
| isConj y = do
doConjVector y
hemv alpha h x beta (conj y)
doConjVector y
| isConj x = do
(x' :: y k e) <- newCopyVector' x
hemv alpha h x' beta y
| otherwise =
let (u,a) = hermToBase h
n = numCols a
u' = case isHermMatrix a of
True -> flipUpLo u
False -> u
uploA = u'
ldA = ldaMatrix a
incX = stride x
incY = stride y
x' = unsafeVectorToIOVector x
y' = unsafeVectorToIOVector y
in
withMatrixPtr a $ \pA ->
withIOVector x' $ \pX ->
withIOVector y' $ \pY ->
BLAS.hemv uploA n alpha pA ldA pX incX beta pY incY
where
withMatrixPtr d f = unsafePerformIOWithMatrix d $ flip withIOMatrix f
hemm :: (ReadMatrix a e m, ReadMatrix b e m, WriteMatrix c e m) =>
e -> Herm a (k,k) e -> b (k,l) e -> e -> c (k,l) e -> m ()
hemm alpha h b beta c
| numRows b == 0 || numCols b == 0 || numCols c == 0 = return ()
| (isHermMatrix a) /= (isHermMatrix c) || (isHermMatrix a) /= (isHermMatrix b) =
zipWithM_ (\x y -> hemv alpha h x beta y) (colViews b) (colViews c)
| otherwise =
let (m,n) = shape c
(side,u',m',n')
= if isHermMatrix a
then (RightSide, flipUpLo u, n, m)
else (LeftSide, u, m, n)
uploA = u'
ldA = ldaMatrix a
ldB = ldaMatrix b
ldC = ldaMatrix c
in
withMatrixPtr a $ \pA ->
withIOMatrix (unsafeMatrixToIOMatrix b) $ \pB ->
withIOMatrix (unsafeMatrixToIOMatrix c) $ \pC ->
BLAS.hemm side uploA m' n' alpha pA ldA pB ldB beta pC ldC
where
withMatrixPtr d f = unsafePerformIOWithMatrix d $ flip withIOMatrix f
(u,a) = hermToBase h
hemv' :: (ReadMatrix a e m, ReadVector x e m, WriteVector y e m) =>
e -> Herm a (r,s) e -> x s e -> e -> y r e -> m ()
hemv' alpha a x beta y =
hemv alpha (coerceHerm a) x beta (coerceVector y)
hemm' :: (ReadMatrix a e m, ReadMatrix b e m, WriteMatrix c e m) =>
e -> Herm a (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m ()
hemm' alpha a b beta c =
hemm alpha (coerceHerm a) b beta (coerceMatrix c)
unsafeDoSApplyAddTriMatrix :: (ReadMatrix a e m, MMatrix a e m,
ReadVector x e m, WriteVector y e m) =>
e -> Tri a (k,l) e -> x l e -> e -> y k e -> m ()
unsafeDoSApplyAddTriMatrix alpha t x beta (y :: y k e) =
if beta == 0
then unsafeDoSApplyTriMatrix alpha t x y
else do
(y' :: y k e) <- newCopyVector y
unsafeDoSApplyTriMatrix alpha t x y'
scaleByVector beta y
unsafeAxpyVector 1 y' y
unsafeDoSApplyAddMatTriMatrix :: (ReadMatrix a e m,
ReadMatrix b e m, WriteMatrix c e m) =>
e -> Tri a (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m ()
unsafeDoSApplyAddMatTriMatrix alpha t b beta (c :: c (r,t) e) =
if beta == 0
then unsafeDoSApplyMatTriMatrix alpha t b c
else do
(c' :: c (r,t) e) <- newCopyMatrix c
unsafeDoSApplyMatTriMatrix alpha t b c'
scaleByMatrix beta c
unsafeAxpyMatrix 1 c' c
unsafeDoSApplyTriMatrix :: (ReadMatrix a e m, MMatrix a e m,
ReadVector x e m, WriteVector y e m) =>
e -> Tri a (k,l) e -> x l e -> y k e -> m ()
unsafeDoSApplyTriMatrix alpha t x y =
case (u, toLower d a, toUpper d a) of
(Lower,Left t',_) -> do
unsafeCopyVector y (coerceVector x)
trmv alpha t' y
(Lower,Right (t',r),_) -> do
let y1 = unsafeSubvectorView y 0 (numRows t')
y2 = unsafeSubvectorView y (numRows t') (numRows r)
unsafeCopyVector y1 x
trmv alpha t' y1
unsafeDoSApplyAdd alpha r x 0 y2
(Upper,_,Left t') -> do
unsafeCopyVector (coerceVector y) x
trmv alpha t' (coerceVector y)
(Upper,_,Right (t',r)) ->
let x1 = unsafeSubvectorView x 0 (numCols t')
x2 = unsafeSubvectorView x (numCols t') (numCols r)
in do
unsafeCopyVector y x1
trmv alpha t' y
unsafeDoSApplyAdd alpha r x2 1 y
where
(u,d,a) = triToBase t
unsafeDoSApplyMatTriMatrix :: (ReadMatrix a e m,
ReadMatrix b e m, WriteMatrix c e m) =>
e -> Tri a (r,s) e -> b (s,t) e -> c (r,t) e -> m ()
unsafeDoSApplyMatTriMatrix alpha t b c =
case (u, toLower d a, toUpper d a) of
(Lower,Left t',_) -> do
unsafeCopyMatrix c (coerceMatrix b)
trmm alpha t' c
(Lower,Right (t',r),_) -> do
let c1 = unsafeSubmatrixView c (0,0) (numRows t',numCols c)
c2 = unsafeSubmatrixView c (numRows t',0) (numRows r ,numCols c)
unsafeCopyMatrix c1 b
trmm alpha t' c1
unsafeDoSApplyAddMat alpha r b 0 c2
(Upper,_,Left t') -> do
unsafeCopyMatrix (coerceMatrix c) b
trmm alpha t' (coerceMatrix c)
(Upper,_,Right (t',r)) ->
let b1 = unsafeSubmatrixView b (0,0) (numCols t',numCols b)
b2 = unsafeSubmatrixView b (numCols t',0) (numCols r ,numCols b)
in do
unsafeCopyMatrix c b1
trmm alpha t' c
unsafeDoSApplyAddMat alpha r b2 1 c
where
(u,d,a) = triToBase t
toLower :: (BaseMatrix a e) => DiagEnum -> a (m,n) e
-> Either (Tri a (m,m) e)
(Tri a (n,n) e, a (k,n) e)
toLower d a =
if m <= n
then Left $ triFromBase Lower d (unsafeSubmatrixView a (0,0) (m,m))
else let t = triFromBase Lower d (unsafeSubmatrixView a (0,0) (n,n))
r = unsafeSubmatrixView a (n,0) (k,n)
in Right (t,r)
where
(m,n) = shape a
k = m n
toUpper :: (BaseMatrix a e) => DiagEnum -> a (m,n) e
-> Either (Tri a (n,n) e)
(Tri a (m,m) e, a (m,k) e)
toUpper d a =
if n <= m
then Left $ triFromBase Upper d (unsafeSubmatrixView a (0,0) (n,n))
else let t = triFromBase Upper d (unsafeSubmatrixView a (0,0) (m,m))
r = unsafeSubmatrixView a (0,m) (m,k)
in Right (t,r)
where
(m,n) = shape a
k = n m
trmv :: (ReadMatrix a e m, WriteVector y e m) =>
e -> Tri a (k,k) e -> y n e -> m ()
trmv alpha t x
| dim x == 0 =
return ()
| isConj x =
let (u,d,a) = triToBase t
side = RightSide
(h,u') = if isHermMatrix a then (NoTrans , flipUpLo u)
else (ConjTrans, u)
uploA = u'
transA = h
diagA = d
m = 1
n = dim x
alpha' = conjugate alpha
ldA = ldaMatrix a
ldB = stride x
in
withMatrixPtr a $ \pA ->
withVectorPtrIO x $ \pB ->
BLAS.trmm side uploA transA diagA m n alpha' pA ldA pB ldB
| otherwise =
let (u,d,a) = triToBase t
(transA,u') = if isHermMatrix a then (ConjTrans, flipUpLo u)
else (NoTrans , u)
uploA = u'
diagA = d
n = dim x
ldA = ldaMatrix a
incX = stride x
in do
when (alpha /= 1) $ scaleByVector alpha x
withMatrixPtr a $ \pA ->
withVectorPtrIO x $ \pX -> do
BLAS.trmv uploA transA diagA n pA ldA pX incX
where
withMatrixPtr d f = unsafePerformIOWithMatrix d $ flip withIOMatrix f
withVectorPtrIO = withIOVector . unsafeVectorToIOVector
trmm :: (ReadMatrix a e m, WriteMatrix b e m) =>
e -> Tri a (k,k) e -> b (k,l) e -> m ()
trmm _ _ b
| numRows b == 0 || numCols b == 0 = return ()
trmm alpha t b =
let (u,d,a) = triToBase t
(h,u') = if isHermMatrix a then (ConjTrans, flipUpLo u) else (NoTrans, u)
(m,n) = shape b
(side,h',m',n',alpha')
= if isHermMatrix b
then (RightSide, flipTrans h, n, m, conjugate alpha)
else (LeftSide , h , m, n, alpha )
uploA = u'
transA = h'
diagA = d
ldA = ldaMatrix a
ldB = ldaMatrix b
in
withMatrixPtr a $ \pA ->
withIOMatrix (unsafeMatrixToIOMatrix b) $ \pB ->
BLAS.trmm side uploA transA diagA m' n' alpha' pA ldA pB ldB
where
withMatrixPtr d f = unsafePerformIOWithMatrix d $ flip withIOMatrix f
unsafeDoSSolveTriMatrix :: (ReadMatrix a e m,
ReadVector y e m, WriteVector x e m) =>
e -> Tri a (k,l) e -> y k e -> x l e -> m ()
unsafeDoSSolveTriMatrix alpha t y x =
case (u, toLower d a, toUpper d a) of
(Lower,Left t',_) -> do
unsafeCopyVector x (coerceVector y)
trsv alpha t' (coerceVector x)
(Lower,Right (t',_),_) -> do
let y1 = unsafeSubvectorView y 0 (numRows t')
unsafeCopyVector x y1
trsv alpha t' x
(Upper,_,Left t') -> do
unsafeCopyVector x (coerceVector y)
trsv alpha t' x
(Upper,_,Right (t',r)) ->
let x1 = unsafeSubvectorView x 0 (numCols t')
x2 = unsafeSubvectorView x (numCols t') (numCols r)
in do
unsafeCopyVector x1 y
trsv alpha t' x1
setZeroVector x2
where
(u,d,a) = triToBase t
unsafeDoSSolveMatTriMatrix :: (ReadMatrix a e m,
ReadMatrix c e m, WriteMatrix b e m) =>
e -> Tri a (r,s) e -> c (r,t) e -> b (s,t) e -> m ()
unsafeDoSSolveMatTriMatrix alpha t c b =
case (u, toLower d a, toUpper d a) of
(Lower,Left t',_) -> do
unsafeCopyMatrix b (coerceMatrix c)
trsm alpha t' (coerceMatrix b)
(Lower,Right (t',_),_) -> do
let c1 = unsafeSubmatrixView c (0,0) (numRows t',numCols c)
unsafeCopyMatrix b c1
trsm alpha t' b
(Upper,_,Left t') -> do
unsafeCopyMatrix (coerceMatrix b) c
trsm alpha t' (coerceMatrix b)
(Upper,_,Right (t',r)) ->
let b1 = unsafeSubmatrixView b (0,0) (numCols t',numCols b)
b2 = unsafeSubmatrixView b (numCols t',0) (numCols r ,numCols b)
in do
unsafeCopyMatrix b1 c
trsm alpha t' b1
setZeroMatrix b2
where
(u,d,a) = triToBase t
trsv :: (ReadMatrix a e m, WriteVector y e m) =>
e -> Tri a (k,k) e -> y n e -> m ()
trsv alpha t x
| dim x == 0 = return ()
| isConj x =
let (u,d,a) = triToBase t
side = RightSide
(h,u') = if isHermMatrix a then (NoTrans, flipUpLo u) else (ConjTrans, u)
uploA = u'
transA = h
diagA = d
m = 1
n = dim x
alpha' = conjugate alpha
ldA = ldaMatrix a
ldB = stride x
in
withMatrixPtr a $ \pA ->
withVectorPtrIO x $ \pB ->
BLAS.trsm side uploA transA diagA m n alpha' pA ldA pB ldB
| otherwise =
let (u,d,a) = triToBase t
(transA,u') = if isHermMatrix a then (ConjTrans, flipUpLo u)
else (NoTrans , u)
uploA = u'
diagA = d
n = dim x
ldA = ldaMatrix a
incX = stride x
in do
when (alpha /= 1) $ scaleByVector alpha x
withMatrixPtr a $ \pA ->
withVectorPtrIO x $ \pX ->
BLAS.trsv uploA transA diagA n pA ldA pX incX
where
withVectorPtrIO = withIOVector . unsafeVectorToIOVector
withMatrixPtr d f = unsafePerformIOWithMatrix d $ flip withIOMatrix f
trsm :: (ReadMatrix a e m, WriteMatrix b e m) =>
e -> Tri a (k,k) e -> b (k,l) e -> m ()
trsm _ _ b
| numRows b == 0 || numCols b == 0 = return ()
trsm alpha t b =
let (u,d,a) = triToBase t
(h,u') = if isHermMatrix a then (ConjTrans, flipUpLo u) else (NoTrans, u)
(m,n) = shape b
(side,h',m',n',alpha')
= if isHermMatrix b
then (RightSide, flipTrans h, n, m, conjugate alpha)
else (LeftSide , h , m, n, alpha )
uploA = u'
transA = h'
diagA = d
ldA = ldaMatrix a
ldB = ldaMatrix b
in
withMatrixPtr a $ \pA ->
withIOMatrix (unsafeMatrixToIOMatrix b) $ \pB -> do
BLAS.trsm side uploA transA diagA m' n' alpha' pA ldA pB ldB
where
withMatrixPtr d f = unsafePerformIOWithMatrix d $ flip withIOMatrix f
class (MatrixShaped a, BLAS3 e, Monad m) => MSolve a e m where
unsafeDoSolve :: (ReadVector y e m, WriteVector x e m) =>
a (k,l) e -> y k e -> x l e -> m ()
unsafeDoSolve = unsafeDoSSolve 1
unsafeDoSolveMat :: (ReadMatrix c e m, WriteMatrix b e m) =>
a (r,s) e -> c (r,t) e -> b (s,t) e -> m ()
unsafeDoSolveMat = unsafeDoSSolveMat 1
unsafeDoSSolve :: (ReadVector y e m, WriteVector x e m) =>
e -> a (k,l) e -> y k e -> x l e -> m ()
unsafeDoSSolve alpha a y x = do
unsafeDoSolve a y x
scaleByVector alpha x
unsafeDoSSolveMat :: (ReadMatrix c e m, WriteMatrix b e m) =>
e -> a (r,s) e -> c (r,t) e -> b (s,t) e -> m ()
unsafeDoSSolveMat alpha a c b = do
unsafeDoSolveMat a c b
scaleByMatrix alpha b
unsafeDoSolve_ :: (WriteVector x e m) => a (k,k) e -> x k e -> m ()
unsafeDoSolve_ = unsafeDoSSolve_ 1
unsafeDoSSolve_ :: (WriteVector x e m) => e -> a (k,k) e -> x k e -> m ()
unsafeDoSSolve_ alpha a x = do
scaleByVector alpha x
unsafeDoSolve_ a x
unsafeDoSolveMat_ :: (WriteMatrix b e m) => a (k,k) e -> b (k,l) e -> m ()
unsafeDoSolveMat_ = unsafeDoSSolveMat_ 1
unsafeDoSSolveMat_ :: (WriteMatrix b e m) => e -> a (k,k) e -> b (k,l) e -> m ()
unsafeDoSSolveMat_ alpha a b = do
scaleByMatrix alpha b
unsafeDoSolveMat_ a b
instance (Elem e) => BaseMatrix IOMatrix e where
ldaMatrix = ldaMatrixIOMatrix
isHermMatrix = isHermIOMatrix
unsafeSubmatrixView = unsafeSubmatrixViewIOMatrix
unsafeDiagView = unsafeDiagViewIOMatrix
unsafeRowView = unsafeRowViewIOMatrix
unsafeColView = unsafeColViewIOMatrix
maybeViewMatrixAsVector = maybeViewIOMatrixAsVector
maybeViewVectorAsMatrix = maybeViewVectorAsIOMatrix
maybeViewVectorAsRow = maybeViewVectorAsRowIOMatrix
maybeViewVectorAsCol = maybeViewVectorAsColIOMatrix
unsafeIOMatrixToMatrix = id
unsafeMatrixToIOMatrix = id
instance (BLAS3 e) => MMatrix IOMatrix e IO where
unsafeDoSApplyAdd = gemv
unsafeDoSApplyAddMat = gemm
unsafeGetRow = unsafeGetRowMatrix
unsafeGetCol = unsafeGetColMatrix
getRows = getRowsIO
getCols = getColsIO
instance (BLAS3 e) => MMatrix (Herm IOMatrix) e IO where
unsafeDoSApplyAdd = hemv'
unsafeDoSApplyAddMat = hemm'
getRows = getRowsIO
getCols = getColsIO
instance (BLAS3 e) => MMatrix (Tri IOMatrix) e IO where
unsafeDoSApplyAdd = unsafeDoSApplyAddTriMatrix
unsafeDoSApplyAddMat = unsafeDoSApplyAddMatTriMatrix
unsafeDoSApply_ = trmv
unsafeDoSApplyMat_ = trmm
getRows = getRowsIO
getCols = getColsIO
instance (BLAS3 e) => MSolve (Tri IOMatrix) e IO where
unsafeDoSSolve = unsafeDoSSolveTriMatrix
unsafeDoSSolveMat = unsafeDoSSolveMatTriMatrix
unsafeDoSSolve_ = trsv
unsafeDoSSolveMat_ = trsm
instance (BLAS3 e) => ReadMatrix IOMatrix e IO where
unsafePerformIOWithMatrix a f = f a
freezeMatrix = freezeIOMatrix
unsafeFreezeMatrix = unsafeFreezeIOMatrix
instance (BLAS3 e) => WriteMatrix IOMatrix e IO where
newMatrix_ = newIOMatrix_
unsafeConvertIOMatrix = id
thawMatrix = thawIOMatrix
unsafeThawMatrix = unsafeThawIOMatrix
matrix :: (BLAS3 e) => (Int,Int) -> [((Int,Int), e)] -> Matrix (n,p) e
matrix mn ies = unsafePerformIO $
unsafeFreezeIOMatrix =<< newMatrix mn ies
unsafeMatrix :: (BLAS3 e) => (Int,Int) -> [((Int,Int), e)] -> Matrix (n,p) e
unsafeMatrix mn ies = unsafePerformIO $
unsafeFreezeIOMatrix =<< unsafeNewMatrix mn ies
listMatrix :: (BLAS3 e) => (Int,Int) -> [e] -> Matrix (n,p) e
listMatrix mn es = unsafePerformIO $
unsafeFreezeIOMatrix =<< newListMatrix mn es
replaceMatrix :: (BLAS1 e) => Matrix np e -> [((Int,Int),e)] -> Matrix np e
replaceMatrix (Matrix a) ies =
unsafePerformIO $ do
b <- newCopyIOMatrix a
mapM_ (uncurry $ writeElem b) ies
return (Matrix b)
unsafeReplaceMatrix :: (BLAS1 e) => Matrix np e -> [((Int,Int),e)] -> Matrix np e
unsafeReplaceMatrix (Matrix a) ies =
unsafePerformIO $ do
b <- newCopyIOMatrix a
mapM_ (uncurry $ unsafeWriteElem b) ies
return (Matrix b)
rowsMatrix :: (BLAS3 e) => (Int,Int) -> [Vector p e] -> Matrix (n,p) e
rowsMatrix mn rs = unsafePerformIO $
unsafeFreezeIOMatrix =<< newRowsMatrix mn rs
colsMatrix :: (BLAS3 e) => (Int,Int) -> [Vector n e] -> Matrix (n,p) e
colsMatrix mn cs = unsafePerformIO $
unsafeFreezeIOMatrix =<< newColsMatrix mn cs
matrixFromRow :: (BLAS3 e) => Vector p e -> Matrix (one,p) e
matrixFromRow (Vector x) =
case maybeViewVectorAsRow x of
Just x' -> Matrix x'
Nothing -> unsafePerformIO $ unsafeFreezeIOMatrix =<< newRowMatrix x
matrixFromCol :: (BLAS3 e) => Vector n e -> Matrix (n,one) e
matrixFromCol (Vector x) =
case maybeViewVectorAsCol x of
Just x' -> Matrix x'
Nothing -> unsafePerformIO $ unsafeFreezeIOMatrix =<< newColMatrix x
matrixFromVector :: (BLAS3 e) => (Int,Int) -> Vector np e -> Matrix (n,p) e
matrixFromVector (m,n) x
| dim x /= m*n =
error $ "matrixFromVector " ++ show (m,n) ++ "<vector of dim "
++ show (dim x) ++ ">: vector dimension must be equal to "
++ "the number of elements in the desired matrix"
| otherwise =
case maybeViewVectorAsMatrix (m,n) x of
Just a -> a
Nothing -> listMatrix (m,n) (elems x)
vectorFromMatrix :: (BLAS3 e) => Matrix (n,p) e -> Vector np e
vectorFromMatrix a =
case maybeViewMatrixAsVector a of
Just x -> x
Nothing -> listVector (size a) (concatMap elems (colViews a))
zeroMatrix :: (BLAS3 e) => (Int,Int) -> Matrix (n,p) e
zeroMatrix mn = unsafePerformIO $
unsafeFreezeIOMatrix =<< newZeroMatrix mn
constantMatrix :: (BLAS3 e) => (Int,Int) -> e -> Matrix (n,p) e
constantMatrix mn e = unsafePerformIO $
unsafeFreezeIOMatrix =<< newConstantMatrix mn e
identityMatrix :: (BLAS3 e) => (Int,Int) -> Matrix (n,p) e
identityMatrix mn = unsafePerformIO $
unsafeFreezeIOMatrix =<< newIdentityMatrix mn
submatrix :: (Elem e) => Matrix (n,p) e -> (Int,Int) -> (Int,Int) -> Matrix (n',p') e
submatrix (Matrix a) ij mn =
Matrix $ submatrixView a ij mn
unsafeSubmatrix :: (Elem e) => Matrix (n,p) e -> (Int,Int) -> (Int,Int) -> Matrix (n',p') e
unsafeSubmatrix (Matrix a) ij mn =
Matrix $ unsafeSubmatrixView a ij mn
diag :: (Elem e) => Matrix (n,p) e -> Int -> Vector k e
diag (Matrix a) i = Vector (diagView a i)
unsafeDiag :: (Elem e) => Matrix (n,p) e -> Int -> Vector k e
unsafeDiag (Matrix a) i = Vector (diagView a i)
unsafeAtMatrix :: (Elem e) => Matrix np e -> (Int,Int) -> e
unsafeAtMatrix (Matrix (IOMatrix f p _ _ l h)) (i,j)
| h = inlinePerformIO $ do
e <- liftM conjugate $ peekElemOff p (i*l+j)
io <- touchForeignPtr f
e `seq` io `seq` return e
| otherwise = inlinePerformIO $ do
e <- peekElemOff p (i+j*l)
io <- touchForeignPtr f
e `seq` io `seq` return e
indicesMatrix :: Matrix np e -> [(Int,Int)]
indicesMatrix (Matrix a) = indicesIOMatrix a
elemsMatrix :: (Elem e) => Matrix np e -> [e]
elemsMatrix (Matrix a) =
case maybeViewIOMatrixAsVector a of
(Just x) -> elemsVector (Vector x)
Nothing -> concatMap (elemsVector . Vector) (vecViews a)
where
vecViews = if isHermIOMatrix a
then rowViews . coerceMatrix
else colViews . coerceMatrix
assocsMatrix :: (Elem e) => Matrix np e -> [((Int,Int),e)]
assocsMatrix a = zip (indicesMatrix a) (elemsMatrix a)
tmapMatrix :: (BLAS3 e) => (e -> e) -> Matrix np e -> Matrix np e
tmapMatrix f a@(Matrix ma)
| isHermIOMatrix ma = coerceMatrix $ herm $
listMatrix (n,m) $ map (conjugate . f) (elems a)
| otherwise = coerceMatrix $
listMatrix (m,n) $ map f (elems a)
where
(m,n) = shape a
tzipWithMatrix :: (BLAS3 e) =>
(e -> e -> e) -> Matrix np e -> Matrix np e -> Matrix np e
tzipWithMatrix f a b
| shape b /= mn =
error ("tzipWith: matrix shapes differ; first has shape `" ++
show mn ++ "' and second has shape `" ++
show (shape b) ++ "'")
| otherwise =
coerceMatrix $
listMatrix mn $ zipWith f (colElems a) (colElems b)
where
mn = shape a
colElems = (concatMap elems) . colViews . coerceMatrix
instance Shaped Matrix (Int,Int) where
shape (Matrix a) = shapeIOMatrix a
bounds (Matrix a) = boundsIOMatrix a
instance MatrixShaped Matrix where
herm (Matrix a) = Matrix (herm a)
instance HasVectorView Matrix where
type VectorView Matrix = Vector
instance (Elem e) => BaseMatrix Matrix e where
ldaMatrix (Matrix a) = ldaMatrixIOMatrix a
isHermMatrix (Matrix a) = isHermMatrix a
unsafeSubmatrixView (Matrix a) ij mn =
Matrix (unsafeSubmatrixViewIOMatrix a ij mn)
unsafeDiagView (Matrix a) i = Vector (unsafeDiagViewIOMatrix a i)
unsafeRowView (Matrix a) i = Vector (unsafeRowViewIOMatrix a i)
unsafeColView (Matrix a) i = Vector (unsafeColViewIOMatrix a i)
maybeViewMatrixAsVector (Matrix a) = liftM Vector (maybeViewMatrixAsVector a)
maybeViewVectorAsMatrix mn (Vector x) =
liftM Matrix $ maybeViewVectorAsIOMatrix mn x
maybeViewVectorAsRow (Vector x) = liftM Matrix (maybeViewVectorAsRow x)
maybeViewVectorAsCol (Vector x) = liftM Matrix (maybeViewVectorAsCol x)
unsafeIOMatrixToMatrix = Matrix
unsafeMatrixToIOMatrix (Matrix a) = a
instance (BLAS3 e) => ITensor Matrix (Int,Int) e where
size (Matrix a) = sizeIOMatrix a
(//) = replaceMatrix
unsafeReplace = unsafeReplaceMatrix
unsafeAt = unsafeAtMatrix
indices = indicesMatrix
elems = elemsMatrix
assocs = assocsMatrix
tmap = tmapMatrix
(*>) k (Matrix a) = unsafePerformIO $ liftM coerceMatrix $
unsafeFreezeIOMatrix =<< getScaledMatrix k (coerceMatrix a)
shift k (Matrix a) = unsafePerformIO $ liftM coerceMatrix $
unsafeFreezeIOMatrix =<< getShiftedMatrix k (coerceMatrix a)
instance (BLAS3 e, Monad m) => ReadTensor Matrix (Int,Int) e m where
getSize = return . size
getAssocs = return . assocs
getIndices = return . indices
getElems = return . elems
getAssocs' = return . assocs
getIndices' = return . indices
getElems' = return . elems
unsafeReadElem x i = return $ unsafeAt x i
instance (BLAS3 e) => MMatrix Matrix e IO where
unsafeDoSApplyAdd = gemv
unsafeDoSApplyAddMat = gemm
unsafeGetRow = unsafeGetRowMatrix
unsafeGetCol = unsafeGetColMatrix
getRows = getRowsIO
getCols = getColsIO
instance (BLAS3 e) => MMatrix (Herm Matrix) e IO where
unsafeDoSApplyAdd = hemv'
unsafeDoSApplyAddMat = hemm'
getRows = getRowsIO
getCols = getColsIO
instance (BLAS3 e) => MMatrix (Tri Matrix) e IO where
unsafeDoSApplyAdd = unsafeDoSApplyAddTriMatrix
unsafeDoSApplyAddMat = unsafeDoSApplyAddMatTriMatrix
unsafeDoSApply_ = trmv
unsafeDoSApplyMat_ = trmm
getRows = getRowsIO
getCols = getColsIO
instance (BLAS3 e) => MSolve (Tri Matrix) e IO where
unsafeDoSSolve = unsafeDoSSolveTriMatrix
unsafeDoSSolveMat = unsafeDoSSolveMatTriMatrix
unsafeDoSSolve_ = trsv
unsafeDoSSolveMat_ = trsm
instance (BLAS3 e) => ReadMatrix Matrix e IO where
unsafePerformIOWithMatrix (Matrix a) f = f a
freezeMatrix (Matrix a) = freezeIOMatrix a
unsafeFreezeMatrix (Matrix a) = unsafeFreezeIOMatrix a
instance (BLAS3 e) => MMatrix Matrix e (ST s) where
unsafeDoSApplyAdd = gemv
unsafeDoSApplyAddMat = gemm
unsafeGetRow = unsafeGetRowMatrix
unsafeGetCol = unsafeGetColMatrix
getRows = getRowsST
getCols = getColsST
instance (BLAS3 e) => MMatrix (Herm Matrix) e (ST s) where
unsafeDoSApplyAdd = hemv'
unsafeDoSApplyAddMat = hemm'
getRows = getRowsST
getCols = getColsST
instance (BLAS3 e) => MMatrix (Tri Matrix) e (ST s) where
unsafeDoSApplyAdd = unsafeDoSApplyAddTriMatrix
unsafeDoSApplyAddMat = unsafeDoSApplyAddMatTriMatrix
unsafeDoSApply_ = trmv
unsafeDoSApplyMat_ = trmm
getRows = getRowsST
getCols = getColsST
instance (BLAS3 e) => MSolve (Tri Matrix) e (ST s) where
unsafeDoSSolve = unsafeDoSSolveTriMatrix
unsafeDoSSolveMat = unsafeDoSSolveMatTriMatrix
unsafeDoSSolve_ = trsv
unsafeDoSSolveMat_ = trsm
instance (BLAS3 e) => ReadMatrix Matrix e (ST s) where
unsafePerformIOWithMatrix (Matrix a) f = unsafeIOToST $ f a
freezeMatrix (Matrix a) = unsafeIOToST $ freezeIOMatrix a
unsafeFreezeMatrix (Matrix a) = unsafeIOToST $ unsafeFreezeIOMatrix a
compareMatrixWith :: (BLAS3 e) =>
(e -> e -> Bool) -> Matrix (n,p) e -> Matrix (n,p) e -> Bool
compareMatrixWith cmp a b
| shape a /= shape b =
False
| isHermMatrix a == isHermMatrix b =
let elems' = if isHermMatrix a then elems . herm
else elems
in
and $ zipWith cmp (elems' a) (elems' b)
| otherwise =
and $ zipWith cmp (colElems a) (colElems b)
where
colElems c = concatMap elems (colViews $ coerceMatrix c)
instance (BLAS3 e, Eq e) => Eq (Matrix (n,p) e) where
(==) = compareMatrixWith (==)
instance (BLAS3 e, AEq e) => AEq (Matrix (n,p) e) where
(===) = compareMatrixWith (===)
(~==) = compareMatrixWith (~==)
instance (BLAS3 e, Show e) => Show (Matrix (n,p) e) where
show a | isHermMatrix a =
"herm (" ++ show (herm a) ++ ")"
| otherwise =
"listMatrix " ++ show (shape a) ++ " " ++ show (elems a)
instance (BLAS3 e) => Num (Matrix (n,p) e) where
(+) x y = unsafePerformIO $ unsafeFreezeIOMatrix =<< getAddMatrix x y
() x y = unsafePerformIO $ unsafeFreezeIOMatrix =<< getSubMatrix x y
(*) x y = unsafePerformIO $ unsafeFreezeIOMatrix =<< getMulMatrix x y
negate = ((1) *>)
abs = tmap abs
signum = tmap signum
fromInteger = coerceMatrix . (constantMatrix (1,1)) . fromInteger
instance (BLAS3 e) => Fractional (Matrix (n,p) e) where
(/) x y = unsafePerformIO $ unsafeFreezeIOMatrix =<< getDivMatrix x y
recip = tmap recip
fromRational = coerceMatrix . (constantMatrix (1,1)) . fromRational
instance (BLAS3 e, Floating e) => Floating (Matrix (m,n) e) where
pi = constantMatrix (1,1) pi
exp = tmap exp
sqrt = tmap sqrt
log = tmap log
(**) = tzipWithMatrix (**)
sin = tmap sin
cos = tmap cos
tan = tmap tan
asin = tmap asin
acos = tmap acos
atan = tmap atan
sinh = tmap sinh
cosh = tmap cosh
tanh = tmap tanh
asinh = tmap asinh
acosh = tmap acosh
atanh = tmap atanh
liftMatrix :: (ReadMatrix a e m) =>
(forall k. VectorView a k e -> m ()) -> a (n,p) e -> m ()
liftMatrix f a =
case maybeViewMatrixAsVector a of
Just x -> f x
_ ->
let xs = case isHermMatrix a of
True -> rowViews (coerceMatrix a)
False -> colViews (coerceMatrix a)
in mapM_ f xs
liftMatrix2 :: (ReadMatrix a e m, ReadMatrix b f m) =>
(forall k. VectorView a k e -> VectorView b k f -> m ()) ->
a (n,p) e -> b (n,p) f -> m ()
liftMatrix2 f a b =
if isHermMatrix a == isHermMatrix b
then case (maybeViewMatrixAsVector a, maybeViewMatrixAsVector b) of
((Just x), (Just y)) -> f x y
_ -> elementwise
else elementwise
where
elementwise =
let vecsA = if isHermMatrix a then rowViews . coerceMatrix
else colViews . coerceMatrix
vecsB = if isHermMatrix a then rowViews . coerceMatrix
else colViews . coerceMatrix
xs = vecsA a
ys = vecsB b
in zipWithM_ f xs ys
checkMatrixOp2 :: (BaseMatrix x e, BaseMatrix y f) =>
(x n e -> y n f -> a) ->
x n e -> y n f -> a
checkMatrixOp2 f x y =
checkBinaryOp (shape x) (shape y) $ f x y
getUnaryMatrixOp :: (ReadMatrix a e m, WriteMatrix b e m) =>
(b (n,p) e -> m ()) -> a (n,p) e -> m (b (n,p) e)
getUnaryMatrixOp f a = do
b <- newCopyMatrix a
f b
return b
unsafeGetBinaryMatrixOp ::
(WriteMatrix c e m, ReadMatrix a e m, ReadMatrix b f m) =>
(c (n,p) e -> b (n,p) f -> m ()) ->
a (n,p) e -> b (n,p) f -> m (c (n,p) e)
unsafeGetBinaryMatrixOp f a b = do
c <- newCopyMatrix a
f c b
return c
transMatrix :: (BaseMatrix a e) => a (n,p) e -> TransEnum
transMatrix a =
case (isHermMatrix a) of
False -> NoTrans
True -> ConjTrans
indexOfMatrix :: (BaseMatrix a e) => a (n,p) e -> (Int,Int) -> Int
indexOfMatrix a (i,j) =
let (i',j') = case isHermMatrix a of
True -> (j,i)
False -> (i,j)
l = ldaMatrix a
in i' + j'*l