{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Permutation.Private where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Shape as ExtShape
import qualified Numeric.LAPACK.Output as Output
import Numeric.LAPACK.Output (Output, formatAligned)
import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor))
import Numeric.LAPACK.Matrix.Modifier
         (Transposition(NonTransposed,Transposed),
          Inversion(NonInverted,Inverted))
import Numeric.LAPACK.Matrix.Private (Full, Square, shapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Private (copyBlock, copyToTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Mutable.Unchecked as MutArray
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array), (!))

import Foreign.C.Types (CInt)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable (Storable, sizeOf, alignment, poke, peek)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.ST (ST, runST)
import Control.Monad (when, forM_)
import Control.Applicative (liftA2, (<$>))

import qualified Data.Tuple.HT as Tuple
import Data.Function.HT (powerAssociative)
import Data.Monoid (Monoid, mempty, mappend)
import Data.Semigroup (Semigroup, (<>))

import Prelude hiding (odd)


{- $setup
>>> import qualified Test.QuickCheck as QC
>>> import Test.Permutation (genPerm, genPivots)
>>>
>>> import qualified Numeric.LAPACK.Permutation as Perm
>>> import Numeric.LAPACK.Permutation (Permutation, Inversion(NonInverted), determinant, multiply, transpose)
>>> import Numeric.LAPACK.Matrix (ShapeInt)
>>>
>>> import qualified Data.Array.Comfort.Storable as Array
>>> import Data.Eq.HT (equating)
>>> import Data.Semigroup ((<>))
>>>
>>> import Control.Applicative (liftA2)
>>>
>>> genPerm2 :: QC.Gen (Permutation ShapeInt, Permutation ShapeInt)
>>> genPerm2 = do
>>>    nat <- QC.arbitrary
>>>    liftA2 (,) (genPerm nat) (genPerm nat)
-}


newtype Permutation sh = Permutation (Vector (Shape sh) (Element sh))
   deriving (Int -> Permutation sh -> ShowS
[Permutation sh] -> ShowS
Permutation sh -> String
(Int -> Permutation sh -> ShowS)
-> (Permutation sh -> String)
-> ([Permutation sh] -> ShowS)
-> Show (Permutation sh)
forall sh. (C sh, Show sh) => Int -> Permutation sh -> ShowS
forall sh. (C sh, Show sh) => [Permutation sh] -> ShowS
forall sh. (C sh, Show sh) => Permutation sh -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Permutation sh] -> ShowS
$cshowList :: forall sh. (C sh, Show sh) => [Permutation sh] -> ShowS
show :: Permutation sh -> String
$cshow :: forall sh. (C sh, Show sh) => Permutation sh -> String
showsPrec :: Int -> Permutation sh -> ShowS
$cshowsPrec :: forall sh. (C sh, Show sh) => Int -> Permutation sh -> ShowS
Show)

format :: (Shape.C sh, Output out) => Permutation sh -> out
format :: Permutation sh -> out
format (Permutation Vector (Shape sh) (Element sh)
perm) =
   let n :: Int
n = Shape sh -> Int
forall sh. C sh => sh -> Int
Shape.size (Shape sh -> Int) -> Shape sh -> Int
forall a b. (a -> b) -> a -> b
$ Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
perm
   in [[[out]]] -> out
forall out (f :: * -> *).
(Output out, Foldable f) =>
[[f out]] -> out
formatAligned ([[[out]]] -> out) -> [[[out]]] -> out
forall a b. (a -> b) -> a -> b
$
      (String -> [[out]]) -> [String] -> [[[out]]]
forall a b. (a -> b) -> [a] -> [b]
map ((Char -> [out]) -> String -> [[out]]
forall a b. (a -> b) -> [a] -> [b]
map ((out -> [out] -> [out]
forall a. a -> [a] -> [a]
:[]) (out -> [out]) -> (Char -> out) -> Char -> [out]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> out
forall out. Output out => String -> out
Output.text (String -> out) -> (Char -> String) -> Char -> out
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> ShowS
forall a. a -> [a] -> [a]
:String
""))) ([String] -> [[[out]]]) -> [String] -> [[[out]]]
forall a b. (a -> b) -> a -> b
$
      (Int -> String) -> [Int] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
k -> (Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Char
'.' String -> ShowS
forall a. [a] -> [a] -> [a]
++ Char
'1' Char -> ShowS
forall a. a -> [a] -> [a]
: Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
k) Char
'.')) ([Int] -> [String]) -> [Int] -> [String]
forall a b. (a -> b) -> a -> b
$
      (Element sh -> Int) -> [Element sh] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> (Element sh -> CInt) -> Element sh -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Element sh -> CInt
forall sh. Element sh -> CInt
deconsElement) ([Element sh] -> [Int]) -> [Element sh] -> [Int]
forall a b. (a -> b) -> a -> b
$ Vector (Shape sh) (Element sh) -> [Element sh]
forall sh a. (C sh, Storable a) => Array sh a -> [a]
Array.toList Vector (Shape sh) (Element sh)
perm


size :: Permutation sh -> sh
size :: Permutation sh -> sh
size (Permutation (Array (Shape sh
shape) ForeignPtr (Element sh)
_perm)) = sh
shape

identity :: (Shape.C sh) => sh -> Permutation sh
identity :: sh -> Permutation sh
identity sh
shape = Vector (Shape sh) (Element sh) -> Permutation sh
forall sh. Vector (Shape sh) (Element sh) -> Permutation sh
Permutation (Vector (Shape sh) (Element sh) -> Permutation sh)
-> Vector (Shape sh) (Element sh) -> Permutation sh
forall a b. (a -> b) -> a -> b
$ Shape sh
-> (Index (Shape sh) -> Element sh)
-> Vector (Shape sh) (Element sh)
forall sh a.
(Indexed sh, Storable a) =>
sh -> (Index sh -> a) -> Array sh a
CheckedArray.sample (sh -> Shape sh
forall sh. sh -> Shape sh
Shape sh
shape) Index (Shape sh) -> Element sh
forall a. a -> a
id

