module Data.Matrix.Diag (
Diag,
diagFromVector,
vectorFromDiag,
coerceDiag,
module BLAS.Matrix,
) where
import Control.Monad( zipWithM_ )
import Control.Monad.ST( ST )
import BLAS.Elem( BLAS1 )
import BLAS.Matrix hiding ( BaseMatrix )
import qualified BLAS.Matrix as BLAS
import BLAS.Tensor
import BLAS.UnsafeIOToM
import Unsafe.Coerce
import Data.AEq
import Data.Matrix.Dense.Class( ReadMatrix, WriteMatrix, unsafeCopyMatrix,
rowViews, coerceMatrix )
import Data.Matrix.Dense.ST( runSTMatrix )
import Data.Vector.Dense( Vector )
import Data.Vector.Dense.IO( IOVector )
import Data.Vector.Dense.ST( STVector, runSTVector )
import Data.Vector.Dense.Class( BaseVector, ReadVector, WriteVector,
conj, dim, coerceVector, scaleBy, unsafeCopyVector, unsafeAxpyVector,
unsafeMulVector, unsafeDivVector, newCopyVector )
data Diag x nn e = forall n. Diag (x n e)
coerceDiag :: Diag x mn e -> Diag x mn' e
coerceDiag = unsafeCoerce
diagFromVector :: (BaseVector x) => x n e -> Diag x (n,n) e
diagFromVector = Diag . coerceVector
vectorFromDiag :: (BaseVector x) => Diag x (n,n) e -> x n e
vectorFromDiag (Diag x) = coerceVector x
instance (BaseVector x) => BaseTensor (Diag x) (Int,Int) where
shape (Diag x) = (n,n) where n = dim x
bounds (Diag x) = ((0,0),(n1,n1)) where n = dim x
instance (BaseVector x) => BLAS.BaseMatrix (Diag x) where
herm (Diag x) = Diag (conj x)
instance (BLAS1 e) => IMatrix (Diag Vector) e where
unsafeSApply alpha a x = runSTVector $ unsafeGetSApply alpha a x
unsafeSApplyMat alpha a b = runSTMatrix $ unsafeGetSApplyMat alpha a b
instance (BLAS1 e) => MMatrix (Diag IOVector) e IO where
unsafeDoSApplyAdd = unsafeDoSApplyAddDiagVector
unsafeDoSApplyAddMat = unsafeDoSApplyAddMatDiagVector
unsafeDoSApply_ = unsafeDoSApplyDiagVector_
unsafeDoSApplyMat_ = unsafeDoSApplyMatDiagVector_
instance (BLAS1 e) => MMatrix (Diag (STVector s)) e (ST s) where
unsafeDoSApplyAdd = unsafeDoSApplyAddDiagVector
unsafeDoSApplyAddMat = unsafeDoSApplyAddMatDiagVector
unsafeDoSApply_ = unsafeDoSApplyDiagVector_
unsafeDoSApplyMat_ = unsafeDoSApplyMatDiagVector_
instance (BLAS1 e, UnsafeIOToM m) => MMatrix (Diag Vector) e m where
unsafeDoSApplyAdd = unsafeDoSApplyAddDiagVector
unsafeDoSApplyAddMat = unsafeDoSApplyAddMatDiagVector
unsafeDoSApply_ = unsafeDoSApplyDiagVector_
unsafeDoSApplyMat_ = unsafeDoSApplyMatDiagVector_
unsafeDoSApplyAddDiagVector :: (ReadVector z m, ReadVector x m, WriteVector y m, BLAS1 e) =>
e -> Diag z (r,s) e -> x s e -> e -> y r e -> m ()
unsafeDoSApplyAddDiagVector alpha a x beta y = do
x' <- newCopyVector x
unsafeDoSApplyDiagVector_ 1 (coerceDiag a) x'
scaleBy beta y
unsafeAxpyVector alpha x' (coerceVector y)
unsafeDoSApplyAddMatDiagVector :: (ReadVector x m, ReadMatrix b y m, WriteMatrix c z m, BLAS1 e) =>
e -> Diag x (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m ()
unsafeDoSApplyAddMatDiagVector alpha a b beta c = do
scaleBy beta c
ks <- getElems (vectorFromDiag $ coerceDiag a)
let (kxs) = zip ks (rowViews b)
ys = rowViews c
zipWithM_ (\(k,x) y -> unsafeAxpyVector (alpha*k) x y) kxs ys
unsafeDoSApplyDiagVector_ :: (ReadVector x m, WriteVector y m, BLAS1 e) =>
e -> Diag x (s,s) e -> y s e -> m ()
unsafeDoSApplyDiagVector_ alpha a x = do
unsafeMulVector x (vectorFromDiag a)
scaleBy alpha x
unsafeDoSApplyMatDiagVector_ :: (ReadVector x m, WriteMatrix b y m, BLAS1 e) =>
e -> Diag x (s,s) e -> b (s,t) e -> m ()
unsafeDoSApplyMatDiagVector_ alpha a b = do
ks <- getElems (vectorFromDiag a)
zipWithM_ (\k r -> scaleBy (alpha*k) r) ks (rowViews b)
instance (BLAS1 e) => ISolve (Diag Vector) e where
unsafeSSolve alpha a y = runSTVector $ unsafeGetSSolve alpha a y
unsafeSSolveMat alpha a c = runSTMatrix $ unsafeGetSSolveMat alpha a c
instance (BLAS1 e) => MSolve (Diag IOVector) e IO where
unsafeDoSSolve = unsafeDoSSolveDiagVector
unsafeDoSSolveMat = unsafeDoSSolveMatDiagVector
unsafeDoSSolve_ = unsafeDoSSolveDiagVector_
unsafeDoSSolveMat_ = unsafeDoSSolveMatDiagVector_
instance (BLAS1 e) => MSolve (Diag (STVector s)) e (ST s) where
unsafeDoSSolve = unsafeDoSSolveDiagVector
unsafeDoSSolveMat = unsafeDoSSolveMatDiagVector
unsafeDoSSolve_ = unsafeDoSSolveDiagVector_
unsafeDoSSolveMat_ = unsafeDoSSolveMatDiagVector_
instance (BLAS1 e, UnsafeIOToM m) => MSolve (Diag Vector) e m where
unsafeDoSSolve = unsafeDoSSolveDiagVector
unsafeDoSSolveMat = unsafeDoSSolveMatDiagVector
unsafeDoSSolve_ = unsafeDoSSolveDiagVector_
unsafeDoSSolveMat_ = unsafeDoSSolveMatDiagVector_
unsafeDoSSolveDiagVector :: (ReadVector z m, ReadVector y m, WriteVector x m, BLAS1 e) =>
e -> Diag z (r,s) e -> y r e -> x s e -> m ()
unsafeDoSSolveDiagVector alpha a y x = do
unsafeCopyVector x (coerceVector y)
unsafeDoSSolveDiagVector_ alpha (coerceDiag a) x
unsafeDoSSolveMatDiagVector :: (ReadVector x m, ReadMatrix c z m, WriteMatrix b y m, BLAS1 e) =>
e -> Diag x (r,s) e -> c (r,t) e -> b (s,t) e -> m ()
unsafeDoSSolveMatDiagVector alpha a c b = do
unsafeCopyMatrix b (coerceMatrix c)
unsafeDoSSolveMatDiagVector_ alpha (coerceDiag a) b
unsafeDoSSolveDiagVector_ :: (ReadVector x m, WriteVector y m, BLAS1 e) =>
e -> Diag x (k,k) e -> y k e -> m ()
unsafeDoSSolveDiagVector_ alpha a x = do
scaleBy alpha x
unsafeDivVector x (vectorFromDiag a)
unsafeDoSSolveMatDiagVector_ :: (ReadVector x m, WriteMatrix a y m, BLAS1 e) =>
e -> Diag x (k,k) e -> a (k,l) e -> m ()
unsafeDoSSolveMatDiagVector_ alpha a b = do
scaleBy alpha b
ks <- unsafeInterleaveM $ getElems (vectorFromDiag a)
zipWithM_ (\k r -> scaleBy (1/k) r) ks (rowViews b)
instance (Show e, BLAS1 e) => Show (Diag Vector (n,n) e) where
show x = "diagFromVector (" ++ show (vectorFromDiag x) ++ ")"
instance (Eq e, BLAS1 e) => Eq (Diag Vector (n,n) e) where
(==) x y = (==) (vectorFromDiag x) (vectorFromDiag y)
instance (AEq e, BLAS1 e) => AEq (Diag Vector (n,n) e) where
(===) x y = (===) (vectorFromDiag x) (vectorFromDiag y)
(~==) x y = (~==) (vectorFromDiag x) (vectorFromDiag y)