{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Knead.Parameterized.Slice (
   T,
   apply,
   Linear,
   passAny,
   pass,
   pick,
   extrude,
   (Core.$:.),
   ) where

import qualified Data.Array.Knead.Parameterized.Private as Priv
import Data.Array.Knead.Parameterized.Private (Array(Array), )

import qualified Data.Array.Knead.Simple.Slice as Slice
import qualified Data.Array.Knead.Simple.Private as Core

import qualified Data.Array.Knead.Index.Linear as Linear
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )
import Data.Array.Knead.Index.Linear ((:.), )

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue

import Foreign.Storable (Storable, )


{-
This wrapper data type is pretty much the same as Parameterized.Array
but there seems to be no benefit from using the same data structure for it.
-}
data T p sh0 sh1 =
   forall parameter context.
   (Storable parameter, MultiValueMemory.C parameter) =>
   Cons {
      _core :: MultiValue.T parameter -> Slice.T sh0 sh1,
      _createContext :: p -> IO (context, parameter),
      _deleteContext :: context -> IO ()
   }

apply ::
   (Shape.C sh0, Shape.C sh1, MultiValue.C a) =>
   T p sh0 sh1 ->
   Array p sh0 a ->
   Array p sh1 a
apply (Cons slice createSlice deleteSlice) (Array arr createArr deleteArr) =
   Array
      (\p ->
         case MultiValue.unzip p of
            (paramSlice, paramArr) ->
               Slice.apply (slice paramSlice) (arr paramArr))
      (Priv.combineCreate createSlice createArr)
      (Priv.combineDelete deleteSlice deleteArr)


type Linear p sh0 sh1 = T p (Linear.Shape sh0) (Linear.Shape sh1)


passAny :: Linear p sh sh
passAny =
   Cons (const Slice.passAny) (Priv.createPlain $ const ()) Priv.deletePlain

pass ::
   Linear p sh0 sh1 ->
   Linear p (sh0:.i) (sh1:.i)
pass (Cons slice create delete) = Cons (Slice.pass . slice) create delete

pick ::
   (MultiValueMemory.C i, Storable i) =>
   Param.T p i ->
   Linear p sh0 sh1 ->
   Linear p (sh0:.i) sh1
pick = lift Slice.pick

extrude ::
   (MultiValueMemory.C i, Storable i) =>
   Param.T p i ->
   Linear p sh0 sh1 ->
   Linear p sh0 (sh1:.i)
extrude = lift Slice.extrude

lift ::
   (MultiValueMemory.C i, Storable i) =>
   (Exp i -> Slice.Linear sh0 sh1 -> Slice.Linear sh2 sh3) ->
   Param.T p i ->
   Linear p sh0 sh1 -> Linear p sh2 sh3
lift f i (Cons slice create delete) =
   Param.withMulti i $ \getI valueI ->
   Cons
      (\p ->
         case MultiValue.unzip p of
            (slicep, ip) ->
               f (Expr.lift0 (valueI ip)) (slice slicep))
      (\p -> do
         (ctx, param) <- create p
         return (ctx, (param, getI p)))
      delete

instance Core.Process (T p sh0 sh1) where