{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Matrix.Perm -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Matrix.Perm ( module BLAS.Matrix, Perm(..), -- * The identity permutation identity, -- * Converting to/from @Permutation@s fromPermutation, toPermutation, -- * Coercing coercePerm ) where import Control.Monad ( forM_ ) import Data.AEq import BLAS.Elem import BLAS.Matrix import BLAS.Tensor import BLAS.UnsafeIOToM import Data.Matrix.Dense.Class( ReadMatrix, WriteMatrix, unsafeRowView, unsafeAxpyMatrix, coerceMatrix ) import Data.Vector.Dense.Class( ReadVector, WriteVector, dim, isConj, scaleBy, unsafeAxpyVector, unsafeSwapVector, coerceVector ) import Data.Vector.Dense.ST( runSTVector ) import Data.Matrix.Dense.ST( runSTMatrix ) {- import Data.Vector.Dense.IO ( dim, isConj, conj, unsafeSwapVectors, unsafeCopyVector, unsafeWithElemPtr, unsafeReadElem, unsafeWriteElem, coerceVector ) import qualified Data.Vector.Dense.IO as V import Data.Matrix.Dense.IO ( unsafeRow, unsafeCopyMatrix, coerceMatrix ) import qualified Data.Matrix.Dense.IO as M -} import Data.Permutation ( Permutation ) import qualified Data.Permutation as P import Unsafe.Coerce data Perm mn e = P { baseOf :: !Permutation , isHerm :: !Bool } | I !Int identity :: Int -> Perm (n,n) e identity = I fromPermutation :: Permutation -> Perm (n,n) e fromPermutation = flip P False toPermutation :: Perm (n,n) e -> Permutation toPermutation (I n) = P.identity n toPermutation (P sigma h) = if h then P.inverse sigma else sigma coercePerm :: Perm mn e -> Perm mn' e coercePerm = unsafeCoerce instance BaseTensor Perm (Int,Int) where shape (P sigma _) = (n,n) where n = P.size sigma shape (I n) = (n,n) bounds p = ((0,0), (m-1,n-1)) where (m,n) = shape p instance BaseMatrix Perm where herm a@(P _ _) = a{ isHerm=(not . isHerm) a } herm a@(I _) = coercePerm a {- instance (BLAS1 e) => IMatrix Perm e where instance (BLAS1 e) => RMatrix Perm e where -} instance (BLAS1 e) => IMatrix Perm e where unsafeSApply alpha a x = runSTVector $ unsafeGetSApply alpha a x unsafeSApplyMat alpha a b = runSTMatrix $ unsafeGetSApplyMat alpha a b instance (BLAS1 e, UnsafeIOToM m) => MMatrix Perm e m where unsafeDoSApplyAdd = unsafeDoSApplyAddPerm unsafeDoSApplyAddMat = unsafeDoSApplyAddMatPerm unsafeDoSApply_ = unsafeDoSApplyPerm_ unsafeDoSApplyMat_ = unsafeDoSApplyMatPerm_ unsafeDoSApplyPerm_ :: (WriteVector y m, BLAS1 e) => e -> Perm (k,k) e -> y k e -> m () unsafeDoSApplyPerm_ alpha (I _) x = scaleBy alpha x unsafeDoSApplyPerm_ alpha p@(P _ _) x | isHerm p = P.invertWith swap sigma >> scaleBy alpha x | otherwise = P.applyWith swap sigma >> scaleBy alpha x where sigma = baseOf p swap = unsafeSwapElems x unsafeDoSApplyMatPerm_ :: (WriteMatrix c z m, BLAS1 e) => e -> Perm (k,k) e -> c (k,l) e -> m () unsafeDoSApplyMatPerm_ alpha (I _) a = scaleBy alpha a unsafeDoSApplyMatPerm_ alpha p@(P _ _) a | isHerm p = P.invertWith swap sigma >> scaleBy alpha a | otherwise = P.applyWith swap sigma >> scaleBy alpha a where sigma = baseOf p swap i j = unsafeSwapVector (unsafeRowView a i) (unsafeRowView a j) unsafeDoSApplyAddPerm :: (ReadVector x m, WriteVector y m, BLAS1 e) => e -> Perm (k,l) e -> x l e -> e -> y k e -> m () unsafeDoSApplyAddPerm alpha (I _) x beta y = do scaleBy beta y unsafeAxpyVector alpha (coerceVector x) y unsafeDoSApplyAddPerm alpha p x beta y | isConj x = unsafeDoSApplyAddPerm (conj alpha) p (conj x) (conj beta) (conj y) | otherwise = let n = dim x sigma = baseOf p in do scaleBy beta y forM_ [0..(n-1)] $ \i -> let i' = P.unsafeApply sigma i in case (isHerm p) of False -> do e <- unsafeReadElem x i f <- unsafeReadElem y i' unsafeWriteElem y i' (alpha*e + f) True -> do e <- unsafeReadElem x i' f <- unsafeReadElem y i unsafeWriteElem y i (alpha*e + f) unsafeDoSApplyAddMatPerm :: (ReadMatrix b x m, WriteMatrix c y m, BLAS1 e) => e -> Perm (r,s) e -> b (s,t) e -> e -> c (r,t) e -> m () unsafeDoSApplyAddMatPerm alpha (I _) b beta c = do scaleBy beta c unsafeAxpyMatrix alpha b (coerceMatrix c) unsafeDoSApplyAddMatPerm alpha p b beta c = let m = numCols p sigma = baseOf p in do scaleBy beta c forM_ [0..(m-1)] $ \i -> let i' = P.unsafeApply sigma i in case (isHerm p) of False -> unsafeAxpyVector alpha (unsafeRowView b i) (unsafeRowView c i') True -> unsafeAxpyVector alpha (unsafeRowView b i') (unsafeRowView c i) instance (BLAS1 e) => ISolve Perm e where unsafeSSolve alpha a y = runSTVector $ unsafeGetSSolve alpha a y unsafeSSolveMat alpha a c = runSTMatrix $ unsafeGetSSolveMat alpha a c instance (BLAS1 e, UnsafeIOToM m) => MSolve Perm e m where unsafeDoSSolve_ alpha p = unsafeDoSApplyPerm_ alpha (herm p) unsafeDoSSolveMat_ alpha p = unsafeDoSApplyMatPerm_ alpha (herm p) unsafeDoSSolve alpha p x y = unsafeDoSApplyAddPerm alpha (coercePerm $ herm p) x 0 y unsafeDoSSolveMat alpha p a b = unsafeDoSApplyAddMatPerm alpha (coercePerm $ herm p) a 0 b instance (Elem e) => Show (Perm (n,n) e) where show (I n) = "identity " ++ show n show p | isHerm p = "herm (" ++ show (herm p) ++ ")" | otherwise = "fromPermutation (" ++ show (baseOf p) ++ ")" instance Eq (Perm (n,n) e) where (==) (I n) (I n') = n == n' (==) (I n) p | isHerm p = (==) (I n) (herm p) | otherwise = (==) (fromPermutation $ P.identity n) p (==) p (I n) = (==) (I n) p (==) (P sigma h) (P sigma' h') | h == h' = sigma == sigma' | otherwise = P.size sigma == P.size sigma' && sigma == (P.inverse sigma') instance AEq (Perm (n,n) e) where (===) = (==) (~==) = (==)