{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Permutation.Private where

import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Output as Output
import Numeric.LAPACK.Output (Output, formatAligned)
import Numeric.LAPACK.Matrix.Layout.Private (Order(RowMajor, ColumnMajor))
import Numeric.LAPACK.Matrix.Modifier
         (Transposition(NonTransposed,Transposed),
          Inversion(NonInverted,Inverted))
import Numeric.LAPACK.Matrix.Private (Full, Square)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Private (copyBlock, copyToTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Mutable.Unchecked as MutArray
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array), (!))

import Foreign.C.Types (CInt)
import Foreign.ForeignPtr (withForeignPtr, castForeignPtr)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable (Storable, sizeOf, alignment, poke, peek)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.ST (ST, runST)
import Control.Monad (when, forM_)
import Control.Applicative (liftA2, (<$>))

import qualified Data.Tuple.HT as Tuple
import Data.Function.HT (powerAssociative)
import Data.Monoid (Monoid, mempty, mappend)
import Data.Semigroup (Semigroup, (<>))

import Prelude hiding (odd)


{- $setup
>>> import qualified Test.QuickCheck as QC
>>> import Test.Permutation (genPerm, genPivots)
>>>
>>> import qualified Numeric.LAPACK.Permutation as Perm
>>> import Numeric.LAPACK.Permutation
>>>          (Permutation, determinant, multiply, transpose)
>>> import Numeric.LAPACK.Matrix (ShapeInt)
>>>
>>> import Data.Semigroup ((<>))
>>>
>>> import Control.Applicative (liftA2)
>>>
>>> genPerm2 :: QC.Gen (Permutation ShapeInt, Permutation ShapeInt)
>>> genPerm2 = do
>>>    nat <- QC.arbitrary
>>>    liftA2 (,) (genPerm nat) (genPerm nat)
-}


newtype Permutation sh = Permutation (Vector (Shape sh) (Element sh))
   deriving (Permutation sh -> Permutation sh -> Bool
(Permutation sh -> Permutation sh -> Bool)
-> (Permutation sh -> Permutation sh -> Bool)
-> Eq (Permutation sh)
forall sh.
(C sh, Eq sh) =>
Permutation sh -> Permutation sh -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Permutation sh -> Permutation sh -> Bool
$c/= :: forall sh.
(C sh, Eq sh) =>
Permutation sh -> Permutation sh -> Bool
== :: Permutation sh -> Permutation sh -> Bool
$c== :: forall sh.
(C sh, Eq sh) =>
Permutation sh -> Permutation sh -> Bool
Eq, Int -> Permutation sh -> ShowS
[Permutation sh] -> ShowS
Permutation sh -> String
(Int -> Permutation sh -> ShowS)
-> (Permutation sh -> String)
-> ([Permutation sh] -> ShowS)
-> Show (Permutation sh)
forall sh. (C sh, Show sh) => Int -> Permutation sh -> ShowS
forall sh. (C sh, Show sh) => [Permutation sh] -> ShowS
forall sh. (C sh, Show sh) => Permutation sh -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Permutation sh] -> ShowS
$cshowList :: forall sh. (C sh, Show sh) => [Permutation sh] -> ShowS
show :: Permutation sh -> String
$cshow :: forall sh. (C sh, Show sh) => Permutation sh -> String
showsPrec :: Int -> Permutation sh -> ShowS
$cshowsPrec :: forall sh. (C sh, Show sh) => Int -> Permutation sh -> ShowS
Show)

format :: (Shape.C sh, Output out) => Permutation sh -> out
format :: Permutation sh -> out
format (Permutation Vector (Shape sh) (Element sh)
perm) =
   let n :: Int
n = Shape sh -> Int
forall sh. C sh => sh -> Int
Shape.size (Shape sh -> Int) -> Shape sh -> Int
forall a b. (a -> b) -> a -> b
$ Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
perm
   in [[[out]]] -> out
forall out (f :: * -> *).
(Output out, Foldable f) =>
[[f out]] -> out
formatAligned ([[[out]]] -> out) -> [[[out]]] -> out
forall a b. (a -> b) -> a -> b
$
      (String -> [[out]]) -> [String] -> [[[out]]]
forall a b. (a -> b) -> [a] -> [b]
map ((Char -> [out]) -> String -> [[out]]
forall a b. (a -> b) -> [a] -> [b]
map ((out -> [out] -> [out]
forall a. a -> [a] -> [a]
:[]) (out -> [out]) -> (Char -> out) -> Char -> [out]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> out
forall out. Output out => String -> out
Output.text (String -> out) -> (Char -> String) -> Char -> out
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> ShowS
forall a. a -> [a] -> [a]
:String
""))) ([String] -> [[[out]]]) -> [String] -> [[[out]]]
forall a b. (a -> b) -> a -> b
$
      (Int -> String) -> [Int] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
k -> (Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Char
'.' String -> ShowS
forall a. [a] -> [a] -> [a]
++ Char
'1' Char -> ShowS
forall a. a -> [a] -> [a]
: Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
k) Char
'.')) ([Int] -> [String]) -> [Int] -> [String]
forall a b. (a -> b) -> a -> b
$
      (Element sh -> Int) -> [Element sh] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> (Element sh -> CInt) -> Element sh -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Element sh -> CInt
forall sh. Element sh -> CInt
deconsElement) ([Element sh] -> [Int]) -> [Element sh] -> [Int]
forall a b. (a -> b) -> a -> b
$ Vector (Shape sh) (Element sh) -> [Element sh]
forall sh a. (C sh, Storable a) => Array sh a -> [a]
Array.toList Vector (Shape sh) (Element sh)
perm


size :: Permutation sh -> sh
size :: Permutation sh -> sh
size (Permutation (Array (Shape sh
shape) ForeignPtr (Element sh)
_perm)) = sh
shape

{- |
Number of elements must be maintained.
-}
mapSizeUnchecked :: (shA -> shB) -> Permutation shA -> Permutation shB
mapSizeUnchecked :: (shA -> shB) -> Permutation shA -> Permutation shB
mapSizeUnchecked shA -> shB
f (Permutation (Array (Shape shA
shape) ForeignPtr (Element shA)
perm)) =
   Vector (Shape shB) (Element shB) -> Permutation shB
