module Data.Matrix.Tri.Banded (
module Data.Matrix.Tri,
module BLAS.Matrix.Immutable,
module BLAS.Matrix.ReadOnly,
module BLAS.Matrix.Solve,
) where
import Data.Matrix.Banded.Internal
import qualified Data.Matrix.Banded.Internal as B
import Data.Matrix.Dense.IO ( IOMatrix, coerceMatrix, unsafeCopyMatrix )
import qualified Data.Matrix.Dense.IO as M
import Data.Vector.Dense.IO ( IOVector, isConj, strideOf, conj, coerceVector,
unsafeCopyVector )
import qualified Data.Vector.Dense.IO as V
import BLAS.Access
import BLAS.Types ( flipUpLo )
import BLAS.C ( BLAS2, cblasDiag, cblasUpLo, colMajor, noTrans, conjTrans )
import qualified BLAS.C as BLAS
import BLAS.Matrix.Immutable
import BLAS.Matrix.ReadOnly
import BLAS.Matrix.Solve
import Data.Matrix.Tri
instance (BLAS2 e) => IMatrix (Tri (BMatrix Imm)) e where
instance (BLAS2 e) => RMatrix (Tri (BMatrix s)) e where
unsafeDoSApply alpha a x y = do
unsafeCopyVector (coerceVector y) x
unsafeDoSApply_ alpha (coerceTri a) y
unsafeDoSApplyMat alpha a b c = do
unsafeCopyMatrix (coerceMatrix c) b
unsafeDoSApplyMat_ alpha (coerceTri a) c
unsafeDoSApply_ = tbmv
unsafeDoSApplyMat_ = tbmm
instance (BLAS2 e) => ISolve (Tri (BMatrix Imm)) e where
instance (BLAS2 e) => RSolve (Tri (BMatrix s)) e where
unsafeDoSSolve alpha a y x = do
unsafeCopyVector (coerceVector x) y
unsafeDoSSolve_ alpha (coerceTri a) x
unsafeDoSSolveMat alpha a c b = do
unsafeCopyMatrix (coerceMatrix b) c
unsafeDoSSolveMat_ alpha (coerceTri a) b
unsafeDoSSolve_ = tbsv
unsafeDoSSolveMat_ = tbsm
tbmv :: (BLAS2 e) => e -> Tri (BMatrix t) (n,n) e -> IOVector n e -> IO ()
tbmv alpha t x | isConj x = do
V.doConj x
tbmv alpha t (conj x)
V.doConj x
tbmv alpha t x =
let (u,d,a) = toBase t
order = colMajor
(transA,u') = if isHerm a then (conjTrans, flipUpLo u) else (noTrans, u)
uploA = cblasUpLo u'
diagA = cblasDiag d
n = numCols a
k = case u of Upper -> numUpper a
Lower -> numLower a
ldA = ldaOf a
incX = strideOf x
withPtrA = case u' of
Upper -> B.unsafeWithBasePtr a
Lower -> B.unsafeWithElemPtr a (0,0)
in withPtrA $ \pA ->
V.unsafeWithElemPtr x 0 $ \pX -> do
V.scaleBy alpha x
BLAS.tbmv order uploA transA diagA n k pA ldA pX incX
tbmm :: (BLAS2 e) => e -> Tri (BMatrix t) (m,m) e -> IOMatrix (m,n) e -> IO ()
tbmm 1 t b = mapM_ (\x -> tbmv 1 t x) (M.cols b)
tbmm alpha t b = M.scaleBy alpha b >> tbmm 1 t b
tbsv :: (BLAS2 e) => e -> Tri (BMatrix t) (n,n) e -> IOVector n e -> IO ()
tbsv alpha t x | isConj x = do
V.doConj x
tbsv alpha t (conj x)
V.doConj x
tbsv alpha t x =
let (u,d,a) = toBase t
order = colMajor
(transA,u') = if isHerm a then (conjTrans, flipUpLo u) else (noTrans, u)
uploA = cblasUpLo u'
diagA = cblasDiag d
n = numCols a
k = case u of Upper -> numUpper a
Lower -> numLower a
ldA = ldaOf a
incX = strideOf x
withPtrA = case u' of
Upper -> B.unsafeWithBasePtr a
Lower -> B.unsafeWithElemPtr a (0,0)
in withPtrA $ \pA ->
V.unsafeWithElemPtr x 0 $ \pX -> do
V.scaleBy alpha x
BLAS.tbsv order uploA transA diagA n k pA ldA pX incX
tbsm :: (BLAS2 e) => e -> Tri (BMatrix t) (m,m) e -> IOMatrix (m,n) e -> IO ()
tbsm 1 t b = mapM_ (\x -> tbsv 1 t x) (M.cols b)
tbsm alpha t b = M.scaleBy alpha b >> tbsm 1 t b