{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Matrix.Tri.Banded -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Matrix.Tri.Banded ( module Data.Matrix.Tri, module BLAS.Matrix.Immutable, module BLAS.Matrix.ReadOnly, module BLAS.Matrix.Solve, ) where import Data.Matrix.Banded.Internal import qualified Data.Matrix.Banded.Internal as B import Data.Matrix.Dense.IO ( IOMatrix, coerceMatrix, unsafeCopyMatrix ) import qualified Data.Matrix.Dense.IO as M import Data.Vector.Dense.IO ( IOVector, isConj, strideOf, conj, coerceVector, unsafeCopyVector ) import qualified Data.Vector.Dense.IO as V import BLAS.Access import BLAS.Types ( flipUpLo ) import BLAS.C ( BLAS2, cblasDiag, cblasUpLo, colMajor, noTrans, conjTrans ) import qualified BLAS.C as BLAS import BLAS.Matrix.Immutable import BLAS.Matrix.ReadOnly import BLAS.Matrix.Solve import Data.Matrix.Tri instance (BLAS2 e) => IMatrix (Tri (BMatrix Imm)) e where instance (BLAS2 e) => RMatrix (Tri (BMatrix s)) e where unsafeDoSApply alpha a x y = do unsafeCopyVector (coerceVector y) x unsafeDoSApply_ alpha (coerceTri a) y unsafeDoSApplyMat alpha a b c = do unsafeCopyMatrix (coerceMatrix c) b unsafeDoSApplyMat_ alpha (coerceTri a) c unsafeDoSApply_ = tbmv unsafeDoSApplyMat_ = tbmm instance (BLAS2 e) => ISolve (Tri (BMatrix Imm)) e where instance (BLAS2 e) => RSolve (Tri (BMatrix s)) e where unsafeDoSSolve alpha a y x = do unsafeCopyVector (coerceVector x) y unsafeDoSSolve_ alpha (coerceTri a) x unsafeDoSSolveMat alpha a c b = do unsafeCopyMatrix (coerceMatrix b) c unsafeDoSSolveMat_ alpha (coerceTri a) b unsafeDoSSolve_ = tbsv unsafeDoSSolveMat_ = tbsm tbmv :: (BLAS2 e) => e -> Tri (BMatrix t) (n,n) e -> IOVector n e -> IO () tbmv alpha t x | isConj x = do V.doConj x tbmv alpha t (conj x) V.doConj x tbmv alpha t x = let (u,d,a) = toBase t order = colMajor (transA,u') = if isHerm a then (conjTrans, flipUpLo u) else (noTrans, u) uploA = cblasUpLo u' diagA = cblasDiag d n = numCols a k = case u of Upper -> numUpper a Lower -> numLower a ldA = ldaOf a incX = strideOf x withPtrA = case u' of Upper -> B.unsafeWithBasePtr a Lower -> B.unsafeWithElemPtr a (0,0) in withPtrA $ \pA -> V.unsafeWithElemPtr x 0 $ \pX -> do V.scaleBy alpha x BLAS.tbmv order uploA transA diagA n k pA ldA pX incX tbmm :: (BLAS2 e) => e -> Tri (BMatrix t) (m,m) e -> IOMatrix (m,n) e -> IO () tbmm 1 t b = mapM_ (\x -> tbmv 1 t x) (M.cols b) tbmm alpha t b = M.scaleBy alpha b >> tbmm 1 t b tbsv :: (BLAS2 e) => e -> Tri (BMatrix t) (n,n) e -> IOVector n e -> IO () tbsv alpha t x | isConj x = do V.doConj x tbsv alpha t (conj x) V.doConj x tbsv alpha t x = let (u,d,a) = toBase t order = colMajor (transA,u') = if isHerm a then (conjTrans, flipUpLo u) else (noTrans, u) uploA = cblasUpLo u' diagA = cblasDiag d n = numCols a k = case u of Upper -> numUpper a Lower -> numLower a ldA = ldaOf a incX = strideOf x withPtrA = case u' of Upper -> B.unsafeWithBasePtr a Lower -> B.unsafeWithElemPtr a (0,0) in withPtrA $ \pA -> V.unsafeWithElemPtr x 0 $ \pX -> do V.scaleBy alpha x BLAS.tbsv order uploA transA diagA n k pA ldA pX incX tbsm :: (BLAS2 e) => e -> Tri (BMatrix t) (m,m) e -> IOMatrix (m,n) e -> IO () tbsm 1 t b = mapM_ (\x -> tbsv 1 t x) (M.cols b) tbsm alpha t b = M.scaleBy alpha b >> tbsm 1 t b