{- |
Generate and apply index maps.
This unifies the @replicate@ and @slice@ functions of the @accelerate@ package.
However the structure of slicing and replicating cannot depend on parameters.
If you need that, you must use 'ShapeDep.backpermute' and friends.
-}
{-
Some notes on the design choice:

Instead of the shallow embedding implemented by the 'T' type,
we could maintain a symbolic representation of the Slice and Replicate pattern,
like the accelerate package does.
We actually used that representation in former versions.
It has however some drawbacks:

* We need additional type functions that map from the pattern
  to the source and the target shape and we need a proof,
  that the images of these type functions are actually shapes.
  This worked already, but was rather cumbersome.

* We need a way to store and pass this pattern through the Parameter handler.
  This yields new problems:
  We need a wrapper type for wrapping Index, Shape, Slice, Replicate, Fold patterns.
  Then the question is whether we use one Wrap type with a phantom parameter
  or whether we define a Wrap type for every pattern type.
  That is, the options are to write either

  > Wrap Shape (Z:.Int:.Int)

  or

  > Shape (Z:.Int:.Int)

  The first one seems to save us many duplicate instances of
  Storable, MultiValue etc.
  and it allows us easily to reuse the (:.) for all kinds of patterns.
  However, we need a way to restrict the element type of the (:.)-list elements.
  We can define that using variable ConstraintKinds,
  but e.g. we are not able to add a Storable superclass constraint
  to the instance Storable (Wrap constr).
  That is, we are left with the second option
  and had to define a lot of similar Storable, MultiValue instances.
-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Knead.Simple.Slice (
   T,
   Cubic,
   apply,
   passAny,
   pass,
   pick,
   pickFst,
   pickSnd,
   extrude,
   extrudeFst,
   extrudeSnd,
   transpose,
   (Core.$:.),

   id,
   first,
   second,
   compose,
   ) where

import qualified Data.Array.Knead.Simple.ShapeDependent as ShapeDep
import qualified Data.Array.Knead.Simple.Private as Core

import qualified Data.Array.Knead.Shape.Cubic.Int as Index
import qualified Data.Array.Knead.Shape.Cubic as Cubic
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Shape.Cubic ((#:.), (:.)((:.)), )
import Data.Array.Knead.Expression (Exp, )

import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.Multi.Value (atom, )

import qualified Type.Data.Num.Unary as Unary

import qualified Prelude as P
import Prelude hiding (id, zipWith, zipWith3, zip, zip3, replicate, )



{-
This data type is almost identical to Core.Array.
The only difference is,
that the shape @sh1@ in T can depend on another shape @sh0@.
-}
data T sh0 sh1 =
   forall ix0 ix1.
   (Shape.Index sh0 ~ ix0, Shape.Index sh1 ~ ix1) =>
   Cons
      (Exp sh0 -> Exp sh1)
      (Exp ix1 -> Exp ix0)

{- |
This is essentially a 'ShapeDep.backpermute'.
-}
apply ::
   (Core.C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
   T sh0 sh1 ->
   array sh0 a ->
   array sh1 a
apply (Cons fsh fix) =
   ShapeDep.backpermute fsh fix


pickFst :: Exp (Shape.Index n) -> T (n,sh) sh
pickFst i = Cons Expr.snd (Expr.zip i)

pickSnd :: Exp (Shape.Index n) -> T (sh,n) sh
pickSnd i = Cons Expr.fst (flip Expr.zip i)

{- |
Extrusion has the potential to do duplicate work.
Only use it to add dimensions of size 1, e.g. numeric 1 or unit @()@
or to duplicate slices of physical arrays.
-}
extrudeFst :: Exp n -> T sh (n,sh)
extrudeFst n = Cons (Expr.zip n) Expr.snd

extrudeSnd :: Exp n -> T sh (sh,n)
extrudeSnd n = Cons (flip Expr.zip n) Expr.fst

transpose :: T (sh0,sh1) (sh1,sh0)
transpose = Cons Expr.swap Expr.swap


-- Arrow combinators

id :: T sh sh
id = Cons P.id P.id

first :: T sh0 sh1 -> T (sh0,sh) (sh1,sh)
first (Cons fsh fix) = Cons (Expr.mapFst fsh) (Expr.mapFst fix)

second :: T sh0 sh1 -> T (sh,sh0) (sh,sh1)
second (Cons fsh fix) = Cons (Expr.mapSnd fsh) (Expr.mapSnd fix)

infixr 1 `compose`

compose :: T sh0 sh1 -> T sh1 sh2 -> T sh0 sh2
compose (Cons fshA fixA) (Cons fshB fixB) = Cons (fshB . fshA) (fixA . fixB)


type Cubic rank0 rank1 = T (Cubic.Shape rank0) (Cubic.Shape rank1)

{- |
Like @Any@ in @accelerate@.
-}
passAny :: Cubic rank rank
passAny = Cons P.id P.id

{- |
Like @All@ in @accelerate@.
-}
pass ::
   (Unary.Natural rank0, Unary.Natural rank1) =>
   Cubic rank0 rank1 ->
   Cubic (Unary.Succ rank0) (Unary.Succ rank1)
pass (Cons fsh fix) =
   Cons
      (Expr.modify (atom:.atom) $ \(sh:.s) -> fsh sh :. s)
      (Expr.modify (atom:.atom) $ \(ix:.i) -> fix ix :. i)

{- |
Like @Int@ in @accelerate/slice@.
-}
pick ::
   (Unary.Natural rank0, Unary.Natural rank1) =>
   Exp Index.Int ->
   Cubic rank0 rank1 ->
   Cubic (Unary.Succ rank0) rank1
pick i (Cons fsh fix) =
   Cons
      (fsh . Cubic.tail)
      (\ix -> fix ix #:. i)

{- |
Like @Int@ in @accelerate/replicate@.
-}
extrude ::
   (Unary.Natural rank0, Unary.Natural rank1) =>
   Exp Index.Int ->
   Cubic rank0 rank1 ->
   Cubic rank0 (Unary.Succ rank1)
extrude n (Cons fsh fix) =
   Cons
      (\sh -> fsh sh #:. n)
      (fix . Cubic.tail)


instance Core.Process (T sh0 sh1) where