{- |
prop> QC.forAll QC.arbitraryBoundedEnum $ \inv -> QC.forAll (QC.arbitrary >>= genPivots) $ \xs -> Array.toList xs == Array.toList (Perm.toPivots inv (Perm.fromPivots inv xs))
-}
fromPivots ::
   (Shape.C sh) =>
   Inversion -> Vector (Shape sh) (Element sh) -> Permutation sh
fromPivots :: Inversion -> Vector (Shape sh) (Element sh) -> Permutation sh
fromPivots Inversion
inverted Vector (Shape sh) (Element sh)
ipiv =
   Inversion
-> Shape sh -> Vector (Shape sh) (Element sh) -> Permutation sh
forall sh small.
(C sh, Indexed small, Index small ~ Element sh) =>
Inversion
-> Shape sh -> Vector small (Element sh) -> Permutation sh
fromPivotsGen Inversion
inverted (Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
ipiv) Vector (Shape sh) (Element sh)
ipiv

{-
We could use laswp if it would be available for CInt elements.
-}
fromTruncatedPivots ::
   (Shape.C sh, Shape.C sh1) =>
   Inversion ->
   Vector (ExtShape.Min sh1 (Shape sh)) (Element sh) -> Permutation sh
fromTruncatedPivots :: Inversion
-> Vector (Min sh1 (Shape sh)) (Element sh) -> Permutation sh
fromTruncatedPivots Inversion
inverted Vector (Min sh1 (Shape sh)) (Element sh)
ipiv =
   Inversion
-> Shape sh
-> Vector (Min sh1 (Shape sh)) (Element sh)
-> Permutation sh
forall sh small.
(C sh, Indexed small, Index small ~ Element sh) =>
Inversion
-> Shape sh -> Vector small (Element sh) -> Permutation sh
fromPivotsGen Inversion
inverted (Min sh1 (Shape sh) -> Shape sh
forall sh0 sh1. Min sh0 sh1 -> sh1
ExtShape.minShape1 (Min sh1 (Shape sh) -> Shape sh) -> Min sh1 (Shape sh) -> Shape sh
forall a b. (a -> b) -> a -> b
$ Vector (Min sh1 (Shape sh)) (Element sh) -> Min sh1 (Shape sh)
forall sh a. Array sh a -> sh
Array.shape Vector (Min sh1 (Shape sh)) (Element sh)
ipiv) Vector (Min sh1 (Shape sh)) (Element sh)
ipiv

fromPivotsGen ::
   (Shape.C sh, Shape.Indexed small, Shape.Index small ~ Element sh) =>
   Inversion -> Shape sh -> Vector small (Element sh) -> Permutation sh
fromPivotsGen :: Inversion
-> Shape sh -> Vector small (Element sh) -> Permutation sh
fromPivotsGen Inversion
inverted Shape sh
sh Vector small (Element sh)
ipiv =
   Vector (Shape sh) (Element sh) -> Permutation sh
forall sh. Vector (Shape sh) (Element sh) -> Permutation sh
Permutation (Vector (Shape sh) (Element sh) -> Permutation sh)
-> Vector (Shape sh) (Element sh) -> Permutation sh
forall a b. (a -> b) -> a -> b
$
   (forall s. ST s (Vector (Shape sh) (Element sh)))
-> Vector (Shape sh) (Element sh)
forall a. (forall s. ST s a) -> a
runST (do
      Array (ST s) (Shape sh) (Element sh)
perm <- Shape sh
-> (Array (ST s) (Shape sh) (Element sh) -> Element sh -> ST s ())
-> ST s (Array (ST s) (Shape sh) (Element sh))
forall sh ix a s.
(Indexed sh, Index sh ~ ix, Storable a) =>
sh
-> (Array (ST s) sh a -> ix -> ST s ()) -> ST s (Array (ST s) sh a)
initMutable Shape sh
sh ((Array (ST s) (Shape sh) (Element sh) -> Element sh -> ST s ())
 -> ST s (Array (ST s) (Shape sh) (Element sh)))
-> (Array (ST s) (Shape sh) (Element sh) -> Element sh -> ST s ())
-> ST s (Array (ST s) (Shape sh) (Element sh))
forall a b. (a -> b) -> a -> b
$ \Array (ST s) (Shape sh) (Element sh)
perm Element sh
i -> Array (ST s) (Shape sh) (Element sh)
-> Index (Shape sh) -> Element sh -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) (Shape sh) (Element sh)
perm Index (Shape sh)
Element sh
i Element sh
i
      [Element sh] -> (Element sh -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Inversion -> small -> [Element sh]
forall sh small.
(C sh, Indexed small, Index small ~ Element sh) =>
Inversion -> small -> [Element sh]
indices Inversion
inverted (small -> [Element sh]) -> small -> [Element sh]
forall a b. (a -> b) -> a -> b
$ Vector small (Element sh) -> small
forall sh a. Array sh a -> sh
Array.shape Vector small (Element sh)
ipiv) ((Element sh -> ST s ()) -> ST s ())
-> (Element sh -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Element sh
i -> Array (ST s) (Shape sh) (Element sh)
-> Element sh -> Element sh -> ST s ()
forall sh ix a s.
(Indexed sh, Index sh ~ ix, Storable a) =>
Array (ST s) sh a -> ix -> ix -> ST s ()
swap Array (ST s) (Shape sh) (Element sh)
perm Element sh
i (Vector small (Element sh)
ipivVector small (Element sh) -> Index small -> Element sh
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!Index small
Element sh
i)
      Array (ST s) (Shape sh) (Element sh)
-> ST s (Vector (Shape sh) (Element sh))
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
MutArray.unsafeFreeze Array (ST s) (Shape sh) (Element sh)
perm)