forall sh. Vector (Shape sh) (Element sh) -> Permutation sh
Permutation (Vector (Shape shB) (Element shB) -> Permutation shB)
-> Vector (Shape shB) (Element shB) -> Permutation shB
forall a b. (a -> b) -> a -> b
$ Shape shB
-> ForeignPtr (Element shB) -> Vector (Shape shB) (Element shB)
forall sh a. sh -> ForeignPtr a -> Array sh a
Array (shB -> Shape shB
forall sh. sh -> Shape sh
Shape (shB -> Shape shB) -> shB -> Shape shB
forall a b. (a -> b) -> a -> b
$ shA -> shB
f shA
shape) (ForeignPtr (Element shB) -> Vector (Shape shB) (Element shB))
-> ForeignPtr (Element shB) -> Vector (Shape shB) (Element shB)
forall a b. (a -> b) -> a -> b
$ ForeignPtr (Element shA) -> ForeignPtr (Element shB)
forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr (Element shA)
perm

{- |
Number of elements must be maintained.
-}
mapSize ::
   (Shape.C shA, Shape.C shB) =>
   (shA -> shB) -> Permutation shA -> Permutation shB
mapSize :: (shA -> shB) -> Permutation shA -> Permutation shB
mapSize shA -> shB
f = (shA -> shB) -> Permutation shA -> Permutation shB
forall shA shB. (shA -> shB) -> Permutation shA -> Permutation shB
mapSizeUnchecked (String -> (shA -> shB) -> shA -> shB
forall sha shb.
(C sha, C shb) =>
String -> (sha -> shb) -> sha -> shb
Layout.mapChecked String
"Permutation.mapSize" shA -> shB
f)

identity :: (Shape.C sh) => sh -> Permutation sh
identity :: sh -> Permutation sh
identity sh
shape = Vector (Shape sh) (Element sh) -> Permutation sh
forall sh. Vector (Shape sh) (Element sh) -> Permutation sh
Permutation (Vector (Shape sh) (Element sh) -> Permutation sh)
-> Vector (Shape sh) (Element sh) -> Permutation sh
forall a b. (a -> b) -> a -> b
$ Shape sh
-> (Index (Shape sh) -> Element sh)
-> Vector (Shape sh) (Element sh)
forall sh a.
(Indexed sh, Storable a) =>
sh -> (Index sh -> a) -> Array sh a
CheckedArray.sample (sh -> Shape sh
forall sh. sh -> Shape sh
Shape sh
shape) Index (Shape sh) -> Element sh
forall a. a -> a
id

{- |
prop> QC.forAll QC.arbitraryBoundedEnum $ \inv -> QC.forAll (QC.arbitrary >>= genPivots) $ \xs -> xs == Perm.toPivots inv (Perm.fromPivots inv xs)
-}
fromPivots ::
   (Shape.C sh) =>
   Inversion -> Vector (Shape sh) (Element sh) -> Permutation sh
fromPivots :: Inversion -> Vector (Shape sh) (Element sh) -> Permutation sh
fromPivots Inversion
inverted Vector (Shape sh) (Element sh)
ipiv =
   (Element sh -> Element sh)
-> Inversion
-> Shape sh
-> Vector (Shape sh) (Element sh)
-> Permutation sh
forall sh small.
(C sh, C small) =>
(Element small -> Element sh)
-> Inversion
-> Shape sh
-> Vector (Shape small) (Element sh)
-> Permutation sh
fromPivotsGen Element sh -> Element sh
forall a. a -> a
id Inversion
inverted (Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
ipiv) Vector (Shape sh) (Element sh)
ipiv

{-
We could use laswp if it would be available for CInt elements.
-}
fromTruncatedPivots ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width) =>
   (diagShape ~ Layout.RectangularDiagonal meas vert horiz height width) =>
   Inversion -> Vector (Shape diagShape) (Element height) -> Permutation height
fromTruncatedPivots :: Inversion
-> Vector (Shape diagShape) (Element height) -> Permutation height
fromTruncatedPivots Inversion
inverted Vector (Shape diagShape) (Element height)
ipiv =
   (Element diagShape -> Element height)
-> Inversion
-> Shape height
-> Vector (Shape diagShape) (Element height)
-> Permutation height
forall sh small.
(C sh, C small) =>
(Element small -> Element sh)
-> Inversion
-> Shape sh
-> Vector (Shape small) (Element sh)
-> Permutation sh
fromPivotsGen (\(Element CInt
i) -> CInt -> Element height
forall sh. CInt -> Element sh
Element CInt
i)
      Inversion
inverted (Banded U0 U0 meas vert horiz height width -> height
forall meas vert horiz sub super height width.
(Measure meas, C vert, C horiz) =>
Banded sub super meas vert horiz height width -> height
Layout.bandedHeight (Banded U0 U0 meas vert horiz height width -> height)
-> Shape (Banded U0 U0 meas vert horiz height width)
-> Shape height
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (Shape diagShape) (Element height) -> Shape diagShape
forall sh a. Array sh a -> sh
Array.shape Vector (Shape diagShape) (Element height)
ipiv) Vector (Shape diagShape) (Element height)
ipiv

fromPivotsGen ::
   (Shape.C sh, Shape.C small) =>
   (Element small -> Element sh) ->
   Inversion -> Shape sh -> Vector (Shape small) (Element sh) -> Permutation sh
fromPivotsGen :: (Element small -> Element sh)
-> Inversion
-> Shape sh
-> Vector (Shape small) (Element sh)
-> Permutation sh
fromPivotsGen Element small -> Element sh
mapIx Inversion
inverted Shape sh
sh Vector (Shape small) (Element sh)
ipiv =
   Vector (Shape sh) (Element sh) -> Permutation sh
forall sh. Vector (Shape sh) (Element sh) -> Permutation sh
Permutation (Vector (Shape sh) (Element sh) -> Permutation sh)
-> Vector (Shape sh) (Element sh) -> Permutation sh
forall a b. (a -> b) -> a -> b
$
   (forall s. ST s (Vector (Shape sh) (Element sh)))
