module Data.Matrix.Tri.Dense (
module Data.Matrix.Tri,
module BLAS.Matrix.Immutable,
module BLAS.Matrix.ReadOnly,
module BLAS.Matrix.Solve,
trmv,
trsv,
trmm,
trsm
) where
import Control.Monad ( when )
import Data.Maybe ( fromJust )
import System.IO.Unsafe
import Unsafe.Coerce
import Data.Matrix.Dense.Internal
import qualified Data.Matrix.Dense.Internal as M
import Data.Vector.Dense.Internal
import qualified Data.Vector.Dense.Internal as V
import qualified Data.Vector.Dense.Operations as V
import BLAS.Access
import BLAS.Elem ( BLAS3 )
import qualified BLAS.Elem as E
import BLAS.C.Types ( cblasDiag, cblasUpLo, cblasTrans, colMajor,
noTrans, conjTrans, leftSide, rightSide )
import BLAS.Types ( Trans(..), flipTrans, flipUpLo )
import qualified BLAS.C as BLAS
import qualified BLAS.C.Types as BLAS
import BLAS.Matrix.Immutable
import BLAS.Matrix.ReadOnly
import BLAS.Matrix.Solve
import Data.Matrix.Tri
instance (BLAS3 e) => IMatrix (Tri (DMatrix Imm)) e where
(<*>) t x = unsafePerformIO $ getApply t x
(<**>) t a = unsafePerformIO $ getApplyMat t a
instance (BLAS3 e) => ISolve (Tri (DMatrix Imm)) e where
(<\>) t x = unsafePerformIO $ getSolve t x
(<\\>) t a = unsafePerformIO $ getSolveMat t a
instance (BLAS3 e) => RMatrix (Tri (DMatrix s)) e where
getApply t x = do
x' <- newCopy x
trmv (unsafeCoerce t) (V.unsafeThaw x')
return (unsafeCoerce x')
getApplyMat t a = do
a' <- newCopy a
trmm (unsafeCoerce t) (M.unsafeThaw a')
return (unsafeCoerce a')
instance (BLAS3 e) => RSolve (Tri (DMatrix s)) e where
getSolve t x = do
x' <- newCopy x
trsv (unsafeCoerce t) (V.unsafeThaw x')
return (unsafeCoerce x')
getSolveMat t a = do
a' <- newCopy a
trsm (unsafeCoerce t) (M.unsafeThaw a')
return (unsafeCoerce a')
trmv :: (BLAS3 e) => Tri (DMatrix t) (n,n) e -> IOVector n e -> IO ()
trmv _ x
| dim x == 0 = return ()
trmv t x
| isConj x =
let b = fromJust $ maybeFromCol x
in trmm t b
trmv t x =
let (u,d,alpha,a) = toBase t
order = colMajor
(transA,u') = if isHerm a then (conjTrans, flipUpLo u) else (noTrans, u)
uploA = cblasUpLo u'
diagA = cblasDiag d
n = numCols a
ldA = ldaOf a
incX = strideOf x
in M.unsafeWithElemPtr a (0,0) $ \pA ->
V.unsafeWithElemPtr x 0 $ \pX -> do
BLAS.trmv order uploA transA diagA n pA ldA pX incX
when (alpha /= 1) $ V.scaleBy alpha x
trmm :: (BLAS3 e) => Tri (DMatrix t) (m,m) e -> IOMatrix (m,n) e -> IO ()
trmm _ b
| M.numRows b == 0 || M.numCols b == 0 = return ()
trmm t b =
let (u,d,alpha,a) = toBase t
order = colMajor
(h,u') = if isHerm a then (ConjTrans, flipUpLo u) else (NoTrans, u)
(m,n) = shape b
(side,h',m',n',alpha')
= if M.isHerm b
then (rightSide, flipTrans h, n, m, E.conj alpha)
else (leftSide , h , m, n, alpha )
uploA = cblasUpLo u'
transA = cblasTrans h'
diagA = cblasDiag d
ldA = ldaOf a
ldB = ldaOf b
in M.unsafeWithElemPtr a (0,0) $ \pA ->
M.unsafeWithElemPtr b (0,0) $ \pB ->
BLAS.trmm order side uploA transA diagA m' n' alpha' pA ldA pB ldB
trsv :: (BLAS3 e) =>Tri (DMatrix t) (n,n) e -> IOVector n e -> IO ()
trsv _ x
| dim x == 0 = return ()
trsv t x
| isConj x =
let b = fromJust $ maybeFromCol x
in trsm t b
trsv t x =
let (u,d,alpha,a) = toBase t
order = colMajor
(transA,u') = if isHerm a then (conjTrans, flipUpLo u) else (noTrans, u)
uploA = cblasUpLo u'
diagA = cblasDiag d
n = numCols a
ldA = ldaOf a
incX = strideOf x
in M.unsafeWithElemPtr a (0,0) $ \pA ->
V.unsafeWithElemPtr x 0 $ \pX -> do
BLAS.trsv order uploA transA diagA n pA ldA pX incX
when (alpha /= 1) $ V.invScaleBy alpha x
trsm :: (BLAS3 e) => Tri (DMatrix t) (m,m) e -> IOMatrix (m,n) e -> IO ()
trsm _ b
| M.numRows b == 0 || M.numCols b == 0 = return ()
trsm t b =
let (u,d,alpha,a) = toBase t
order = colMajor
(h,u') = if isHerm a then (ConjTrans, flipUpLo u) else (NoTrans, u)
(m,n) = shape b
(side,h',m',n',alpha')
= if isHerm b
then (rightSide, flipTrans h, n, m, E.conj alpha)
else (leftSide , h , m, n, alpha )
uploA = cblasUpLo u'
transA = cblasTrans h'
diagA = cblasDiag d
ldA = ldaOf a
ldB = ldaOf b
in M.unsafeWithElemPtr a (0,0) $ \pA ->
M.unsafeWithElemPtr b (0,0) $ \pB -> do
BLAS.trsm order side uploA transA diagA m' n' (1/alpha') pA ldA pB ldB