module Data.Matrix.Perm (
module BLAS.Matrix,
Perm(..),
identity,
fromPermutation,
toPermutation,
coercePerm
) where
import Control.Monad ( forM_ )
import Data.AEq
import BLAS.Elem
import BLAS.Matrix
import BLAS.Tensor
import BLAS.UnsafeIOToM
import Data.Matrix.Dense.Class( ReadMatrix, WriteMatrix, unsafeRowView,
unsafeAxpyMatrix, coerceMatrix )
import Data.Vector.Dense.Class( ReadVector, WriteVector, dim, isConj,
scaleBy, unsafeAxpyVector, unsafeSwapVector, coerceVector )
import Data.Vector.Dense.ST( runSTVector )
import Data.Matrix.Dense.ST( runSTMatrix )
import Data.Permutation ( Permutation )
import qualified Data.Permutation as P
import Unsafe.Coerce
data Perm mn e =
P { baseOf :: !Permutation
, isHerm :: !Bool
}
| I !Int
identity :: Int -> Perm (n,n) e
identity = I
fromPermutation :: Permutation -> Perm (n,n) e
fromPermutation = flip P False
toPermutation :: Perm (n,n) e -> Permutation
toPermutation (I n) = P.identity n
toPermutation (P sigma h) = if h then P.inverse sigma else sigma
coercePerm :: Perm mn e -> Perm mn' e
coercePerm = unsafeCoerce
instance BaseTensor Perm (Int,Int) where
shape (P sigma _) = (n,n) where n = P.size sigma
shape (I n) = (n,n)
bounds p = ((0,0), (m1,n1)) where (m,n) = shape p
instance BaseMatrix Perm where
herm a@(P _ _) = a{ isHerm=(not . isHerm) a }
herm a@(I _) = coercePerm a
instance (BLAS1 e) => IMatrix Perm e where
unsafeSApply alpha a x = runSTVector $ unsafeGetSApply alpha a x
unsafeSApplyMat alpha a b = runSTMatrix $ unsafeGetSApplyMat alpha a b
instance (BLAS1 e, UnsafeIOToM m) => MMatrix Perm e m where
unsafeDoSApplyAdd = unsafeDoSApplyAddPerm
unsafeDoSApplyAddMat = unsafeDoSApplyAddMatPerm
unsafeDoSApply_ = unsafeDoSApplyPerm_
unsafeDoSApplyMat_ = unsafeDoSApplyMatPerm_
unsafeDoSApplyPerm_ :: (WriteVector y m, BLAS1 e) =>
e -> Perm (k,k) e -> y k e -> m ()
unsafeDoSApplyPerm_ alpha (I _) x = scaleBy alpha x
unsafeDoSApplyPerm_ alpha p@(P _ _) x
| isHerm p = P.invertWith swap sigma >> scaleBy alpha x
| otherwise = P.applyWith swap sigma >> scaleBy alpha x
where
sigma = baseOf p
swap = unsafeSwapElems x
unsafeDoSApplyMatPerm_ :: (WriteMatrix c z m, BLAS1 e) =>
e -> Perm (k,k) e -> c (k,l) e -> m ()
unsafeDoSApplyMatPerm_ alpha (I _) a = scaleBy alpha a
unsafeDoSApplyMatPerm_ alpha p@(P _ _) a
| isHerm p = P.invertWith swap sigma >> scaleBy alpha a
| otherwise = P.applyWith swap sigma >> scaleBy alpha a
where
sigma = baseOf p
swap i j = unsafeSwapVector (unsafeRowView a i) (unsafeRowView a j)
unsafeDoSApplyAddPerm :: (ReadVector x m, WriteVector y m, BLAS1 e) =>
e -> Perm (k,l) e -> x l e -> e -> y k e -> m ()
unsafeDoSApplyAddPerm alpha (I _) x beta y = do
scaleBy beta y
unsafeAxpyVector alpha (coerceVector x) y
unsafeDoSApplyAddPerm alpha p x beta y
| isConj x =
unsafeDoSApplyAddPerm (conj alpha) p (conj x) (conj beta) (conj y)
| otherwise =
let n = dim x
sigma = baseOf p
in do
scaleBy beta y
forM_ [0..(n1)] $ \i ->
let i' = P.unsafeApply sigma i
in case (isHerm p) of
False -> do
e <- unsafeReadElem x i
f <- unsafeReadElem y i'
unsafeWriteElem y i' (alpha*e + f)
True -> do
e <- unsafeReadElem x i'
f <- unsafeReadElem y i
unsafeWriteElem y i (alpha*e + f)
unsafeDoSApplyAddMatPerm :: (ReadMatrix b x m, WriteMatrix c y m, BLAS1 e) =>
e -> Perm (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m ()
unsafeDoSApplyAddMatPerm alpha (I _) b beta c = do
scaleBy beta c
unsafeAxpyMatrix alpha b (coerceMatrix c)
unsafeDoSApplyAddMatPerm alpha p b beta c =
let m = numCols p
sigma = baseOf p
in do
scaleBy beta c
forM_ [0..(m1)] $ \i ->
let i' = P.unsafeApply sigma i
in case (isHerm p) of
False -> unsafeAxpyVector alpha (unsafeRowView b i)
(unsafeRowView c i')
True -> unsafeAxpyVector alpha (unsafeRowView b i')
(unsafeRowView c i)
instance (BLAS1 e) => ISolve Perm e where
unsafeSSolve alpha a y = runSTVector $ unsafeGetSSolve alpha a y
unsafeSSolveMat alpha a c = runSTMatrix $ unsafeGetSSolveMat alpha a c
instance (BLAS1 e, UnsafeIOToM m) => MSolve Perm e m where
unsafeDoSSolve_ alpha p = unsafeDoSApplyPerm_ alpha (herm p)
unsafeDoSSolveMat_ alpha p = unsafeDoSApplyMatPerm_ alpha (herm p)
unsafeDoSSolve alpha p x y = unsafeDoSApplyAddPerm alpha (coercePerm $ herm p) x 0 y
unsafeDoSSolveMat alpha p a b = unsafeDoSApplyAddMatPerm alpha (coercePerm $ herm p) a 0 b
instance (Elem e) => Show (Perm (n,n) e) where
show (I n) = "identity " ++ show n
show p | isHerm p = "herm (" ++ show (herm p) ++ ")"
| otherwise = "fromPermutation (" ++ show (baseOf p) ++ ")"
instance Eq (Perm (n,n) e) where
(==) (I n) (I n') = n == n'
(==) (I n) p
| isHerm p = (==) (I n) (herm p)
| otherwise = (==) (fromPermutation $ P.identity n) p
(==) p (I n) = (==) (I n) p
(==) (P sigma h) (P sigma' h')
| h == h' = sigma == sigma'
| otherwise = P.size sigma == P.size sigma'
&& sigma == (P.inverse sigma')
instance AEq (Perm (n,n) e) where
(===) = (==)
(~==) = (==)