module Data.Matrix.Herm (
Herm(..),
UpLo(..),
fromBase,
toBase,
mapHerm,
hermL,
hermU,
coerceHerm,
) where
import Control.Monad( zipWithM_ )
import Control.Monad.ST( ST )
import Unsafe.Coerce
import BLAS.C( BLAS2, BLAS3, colMajor, rightSide, leftSide, cblasUpLo )
import qualified BLAS.C as BLAS
import BLAS.UnsafeIOToM
import BLAS.Matrix
import BLAS.Types ( UpLo(..), flipUpLo )
import Data.Matrix.Banded( Banded )
import Data.Matrix.Banded.Class
import Data.Matrix.Banded.IO( IOBanded )
import Data.Matrix.Banded.ST( STBanded )
import Data.Matrix.Dense( Matrix )
import Data.Matrix.Dense.Class hiding ( BaseMatrix )
import Data.Matrix.Dense.IO( IOMatrix )
import Data.Matrix.Dense.ST( STMatrix, runSTMatrix )
import Data.Vector.Dense.Class
import Data.Vector.Dense.ST( runSTVector )
data Herm a nn e = Herm UpLo (a nn e)
coerceHerm :: Herm a mn e -> Herm a mn' e
coerceHerm = unsafeCoerce
mapHerm :: (a (n,n) e -> b (n,n) e) -> Herm a (n,n) e -> Herm b (n,n) e
mapHerm f (Herm u a) = Herm u $ f a
fromBase :: UpLo -> a (n,n) e -> Herm a (n,n) e
fromBase = Herm
toBase :: Herm a (n,n) e -> (UpLo, a (n,n) e)
toBase (Herm u a) = (u,a)
hermL :: a (n,n) e -> Herm a (n,n) e
hermL = Herm Lower
hermU :: a (n,n) e -> Herm a (n,n) e
hermU = Herm Upper
instance BaseMatrix a => BaseTensor (Herm a) (Int,Int) where
shape (Herm _ a) = (n,n) where n = min (numRows a) (numCols a)
bounds (Herm _ a) = ((0,0),(n1,n1)) where n = min (numRows a) (numCols a)
instance BaseMatrix a => BaseMatrix (Herm a) where
herm = coerceHerm
instance Show (a mn e) => Show (Herm a mn e) where
show (Herm u a) = constructor ++ " (" ++ show a ++ ")"
where
constructor = case u of
Lower -> "hermL"
Upper -> "hermU"
hemv :: (ReadMatrix a z m, ReadVector x m, WriteVector y m, BLAS2 e) =>
e -> Herm a (k,k) e -> x k e -> e -> y k e -> m ()
hemv alpha h x beta y
| numRows h == 0 =
return ()
| isConj y = do
doConj y
hemv alpha h x beta (conj y)
doConj y
| isConj x = do
x' <- newCopyVector x
doConj x'
hemv alpha h (conj x') beta y
| otherwise =
let order = colMajor
(u,a) = toBase h
n = numCols a
u' = case isHermMatrix a of
True -> flipUpLo u
False -> u
uploA = cblasUpLo u'
ldA = ldaOfMatrix a
incX = stride x
incY = stride y
in unsafeIOToM $
withMatrixPtr a $ \pA ->
withVectorPtr x $ \pX ->
withVectorPtr y $ \pY ->
BLAS.hemv order uploA n alpha pA ldA pX incX beta pY incY
hemm :: (ReadMatrix a x m, ReadMatrix b y m, WriteMatrix c z m, BLAS3 e) =>
e -> Herm a (k,k) e -> b (k,l) e -> e -> c (k,l) e -> m ()
hemm alpha h b beta c
| numRows b == 0 || numCols b == 0 || numCols c == 0 = return ()
| (isHermMatrix a) /= (isHermMatrix c) || (isHermMatrix a) /= (isHermMatrix b) =
zipWithM_ (\x y -> hemv alpha h x beta y) (colViews b) (colViews c)
| otherwise =
let order = colMajor
(m,n) = shape c
(side,u',m',n')
= if isHermMatrix a
then (rightSide, flipUpLo u, n, m)
else (leftSide, u, m, n)
uploA = cblasUpLo u'
ldA = ldaOfMatrix a
ldB = ldaOfMatrix b
ldC = ldaOfMatrix c
in unsafeIOToM $
withMatrixPtr a $ \pA ->
withMatrixPtr b $ \pB ->
withMatrixPtr c $ \pC ->
BLAS.hemm order side uploA m' n' alpha pA ldA pB ldB beta pC ldC
where
(u,a) = toBase h
hemv' :: (ReadMatrix a z m, ReadVector x m, WriteVector y m, BLAS2 e) =>
e -> Herm a (r,s) e -> x s e -> e -> y r e -> m ()
hemv' alpha a x beta y =
hemv alpha (coerceHerm a) x beta (coerceVector y)
hemm' :: (ReadMatrix a x m, ReadMatrix b y m, WriteMatrix c z m, BLAS3 e) =>
e -> Herm a (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m ()
hemm' alpha a b beta c =
hemm alpha (coerceHerm a) b beta (coerceMatrix c)
instance (BLAS3 e) => IMatrix (Herm Matrix) e where
unsafeSApply alpha a x = runSTVector $ unsafeGetSApply alpha a x
unsafeSApplyMat alpha a b = runSTMatrix $ unsafeGetSApplyMat alpha a b
instance (BLAS3 e) => MMatrix (Herm (STMatrix s)) e (ST s) where
unsafeDoSApplyAdd = hemv'
unsafeDoSApplyAddMat = hemm'
instance (BLAS3 e) => MMatrix (Herm IOMatrix) e IO where
unsafeDoSApplyAdd = hemv'
unsafeDoSApplyAddMat = hemm'
instance (BLAS3 e, UnsafeIOToM m) => MMatrix (Herm Matrix) e m where
unsafeDoSApplyAdd = hemv'
unsafeDoSApplyAddMat = hemm'
hbmv :: (ReadBanded a z m, ReadVector x m, WriteVector y m, BLAS2 e) =>
e -> Herm a (k,k) e -> x k e -> e -> y k e -> m ()
hbmv alpha h x beta y
| numRows h == 0 =
return ()
| isConj y = do
doConj y
hbmv alpha h x beta (conj y)
doConj y
| isConj x = do
x' <- newCopyVector x
doConj x'
hbmv alpha h (conj x') beta y
| otherwise =
let order = colMajor
(u,a) = toBase h
n = numCols a
k = case u of
Upper -> numUpper a
Lower -> numLower a
u' = case (isHermBanded a) of
True -> flipUpLo u
False -> u
uploA = cblasUpLo u'
ldA = ldaOfBanded a
incX = stride x
incY = stride y
withPtrA
= case u' of Upper -> withBandedPtr a
Lower -> withBandedElemPtr a (0,0)
in unsafeIOToM $
withPtrA $ \pA ->
withVectorPtr x $ \pX ->
withVectorPtr y $ \pY -> do
BLAS.hbmv order uploA n k alpha pA ldA pX incX beta pY incY
hbmm :: (ReadBanded a x m, ReadMatrix b y m, WriteMatrix c z m, BLAS2 e) =>
e -> Herm a (k,k) e -> b (k,l) e -> e -> c (k,l) e -> m ()
hbmm alpha h b beta c =
zipWithM_ (\x y -> hbmv alpha h x beta y) (colViews b) (colViews c)
hbmv' :: (ReadBanded a z m, ReadVector x m, WriteVector y m, BLAS2 e) =>
e -> Herm a (r,s) e -> x s e -> e -> y r e -> m ()
hbmv' alpha a x beta y =
hbmv alpha (coerceHerm a) x beta (coerceVector y)
hbmm' :: (ReadBanded a x m, ReadMatrix b y m, WriteMatrix c z m, BLAS3 e) =>
e -> Herm a (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m ()
hbmm' alpha a b beta c =
hbmm alpha (coerceHerm a) b beta (coerceMatrix c)
instance (BLAS3 e) => IMatrix (Herm Banded) e where
unsafeSApply alpha a x = runSTVector $ unsafeGetSApply alpha a x
unsafeSApplyMat alpha a b = runSTMatrix $ unsafeGetSApplyMat alpha a b
instance (BLAS3 e) => MMatrix (Herm (STBanded s)) e (ST s) where
unsafeDoSApplyAdd = hbmv'
unsafeDoSApplyAddMat = hbmm'
instance (BLAS3 e) => MMatrix (Herm IOBanded) e IO where
unsafeDoSApplyAdd = hbmv'
unsafeDoSApplyAddMat = hbmm'
instance (BLAS3 e, UnsafeIOToM m) => MMatrix (Herm Banded) e m where
unsafeDoSApplyAdd = hbmv'
unsafeDoSApplyAddMat = hbmm'