swap ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, Storable a) =>
   MutArray.Array (ST s) sh a -> ix -> ix -> ST s ()
swap :: Array (ST s) sh a -> ix -> ix -> ST s ()
swap Array (ST s) sh a
arr ix
i ix
j = do
   a
a <- Array (ST s) sh a -> Index sh -> ST s a
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
MutArray.read Array (ST s) sh a
arr ix
Index sh
i
   Array (ST s) sh a -> Index sh -> a -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) sh a
arr ix
Index sh
i (a -> ST s ()) -> ST s a -> ST s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Array (ST s) sh a -> Index sh -> ST s a
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
MutArray.read Array (ST s) sh a
arr ix
Index sh
j
   Array (ST s) sh a -> Index sh -> a -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) sh a
arr ix
Index sh
j a
a

indices ::
   (Shape.C sh, Shape.Indexed small, Shape.Index small ~ Element sh) =>
   Inversion -> small -> [Element sh]
indices :: Inversion -> small -> [Element sh]
indices Inversion
inverted small
sh =
   let numIPiv :: Int
numIPiv = small -> Int
forall sh. C sh => sh -> Int
Shape.size small
sh
   in Int -> [Element sh] -> [Element sh]
forall a. Int -> [a] -> [a]
take Int
numIPiv ([Element sh] -> [Element sh]) -> [Element sh] -> [Element sh]
forall a b. (a -> b) -> a -> b
$ (CInt -> Element sh) -> [CInt] -> [Element sh]
forall a b. (a -> b) -> [a] -> [b]
map CInt -> Element sh
forall sh. CInt -> Element sh
Element ([CInt] -> [Element sh]) -> [CInt] -> [Element sh]
forall a b. (a -> b) -> a -> b
$
      case Inversion
inverted of
         Inversion
Inverted -> (CInt -> CInt) -> CInt -> [CInt]
forall a. (a -> a) -> a -> [a]
iterate (CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
subtract CInt
1) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
numIPiv)
         Inversion
NonInverted -> (CInt -> CInt) -> CInt -> [CInt]
forall a. (a -> a) -> a -> [a]
iterate (CInt
1CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
+) CInt
1


toPivots ::
   (Shape.C sh) => Inversion -> Permutation sh -> Vector sh (Element sh)
toPivots :: Inversion -> Permutation sh -> Vector sh (Element sh)
toPivots Inversion
inverted (Permutation Vector (Shape sh) (Element sh)
a) =
   let sh :: Shape sh
sh = Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
a
   in sh -> Vector (Shape sh) (Element sh) -> Vector sh (Element sh)
forall sh1 sh0 a. sh1 -> Array sh0 a -> Array sh1 a
Array.reshape (Shape sh -> sh
forall sh. Shape sh -> sh
deconsShape Shape sh
sh) (Vector (Shape sh) (Element sh) -> Vector sh (Element sh))
-> Vector (Shape sh) (Element sh) -> Vector sh (Element sh)
forall a b. (a -> b) -> a -> b
$
      (forall s. ST s (Vector (Shape sh) (Element sh)))
-> Vector (Shape sh) (Element sh)
forall a. (forall s. ST s a) -> a
runST (do
         (Array (ST s) (Shape sh) (Element sh)
inv,Array (ST s) (Shape sh) (Element sh)
perm) <-
            (case Inversion
inverted of Inversion
Inverted -> (Array (ST s) (Shape sh) (Element sh),
 Array (ST s) (Shape sh) (Element sh))
-> (Array (ST s) (Shape sh) (Element sh),
    Array (ST s) (Shape sh) (Element sh))
forall a b. (a, b) -> (b, a)
Tuple.swap; Inversion
NonInverted -> (Array (ST s) (Shape sh) (Element sh),
 Array (ST s) (Shape sh) (Element sh))
-> (Array (ST s) (Shape sh) (Element sh),
    Array (ST s) (Shape sh) (Element sh))
forall a. a -> a
id)
            ((Array (ST s) (Shape sh) (Element sh),
  Array (ST s) (Shape sh) (Element sh))
 -> (Array (ST s) (Shape sh) (Element sh),
     Array (ST s) (Shape sh) (Element sh)))
-> ST
     s
     (Array (ST s) (Shape sh) (Element sh),
      Array (ST s) (Shape sh) (Element sh))
-> ST
     s
     (Array (ST s) (Shape sh) (Element sh),
      Array (ST s) (Shape sh) (Element sh))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
            (Array (ST s) (Shape sh) (Element sh)
 -> Array (ST s) (Shape sh) (Element sh)
 -> (Array (ST s) (Shape sh) (Element sh),
     Array (ST s) (Shape sh) (Element sh)))
-> ST s (Array (ST s) (Shape sh) (Element sh))
-> ST s (Array (ST s) (Shape sh) (Element sh))
-> ST
     s
     (Array (ST s) (Shape sh) (Element sh),
      Array (ST s) (Shape sh) (Element sh))
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,)
               (Vector (Shape sh) (Element sh)
-> ST s (Array (ST s) (Shape sh) (Element sh))
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
MutArray.thaw Vector (Shape sh) (Element sh)
a)
               (Vector (Shape sh) (Element sh)
-> ST s (Array (ST s) (Shape sh) (Element sh))
forall sh ix s.
(Indexed sh, Index sh ~ ix, Storable ix) =>
Array sh ix -> ST s (Array (ST s) sh ix)
transposeToMutable Vector (Shape sh) (Element sh)
a)
         [Element sh] -> (Element sh -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Shape sh -> [Index (Shape sh)]
forall sh. Indexed sh => sh -> [Index sh]
Shape.indices Shape sh
sh) ((Element sh -> ST s ()) -> ST s ())
-> (Element sh -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Element sh
i -> do
            Element sh
j <- Array (ST s) (Shape sh) (Element sh)
-> Index (Shape sh) -> ST s (Element sh)
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
MutArray.read Array (ST s) (Shape sh) (Element sh)
inv Index (Shape sh)
Element sh
i
            Element sh
k <- Array (ST s) (Shape sh) (Element sh)
-> Index (Shape sh) -> ST s (Element sh)
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
MutArray.read Array (ST s) (Shape sh) (Element sh)
perm Index (Shape sh)
Element sh
i
            Array (ST s) (Shape sh) (Element sh)
-> Index (Shape sh) -> Element sh -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) (Shape sh) (Element sh)
perm Index (Shape sh)
Element sh
j Element sh
k
            Array (ST s) (Shape sh) (Element sh)
