{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Matrix.Tri.Dense -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Matrix.Tri.Dense ( module Data.Matrix.Tri, module BLAS.Matrix.Immutable, module BLAS.Matrix.ReadOnly, module BLAS.Matrix.Solve, trmv, trsv, trmm, trsm, ) where import Control.Monad ( when ) import Data.Maybe ( fromJust ) import Data.Matrix.Dense.Internal hiding ( diag ) import qualified Data.Matrix.Dense.Internal as M import qualified Data.Matrix.Dense.Operations as M import Data.Vector.Dense.Internal import Data.Vector.Dense.Operations( unsafeCopyVector ) import Data.Matrix.Dense.Operations( unsafeCopyMatrix ) import qualified Data.Vector.Dense.Internal as V import qualified Data.Vector.Dense.Operations as V import BLAS.Access import BLAS.Elem ( Elem, BLAS3 ) import qualified BLAS.Elem as E import BLAS.C.Types ( cblasDiag, cblasUpLo, cblasTrans, colMajor, noTrans, conjTrans, leftSide, rightSide ) import BLAS.Types ( Trans(..), flipTrans, flipUpLo ) import qualified BLAS.C as BLAS import qualified BLAS.C.Types as BLAS import BLAS.Matrix.Immutable import BLAS.Matrix.ReadOnly import BLAS.Matrix.Solve import Data.Matrix.Tri instance (BLAS3 e) => IMatrix (Tri (DMatrix Imm)) e where instance (BLAS3 e) => ISolve (Tri (DMatrix Imm)) e where toLower :: (Elem e) => Diag -> DMatrix s (m,n) e -> Either (Tri (DMatrix s) (m,m) e) (Tri (DMatrix s) (n,n) e, DMatrix s (d,n) e) toLower diag a = if m <= n then Left $ fromBase Lower diag (unsafeSubmatrix a (0,0) (m,m)) else let t = fromBase Lower diag (unsafeSubmatrix a (0,0) (n,n)) r = unsafeSubmatrix a (n,0) (d,n) in Right (t,r) where (m,n) = shape a d = m - n toUpper :: (Elem e) => Diag -> DMatrix s (m,n) e -> Either (Tri (DMatrix s) (n,n) e) (Tri (DMatrix s) (m,m) e, DMatrix s (m,d) e) toUpper diag a = if n <= m then Left $ fromBase Upper diag (unsafeSubmatrix a (0,0) (n,n)) else let t = fromBase Upper diag (unsafeSubmatrix a (0,0) (m,m)) r = unsafeSubmatrix a (0,m) (m,d) in Right (t,r) where (m,n) = shape a d = n - m instance (BLAS3 e) => RMatrix (Tri (DMatrix s)) e where unsafeDoSApply_ = trmv unsafeDoSApplyMat_ = trmm unsafeDoSApply alpha t x y = case (u, toLower d a, toUpper d a) of (Lower,Left t',_) -> do unsafeCopyVector y (coerceVector x) trmv alpha t' y (Lower,Right (t',r),_) -> do let y1 = unsafeSubvector y 0 (numRows t') y2 = unsafeSubvector y (numRows t') (numRows r) unsafeCopyVector y1 x trmv alpha t' y1 unsafeDoSApply alpha r x y2 (Upper,_,Left t') -> do unsafeCopyVector (coerceVector y) x trmv alpha t' (coerceVector y) (Upper,_,Right (t',r)) -> let x1 = unsafeSubvector x 0 (numCols t') x2 = unsafeSubvector x (numCols t') (numCols r) in do unsafeCopyVector y x1 trmv alpha t' y unsafeDoSApplyAdd alpha r x2 1 y where (u,d,a) = toBase t unsafeDoSApplyMat alpha t b c = case (u, toLower d a, toUpper d a) of (Lower,Left t',_) -> do unsafeCopyMatrix c (coerceMatrix b) trmm alpha t' c (Lower,Right (t',r),_) -> do let c1 = unsafeSubmatrix c (0,0) (numRows t',numCols c) c2 = unsafeSubmatrix c (numRows t',0) (numRows r ,numCols c) unsafeCopyMatrix c1 b trmm alpha t' c1 unsafeDoSApplyMat alpha r b c2 (Upper,_,Left t') -> do unsafeCopyMatrix (coerceMatrix c) b trmm alpha t' (coerceMatrix c) (Upper,_,Right (t',r)) -> let b1 = unsafeSubmatrix b (0,0) (numCols t',numCols b) b2 = unsafeSubmatrix b (numCols t',0) (numCols r ,numCols b) in do unsafeCopyMatrix c b1 trmm alpha t' c unsafeDoSApplyAddMat alpha r b2 1 c where (u,d,a) = toBase t instance (BLAS3 e) => RSolve (Tri (DMatrix s)) e where unsafeDoSSolve_ = trsv unsafeDoSSolveMat_ = trsm unsafeDoSSolve alpha t y x = case (u, toLower d a, toUpper d a) of (Lower,Left t',_) -> do unsafeCopyVector x (coerceVector y) trsv alpha t' (coerceVector x) (Lower,Right (t',_),_) -> do let y1 = unsafeSubvector y 0 (numRows t') unsafeCopyVector x y1 trsv alpha t' x (Upper,_,Left t') -> do unsafeCopyVector x (coerceVector y) trsv alpha t' x (Upper,_,Right (t',r)) -> let x1 = unsafeSubvector x 0 (numCols t') x2 = unsafeSubvector x (numCols t') (numCols r) in do unsafeCopyVector x1 y trsv alpha t' x1 setZero x2 where (u,d,a) = toBase t unsafeDoSSolveMat alpha t c b = case (u, toLower d a, toUpper d a) of (Lower,Left t',_) -> do unsafeCopyMatrix b (coerceMatrix c) trsm alpha t' (coerceMatrix b) (Lower,Right (t',_),_) -> do let c1 = unsafeSubmatrix c (0,0) (numRows t',numCols c) unsafeCopyMatrix b c1 trsm alpha t' b (Upper,_,Left t') -> do unsafeCopyMatrix (coerceMatrix b) c trsm alpha t' (coerceMatrix b) (Upper,_,Right (t',r)) -> let b1 = unsafeSubmatrix b (0,0) (numCols t',numCols b) b2 = unsafeSubmatrix b (numCols t',0) (numCols r ,numCols b) in do unsafeCopyMatrix b1 c trsm alpha t' b1 setZero b2 where (u,d,a) = toBase t trmv :: (BLAS3 e) => e -> Tri (DMatrix t) (n,n) e -> IOVector n e -> IO () trmv alpha t x | dim x == 0 = return () | isConj x = let b = fromJust $ maybeFromCol x in trmm alpha t b | otherwise = let (u,d,a) = toBase t order = colMajor (transA,u') = if isHerm a then (conjTrans, flipUpLo u) else (noTrans, u) uploA = cblasUpLo u' diagA = cblasDiag d n = dim x ldA = ldaOf a incX = strideOf x in M.unsafeWithElemPtr a (0,0) $ \pA -> V.unsafeWithElemPtr x 0 $ \pX -> do when (alpha /= 1) $ V.scaleBy alpha x BLAS.trmv order uploA transA diagA n pA ldA pX incX trmm :: (BLAS3 e) => e -> Tri (DMatrix t) (m,m) e -> IOMatrix (m,n) e -> IO () trmm _ _ b | M.numRows b == 0 || M.numCols b == 0 = return () trmm alpha t b = let (u,d,a) = toBase t order = colMajor (h,u') = if isHerm a then (ConjTrans, flipUpLo u) else (NoTrans, u) (m,n) = shape b (side,h',m',n',alpha') = if M.isHerm b then (rightSide, flipTrans h, n, m, E.conj alpha) else (leftSide , h , m, n, alpha ) uploA = cblasUpLo u' transA = cblasTrans h' diagA = cblasDiag d ldA = ldaOf a ldB = ldaOf b in M.unsafeWithElemPtr a (0,0) $ \pA -> M.unsafeWithElemPtr b (0,0) $ \pB -> BLAS.trmm order side uploA transA diagA m' n' alpha' pA ldA pB ldB trsv :: (BLAS3 e) => e -> Tri (DMatrix t) (n,n) e -> IOVector n e -> IO () trsv _ _ x | dim x == 0 = return () trsv alpha t x | isConj x = let b = fromJust $ maybeFromCol x in trsm alpha t b trsv alpha t x = let (u,d,a) = toBase t order = colMajor (transA,u') = if isHerm a then (conjTrans, flipUpLo u) else (noTrans, u) uploA = cblasUpLo u' diagA = cblasDiag d n = dim x ldA = ldaOf a incX = strideOf x in M.unsafeWithElemPtr a (0,0) $ \pA -> V.unsafeWithElemPtr x 0 $ \pX -> do when (alpha /= 1) $ V.scaleBy alpha x BLAS.trsv order uploA transA diagA n pA ldA pX incX trsm :: (BLAS3 e) => e -> Tri (DMatrix t) (m,m) e -> IOMatrix (m,n) e -> IO () trsm _ _ b | M.numRows b == 0 || M.numCols b == 0 = return () trsm alpha t b = let (u,d,a) = toBase t order = colMajor (h,u') = if isHerm a then (ConjTrans, flipUpLo u) else (NoTrans, u) (m,n) = shape b (side,h',m',n',alpha') = if isHerm b then (rightSide, flipTrans h, n, m, E.conj alpha) else (leftSide , h , m, n, alpha ) uploA = cblasUpLo u' transA = cblasTrans h' diagA = cblasDiag d ldA = ldaOf a ldB = ldaOf b in M.unsafeWithElemPtr a (0,0) $ \pA -> M.unsafeWithElemPtr b (0,0) $ \pB -> do BLAS.trsm order side uploA transA diagA m' n' alpha' pA ldA pB ldB