{-# LANGUAGE GADTs #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Data.Array.Knead.Parameterized.PhysicalHull (
   render,
   Scatter(..),
   scatter,
   ScatterMaybe(..),
   scatterMaybe,
   MapAccumL(..),
   mapAccumL,
   FoldOuterL(..),
   foldOuterL,
   ) where

import qualified Data.Array.Knead.Parameterized.Private as Sym
import qualified Data.Array.Knead.Simple.PhysicalPrivate as Priv
import qualified Data.Array.Knead.Simple.Physical as Phys
import qualified Data.Array.Knead.Simple.Private as Core
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Code as Code
import Data.Array.Knead.Expression (Exp, unExp, )
import Data.Array.Knead.Code (compile, )

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Arithmetic as A

import qualified LLVM.Core as LLVM

import Foreign.Marshal.Utils (with, )
import Foreign.Marshal.Alloc (alloca, )
import Foreign.Storable (Storable, peek, )
import Foreign.ForeignPtr (withForeignPtr, mallocForeignPtrArray, )
import Foreign.Ptr (FunPtr, Ptr, )

import Control.Exception (bracket, )
import Control.Monad.HT (void, )
import Control.Applicative (liftA2, )
import Data.Word (Word32, )


type Importer f = FunPtr f -> f


foreign import ccall safe "dynamic" callShaper ::
   Importer (Ptr param -> Ptr shape -> IO Word32)

foreign import ccall safe "dynamic" callFill ::
   Importer (Ptr param -> Ptr shape -> Ptr am -> IO ())

materialize ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   String ->
   (core -> Exp sh) ->
   (core ->
    LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
    LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
    LLVM.CodeGenFunction () ()) ->
   Sym.Hull p core -> IO (p -> IO (Phys.Array sh a))
materialize name shape fill (Sym.Hull core create delete) = do
   (fsh, farr) <-
      compile name $
      liftA2 (,)
         (Code.createFunction callShaper "shape" $
          \paramPtr resultPtr -> do
            param <- Memory.load paramPtr
            sh <- unExp $ shape $ core param
            MultiValueMemory.store sh resultPtr
            Shape.sizeCode sh >>= LLVM.ret)
         (Code.createFunction callFill "fill" $
          \paramPtr shapePtr bufferPtr -> do
            param <- Memory.load paramPtr
            fill (core param) shapePtr bufferPtr
            LLVM.ret ())

   return $ \p ->
      bracket (create p) (delete . fst) $ \(_ctx, param) ->
      alloca $ \shptr ->
      with param $ \paramPtr -> do
         let paramMVPtr = MultiValueMemory.castStructPtr paramPtr
         let shapeMVPtr = MultiValueMemory.castStructPtr shptr
         n <- fsh paramMVPtr shapeMVPtr
         fptr <- mallocForeignPtrArray (fromIntegral n)
         withForeignPtr fptr $
            farr paramMVPtr shapeMVPtr . MultiValueMemory.castStructPtr
         sh <- peek shptr
         return (Phys.Array sh fptr)


render ::
   (Shape.C sh, Shape.Index sh ~ ix,
    Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   Sym.Hull p (Core.Array sh a) -> IO (p -> IO (Phys.Array sh a))
render =
   materialize "render" Core.shape
      (\(Core.Array esh code) shapePtr bufferPtr -> do
         let step ix p = do
                flip Memory.store p =<< code ix
                A.advanceArrayElementPtr p
         sh <- Shape.load esh shapePtr
         void $ Shape.loop step sh bufferPtr)


data Scatter sh0 sh1 a =
   Scatter {
      scatterAccum :: Exp a -> Exp a -> Exp a,
      scatterInit :: Core.Array sh1 a,
      scatterMap :: Core.Array sh0 (Shape.Index sh1, a)
   }

scatter ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   Sym.Hull p (Scatter sh0 sh1 a) -> IO (p -> IO (Phys.Array sh1 a))
scatter =
   materialize "scatter"
      (Core.shape . scatterInit)
      (\(Scatter accum arrInit arrMap) ->
         Priv.scatter accum arrInit arrMap)



data ScatterMaybe sh0 sh1 a =
   ScatterMaybe {
      scatterMaybeAccum :: Exp a -> Exp a -> Exp a,
      scatterMaybeInit :: Core.Array sh1 a,
      scatterMaybeMap :: Core.Array sh0 (Maybe (Shape.Index sh1, a))
   }

scatterMaybe ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   Sym.Hull p (ScatterMaybe sh0 sh1 a) -> IO (p -> IO (Phys.Array sh1 a))
scatterMaybe =
   materialize "scatterMaybe"
      (Core.shape . scatterMaybeInit)
      (\(ScatterMaybe accum arrInit arrMap) ->
         Priv.scatterMaybe accum arrInit arrMap)


data MapAccumL sh n acc a b =
   MapAccumL {
      mapAccumLAccum :: Exp acc -> Exp a -> Exp (acc,b),
      mapAccumLInit :: Core.Array sh acc,
      mapAccumLMap :: Core.Array (sh, n) a
   }

mapAccumL ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    MultiValue.C acc,
    Storable a, MultiValueMemory.C a,
    Storable b, MultiValueMemory.C b) =>
   Sym.Hull p (MapAccumL sh n acc a b) -> IO (p -> IO (Phys.Array (sh,n) b))
mapAccumL =
   materialize "mapAccumL"
      (Core.shape . mapAccumLMap)
      (\(MapAccumL f arrInit arrData) -> Priv.mapAccumL f arrInit arrData)


data FoldOuterL n sh a b =
   FoldOuterL {
      foldOuterLAccum :: Exp a -> Exp b -> Exp a,
      foldOuterLInit :: Core.Array sh a,
      foldOuterLMap :: Core.Array (n,sh) b
   }

-- FIXME: check correct size of array of initial values
foldOuterL ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    Storable a, MultiValueMemory.C a) =>
   Sym.Hull p (FoldOuterL n sh a b) -> IO (p -> IO (Phys.Array sh a))
foldOuterL =
   materialize "foldOuterL"
      (Core.shape . foldOuterLInit)
      (\(FoldOuterL f arrInit arrData) -> Priv.foldOuterL f arrInit arrData)