-> Index (Shape sh) -> Element sh -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) (Shape sh) (Element sh)
inv Index (Shape sh)
Element sh
k Element sh
j
         Array (ST s) (Shape sh) (Element sh)
-> ST s (Vector (Shape sh) (Element sh))
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
MutArray.unsafeFreeze Array (ST s) (Shape sh) (Element sh)
inv)


data Sign = Positive | Negative
   deriving (Sign -> Sign -> Bool
(Sign -> Sign -> Bool) -> (Sign -> Sign -> Bool) -> Eq Sign
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Sign -> Sign -> Bool
$c/= :: Sign -> Sign -> Bool
== :: Sign -> Sign -> Bool
$c== :: Sign -> Sign -> Bool
Eq, Int -> Sign -> ShowS
[Sign] -> ShowS
Sign -> String
(Int -> Sign -> ShowS)
-> (Sign -> String) -> ([Sign] -> ShowS) -> Show Sign
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Sign] -> ShowS
$cshowList :: [Sign] -> ShowS
show :: Sign -> String
$cshow :: Sign -> String
showsPrec :: Int -> Sign -> ShowS
$cshowsPrec :: Int -> Sign -> ShowS
Show, Int -> Sign
Sign -> Int
Sign -> [Sign]
Sign -> Sign
Sign -> Sign -> [Sign]
Sign -> Sign -> Sign -> [Sign]
(Sign -> Sign)
-> (Sign -> Sign)
-> (Int -> Sign)
-> (Sign -> Int)
-> (Sign -> [Sign])
-> (Sign -> Sign -> [Sign])
-> (Sign -> Sign -> [Sign])
-> (Sign -> Sign -> Sign -> [Sign])
-> Enum Sign
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: Sign -> Sign -> Sign -> [Sign]
$cenumFromThenTo :: Sign -> Sign -> Sign -> [Sign]
enumFromTo :: Sign -> Sign -> [Sign]
$cenumFromTo :: Sign -> Sign -> [Sign]
enumFromThen :: Sign -> Sign -> [Sign]
$cenumFromThen :: Sign -> Sign -> [Sign]
enumFrom :: Sign -> [Sign]
$cenumFrom :: Sign -> [Sign]
fromEnum :: Sign -> Int
$cfromEnum :: Sign -> Int
toEnum :: Int -> Sign
$ctoEnum :: Int -> Sign
pred :: Sign -> Sign
$cpred :: Sign -> Sign
succ :: Sign -> Sign
$csucc :: Sign -> Sign
Enum, Sign
Sign -> Sign -> Bounded Sign
forall a. a -> a -> Bounded a
maxBound :: Sign
$cmaxBound :: Sign
minBound :: Sign
$cminBound :: Sign
Bounded)

instance Semigroup Sign where
   Sign
x<> :: Sign -> Sign -> Sign
<>Sign
y = if Sign
xSign -> Sign -> Bool
forall a. Eq a => a -> a -> Bool
==Sign
y then Sign
Positive else Sign
Negative

instance Monoid Sign where
   mempty :: Sign
mempty = Sign
Positive
   mappend :: Sign -> Sign -> Sign
mappend = Sign -> Sign -> Sign
forall a. Semigroup a => a -> a -> a
(<>)

{-
We could also count the cycles of even number. This might be a little faster.
-}
{- |
prop> QC.forAll genPerm2 $ \(p0,p1) -> determinant (multiply p0 p1) == determinant p0 <> determinant p1
-}
determinant :: (Shape.C sh) => Permutation sh -> Sign
determinant :: Permutation sh -> Sign
determinant =
   (\Bool
oddp -> if Bool
oddp then Sign
Negative else Sign
Positive) (Bool -> Sign)
-> (Permutation sh -> Bool) -> Permutation sh -> Sign
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
   [CInt] -> Bool
odd ([CInt] -> Bool)
-> (Permutation sh -> [CInt]) -> Permutation sh -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Element sh -> CInt) -> [Element sh] -> [CInt]
forall a b. (a -> b) -> [a] -> [b]
map Element sh -> CInt
forall sh. Element sh -> CInt
deconsElement ([Element sh] -> [CInt])
-> (Permutation sh -> [Element sh]) -> Permutation sh -> [CInt]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array sh (Element sh) -> [Element sh]
forall sh a. (C sh, Storable a) => Array sh a -> [a]
Array.toList (Array sh (Element sh) -> [Element sh])
-> (Permutation sh -> Array sh (Element sh))
-> Permutation sh
-> [Element sh]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Inversion -> Permutation sh -> Array sh (Element sh)
forall sh.
C sh =>
Inversion -> Permutation sh -> Vector sh (Element sh)
toPivots Inversion
NonInverted

{- |
> numberFromSign s == (-1)^fromEnum s
-}
numberFromSign :: (Class.Floating a) => Sign -> a
numberFromSign :: Sign -> a
numberFromSign Sign
s =
   case Sign
s of
      Sign