-> Vector (Shape sh) (Element sh)
forall a. (forall s. ST s a) -> a
runST (do
      Array (ST s) (Shape sh) (Element sh)
perm <- Shape sh
-> (Array (ST s) (Shape sh) (Element sh) -> Element sh -> ST s ())
-> ST s (Array (ST s) (Shape sh) (Element sh))
forall sh ix a s.
(Indexed sh, Index sh ~ ix, Storable a) =>
sh
-> (Array (ST s) sh a -> ix -> ST s ()) -> ST s (Array (ST s) sh a)
initMutable Shape sh
sh ((Array (ST s) (Shape sh) (Element sh) -> Element sh -> ST s ())
 -> ST s (Array (ST s) (Shape sh) (Element sh)))
-> (Array (ST s) (Shape sh) (Element sh) -> Element sh -> ST s ())
-> ST s (Array (ST s) (Shape sh) (Element sh))
forall a b. (a -> b) -> a -> b
$ \Array (ST s) (Shape sh) (Element sh)
perm Element sh
i -> Array (ST s) (Shape sh) (Element sh)
-> Index (Shape sh) -> Element sh -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) (Shape sh) (Element sh)
perm Index (Shape sh)
Element sh
i Element sh
i
      [Element small] -> (Element small -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Inversion -> Shape small -> [Element small]
forall sh. C sh => Inversion -> Shape sh -> [Element sh]
indices Inversion
inverted (Shape small -> [Element small]) -> Shape small -> [Element small]
forall a b. (a -> b) -> a -> b
$ Vector (Shape small) (Element sh) -> Shape small
forall sh a. Array sh a -> sh
Array.shape Vector (Shape small) (Element sh)
ipiv) ((Element small -> ST s ()) -> ST s ())
-> (Element small -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Element small
i ->
         Array (ST s) (Shape sh) (Element sh)
-> Element sh -> Element sh -> ST s ()
forall sh ix a s.
(Indexed sh, Index sh ~ ix, Storable a) =>
Array (ST s) sh a -> ix -> ix -> ST s ()
swap Array (ST s) (Shape sh) (Element sh)
perm (Element small -> Element sh
mapIx Element small
i) (Vector (Shape small) (Element sh)
ipivVector (Shape small) (Element sh)
-> Index (Shape small) -> Element sh
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!Index (Shape small)
Element small
i)
      Array (ST s) (Shape sh) (Element sh)
-> ST s (Vector (Shape sh) (Element sh))
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
MutArray.unsafeFreeze Array (ST s) (Shape sh) (Element sh)
perm)

swap ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, Storable a) =>
   MutArray.Array (ST s) sh a -> ix -> ix -> ST s ()
swap :: Array (ST s) sh a -> ix -> ix -> ST s ()
swap Array (ST s) sh a
arr ix
i ix
j = do
   a
a <- Array (ST s) sh a -> Index sh -> ST s a
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
MutArray.read Array (ST s) sh a
arr ix
Index sh
i
   Array (ST s) sh a -> Index sh -> a -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) sh a
arr ix
Index sh
i (a -> ST s ()) -> ST s a -> ST s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Array (ST s) sh a -> Index sh -> ST s a
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
MutArray.read Array (ST s) sh a
arr ix
Index sh
j
   Array (ST s) sh a -> Index sh -> a -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) sh a
arr ix
Index sh
j a
a

indices :: (Shape.C sh) => Inversion -> Shape sh -> [Element sh]
indices :: Inversion -> Shape sh -> [Element sh]
indices Inversion
inverted Shape sh
sh =
   let numIPiv :: Int
numIPiv = Shape sh -> Int
forall sh. C sh => sh -> Int
Shape.size Shape sh
sh
   in Int -> [Element sh] -> [Element sh]
forall a. Int -> [a] -> [a]
take Int
numIPiv ([Element sh] -> [Element sh]) -> [Element sh] -> [Element sh]
forall a b. (a -> b) -> a -> b
$ (CInt -> Element sh) -> [CInt] -> [Element sh]
forall a b. (a -> b) -> [a] -> [b]
map CInt -> Element sh
forall sh. CInt -> Element sh
Element ([CInt] -> [Element sh]) -> [CInt] -> [Element sh]
forall a b. (a -> b) -> a -> b
$
      case Inversion
inverted of
         Inversion
Inverted -> (CInt -> CInt) -> CInt -> [CInt]
forall a. (a -> a) -> a -> [a]
iterate (CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
subtract CInt
1) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
numIPiv)
         Inversion
NonInverted -> (CInt -> CInt) -> CInt -> [CInt]
forall a. (a -> a) -> a -> [a]
iterate (CInt
1CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
+) CInt
1


toPivots ::
   (Shape.C sh) => Inversion -> Permutation sh -> Vector (Shape sh) (Element sh)
toPivots :: Inversion -> Permutation sh -> Vector (Shape sh) (Element sh)
toPivots Inversion
inverted (Permutation Vector (Shape sh) (Element sh)
a) =
   let sh :: Shape sh
sh = Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
a
   in Shape sh
-> Vector (Shape sh) (Element sh) -> Vector (Shape sh) (Element sh)
forall sh1 sh0 a. sh1 -> Array sh0 a -> Array sh1 a
Array.reshape Shape sh
sh (Vector (Shape sh) (Element sh) -> Vector (Shape sh) (Element sh))
-> Vector (Shape sh) (Element sh) -> Vector (Shape sh) (Element sh)
forall a b. (a -> b) -> a -> b
$
      (forall s. ST s (Vector (Shape sh) (Element sh)))
