{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Permutation.Private where import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent import qualified Numeric.LAPACK.Shape as ExtShape import qualified Numeric.LAPACK.Output as Output import Numeric.LAPACK.Output (Output, formatAligned) import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor)) import Numeric.LAPACK.Matrix.Modifier (Transposition(NonTransposed,Transposed), Inversion(NonInverted,Inverted)) import Numeric.LAPACK.Matrix.Private (Full, Square, shapeInt) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Scalar (zero, one) import Numeric.LAPACK.Private (copyBlock, copyToTemp) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Mutable.Unchecked as MutArray import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Storable as CheckedArray import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array), (!)) import Foreign.C.Types (CInt) import Foreign.ForeignPtr (withForeignPtr) import Foreign.Ptr (Ptr, castPtr) import Foreign.Storable (Storable, sizeOf, alignment, poke, peek) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Monad.ST (ST, runST) import Control.Monad (when, forM_) import Control.Applicative (liftA2, (<$>)) import qualified Data.Tuple.HT as Tuple import Data.Function.HT (powerAssociative) import Data.Monoid (Monoid, mempty, mappend) import Data.Semigroup (Semigroup, (<>)) import Prelude hiding (odd) newtype Permutation sh = Permutation (Vector (Shape sh) (Element sh)) deriving (Show) format :: (Shape.C sh, Output out) => Permutation sh -> out format (Permutation perm) = let n = Shape.size $ Array.shape perm in formatAligned $ map (map ((:[]) . Output.text . (:""))) $ map (\k -> (replicate (k-1) '.' ++ '1' : replicate (n-k) '.')) $ map (fromIntegral . deconsElement) $ Array.toList perm size :: Permutation sh -> sh size (Permutation (Array (Shape shape) _perm)) = shape identity :: (Shape.C sh) => sh -> Permutation sh identity shape = Permutation $ CheckedArray.sample (Shape shape) id fromPivots :: (Shape.C sh) => Inversion -> Vector (Shape sh) (Element sh) -> Permutation sh fromPivots inverted ipiv = fromPivotsGen inverted (Array.shape ipiv) ipiv {- We could use laswp if it would be available for CInt elements. -} fromTruncatedPivots :: (Shape.C sh, Shape.C sh1) => Inversion -> Vector (ExtShape.Min sh1 (Shape sh)) (Element sh) -> Permutation sh fromTruncatedPivots inverted ipiv = fromPivotsGen inverted (ExtShape.minShape1 $ Array.shape ipiv) ipiv fromPivotsGen :: (Shape.C sh, Shape.Indexed small, Shape.Index small ~ Element sh) => Inversion -> Shape sh -> Vector small (Element sh) -> Permutation sh fromPivotsGen inverted sh ipiv = Permutation $ runST (do perm <- initMutable sh $ \perm i -> MutArray.write perm i i forM_ (indices inverted $ Array.shape ipiv) $ \i -> swap perm i (ipiv!i) MutArray.unsafeFreeze perm) swap :: (Shape.Indexed sh, Shape.Index sh ~ ix, Storable a) => MutArray.Array (ST s) sh a -> ix -> ix -> ST s () swap arr i j = do a <- MutArray.read arr i MutArray.write arr i =<< MutArray.read arr j MutArray.write arr j a indices :: (Shape.C sh, Shape.Indexed small, Shape.Index small ~ Element sh) => Inversion -> small -> [Element sh] indices inverted sh = let numIPiv = Shape.size sh in take numIPiv $ map Element $ case inverted of Inverted -> iterate (subtract 1) (fromIntegral numIPiv) NonInverted -> iterate (1+) 1 toPivots :: (Shape.C sh) => Inversion -> Permutation sh -> Vector sh (Element sh) toPivots inverted (Permutation a) = let sh = Array.shape a in Array.reshape (deconsShape sh) $ runST (do (inv,perm) <- (case inverted of Inverted -> Tuple.swap; NonInverted -> id) <$> liftA2 (,) (MutArray.thaw a) (transposeToMutable a) forM_ (Shape.indices sh) $ \i -> do j <- MutArray.read inv i k <- MutArray.read perm i MutArray.write perm j k MutArray.write inv k j MutArray.unsafeFreeze inv) data Sign = Positive | Negative deriving (Eq, Show, Enum, Bounded) instance Semigroup Sign where x<>y = if x==y then Positive else Negative instance Monoid Sign where mempty = Positive mappend = (<>) {- We could also count the cycles of even number. This might be a little faster. -} determinant :: (Shape.C sh) => Permutation sh -> Sign determinant = (\oddp -> if oddp then Negative else Positive) . odd . map deconsElement . Array.toList . toPivots NonInverted {- | > numberFromSign s == (-1)^fromEnum s -} numberFromSign :: (Class.Floating a) => Sign -> a numberFromSign s = case s of Negative -> -1 Positive -> 1 condNegate :: (Class.Floating a) => [CInt] -> a -> a condNegate ipiv = if odd ipiv then negate else id odd :: [CInt] -> Bool odd = not . null . dropEven . filter id . zipWith (/=) [1..] dropEven :: [a] -> [a] dropEven (_:_:xs) = dropEven xs dropEven xs = xs transpose :: (Shape.C sh) => Permutation sh -> Permutation sh transpose (Permutation perm) = Permutation $ runST (MutArray.unsafeFreeze =<< transposeToMutable perm) transposeToMutable :: (Shape.Indexed sh, Shape.Index sh ~ ix, Storable ix) => Array sh ix -> ST s (MutArray.Array (ST s) sh ix) transposeToMutable perm = initMutable (Array.shape perm) $ \inv i -> MutArray.write inv (perm!i) i inversionFromTransposition :: Transposition -> Inversion inversionFromTransposition trans = case trans of NonTransposed -> NonInverted Transposed -> Inverted multiply :: (Shape.C sh, Eq sh) => Permutation sh -> Permutation sh -> Permutation sh multiply a b = if size a /= size b then error "Permutation.multiply: sizes mismatch" else multiplyUnchecked a b square :: (Shape.C sh) => Permutation sh -> Permutation sh square p = multiplyUnchecked p p power :: (Shape.C sh) => Integer -> Permutation sh -> Permutation sh power n p = powerAssociative multiplyUnchecked (identity $ size p) p n multiplyUnchecked :: (Shape.C sh) => Permutation sh -> Permutation sh -> Permutation sh multiplyUnchecked (Permutation a) (Permutation b) = Permutation $ CheckedArray.sample (Array.shape a) $ \i -> b!(a!i) takeDiagonal :: (Shape.C sh, Class.Floating a) => Permutation sh -> Vector sh a takeDiagonal (Permutation a) = Array.mapShape deconsShape $ CheckedArray.sample (Array.shape a) $ \i -> if a!i == i then 1 else 0 toMatrix :: (Shape.C sh, Class.Floating a) => Permutation sh -> Square sh a toMatrix (Permutation perm) = let shape = Array.shape perm in Array.reshape (MatrixShape.square RowMajor $ deconsShape shape) $ runST (do a <- MutArray.new (shape,shape) zero forM_ (Shape.indices $ Array.shape perm) $ \k -> MutArray.write a (k, perm!k) one MutArray.unsafeFreeze a) apply :: (Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Class.Floating a) => Inversion -> Permutation height -> Full vert horiz height width a -> Full vert horiz height width a apply inverted (Permutation (Array (Shape shapeP) perm)) (Array shape@(MatrixShape.Full order extent) a) = Array.unsafeCreateWithSize shape $ \blockSize bPtr -> do let (height,width) = Extent.dimensions extent Call.assert "Permutation.apply: heights mismatch" (height == shapeP) let m = Shape.size height let n = Shape.size width evalContT $ do fwdPtr <- Call.bool $ inverted==NonInverted mPtr <- Call.cint m nPtr <- Call.cint n kPtr <- deconsElementPtr <$> copyToTemp m perm aPtr <- ContT $ withForeignPtr a liftIO $ do copyBlock blockSize aPtr bPtr when (m>0 && n>0) $ case order of RowMajor -> LapackGen.lapmt fwdPtr nPtr mPtr bPtr nPtr kPtr ColumnMajor -> LapackGen.lapmr fwdPtr mPtr nPtr bPtr mPtr kPtr initMutable :: (Shape.Indexed sh, Shape.Index sh ~ ix, Storable a) => sh -> (MutArray.Array (ST s) sh a -> ix -> ST s ()) -> ST s (MutArray.Array (ST s) sh a) initMutable sh f = do arr <- MutArray.unsafeCreate sh (\ _ -> return ()) mapM_ (f arr) $ Shape.indices sh return arr -- cf. Shape.Deferred newtype Shape sh = Shape {deconsShape :: sh} deriving (Eq, Show) newtype Element sh = Element {deconsElement :: CInt} deriving (Eq, Show) deconsElementPtr :: Ptr (Element sh) -> Ptr CInt deconsElementPtr = castPtr instance (Shape.C sh) => Shape.C (Shape sh) where size (Shape sh) = Shape.size sh uncheckedSize (Shape sh) = Shape.uncheckedSize sh instance (Shape.C sh) => Shape.Indexed (Shape sh) where type Index (Shape sh) = Element sh indices (Shape sh) = map Element $ take (Shape.size sh) [1 ..] offset (Shape sh) (Element k) = Shape.offset (shapeInt $ Shape.size sh) (fromIntegral k - 1) uncheckedOffset _ (Element k) = fromIntegral k - 1 inBounds (Shape sh) (Element k) = Shape.inBounds (shapeInt $ Shape.size sh) (fromIntegral k - 1) instance (Shape.C sh) => Shape.InvIndexed (Shape sh) where indexFromOffset (Shape sh) k = Element $ 1 + fromIntegral (Shape.indexFromOffset (shapeInt $ Shape.size sh) k) uncheckedIndexFromOffset _sh = Element . (1+) . fromIntegral instance Storable (Element sh) where {-# INLINE sizeOf #-} {-# INLINE alignment #-} {-# INLINE peek #-} {-# INLINE poke #-} sizeOf (Element k) = sizeOf k alignment (Element k) = alignment k poke p (Element k) = poke (deconsElementPtr p) k peek p = fmap Element $ peek (deconsElementPtr p)