Negative -> -a
1
      Sign
Positive -> a
1


condNegate :: (Class.Floating a) => [CInt] -> a -> a
condNegate :: [CInt] -> a -> a
condNegate [CInt]
ipiv = if [CInt] -> Bool
odd [CInt]
ipiv then a -> a
forall a. Num a => a -> a
negate else a -> a
forall a. a -> a
id

odd :: [CInt] -> Bool
odd :: [CInt] -> Bool
odd = Bool -> Bool
not (Bool -> Bool) -> ([CInt] -> Bool) -> [CInt] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Bool] -> Bool) -> ([CInt] -> [Bool]) -> [CInt] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [Bool]
forall a. [a] -> [a]
dropEven ([Bool] -> [Bool]) -> ([CInt] -> [Bool]) -> [CInt] -> [Bool]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> Bool) -> [Bool] -> [Bool]
forall a. (a -> Bool) -> [a] -> [a]
filter Bool -> Bool
forall a. a -> a
id ([Bool] -> [Bool]) -> ([CInt] -> [Bool]) -> [CInt] -> [Bool]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CInt -> CInt -> Bool) -> [CInt] -> [CInt] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
(/=) [CInt
1..]

dropEven :: [a] -> [a]
dropEven :: [a] -> [a]
dropEven (a
_:a
_:[a]
xs) = [a] -> [a]
forall a. [a] -> [a]
dropEven [a]
xs
dropEven [a]
xs = [a]
xs


{- |
prop> QC.forAll genPerm2 $ \(p0,p1) -> equating (Array.toList . Perm.toPivots NonInverted) (transpose $ multiply p0 p1) (multiply (transpose p1) (transpose p0))
-}
transpose :: (Shape.C sh) => Permutation sh -> Permutation sh
transpose :: Permutation sh -> Permutation sh
transpose (Permutation Vector (Shape sh) (Element sh)
perm) =
   Vector (Shape sh) (Element sh) -> Permutation sh
forall sh. Vector (Shape sh) (Element sh) -> Permutation sh
Permutation (Vector (Shape sh) (Element sh) -> Permutation sh)
-> Vector (Shape sh) (Element sh) -> Permutation sh
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (Vector (Shape sh) (Element sh)))
-> Vector (Shape sh) (Element sh)
forall a. (forall s. ST s a) -> a
runST (Array (ST s) (Shape sh) (Element sh)
-> ST s (Vector (Shape sh) (Element sh))
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
MutArray.unsafeFreeze (Array (ST s) (Shape sh) (Element sh)
 -> ST s (Vector (Shape sh) (Element sh)))
-> ST s (Array (ST s) (Shape sh) (Element sh))
-> ST s (Vector (Shape sh) (Element sh))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Vector (Shape sh) (Element sh)
-> ST s (Array (ST s) (Shape sh) (Element sh))
forall sh ix s.
(Indexed sh, Index sh ~ ix, Storable ix) =>
Array sh ix -> ST s (Array (ST s) sh ix)
transposeToMutable Vector (Shape sh) (Element sh)
perm)

transposeToMutable ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, Storable ix) =>
   Array sh ix -> ST s (MutArray.Array (ST s) sh ix)
transposeToMutable :: Array sh ix -> ST s (Array (ST s) sh ix)
transposeToMutable Array sh ix
perm =
   sh
-> (Array (ST s) sh ix -> ix -> ST s ())
-> ST s (Array (ST s) sh ix)
forall sh ix a s.
(Indexed sh, Index sh ~ ix, Storable a) =>
sh
-> (Array (ST s) sh a -> ix -> ST s ()) -> ST s (Array (ST s) sh a)
initMutable (Array sh ix -> sh
forall sh a. Array sh a -> sh
Array.shape Array sh ix
perm) ((Array (ST s) sh ix -> ix -> ST s ())
 -> ST s (Array (ST s) sh ix))
-> (Array (ST s) sh ix -> ix -> ST s ())
-> ST s (Array (ST s) sh ix)
forall a b. (a -> b) -> a -> b
$ \Array (ST s) sh ix
inv ix
i -> Array (ST s) sh ix -> Index sh -> ix -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) sh ix
inv (Array sh ix
permArray sh ix -> Index sh -> ix
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!ix
Index sh
i) ix
i

inversionFromTransposition :: Transposition -> Inversion
inversionFromTransposition :: Transposition -> Inversion
inversionFromTransposition Transposition
trans =
   case Transposition
trans of
      Transposition
NonTransposed -> Inversion
NonInverted
      Transposition
Transposed -> Inversion
Inverted


multiply :: (Shape.C sh, Eq sh) =>
   Permutation sh -> Permutation sh -> Permutation sh
multiply :: Permutation sh -> Permutation sh -> Permutation sh
multiply Permutation sh
a Permutation sh
b =
   if Permutation sh -> sh
forall sh. Permutation sh -> sh
size Permutation sh
a sh -> sh -> Bool
forall a. Eq a => a -> a -> Bool
/= Permutation sh -> sh
forall sh. Permutation sh -> sh
size Permutation sh
b
      then String -> Permutation sh
forall a. HasCallStack => String -> a
error String
"Permutation.multiply: sizes mismatch"
      else Permutation sh -> Permutation sh -> Permutation sh
forall sh.
C sh =>
Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked Permutation sh
a Permutation sh
b

square :: (Shape.C sh) => Permutation sh -> Permutation sh
square :: Permutation sh -> Permutation sh
square Permutation sh
p = Permutation sh -> Permutation sh -> Permutation sh
forall sh.
C sh =>
Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked Permutation sh
p Permutation sh
p