-> Vector (Shape sh) (Element sh)
forall a. (forall s. ST s a) -> a
runST (do
         (Array (ST s) (Shape sh) (Element sh)
inv,Array (ST s) (Shape sh) (Element sh)
perm) <-
            (case Inversion
inverted of Inversion
Inverted -> (Array (ST s) (Shape sh) (Element sh),
 Array (ST s) (Shape sh) (Element sh))
-> (Array (ST s) (Shape sh) (Element sh),
    Array (ST s) (Shape sh) (Element sh))
forall a b. (a, b) -> (b, a)
Tuple.swap; Inversion
NonInverted -> (Array (ST s) (Shape sh) (Element sh),
 Array (ST s) (Shape sh) (Element sh))
-> (Array (ST s) (Shape sh) (Element sh),
    Array (ST s) (Shape sh) (Element sh))
forall a. a -> a
id)
            ((Array (ST s) (Shape sh) (Element sh),
  Array (ST s) (Shape sh) (Element sh))
 -> (Array (ST s) (Shape sh) (Element sh),
     Array (ST s) (Shape sh) (Element sh)))
-> ST
     s
     (Array (ST s) (Shape sh) (Element sh),
      Array (ST s) (Shape sh) (Element sh))
-> ST
     s
     (Array (ST s) (Shape sh) (Element sh),
      Array (ST s) (Shape sh) (Element sh))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
            (Array (ST s) (Shape sh) (Element sh)
 -> Array (ST s) (Shape sh) (Element sh)
 -> (Array (ST s) (Shape sh) (Element sh),
     Array (ST s) (Shape sh) (Element sh)))
-> ST s (Array (ST s) (Shape sh) (Element sh))
-> ST s (Array (ST s) (Shape sh) (Element sh))
-> ST
     s
     (Array (ST s) (Shape sh) (Element sh),
      Array (ST s) (Shape sh) (Element sh))
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,)
               (Vector (Shape sh) (Element sh)
-> ST s (Array (ST s) (Shape sh) (Element sh))
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array sh a -> m (Array m sh a)
MutArray.thaw Vector (Shape sh) (Element sh)
a)
               (Vector (Shape sh) (Element sh)
-> ST s (Array (ST s) (Shape sh) (Element sh))
forall sh ix s.
(Indexed sh, Index sh ~ ix, Storable ix) =>
Array sh ix -> ST s (Array (ST s) sh ix)
transposeToMutable Vector (Shape sh) (Element sh)
a)
         [Element sh] -> (Element sh -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Shape sh -> [Index (Shape sh)]
forall sh. Indexed sh => sh -> [Index sh]
Shape.indices Shape sh
sh) ((Element sh -> ST s ()) -> ST s ())
-> (Element sh -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Element sh
i -> do
            Element sh
j <- Array (ST s) (Shape sh) (Element sh)
-> Index (Shape sh) -> ST s (Element sh)
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
MutArray.read Array (ST s) (Shape sh) (Element sh)
inv Index (Shape sh)
Element sh
i
            Element sh
k <- Array (ST s) (Shape sh) (Element sh)
-> Index (Shape sh) -> ST s (Element sh)
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> m a
MutArray.read Array (ST s) (Shape sh) (Element sh)
perm Index (Shape sh)
Element sh
i
            Array (ST s) (Shape sh) (Element sh)
-> Index (Shape sh) -> Element sh -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) (Shape sh) (Element sh)
perm Index (Shape sh)
Element sh
j Element sh
k
            Array (ST s) (Shape sh) (Element sh)
-> Index (Shape sh) -> Element sh -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) (Shape sh) (Element sh)
inv Index (Shape sh)
Element sh
k Element sh
j
         Array (ST s) (Shape sh) (Element sh)
-> ST s (Vector (Shape sh) (Element sh))
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
MutArray.unsafeFreeze Array (ST s) (Shape sh) (Element sh)
inv)


data Sign = Positive | Negative
   deriving (Sign -> Sign -> Bool
(Sign -> Sign -> Bool) -> (Sign -> Sign -> Bool) -> Eq Sign
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Sign -> Sign -> Bool
$c/= :: Sign -> Sign -> Bool
== :: Sign -> Sign -> Bool
$c== :: Sign -> Sign -> Bool
Eq, Int -> Sign -> ShowS
[Sign] -> ShowS
Sign -> String
(Int -> Sign -> ShowS)
-> (Sign -> String) -> ([Sign] -> ShowS) -> Show Sign
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Sign] -> ShowS
$cshowList :: [Sign] -> ShowS
show :: Sign -> String
$cshow :: Sign -> String
showsPrec :: Int -> Sign -> ShowS
$cshowsPrec :: Int -> Sign -> ShowS
Show, Int -> Sign
Sign -> Int
Sign -> [Sign]
Sign -> Sign
Sign -> Sign -> [Sign]
Sign -> Sign -> Sign -> [Sign]
(Sign -> Sign)
-> (Sign -> Sign)
-> (Int -> Sign)
-> (Sign -> Int)
-> (Sign -> [Sign])
-> (Sign -> Sign -> [Sign])
-> (Sign -> Sign -> [Sign])
-> (Sign -> Sign -> Sign -> [Sign])
-> Enum Sign
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: Sign -> Sign -> Sign -> [Sign]
$cenumFromThenTo :: Sign -> Sign -> Sign -> [Sign]
enumFromTo :: Sign -> Sign -> [Sign]
$cenumFromTo :: Sign -> Sign -> [Sign]
enumFromThen :: Sign -> Sign -> [Sign]
$cenumFromThen :: Sign -> Sign -> [Sign]
enumFrom :: Sign -> [Sign]
$cenumFrom :: Sign -> [Sign]
fromEnum :: Sign -> Int
$cfromEnum :: Sign -> Int
toEnum :: Int -> Sign
$ctoEnum :: Int -> Sign
pred :: Sign -> Sign
$cpred :: Sign -> Sign
succ :: Sign -> Sign
$csucc :: Sign -> Sign
Enum, Sign
Sign -> Sign -> Bounded Sign
forall a. a -> a -> Bounded a
maxBound :: Sign
$cmaxBound :: Sign
minBound :: Sign
$cminBound :: Sign
Bounded)

instance Semigroup Sign where
   Sign
x<> :: Sign -> Sign -> Sign
<>Sign
y = if Sign
xSign -> Sign -> Bool
forall a. Eq a => a -> a -> Bool
==Sign
y then Sign
Positive else Sign
Negative

instance Monoid Sign where
   mempty :: Sign
mempty = Sign
Positive
   mappend :: Sign -> Sign -> Sign
mappend = Sign -> Sign -> Sign
forall a. Semigroup a => a -> a -> a
(<>)

{-
We could also count the cycles of even number. This might be a little faster.
-}
{- |
prop> QC.forAll genPerm2 $ \(p0,p1) -> determinant (multiply p0 p1) == determinant p0 <> determinant p1
-}
determinant :: (Shape.C sh) => Permutation sh -> Sign
determinant :: Permutation sh -> Sign
determinant =
   (\Bool
oddp -> if Bool
oddp then Sign
Negative else Sign
Positive) (Bool -> Sign)
-> (Permutation sh -> Bool) -> Permutation sh -> Sign
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
   [CInt] -> Bool
