module Data.Matrix.Tri.Dense (
module Data.Matrix.Tri,
module BLAS.Matrix.Immutable,
module BLAS.Matrix.ReadOnly,
module BLAS.Matrix.Solve,
trmv,
trsv,
trmm,
trsm,
) where
import Control.Monad ( when )
import Data.Maybe ( fromJust )
import Data.Matrix.Dense.Internal hiding ( diag )
import qualified Data.Matrix.Dense.Internal as M
import qualified Data.Matrix.Dense.Operations as M
import Data.Vector.Dense.Internal
import Data.Vector.Dense.Operations( unsafeCopyVector )
import Data.Matrix.Dense.Operations( unsafeCopyMatrix )
import qualified Data.Vector.Dense.Internal as V
import qualified Data.Vector.Dense.Operations as V
import BLAS.Access
import BLAS.Elem ( Elem, BLAS3 )
import qualified BLAS.Elem as E
import BLAS.C.Types ( cblasDiag, cblasUpLo, cblasTrans, colMajor,
noTrans, conjTrans, leftSide, rightSide )
import BLAS.Types ( Trans(..), flipTrans, flipUpLo )
import qualified BLAS.C as BLAS
import qualified BLAS.C.Types as BLAS
import BLAS.Matrix.Immutable
import BLAS.Matrix.ReadOnly
import BLAS.Matrix.Solve
import Data.Matrix.Tri
instance (BLAS3 e) => IMatrix (Tri (DMatrix Imm)) e where
instance (BLAS3 e) => ISolve (Tri (DMatrix Imm)) e where
toLower :: (Elem e) => Diag -> DMatrix s (m,n) e
-> Either (Tri (DMatrix s) (m,m) e)
(Tri (DMatrix s) (n,n) e, DMatrix s (d,n) e)
toLower diag a =
if m <= n
then Left $ fromBase Lower diag (unsafeSubmatrix a (0,0) (m,m))
else let t = fromBase Lower diag (unsafeSubmatrix a (0,0) (n,n))
r = unsafeSubmatrix a (n,0) (d,n)
in Right (t,r)
where
(m,n) = shape a
d = m n
toUpper :: (Elem e) => Diag -> DMatrix s (m,n) e
-> Either (Tri (DMatrix s) (n,n) e)
(Tri (DMatrix s) (m,m) e, DMatrix s (m,d) e)
toUpper diag a =
if n <= m
then Left $ fromBase Upper diag (unsafeSubmatrix a (0,0) (n,n))
else let t = fromBase Upper diag (unsafeSubmatrix a (0,0) (m,m))
r = unsafeSubmatrix a (0,m) (m,d)
in Right (t,r)
where
(m,n) = shape a
d = n m
instance (BLAS3 e) => RMatrix (Tri (DMatrix s)) e where
unsafeDoSApply_ = trmv
unsafeDoSApplyMat_ = trmm
unsafeDoSApply alpha t x y =
case (u, toLower d a, toUpper d a) of
(Lower,Left t',_) -> do
unsafeCopyVector y (coerceVector x)
trmv alpha t' y
(Lower,Right (t',r),_) -> do
let y1 = unsafeSubvector y 0 (numRows t')
y2 = unsafeSubvector y (numRows t') (numRows r)
unsafeCopyVector y1 x
trmv alpha t' y1
unsafeDoSApply alpha r x y2
(Upper,_,Left t') -> do
unsafeCopyVector (coerceVector y) x
trmv alpha t' (coerceVector y)
(Upper,_,Right (t',r)) ->
let x1 = unsafeSubvector x 0 (numCols t')
x2 = unsafeSubvector x (numCols t') (numCols r)
in do
unsafeCopyVector y x1
trmv alpha t' y
unsafeDoSApplyAdd alpha r x2 1 y
where
(u,d,a) = toBase t
unsafeDoSApplyMat alpha t b c =
case (u, toLower d a, toUpper d a) of
(Lower,Left t',_) -> do
unsafeCopyMatrix c (coerceMatrix b)
trmm alpha t' c
(Lower,Right (t',r),_) -> do
let c1 = unsafeSubmatrix c (0,0) (numRows t',numCols c)
c2 = unsafeSubmatrix c (numRows t',0) (numRows r ,numCols c)
unsafeCopyMatrix c1 b
trmm alpha t' c1
unsafeDoSApplyMat alpha r b c2
(Upper,_,Left t') -> do
unsafeCopyMatrix (coerceMatrix c) b
trmm alpha t' (coerceMatrix c)
(Upper,_,Right (t',r)) ->
let b1 = unsafeSubmatrix b (0,0) (numCols t',numCols b)
b2 = unsafeSubmatrix b (numCols t',0) (numCols r ,numCols b)
in do
unsafeCopyMatrix c b1
trmm alpha t' c
unsafeDoSApplyAddMat alpha r b2 1 c
where
(u,d,a) = toBase t
instance (BLAS3 e) => RSolve (Tri (DMatrix s)) e where
unsafeDoSSolve_ = trsv
unsafeDoSSolveMat_ = trsm
unsafeDoSSolve alpha t y x =
case (u, toLower d a, toUpper d a) of
(Lower,Left t',_) -> do
unsafeCopyVector x (coerceVector y)
trsv alpha t' (coerceVector x)
(Lower,Right (t',_),_) -> do
let y1 = unsafeSubvector y 0 (numRows t')
unsafeCopyVector x y1
trsv alpha t' x
(Upper,_,Left t') -> do
unsafeCopyVector x (coerceVector y)
trsv alpha t' x
(Upper,_,Right (t',r)) ->
let x1 = unsafeSubvector x 0 (numCols t')
x2 = unsafeSubvector x (numCols t') (numCols r)
in do
unsafeCopyVector x1 y
trsv alpha t' x1
setZero x2
where
(u,d,a) = toBase t
unsafeDoSSolveMat alpha t c b =
case (u, toLower d a, toUpper d a) of
(Lower,Left t',_) -> do
unsafeCopyMatrix b (coerceMatrix c)
trsm alpha t' (coerceMatrix b)
(Lower,Right (t',_),_) -> do
let c1 = unsafeSubmatrix c (0,0) (numRows t',numCols c)
unsafeCopyMatrix b c1
trsm alpha t' b
(Upper,_,Left t') -> do
unsafeCopyMatrix (coerceMatrix b) c
trsm alpha t' (coerceMatrix b)
(Upper,_,Right (t',r)) ->
let b1 = unsafeSubmatrix b (0,0) (numCols t',numCols b)
b2 = unsafeSubmatrix b (numCols t',0) (numCols r ,numCols b)
in do
unsafeCopyMatrix b1 c
trsm alpha t' b1
setZero b2
where
(u,d,a) = toBase t
trmv :: (BLAS3 e) => e -> Tri (DMatrix t) (n,n) e -> IOVector n e -> IO ()
trmv alpha t x
| dim x == 0 =
return ()
| isConj x =
let b = fromJust $ maybeFromCol x
in trmm alpha t b
| otherwise =
let (u,d,a) = toBase t
order = colMajor
(transA,u') = if isHerm a then (conjTrans, flipUpLo u) else (noTrans, u)
uploA = cblasUpLo u'
diagA = cblasDiag d
n = dim x
ldA = ldaOf a
incX = strideOf x
in M.unsafeWithElemPtr a (0,0) $ \pA ->
V.unsafeWithElemPtr x 0 $ \pX -> do
when (alpha /= 1) $ V.scaleBy alpha x
BLAS.trmv order uploA transA diagA n pA ldA pX incX
trmm :: (BLAS3 e) => e -> Tri (DMatrix t) (m,m) e -> IOMatrix (m,n) e -> IO ()
trmm _ _ b
| M.numRows b == 0 || M.numCols b == 0 = return ()
trmm alpha t b =
let (u,d,a) = toBase t
order = colMajor
(h,u') = if isHerm a then (ConjTrans, flipUpLo u) else (NoTrans, u)
(m,n) = shape b
(side,h',m',n',alpha')
= if M.isHerm b
then (rightSide, flipTrans h, n, m, E.conj alpha)
else (leftSide , h , m, n, alpha )
uploA = cblasUpLo u'
transA = cblasTrans h'
diagA = cblasDiag d
ldA = ldaOf a
ldB = ldaOf b
in M.unsafeWithElemPtr a (0,0) $ \pA ->
M.unsafeWithElemPtr b (0,0) $ \pB ->
BLAS.trmm order side uploA transA diagA m' n' alpha' pA ldA pB ldB
trsv :: (BLAS3 e) => e -> Tri (DMatrix t) (n,n) e -> IOVector n e -> IO ()
trsv _ _ x
| dim x == 0 = return ()
trsv alpha t x
| isConj x =
let b = fromJust $ maybeFromCol x
in trsm alpha t b
trsv alpha t x =
let (u,d,a) = toBase t
order = colMajor
(transA,u') = if isHerm a then (conjTrans, flipUpLo u) else (noTrans, u)
uploA = cblasUpLo u'
diagA = cblasDiag d
n = dim x
ldA = ldaOf a
incX = strideOf x
in M.unsafeWithElemPtr a (0,0) $ \pA ->
V.unsafeWithElemPtr x 0 $ \pX -> do
when (alpha /= 1) $ V.scaleBy alpha x
BLAS.trsv order uploA transA diagA n pA ldA pX incX
trsm :: (BLAS3 e) => e -> Tri (DMatrix t) (m,m) e -> IOMatrix (m,n) e -> IO ()
trsm _ _ b
| M.numRows b == 0 || M.numCols b == 0 = return ()
trsm alpha t b =
let (u,d,a) = toBase t
order = colMajor
(h,u') = if isHerm a then (ConjTrans, flipUpLo u) else (NoTrans, u)
(m,n) = shape b
(side,h',m',n',alpha')
= if isHerm b
then (rightSide, flipTrans h, n, m, E.conj alpha)
else (leftSide , h , m, n, alpha )
uploA = cblasUpLo u'
transA = cblasTrans h'
diagA = cblasDiag d
ldA = ldaOf a
ldB = ldaOf b
in M.unsafeWithElemPtr a (0,0) $ \pA ->
M.unsafeWithElemPtr b (0,0) $ \pB -> do
BLAS.trsm order side uploA transA diagA m' n' alpha' pA ldA pB ldB