power :: (Shape.C sh) => Integer -> Permutation sh -> Permutation sh
power :: Integer -> Permutation sh -> Permutation sh
power Integer
n Permutation sh
p = (Permutation sh -> Permutation sh -> Permutation sh)
-> Permutation sh -> Permutation sh -> Integer -> Permutation sh
forall a. (a -> a -> a) -> a -> a -> Integer -> a
powerAssociative Permutation sh -> Permutation sh -> Permutation sh
forall sh.
C sh =>
Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked (sh -> Permutation sh
forall sh. C sh => sh -> Permutation sh
identity (sh -> Permutation sh) -> sh -> Permutation sh
forall a b. (a -> b) -> a -> b
$ Permutation sh -> sh
forall sh. Permutation sh -> sh
size Permutation sh
p) Permutation sh
p Integer
n

multiplyUnchecked :: (Shape.C sh) =>
   Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked :: Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked (Permutation Vector (Shape sh) (Element sh)
a) (Permutation Vector (Shape sh) (Element sh)
b) =
   Vector (Shape sh) (Element sh) -> Permutation sh
forall sh. Vector (Shape sh) (Element sh) -> Permutation sh
Permutation (Vector (Shape sh) (Element sh) -> Permutation sh)
-> Vector (Shape sh) (Element sh) -> Permutation sh
forall a b. (a -> b) -> a -> b
$ Shape sh
-> (Index (Shape sh) -> Element sh)
-> Vector (Shape sh) (Element sh)
forall sh a.
(Indexed sh, Storable a) =>
sh -> (Index sh -> a) -> Array sh a
CheckedArray.sample (Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
a) ((Index (Shape sh) -> Element sh)
 -> Vector (Shape sh) (Element sh))
-> (Index (Shape sh) -> Element sh)
-> Vector (Shape sh) (Element sh)
forall a b. (a -> b) -> a -> b
$ \Index (Shape sh)
i -> Vector (Shape sh) (Element sh)
bVector (Shape sh) (Element sh) -> Index (Shape sh) -> Element sh
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!(Vector (Shape sh) (Element sh)
aVector (Shape sh) (Element sh) -> Index (Shape sh) -> Element sh
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!Index (Shape sh)
i)


takeDiagonal ::
   (Shape.C sh, Class.Floating a) => Permutation sh -> Vector sh a
takeDiagonal :: Permutation sh -> Vector sh a
takeDiagonal (Permutation Vector (Shape sh) (Element sh)
a) =
   (Shape sh -> sh) -> Array (Shape sh) a -> Vector sh a