odd ([CInt] -> Bool)
-> (Permutation sh -> [CInt]) -> Permutation sh -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Element sh -> CInt) -> [Element sh] -> [CInt]
forall a b. (a -> b) -> [a] -> [b]
map Element sh -> CInt
forall sh. Element sh -> CInt
deconsElement ([Element sh] -> [CInt])
-> (Permutation sh -> [Element sh]) -> Permutation sh -> [CInt]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array (Shape sh) (Element sh) -> [Element sh]
forall sh a. (C sh, Storable a) => Array sh a -> [a]
Array.toList (Array (Shape sh) (Element sh) -> [Element sh])
-> (Permutation sh -> Array (Shape sh) (Element sh))
-> Permutation sh
-> [Element sh]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Inversion -> Permutation sh -> Array (Shape sh) (Element sh)
forall sh.
C sh =>
Inversion -> Permutation sh -> Vector (Shape sh) (Element sh)
toPivots Inversion
NonInverted

{- |
> numberFromSign s == (-1)^fromEnum s
-}
numberFromSign :: (Class.Floating a) => Sign -> a
numberFromSign :: Sign -> a
numberFromSign Sign
s =
   case Sign
s of
      Sign
Negative -> -a
1
      Sign
Positive -> a
1


condNegate :: (Class.Floating a) => [CInt] -> a -> a
condNegate :: [CInt] -> a -> a
condNegate [CInt]
ipiv = if [CInt] -> Bool
odd [CInt]
ipiv then a -> a
forall a. Num a => a -> a
negate else a -> a
forall a. a -> a
id

odd :: [CInt] -> Bool
odd :: [CInt] -> Bool
odd = Bool -> Bool
not (Bool -> Bool) -> ([CInt] -> Bool) -> [CInt] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Bool] -> Bool) -> ([CInt] -> [Bool]) -> [CInt] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [Bool]
forall a. [a] -> [a]
dropEven ([Bool] -> [Bool]) -> ([CInt] -> [Bool]) -> [CInt] -> [Bool]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> Bool) -> [Bool] -> [Bool]
forall a. (a -> Bool) -> [a] -> [a]
filter Bool -> Bool
forall a. a -> a
id ([Bool] -> [Bool]) -> ([CInt] -> [Bool]) -> [CInt] -> [Bool]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CInt -> CInt -> Bool) -> [CInt] -> [CInt] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
(/=) [CInt
1..]

dropEven :: [a] -> [a]
dropEven :: [a] -> [a]
dropEven (a
_:a
_:[a]
xs) = [a] -> [a]
forall a. [a] -> [a]
dropEven [a]
xs
dropEven [a]
xs = [a]
xs


{- |
prop> QC.forAll genPerm2 $ \(p0,p1) -> transpose (multiply p0 p1) == multiply (transpose p1) (transpose p0)
-}
transpose :: (Shape.C sh) => Permutation sh -> Permutation sh
transpose :: Permutation sh -> Permutation sh
transpose (Permutation Vector (Shape sh) (Element sh)
perm) =
   Vector (Shape sh) (Element sh) -> Permutation sh
forall sh. Vector (Shape sh) (Element sh) -> Permutation sh
Permutation (Vector (Shape sh) (Element sh) -> Permutation sh)
-> Vector (Shape sh) (Element sh) -> Permutation sh
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (Vector (Shape sh) (Element sh)))
-> Vector (Shape sh) (Element sh)
forall a. (forall s. ST s a) -> a
runST (Array (ST s) (Shape sh) (Element sh)
-> ST s (Vector (Shape sh) (Element sh))
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
MutArray.unsafeFreeze (Array (ST s) (Shape sh) (Element sh)
 -> ST s (Vector (Shape sh) (Element sh)))
-> ST s (Array (ST s) (Shape sh) (Element sh))
-> ST s (Vector (Shape sh) (Element sh))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Vector (Shape sh) (Element sh)
-> ST s (Array (ST s) (Shape sh) (Element sh))
forall sh ix s.
(Indexed sh, Index sh ~ ix, Storable ix) =>
Array sh ix -> ST s (Array (ST s) sh ix)
transposeToMutable Vector (Shape sh) (Element sh)
perm)

transposeToMutable ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, Storable ix) =>
   Array sh ix -> ST s (MutArray.Array (ST s) sh ix)
transposeToMutable :: Array sh ix -> ST s (Array (ST s) sh ix)
transposeToMutable Array sh ix
perm =
   sh
-> (Array (ST s) sh ix -> ix -> ST s ())
-> ST s (Array (ST s) sh ix)
forall sh ix a s.
(Indexed sh, Index sh ~ ix, Storable a) =>
sh
-> (Array (ST s) sh a -> ix -> ST s ()) -> ST s (Array (ST s) sh a)
initMutable (Array sh ix -> sh
forall sh a. Array sh a -> sh
Array.shape Array sh ix
perm) ((Array (ST s) sh ix -> ix -> ST s ())
 -> ST s (Array (ST s) sh ix))
-> (Array (ST s) sh ix -> ix -> ST s ())
-> ST s (Array (ST s) sh ix)
forall a b. (a -> b) -> a -> b
$ \Array (ST s) sh ix
inv ix
i -> Array (ST s) sh ix -> Index sh -> ix -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) sh ix
inv (Array sh ix
permArray sh ix -> Index sh -> ix
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!ix
Index sh
i) ix
i

inversionFromTransposition :: Transposition -> Inversion
inversionFromTransposition :: Transposition -> Inversion
inversionFromTransposition Transposition
trans =
   case Transposition
trans of
      Transposition
NonTransposed -> Inversion
NonInverted
      Transposition
Transposed -> Inversion
Inverted


multiply :: (Shape.C sh, Eq sh) =>
   Permutation sh -> Permutation sh -> Permutation sh
multiply :: Permutation sh -> Permutation sh -> Permutation sh
multiply Permutation sh
a Permutation sh
b =
   if Permutation sh -> sh
forall sh. Permutation sh -> sh
size Permutation sh
a sh -> sh -> Bool
forall a. Eq a => a -> a -> Bool
/= Permutation sh -> sh
forall sh. Permutation sh -> sh
size Permutation sh
b
      then String -> Permutation sh
forall a. HasCallStack => String -> a
error String
"Permutation.multiply: sizes mismatch"
      else Permutation sh -> Permutation sh -> Permutation sh
forall sh.
C sh =>
Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked Permutation sh
a Permutation sh
b

