{-# LANGUAGE TypeFamilies #-}
{- |
Simplify running the @render@ function by handling passing of parameters.
-}
module Data.Array.Knead.Parameterized.Render (
   run,
   Scatter(..),
   ScatterMaybe(..),
   MapAccumL(..),
   FoldOuterL(..),
   ) where

import qualified Data.Array.Knead.Parameterized.PhysicalHull as PhysHullP
import qualified Data.Array.Knead.Parameterized.Physical as PhysP
import qualified Data.Array.Knead.Parameterized.Private as Sym
import qualified Data.Array.Knead.Simple.Physical as Phys
import qualified Data.Array.Knead.Simple.Private as Core
import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import Data.Array.Knead.Parameterized.PhysicalHull
         (Scatter, ScatterMaybe, MapAccumL, FoldOuterL)
import Data.Array.Knead.Expression (Exp, )

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue

import Foreign.Storable (Storable, )

import Control.Arrow (arr, )
import Control.Applicative (liftA2, liftA3, pure, (<*>), )

import Data.Tuple.HT (fst3, snd3, thd3, )


class C f where
   type Plain f
   build :: Sym.Hull p f -> IO (p -> Plain f)

instance
   (MultiValueMemory.C sh, Storable sh, Shape.C sh,
    MultiValueMemory.C a, Storable a) =>
      C (Core.Array sh a) where
   type Plain (Core.Array sh a) = IO (Phys.Array sh a)
   build = PhysHullP.render

instance
   (MultiValueMemory.C sh0, Storable sh0, Shape.C sh0,
    MultiValueMemory.C sh1, Storable sh1, Shape.C sh1,
    MultiValueMemory.C a, Storable a) =>
      C (Scatter sh0 sh1 a) where
   type Plain (Scatter sh0 sh1 a) = IO (Phys.Array sh1 a)
   build = PhysHullP.scatter

instance
   (MultiValueMemory.C sh0, Storable sh0, Shape.C sh0,
    MultiValueMemory.C sh1, Storable sh1, Shape.C sh1,
    MultiValueMemory.C a, Storable a) =>
      C (ScatterMaybe sh0 sh1 a) where
   type Plain (ScatterMaybe sh0 sh1 a) = IO (Phys.Array sh1 a)
   build = PhysHullP.scatterMaybe

instance
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    MultiValue.C acc,
    Storable a, MultiValueMemory.C a,
    Storable b, MultiValueMemory.C b) =>
      C (MapAccumL sh n acc a b) where
   type Plain (MapAccumL sh n acc a b) = IO (Phys.Array (sh,n) b)
   build = PhysHullP.mapAccumL

instance
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    Storable a, MultiValueMemory.C a,
    Storable b, MultiValueMemory.C b) =>
      C (FoldOuterL n sh a b) where
   type Plain (FoldOuterL n sh a b) = IO (Phys.Array sh a)
   build = PhysHullP.foldOuterL


singleton :: Exp a -> Core.Array () a
singleton = Core.fromScalar

instance (MultiValueMemory.C a, Storable a) => C (Exp a) where
   type Plain (Exp a) = IO a
   build = PhysP.the . Sym.runHull . fmap singleton

instance (Argument arg, C func) => C (arg -> func) where
   type Plain (arg -> func) = PlainArg arg -> Plain func
   build f = fmap curry $ build $ Sym.extendHull fst f <*> buildArg (arr snd)


class Argument arg where
   type PlainArg arg
   buildArg :: Param.T p (PlainArg arg) -> Sym.Hull p arg

instance
   (MultiValueMemory.C sh, Storable sh, Shape.C sh, MultiValueMemory.C a) =>
      Argument (Core.Array sh a) where
   type PlainArg (Core.Array sh a) = Phys.Array sh a
   buildArg = Sym.arrayHull . PhysP.feed

instance (MultiValueMemory.C a, Storable a) => Argument (Exp a) where
   type PlainArg (Exp a) = a
   buildArg = Sym.expHull . Sym.expParam

instance (Argument a, Argument b) => Argument (a,b) where
   type PlainArg (a,b) = (PlainArg a, PlainArg b)
   buildArg p = liftA2 (,) (buildArg $ fmap fst p) (buildArg $ fmap snd p)

instance (Argument a, Argument b, Argument c) => Argument (a,b,c) where
   type PlainArg (a,b,c) = (PlainArg a, PlainArg b, PlainArg c)
   buildArg p =
      liftA3 (,,)
         (buildArg $ fmap fst3 p) (buildArg $ fmap snd3 p) (buildArg $ fmap thd3 p)


run :: (C f) => f -> IO (Plain f)
run f = fmap ($()) $ build $ pure f



_example ::
   (Storable x, MultiValueMemory.C x,
    Shape.C sha, Storable sha, MultiValueMemory.C sha, MultiValueMemory.C a,
    Shape.C shb, Storable shb, MultiValueMemory.C shb, MultiValueMemory.C b,
    Shape.C shc, Storable shc, MultiValueMemory.C shc, MultiValueMemory.C c,
    Storable c) =>
   (Exp x -> Core.Array sha a -> Core.Array shb b -> Core.Array shc c) ->
   IO (x -> Phys.Array sha a -> Phys.Array shb b -> IO (Phys.Array shc c))
_example f =
   fmap (\g -> curry $ curry g) $
   PhysP.render $
   Sym.runHull $
   pure f
      <*> Sym.expHull (Sym.expParam $ arr (fst.fst))
      <*> Sym.arrayHull (PhysP.feed $ arr (snd.fst))
      <*> Sym.arrayHull (PhysP.feed $ arr snd)