forall sh0 sh1 a. (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
Array.mapShape Shape sh -> sh
forall sh. Shape sh -> sh
deconsShape (Array (Shape sh) a -> Vector sh a)
-> Array (Shape sh) a -> Vector sh a
forall a b. (a -> b) -> a -> b
$
   Shape sh -> (Index (Shape sh) -> a) -> Array (Shape sh) a
forall sh a.
(Indexed sh, Storable a) =>
sh -> (Index sh -> a) -> Array sh a
CheckedArray.sample (Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
a) ((Index (Shape sh) -> a) -> Array (Shape sh) a)
-> (Index (Shape sh) -> a) -> Array (Shape sh) a
forall a b. (a -> b) -> a -> b
$ \Index (Shape sh)
i -> if Vector (Shape sh) (Element sh)
aVector (Shape sh) (Element sh) -> Index (Shape sh) -> Element sh
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!Index (Shape sh)
i Element sh -> Element sh -> Bool
forall a. Eq a => a -> a -> Bool
== Index (Shape sh)
Element sh
i then a
1 else a
0


toMatrix :: (Shape.C sh, Class.Floating a) => Permutation sh -> Square sh a
toMatrix :: Permutation sh -> Square sh a
toMatrix (Permutation Vector (Shape sh) (Element sh)
perm) =
   let shape :: Shape sh
shape = Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
perm
   in Square sh -> Array (Shape sh, Shape sh) a -> Square sh a
forall sh1 sh0 a. sh1 -> Array sh0 a -> Array sh1 a
Array.reshape (Order -> sh -> Square sh
forall sh. Order -> sh -> Square sh
MatrixShape.square Order
RowMajor (sh -> Square sh) -> sh -> Square sh
forall a b. (a -> b) -> a -> b
$ Shape sh -> sh
forall sh. Shape sh -> sh
deconsShape Shape sh
shape) (Array (Shape sh, Shape sh) a -> Square sh a)
-> Array (Shape sh, Shape sh) a -> Square sh a
forall a b. (a -> b) -> a -> b
$
      (forall s. ST s (Array (Shape sh, Shape sh) a))
-> Array (Shape sh, Shape sh) a
forall a. (forall s. ST s a) -> a
runST (do
         Array (ST s) (Shape sh, Shape sh) a
a <- (Shape sh, Shape sh)
-> a -> ST s (Array (ST s) (Shape sh, Shape sh) a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> a -> m (Array m sh a)
MutArray.new (Shape sh
shape,Shape sh
shape) a
forall a. Floating a => a
zero
         [Element sh] -> (Element sh -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Shape sh -> [Index (Shape sh)]
forall sh. Indexed sh => sh -> [Index sh]
Shape.indices (Shape sh -> [Index (Shape sh)]) -> Shape sh -> [Index (Shape sh)]
forall a b. (a -> b) -> a -> b
$ Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
perm) ((Element sh -> ST s ()) -> ST s ())
-> (Element sh -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Element sh
k ->
            Array (ST s) (Shape sh, Shape sh) a
-> Index (Shape sh, Shape sh) -> a -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) (Shape sh, Shape sh) a
a (Element sh
k, Vector (Shape sh) (Element sh)
permVector (Shape sh) (Element sh) -> Index (Shape sh) -> Element sh
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!Index (Shape sh)
Element sh
k) a
forall a. Floating a => a
one
         Array (ST s) (Shape sh, Shape sh) a
-> ST s (Array (Shape sh, Shape sh) a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
MutArray.unsafeFreeze Array (ST s) (Shape sh, Shape sh) a
a)


apply ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Inversion -> Permutation height ->
   Full vert horiz height width a ->
   Full vert horiz height width a
apply :: Inversion
-> Permutation height
-> Full vert horiz height width a
-> Full vert horiz height width a
apply Inversion
inverted
      (Permutation (Array (Shape height
shapeP) ForeignPtr (Element height)
perm))
      (Array shape :: Full vert horiz height width
shape@(MatrixShape.Full Order
order Extent vert horiz height width
extent) ForeignPtr a
a) =

   Full vert horiz height width
-> (Int -> Ptr a -> IO ()) -> Full vert horiz height width a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize Full vert horiz height width
shape ((Int -> Ptr a -> IO ()) -> Full vert horiz height width a)
-> (Int -> Ptr a -> IO ()) -> Full vert horiz height width a
forall a b. (a -> b) -> a -> b
$ \Int
blockSize Ptr a
bPtr -> do

   let (height
height,width
width) = Extent vert horiz height width -> (height, width)
forall vert horiz height width.
(C vert, C horiz) =>
Extent vert horiz height width -> (height, width)
Extent.dimensions Extent vert horiz height width
extent
   String -> Bool -> IO ()
Call.assert String
"Permutation.apply: heights mismatch" (height
height height -> height -> Bool
forall a. Eq a => a -> a -> Bool
== height
shapeP)
   let m :: Int
m = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height
   let n :: Int
n = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr Bool
fwdPtr <- Bool -> FortranIO () (Ptr Bool)
forall r. Bool -> FortranIO r (Ptr Bool)
Call.bool (Bool -> FortranIO () (Ptr Bool))
-> Bool -> FortranIO () (Ptr Bool)
forall a b. (a -> b) -> a -> b
$ Inversion
invertedInversion -> Inversion -> Bool
forall a. Eq a => a -> a -> Bool
==Inversion
NonInverted
      Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
kPtr <- Ptr (Element height) -> Ptr CInt
forall sh. Ptr (Element sh) -> Ptr CInt
deconsElementPtr (Ptr (Element height) -> Ptr CInt)
-> ContT () IO (Ptr (Element height)) -> FortranIO () (Ptr CInt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> ForeignPtr (Element height)
-> ContT () IO (Ptr (Element height))
forall a r. Storable a => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp Int
m ForeignPtr (Element height)
perm
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Int -> Ptr a -> Ptr a -> IO ()
copyBlock Int
blockSize Ptr a
aPtr Ptr a
bPtr
         Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
mInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0 Bool -> Bool -> Bool
&& Int
nInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            case Order
order of
               Order
RowMajor -> Ptr Bool
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr Bool
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.lapmt Ptr Bool
fwdPtr Ptr CInt
nPtr Ptr CInt
mPtr Ptr a
bPtr Ptr CInt
nPtr Ptr CInt
kPtr
               Order
ColumnMajor -> Ptr Bool
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr Bool
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.lapmr Ptr Bool
fwdPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
bPtr Ptr CInt
mPtr Ptr CInt
kPtr


initMutable ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, Storable a) =>
   sh -> (MutArray.Array (ST s) sh a -> ix -> ST s ()) ->
   ST s (MutArray.Array (ST s) sh a)
initMutable :: sh
-> (Array (ST s) sh a -> ix -> ST s ()) -> ST s (Array (ST s) sh a)
initMutable sh
sh Array (ST s) sh a -> ix -> ST s ()
f = do
   Array (ST s) sh a
arr <- sh -> (Ptr a -> IO ()) -> ST s (Array (ST s) sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> m (Array m sh a)
MutArray.unsafeCreate sh
sh (\ Ptr a
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
   (ix -> ST s ()) -> [ix] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Array (ST s) sh a -> ix -> ST s ()
f Array (ST s) sh a
arr) ([ix] -> ST s ()) -> [ix] -> ST s ()
forall a b. (a -> b) -> a -> b
$ sh -> [Index sh]
forall sh. Indexed sh => sh -> [Index sh]
Shape.indices sh
sh
   Array (ST s) sh a -> ST s (Array (ST s) sh a)
forall (m :: * -> *) a. Monad m => a -> m a
return Array (ST s) sh a
arr



-- cf. Shape.Deferred
newtype Shape sh = Shape {Shape sh -> sh
deconsShape :: sh}
   deriving (Shape sh -> Shape sh -> Bool
(Shape sh -> Shape sh -> Bool)
-> (Shape sh -> Shape sh -> Bool) -> Eq (Shape sh)
forall sh. Eq sh => Shape sh -> Shape sh -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Shape sh -> Shape sh -> Bool
$c/= :: forall sh. Eq sh => Shape sh -> Shape sh -> Bool
== :: Shape sh -> Shape sh -> Bool
$c== :: forall sh. Eq sh => Shape sh -> Shape sh -> Bool
Eq, Int -> Shape sh -> ShowS
[Shape sh] -> ShowS
Shape sh -> String
(Int -> Shape sh -> ShowS)
-> (Shape sh -> String) -> ([Shape sh] -> ShowS) -> Show (Shape sh)
forall sh. Show sh => Int -> Shape sh -> ShowS
forall sh. Show sh => [Shape sh] -> ShowS
forall sh. Show sh => Shape sh -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Shape sh] -> ShowS
$cshowList :: forall sh. Show sh => [Shape sh] -> ShowS
show :: Shape sh -> String
$cshow :: forall sh. Show sh => Shape sh -> String
showsPrec :: Int -> Shape sh -> ShowS
$cshowsPrec :: forall sh. Show sh => Int -> Shape sh -> ShowS
Show)

newtype Element sh = Element {Element sh -> CInt
deconsElement :: CInt}
   deriving (Element sh -> Element sh -> Bool
(Element sh -> Element sh -> Bool)
-> (Element sh -> Element sh -> Bool) -> Eq (Element sh)
forall sh. Element sh -> Element sh -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Element sh -> Element sh -> Bool
$c/= :: forall sh. Element sh -> Element sh -> Bool
== :: Element sh -> Element sh -> Bool
$c== :: forall sh. Element sh -> Element sh -> Bool
Eq, Int -> Element sh -> ShowS
[Element sh] -> ShowS
Element sh -> String
(Int -> Element sh -> ShowS)
-> (Element sh -> String)
-> ([Element sh] -> ShowS)
-> Show (Element sh)
forall sh. Int -> Element sh -> ShowS
forall sh. [Element sh] -> ShowS
forall sh. Element sh -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Element sh] -> ShowS
$cshowList :: forall sh. [Element sh] -> ShowS
show :: Element sh -> String
$cshow :: forall sh. Element sh -> String
showsPrec :: Int -> Element sh -> ShowS
$cshowsPrec :: forall sh. Int -> Element sh -> ShowS
Show)

deconsElementPtr :: Ptr (Element sh) -> Ptr CInt
deconsElementPtr :: Ptr (Element sh) -> Ptr CInt
deconsElementPtr = Ptr (Element sh) -> Ptr CInt
forall a b. Ptr a -> Ptr b
castPtr

instance (Shape.C sh) => Shape.C (Shape sh) where
   size :: Shape sh -> Int
size (Shape sh
sh) = sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh
   uncheckedSize :: Shape sh -> Int
uncheckedSize (Shape sh
sh) = sh -> Int
forall sh. C sh => sh -> Int
Shape.uncheckedSize sh
sh

instance (Shape.C sh) => Shape.Indexed (Shape sh) where
   type Index (Shape sh) = Element sh
   indices :: Shape sh -> [Index (Shape sh)]
indices (Shape sh
sh) = (CInt -> Element sh) -> [CInt] -> [Element sh]
forall a b. (a -> b) -> [a] -> [b]
map CInt -> Element sh
forall sh. CInt -> Element sh
Element ([CInt] -> [Element sh]) -> [CInt] -> [Element sh]
forall a b. (a -> b) -> a -> b
$ Int -> [CInt] -> [CInt]
forall a. Int -> [a] -> [a]
take (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) [CInt
1 ..]
   offset :: Shape sh -> Index (Shape sh) -> Int
offset (Shape sh
sh) (Element k) =
      ShapeInt -> Index ShapeInt -> Int
forall sh. Indexed sh => sh -> Index sh -> Int
Shape.offset (Int -> ShapeInt
shapeInt (Int -> ShapeInt) -> Int -> ShapeInt
forall a b. (a -> b) -> a -> b
$ sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
   uncheckedOffset :: Shape sh -> Index (Shape sh) -> Int
uncheckedOffset Shape sh
_ (Element k) = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
   inBounds :: Shape sh -> Index (Shape sh) -> Bool
inBounds (Shape sh
sh) (Element k) =
      ShapeInt -> Index ShapeInt -> Bool
forall sh. Indexed sh => sh -> Index sh -> Bool
Shape.inBounds (Int -> ShapeInt
shapeInt (Int -> ShapeInt) -> Int -> ShapeInt
forall a b. (a -> b) -> a -> b
$ sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

instance (Shape.C sh) => Shape.InvIndexed (Shape sh) where
   indexFromOffset :: Shape sh -> Int -> Index (Shape sh)
indexFromOffset (Shape sh
sh) Int
k =
      CInt -> Element sh
forall sh. CInt -> Element sh
Element (CInt -> Element sh) -> CInt -> Element sh
forall a b. (a -> b) -> a -> b
$
         CInt
1 CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
+ Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ShapeInt -> Int -> Index ShapeInt
forall sh. InvIndexed sh => sh -> Int -> Index sh
Shape.indexFromOffset (Int -> ShapeInt
shapeInt (Int -> ShapeInt) -> Int -> ShapeInt
forall a b. (a -> b) -> a -> b
$ sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) Int
k)
   uncheckedIndexFromOffset :: Shape sh -> Int -> Index (Shape sh)
uncheckedIndexFromOffset Shape sh
_sh = CInt -> Element sh
forall sh. CInt -> Element sh
Element (CInt -> Element sh) -> (Int -> CInt) -> Int -> Element sh
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CInt
1CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
+) (CInt -> CInt) -> (Int -> CInt) -> Int -> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance Storable (Element sh) where
   {-# INLINE sizeOf #-}
   {-# INLINE alignment #-}
   {-# INLINE peek #-}
   {-# INLINE poke #-}
   sizeOf :: Element sh -> Int
sizeOf (Element CInt
k) = CInt -> Int
forall a. Storable a => a -> Int
sizeOf CInt
k
   alignment :: Element sh -> Int
alignment (Element CInt
k) = CInt -> Int
forall a. Storable a => a -> Int
alignment CInt
k
   poke :: Ptr (Element sh) -> Element sh -> IO ()
poke Ptr (Element sh)
p (Element CInt
k) = Ptr CInt -> CInt -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr (Element sh) -> Ptr CInt
forall sh. Ptr (Element sh) -> Ptr CInt
deconsElementPtr Ptr (Element sh)
p) CInt
k
   peek :: Ptr (Element sh) -> IO (Element sh)
peek Ptr (Element sh)
p = (CInt -> Element sh) -> IO CInt -> IO (Element sh)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CInt -> Element sh
forall sh. CInt -> Element sh
Element (IO CInt -> IO (Element sh)) -> IO CInt -> IO (Element sh)
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek (Ptr (Element sh) -> Ptr CInt
forall sh. Ptr (Element sh) -> Ptr CInt
deconsElementPtr Ptr (Element sh)
p)