square :: (Shape.C sh) => Permutation sh -> Permutation sh
square :: Permutation sh -> Permutation sh
square Permutation sh
p = Permutation sh -> Permutation sh -> Permutation sh
forall sh.
C sh =>
Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked Permutation sh
p Permutation sh
p

power :: (Shape.C sh) => Integer -> Permutation sh -> Permutation sh
power :: Integer -> Permutation sh -> Permutation sh
power Integer
n Permutation sh
p = (Permutation sh -> Permutation sh -> Permutation sh)
-> Permutation sh -> Permutation sh -> Integer -> Permutation sh
forall a. (a -> a -> a) -> a -> a -> Integer -> a
powerAssociative Permutation sh -> Permutation sh -> Permutation sh
forall sh.
C sh =>
Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked (sh -> Permutation sh
forall sh. C sh => sh -> Permutation sh
identity (sh -> Permutation sh) -> sh -> Permutation sh
forall a b. (a -> b) -> a -> b
$ Permutation sh -> sh
forall sh. Permutation sh -> sh
size Permutation sh
p) Permutation sh
p Integer
n

multiplyUnchecked :: (Shape.C sh) =>
   Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked :: Permutation sh -> Permutation sh -> Permutation sh
multiplyUnchecked (Permutation Vector (Shape sh) (Element sh)
a) (Permutation Vector (Shape sh) (Element sh)
b) =
   Vector (Shape sh) (Element sh) -> Permutation sh
forall sh. Vector (Shape sh) (Element sh) -> Permutation sh
Permutation (Vector (Shape sh) (Element sh) -> Permutation sh)
-> Vector (Shape sh) (Element sh) -> Permutation sh
forall a b. (a -> b) -> a -> b
$ Shape sh
-> (Index (Shape sh) -> Element sh)
-> Vector (Shape sh) (Element sh)
forall sh a.
(Indexed sh, Storable a) =>
sh -> (Index sh -> a) -> Array sh a
CheckedArray.sample (Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
a) ((Index (Shape sh) -> Element sh)
 -> Vector (Shape sh) (Element sh))
-> (Index (Shape sh) -> Element sh)
-> Vector (Shape sh) (Element sh)
forall a b. (a -> b) -> a -> b
$ \Index (Shape sh)
i -> Vector (Shape sh) (Element sh)
bVector (Shape sh) (Element sh) -> Index (Shape sh) -> Element sh
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!(Vector (Shape sh) (Element sh)
aVector (Shape sh) (Element sh) -> Index (Shape sh) -> Element sh
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!Index (Shape sh)
i)


takeDiagonal ::
   (Shape.C sh, Class.Floating a) => Permutation sh -> Vector sh a
takeDiagonal :: Permutation sh -> Vector sh a
takeDiagonal (Permutation Vector (Shape sh) (Element sh)
a) =
   (Shape sh -> sh) -> Array (Shape sh) a -> Vector sh a
forall sh0 sh1 a. (sh0 -> sh1) -> Array sh0 a -> Array sh1 a
Array.mapShape Shape sh -> sh
forall sh. Shape sh -> sh
deconsShape (Array (Shape sh) a -> Vector sh a)
-> Array (Shape sh) a -> Vector sh a
forall a b. (a -> b) -> a -> b
$
   Shape sh -> (Index (Shape sh) -> a) -> Array (Shape sh) a
forall sh a.
(Indexed sh, Storable a) =>
sh -> (Index sh -> a) -> Array sh a
CheckedArray.sample (Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
a) ((Index (Shape sh) -> a) -> Array (Shape sh) a)
-> (Index (Shape sh) -> a) -> Array (Shape sh) a
forall a b. (a -> b) -> a -> b
$ \Index (Shape sh)
i -> if Vector (Shape sh) (Element sh)
aVector (Shape sh) (Element sh) -> Index (Shape sh) -> Element sh
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!Index (Shape sh)
i Element sh -> Element sh -> Bool
forall a. Eq a => a -> a -> Bool
== Index (Shape sh)
Element sh
i then a
1 else a
0


toMatrix :: (Shape.C sh, Class.Floating a) => Permutation sh -> Square sh a
toMatrix :: Permutation sh -> Square sh a
toMatrix (Permutation Vector (Shape sh) (Element sh)
perm) =
   let shape :: Shape sh
shape = Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
perm
   in Square sh -> Array (Shape sh, Shape sh) a -> Square sh a
forall sh1 sh0 a. sh1 -> Array sh0 a -> Array sh1 a
Array.reshape (Order -> sh -> Square sh
forall sh. Order -> sh -> Square sh
Layout.square Order
RowMajor (sh -> Square sh) -> sh -> Square sh
forall a b. (a -> b) -> a -> b
$ Shape sh -> sh
forall sh. Shape sh -> sh
deconsShape Shape sh
shape) (Array (Shape sh, Shape sh) a -> Square sh a)
-> Array (Shape sh, Shape sh) a -> Square sh a
forall a b. (a -> b) -> a -> b
$
      (forall s. ST s (Array (Shape sh, Shape sh) a))
