{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GADTs #-}
module Data.Array.Knead.Parameterized.Private where

import qualified Data.Array.Knead.Simple.Symbolic as Core

import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue

import Foreign.Storable (Storable, )

import Control.Applicative (Applicative (pure, (<*>)), )

import Prelude hiding (id, map, zipWith, replicate, )


-- in principle we could define Array in terms of Hull and Core.Array
data Array p sh a =
   forall parameter context.
   (Storable parameter, MultiValueMemory.C parameter) =>
   Array {
      core :: MultiValue.T parameter -> Core.Array sh a,
      createContext :: p -> IO (context, parameter),
      deleteContext :: context -> IO ()
   }

instance Core.C (Array p) where
   lift0 arr = Array (const arr) (createPlain (const ())) deletePlain
   lift1 f (Array arr create delete) = Array (f . arr) create delete
   lift2 f (Array arrA createA deleteA) (Array arrB createB deleteB) =
      Array
         (\p ->
            case MultiValue.unzip p of
               (paramA, paramB) -> f (arrA paramA) (arrB paramB))
         (combineCreate createA createB)
         (combineDelete deleteA deleteB)


(!) ::
   (Shape.C sh, Shape.Index sh ~ ix,
    Storable ix, MultiValueMemory.C ix,
    Shape.Scalar z) =>
   Array p sh a -> Param.T p ix -> Array p z a
(!) arr pix =
   runHull $
   mapHullWithExp
      (\ix carr -> Core.fromScalar $ carr Core.! ix)
      (expParam pix)
      (arrayHull arr)


fill ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   Param.T p sh -> Param.T p a -> Array p sh a
fill sh a =
   Shape.paramWith sh $ \getSh valueSh ->
   Param.withMulti a $ \getA valueA ->
   Array
      (\p ->
         case MultiValue.unzip p of
            (vsh, va) ->
               Core.fill (valueSh vsh) (Expr.lift0 $ valueA va))
      (createPlain $ \p -> (getSh p, getA p))
      deletePlain

gather ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, MultiValue.C a) =>
   Array p sh1 ix0 ->
   Array p sh0 a ->
   Array p sh1 a
gather = Core.gather


id ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.Index sh ~ ix) =>
   Param.T p sh -> Array p sh ix
id sh =
   Shape.paramWith sh $ \getSh valueSh ->
   Array
      (Core.id . valueSh)
      (createPlain getSh)
      deletePlain

map ::
   (Shape.C sh, MultiValueMemory.C c, Storable c) =>
   (Exp c -> Exp a -> Exp b) ->
   Param.T p c -> Array p sh a -> Array p sh b
map = lift Core.map

mapWithIndex ::
   (Shape.C sh, MultiValueMemory.C c, Storable c, Shape.Index sh ~ ix) =>
   (Exp c -> Exp ix -> Exp a -> Exp b) ->
   Param.T p c -> Array p sh a -> Array p sh b
mapWithIndex = lift Core.mapWithIndex


fold1 ::
   (Shape.C sh0, Shape.C sh1,
    MultiValueMemory.C c, Storable c, MultiValue.C a) =>
   (Exp c -> Exp a -> Exp a -> Exp a) ->
   Param.T p c -> Array p (sh0, sh1) a -> Array p sh0 a
fold1 = lift Core.fold1

fold1All ::
   (Shape.C sh, Shape.Scalar z,
    MultiValueMemory.C c, Storable c, MultiValue.C a) =>
   (Exp c -> Exp a -> Exp a -> Exp a) ->
   Param.T p c -> Array p sh a -> Array p z a
fold1All = lift Core.fold1All

lift ::
   (Shape.C sh0, Shape.C sh1,
    MultiValueMemory.C c, Storable c) =>
   (f -> Core.Array sh0 a -> Core.Array sh1 b) ->
   (Exp c -> f) ->
   Param.T p c -> Array p sh0 a -> Array p sh1 b
lift g f c arr =
   runHull $
   mapHullWithExp
      (\cexp -> g (f cexp))
      (expParam c)
      (arrayHull arr)


data Hull p a =
   forall parameter context.
   (Storable parameter, MultiValueMemory.C parameter) =>
   Hull {
      hullCore :: MultiValue.T parameter -> a,
      hullCreateContext :: p -> IO (context, parameter),
      hullDeleteContext :: context -> IO ()
   }

instance Functor (Hull p) where
   fmap f (Hull arr create delete) = Hull (f . arr) create delete

instance Applicative (Hull p) where
   pure a = Hull (const a) (const $ return ((),())) return
   Hull arrA createA deleteA <*> Hull arrB createB deleteB =
      Hull
         (\p -> case MultiValue.unzip p of (a,b) -> arrA a $ arrB b)
         (combineCreate createA createB)
         (combineDelete deleteA deleteB)

{- |
Equivalent to @liftA2 f (expHull p)@ but saves us an empty context.
-}
mapHullWithExp ::
   (Exp sl -> a -> b) ->
   Param.Tunnel p sl -> Hull p a -> Hull p b
mapHullWithExp f tunnel (Hull arr create delete) =
   case tunnel of
      Param.Tunnel getSl valueSl ->
         Hull
            (\p ->
               case MultiValue.unzip p of
                  (arrp, sl) -> f (Expr.lift0 $ valueSl sl) $ arr arrp)
            (\p -> do
               (ctx, param) <- create p
               return (ctx, (param, getSl p)))
            delete

expHull :: Param.Tunnel p sl -> Hull p (Exp sl)
expHull tunnel =
   case tunnel of
      Param.Tunnel getSl valueSl ->
         Hull
            (Expr.lift0 . valueSl)
            (\p -> return ((), getSl p))
            return

arrayHull :: Array p sh a -> Hull p (Core.Array sh a)
arrayHull (Array arr create delete) = Hull arr create delete

runHull :: Hull p (Core.Array sh a) -> Array p sh a
runHull (Hull arr create delete) = Array arr create delete

extendHull :: (q -> p) -> Hull p a -> Hull q a
extendHull f (Hull arr create delete) = Hull arr (create . f) delete



expParam ::
   (Storable a, MultiValueMemory.C a) => Param.T p a -> Param.Tunnel p a
expParam = Param.tunnel MultiValue.cons



createPlain :: (Monad m) => (p -> pl) -> p -> m ((), pl)
createPlain f p = return ((), f p)

deletePlain :: (Monad m) => () -> m ()
deletePlain () = return ()


combineCreate ::
   Monad m =>
   (p -> m (ctxA, paramA)) -> (p -> m (ctxB, paramB)) ->
   p -> m ((ctxA, ctxB), (paramA, paramB))
combineCreate createA createB p = do
   (ctxA, paramA) <- createA p
   (ctxB, paramB) <- createB p
   return ((ctxA, ctxB), (paramA, paramB))

combineDelete ::
   Monad m =>
   (ctxA -> m ()) -> (ctxB -> m ()) -> (ctxA, ctxB) -> m ()
combineDelete deleteA deleteB (ctxA, ctxB) = do
   deleteA ctxA
   deleteB ctxB


extendParameter ::
   (q -> p) -> Array p sh a -> Array q sh a
extendParameter f (Array arr create delete) =
   Array arr (create . f) delete