module Data.Matrix.Diag (
module BLAS.Matrix,
module BLAS.Tensor,
DiagMatrix,
Diag,
IODiag,
fromVector,
toVector,
) where
import BLAS.Access
import BLAS.Elem ( BLAS1, BLAS2 )
import BLAS.Matrix hiding ( Matrix )
import qualified BLAS.Matrix as Base
import BLAS.Tensor
import Control.Monad ( zipWithM_ )
import Data.AEq
import Data.Matrix.Dense.Internal ( rows )
import Data.Matrix.Dense.Operations ( unsafeCopyMatrix )
import qualified Data.Matrix.Dense.Operations as M
import qualified Data.Matrix.Dense.Internal as M
import Data.Vector.Dense.Internal
import qualified Data.Vector.Dense.Internal as V
import qualified Data.Vector.Dense.Operations as V
import Data.Vector.Dense.Operations ( (//=), scaleBy, invScaleBy,
unsafeCopyVector )
import System.IO.Unsafe
import Unsafe.Coerce
newtype DiagMatrix t mn e = Diag (DVector t mn e)
type Diag = DiagMatrix Imm
type IODiag = DiagMatrix Mut
coerceDiag :: DiagMatrix t mn e -> DiagMatrix t mn' e
coerceDiag = unsafeCoerce
fromVector :: DVector t n e -> DiagMatrix t (n,n) e
fromVector = Diag . unsafeCoerce
toVector :: DiagMatrix t (n,n) e -> DVector t n e
toVector (Diag x) = unsafeCoerce x
instance (BLAS1 e) => Scalable (DiagMatrix Imm (n,n)) e where
(*>) k (Diag x) = Diag $ k *> x
replaceHelp :: (BLAS1 e) =>
(Vector n e -> [(Int,e)] -> Vector n e)
-> Diag (n,n) e -> [((Int,Int),e)] -> Diag (n,n) e
replaceHelp f a ijes =
let iies = filter (\((i,j),_) -> i == j) ijes
ies = map (\((i,_),e) -> (i,e)) iies
x' = f (toVector a) ies
in fromVector x'
instance (BLAS1 e) => ITensor (DiagMatrix Imm (n,n)) (Int,Int) e where
size = size . toVector
assocs a =
let ies = assocs $ toVector a
in map (\(i,e) -> ((i,i),e)) ies
(//) = replaceHelp (//)
amap f a = fromVector (amap f $ toVector a)
unsafeAt a (i,j) | i /= j = 0
| otherwise = unsafeAt (toVector a) i
unsafeReplace = replaceHelp unsafeReplace
instance (BLAS1 e) => RTensor (DiagMatrix t (n,n)) (Int,Int) e IO where
getSize = getSize . toVector
newCopy a = do
x' <- newCopy $ toVector a
return $ fromVector x'
unsafeReadElem a (i,j)
| i /= j = return 0
| otherwise = unsafeReadElem (toVector a) i
instance (BLAS1 e) => MTensor (DiagMatrix Mut (n,n)) (Int,Int) e IO where
setZero = setZero . toVector
setConstant k = setConstant k . toVector
canModifyElem a (i,j) = return (i == j && i >= 0 && i < numRows a)
unsafeWriteElem a (i,_) = unsafeWriteElem (toVector a) i
modifyWith f a = modifyWith f (toVector a)
instance Base.Matrix (DiagMatrix t) where
numRows (Diag x) = dim x
numCols = numRows
herm (Diag x) = unsafeCoerce $ Diag (conj x)
instance (BLAS2 e) => IMatrix (DiagMatrix Imm) e
instance (BLAS2 e) => RMatrix (DiagMatrix t) e where
unsafeDoSApplyAdd alpha a x beta y = do
x' <- newCopy x
unsafeDoSApply_ 1 (coerceDiag a) (V.unsafeThaw x')
scaleBy beta y
V.unsafeAxpy alpha x' (V.coerceVector y)
unsafeDoSApplyAddMat alpha a b beta c = do
M.scaleBy beta c
ks <- unsafeInterleaveIO $ getElems (toVector $ coerceDiag a)
let (kxs) = zip ks (rows b)
ys = rows c
zipWithM_ (\(k,x) y -> V.unsafeAxpy (alpha*k) x y) kxs ys
unsafeDoSApply_ alpha a x = do
V.unsafeTimesEquals x (toVector a)
V.scaleBy alpha x
unsafeDoSApplyMat_ alpha a b = do
ks <- unsafeInterleaveIO $ getElems (toVector a)
zipWithM_ (\k r -> scaleBy (alpha*k) r) ks (rows b)
instance (BLAS2 e) => ISolve (DiagMatrix Imm) e
instance (BLAS2 e) => RSolve (DiagMatrix Imm) e where
unsafeDoSSolve alpha a y x = do
unsafeCopyVector x (unsafeCoerce y)
unsafeDoSSolve_ alpha (coerceDiag a) x
unsafeDoSSolveMat alpha a c b = do
unsafeCopyMatrix b (unsafeCoerce c)
unsafeDoSSolveMat_ alpha (coerceDiag a) b
unsafeDoSSolve_ alpha a x = do
V.scaleBy alpha x
x //= (toVector a)
unsafeDoSSolveMat_ alpha a b = do
M.scaleBy alpha b
ks <- unsafeInterleaveIO $ getElems (toVector a)
zipWithM_ (\k r -> invScaleBy k r) ks (rows b)
instance (Show e, BLAS1 e) => Show (DiagMatrix Imm (n,n) e) where
show (Diag x) = "fromVector (" ++ show x ++ ")"
instance (Eq e, BLAS1 e) => Eq (DiagMatrix Imm (n,n) e) where
(==) (Diag x) (Diag y) = (==) x y
instance (AEq e, BLAS1 e) => AEq (DiagMatrix Imm (n,n) e) where
(===) (Diag x) (Diag y) = (===) x y
(~==) (Diag x) (Diag y) = (~==) x y