-> Array (Shape sh, Shape sh) a
forall a. (forall s. ST s a) -> a
runST (do
         Array (ST s) (Shape sh, Shape sh) a
a <- (Shape sh, Shape sh)
-> a -> ST s (Array (ST s) (Shape sh, Shape sh) a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> a -> m (Array m sh a)
MutArray.new (Shape sh
shape,Shape sh
shape) a
forall a. Floating a => a
zero
         [Element sh] -> (Element sh -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Shape sh -> [Index (Shape sh)]
forall sh. Indexed sh => sh -> [Index sh]
Shape.indices (Shape sh -> [Index (Shape sh)]) -> Shape sh -> [Index (Shape sh)]
forall a b. (a -> b) -> a -> b
$ Vector (Shape sh) (Element sh) -> Shape sh
forall sh a. Array sh a -> sh
Array.shape Vector (Shape sh) (Element sh)
perm) ((Element sh -> ST s ()) -> ST s ())
-> (Element sh -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Element sh
k ->
            Array (ST s) (Shape sh, Shape sh) a
-> Index (Shape sh, Shape sh) -> a -> ST s ()
forall (m :: * -> *) sh a.
(PrimMonad m, Indexed sh, Storable a) =>
Array m sh a -> Index sh -> a -> m ()
MutArray.write Array (ST s) (Shape sh, Shape sh) a
a (Element sh
k, Vector (Shape sh) (Element sh)
permVector (Shape sh) (Element sh) -> Index (Shape sh) -> Element sh
forall sh a.
(Indexed sh, Storable a) =>
Array sh a -> Index sh -> a
!Index (Shape sh)
Element sh
k) a
forall a. Floating a => a
one
         Array (ST s) (Shape sh, Shape sh) a
-> ST s (Array (Shape sh, Shape sh) a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
Array m sh a -> m (Array sh a)
MutArray.unsafeFreeze Array (ST s) (Shape sh, Shape sh) a
a)


apply ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Inversion -> Permutation height ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
apply :: Inversion
-> Permutation height
-> Full meas vert horiz height width a
-> Full meas vert horiz height width a
apply Inversion
inverted
      (Permutation (Array (Shape height
shapeP) ForeignPtr (Element height)
perm))
      (Array shape :: Full meas vert horiz height width
shape@(Layout.Full Order
order Extent meas vert horiz height width
extent) ForeignPtr a
a) =

   Full meas vert horiz height width
-> (Int -> Ptr a -> IO ()) -> Full meas vert horiz height width a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize Full meas vert horiz height width
shape ((Int -> Ptr a -> IO ()) -> Full meas vert horiz height width a)
-> (Int -> Ptr a -> IO ()) -> Full meas vert horiz height width a
forall a b. (a -> b) -> a -> b
$ \Int
blockSize Ptr a
bPtr -> do

   let (height
height,width
width) = Extent meas vert horiz height width -> (height, width)
forall meas vert horiz height width.
(Measure meas, C vert, C horiz) =>
Extent meas vert horiz height width -> (height, width)
Extent.dimensions Extent meas vert horiz height width
extent
   String -> Bool -> IO ()
Call.assert String
"Permutation.apply: heights mismatch" (height
height height -> height -> Bool
forall a. Eq a => a -> a -> Bool
== height
shapeP)
   let m :: Int
m = height -> Int
forall sh. C sh => sh -> Int
Shape.size height
height
   let n :: Int
n = width -> Int
forall sh. C sh => sh -> Int
Shape.size width
width
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr Bool
fwdPtr <- Bool -> FortranIO () (Ptr Bool)
forall r. Bool -> FortranIO r (Ptr Bool)
Call.bool (Bool -> FortranIO () (Ptr Bool))
-> Bool -> FortranIO () (Ptr Bool)
forall a b. (a -> b) -> a -> b
$ Inversion
invertedInversion -> Inversion -> Bool
forall a. Eq a => a -> a -> Bool
==Inversion
NonInverted
      Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
kPtr <- Ptr (Element height) -> Ptr CInt
forall sh. Ptr (Element sh) -> Ptr CInt
deconsElementPtr (Ptr (Element height) -> Ptr CInt)
-> ContT () IO (Ptr (Element height)) -> FortranIO () (Ptr CInt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> ForeignPtr (Element height)
-> ContT () IO (Ptr (Element height))
forall a r. Storable a => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp Int
m ForeignPtr (Element height)
perm
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Int -> Ptr a -> Ptr a -> IO ()
copyBlock Int
blockSize Ptr a
aPtr Ptr a
bPtr
         Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
mInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0 Bool -> Bool -> Bool
&& Int
nInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            case Order
order of
               Order
RowMajor -> Ptr Bool
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr Bool
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.lapmt Ptr Bool
fwdPtr Ptr CInt
nPtr Ptr CInt
mPtr Ptr a
bPtr Ptr CInt
nPtr Ptr CInt
kPtr
               Order
ColumnMajor -> Ptr Bool
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr Bool
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.lapmr Ptr Bool
fwdPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
bPtr Ptr CInt
mPtr Ptr CInt
kPtr


initMutable ::
   (Shape.Indexed sh, Shape.Index sh ~ ix, Storable a) =>
   sh -> (MutArray.Array (ST s) sh a -> ix -> ST s ()) ->
   ST s (MutArray.Array (ST s) sh a)
initMutable :: sh
-> (Array (ST s) sh a -> ix -> ST s ()) -> ST s (Array (ST s) sh a)
initMutable sh
sh Array (ST s) sh a -> ix -> ST s ()
f = do
   Array (ST s) sh a
arr <- sh -> (Ptr a -> IO ()) -> ST s (Array (ST s) sh a)
forall (m :: * -> *) sh a.
(PrimMonad m, C sh, Storable a) =>
sh -> (Ptr a -> IO ()) -> m (Array m sh a)
MutArray.unsafeCreate sh
sh (\ Ptr a
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
   (ix -> ST s ()) -> [ix] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Array (ST s) sh a -> ix -> ST s ()
f Array (ST s) sh a
arr) ([ix] -> ST s ()) -> [ix] -> ST s ()
forall a b. (a -> b) -> a -> b
$ sh -> [Index sh]
forall sh. Indexed sh => sh -> [Index sh]
Shape.indices sh
sh
   Array (ST s) sh a -> ST s (Array (ST s) sh a)
forall (m :: * -> *) a. Monad m => a -> m a
return Array (ST s) sh a
arr



-- cf. Shape.Deferred
newtype Shape sh = Shape {Shape sh -> sh
deconsShape :: sh}
   deriving (Shape sh -> Shape sh -> Bool
(Shape sh -> Shape sh -> Bool)
-> (Shape sh -> Shape sh -> Bool) -> Eq (Shape sh)
forall sh. Eq sh => Shape sh -> Shape sh -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Shape sh -> Shape sh -> Bool
$c/= :: forall sh. Eq sh => Shape sh -> Shape sh -> Bool
== :: Shape sh -> Shape sh -> Bool
$c== :: forall sh. Eq sh => Shape sh -> Shape sh -> Bool
Eq, Int -> Shape sh -> ShowS
[Shape sh] -> ShowS
Shape sh -> String
(Int -> Shape sh -> ShowS)
-> (Shape sh -> String) -> ([Shape sh] -> ShowS) -> Show (Shape sh)
forall sh. Show sh => Int -> Shape sh -> ShowS
forall sh. Show sh => [Shape sh] -> ShowS
forall sh. Show sh => Shape sh -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Shape sh] -> ShowS
$cshowList :: forall sh. Show sh => [Shape sh] -> ShowS
show :: Shape sh -> String
$cshow :: forall sh. Show sh => Shape sh -> String
showsPrec :: Int -> Shape sh -> ShowS
$cshowsPrec :: forall sh. Show sh => Int -> Shape sh -> ShowS
Show)

newtype Element sh = Element {Element sh -> CInt
deconsElement :: CInt}
   deriving (Element sh -> Element sh -> Bool
(Element sh -> Element sh -> Bool)
-> (Element sh -> Element sh -> Bool) -> Eq (Element sh)
forall sh. Element sh -> Element sh -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Element sh -> Element sh -> Bool
$c/= :: forall sh. Element sh -> Element sh -> Bool
== :: Element sh -> Element sh -> Bool
$c== :: forall sh. Element sh -> Element sh -> Bool
Eq, Int -> Element sh -> ShowS
[Element sh] -> ShowS
Element sh -> String
(Int -> Element sh -> ShowS)
-> (Element sh -> String)
-> ([Element sh] -> ShowS)
-> Show (Element sh)
forall sh. Int -> Element sh -> ShowS
forall sh. [Element sh] -> ShowS
forall sh. Element sh -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Element sh] -> ShowS
$cshowList :: forall sh. [Element sh] -> ShowS
show :: Element sh -> String
$cshow :: forall sh. Element sh -> String
showsPrec :: Int -> Element sh -> ShowS
$cshowsPrec :: forall sh. Int -> Element sh -> ShowS
Show)

deconsElementPtr :: Ptr (Element sh) -> Ptr CInt
deconsElementPtr :: Ptr (Element sh) -> Ptr CInt
deconsElementPtr = Ptr (Element sh) -> Ptr CInt
forall a b. Ptr a -> Ptr b
castPtr

instance Functor Shape where
   fmap :: (a -> b) -> Shape a -> Shape b
fmap a -> b
f (Shape a
sh) = b -> Shape b
forall sh. sh -> Shape sh
Shape (b -> Shape b) -> b -> Shape b
forall a b. (a -> b) -> a -> b
$ a -> b
f a
sh

instance (Shape.C sh) => Shape.C (Shape sh) where
   size :: Shape sh -> Int
size (Shape sh
sh) = sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh

shapeInt :: (Shape.C sh) => sh -> Shape.OneBased Int
shapeInt :: sh -> OneBased Int
shapeInt = Int -> OneBased Int
forall n. n -> OneBased n
Shape.OneBased (Int -> OneBased Int) -> (sh -> Int) -> sh -> OneBased Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh -> Int
forall sh. C sh => sh -> Int
Shape.size

instance (Shape.C sh) => Shape.Indexed (Shape sh) where
   type Index (Shape sh) = Element sh
   indices :: Shape sh -> [Index (Shape sh)]
indices (Shape sh
sh) = (CInt -> Element sh) -> [CInt] -> [Element sh]
forall a b. (a -> b) -> [a] -> [b]
map CInt -> Element sh
forall sh. CInt -> Element sh
Element ([CInt] -> [Element sh]) -> [CInt] -> [Element sh]
forall a b. (a -> b) -> a -> b
$ Int -> [CInt] -> [CInt]
forall a. Int -> [a] -> [a]
take (sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh) [CInt
1 ..]
   unifiedOffset :: Shape sh -> Index (Shape sh) -> Result check Int
unifiedOffset (Shape sh
sh) (Element k) =
      OneBased Int -> Index (OneBased Int) -> Result check Int
forall sh check.
(Indexed sh, Checking check) =>
sh -> Index sh -> Result check Int
Shape.unifiedOffset (sh -> OneBased Int
forall sh. C sh => sh -> OneBased Int
shapeInt sh
sh) (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
k)
   inBounds :: Shape sh -> Index (Shape sh) -> Bool
inBounds (Shape sh
sh) (Element k) =
      OneBased Int -> Index (OneBased Int) -> Bool
forall sh. Indexed sh => sh -> Index sh -> Bool
Shape.inBounds (sh -> OneBased Int
forall sh. C sh => sh -> OneBased Int
shapeInt sh
sh) (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
k)

instance (Shape.C sh) => Shape.InvIndexed (Shape sh) where
   unifiedIndexFromOffset :: Shape sh -> Int -> Result check (Index (Shape sh))
unifiedIndexFromOffset (Shape sh
sh) =
      (Int -> Element sh)
-> Result check Int -> Result check (Element sh)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CInt -> Element sh
forall sh. CInt -> Element sh
Element (CInt -> Element sh) -> (Int -> CInt) -> Int -> Element sh
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral) (Result check Int -> Result check (Element sh))
-> (Int -> Result check Int) -> Int -> Result check (Element sh)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OneBased Int -> Int -> Result check (Index (OneBased Int))
forall sh check.
(InvIndexed sh, Checking check) =>
sh -> Int -> Result check (Index sh)
Shape.unifiedIndexFromOffset (sh -> OneBased Int
forall sh. C sh => sh -> OneBased Int
shapeInt sh
sh)

instance Storable (Element sh) where
   {-# INLINE sizeOf #-}
   {-# INLINE alignment #-}
   {-# INLINE peek #-}
   {-# INLINE poke #-}
   sizeOf :: Element sh -> Int
sizeOf (Element CInt
k) = CInt -> Int
forall a. Storable a => a -> Int
sizeOf CInt
k
   alignment :: Element sh -> Int
alignment (Element CInt
k) = CInt -> Int
forall a. Storable a => a -> Int
alignment CInt
k
   poke :: Ptr (Element sh) -> Element sh -> IO ()
poke Ptr (Element sh)
p (Element CInt
k) = Ptr CInt -> CInt -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr (Element sh) -> Ptr CInt
forall sh. Ptr (Element sh) -> Ptr CInt
deconsElementPtr Ptr (Element sh)
p) CInt
k
   peek :: Ptr (Element sh) -> IO (Element sh)
peek Ptr (Element sh)
p = (CInt -> Element sh) -> IO CInt -> IO (Element sh)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CInt -> Element sh
forall sh. CInt -> Element sh
Element (IO CInt -> IO (Element sh)) -> IO CInt -> IO (Element sh)
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek (Ptr (Element sh) -> Ptr CInt
forall sh. Ptr (Element sh) -> Ptr CInt
deconsElementPtr Ptr (Element sh)
p)