{-# LANGUAGE GADTs #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Data.Array.Knead.Parameterized.PhysicalHull (
   render,
   MapFilter(..),
   mapFilter,
   FilterOuter(..),
   filterOuter,
   Scatter(..),
   scatter,
   ScatterMaybe(..),
   scatterMaybe,
   MapAccumLSimple(..),
   mapAccumLSimple,
   MapAccumLSequence(..),
   mapAccumLSequence,
   MapAccumL(..),
   mapAccumL,
   FoldOuterL(..),
   foldOuterL,
   AddDimension(..),
   addDimension,
   ) where

import qualified Data.Array.Knead.Parameterized.Private as Sym
import qualified Data.Array.Knead.Simple.PhysicalPrivate as Priv
import qualified Data.Array.Knead.Simple.Private as Core
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Code as Code
import Data.Array.Knead.Expression (Exp, unExp, )
import Data.Array.Knead.Code (compile, )

import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Arithmetic as A

import qualified LLVM.Core as LLVM

import Foreign.Marshal.Utils (with, )
import Foreign.Marshal.Array (allocaArray, )
import Foreign.Marshal.Alloc (alloca, )
import Foreign.Storable (Storable, peek, peekElemOff, )
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, mallocForeignPtrArray, )
import Foreign.Ptr (FunPtr, Ptr, )

import Control.Exception (bracket, )
import Control.Monad.HT (void, )
import Control.Applicative (liftA2, )


mallocArray :: (Storable a) => Shape.Size -> IO (ForeignPtr a)
mallocArray = mallocForeignPtrArray . fromIntegral

withForeignMemPtr ::
   ForeignPtr a -> (Ptr (MultiValueMemory.Struct a) -> IO b) -> IO b
withForeignMemPtr fptr act =
   withForeignPtr fptr $ act . MultiValueMemory.castStructPtr


type Importer f = FunPtr f -> f


foreign import ccall safe "dynamic" callShaper ::
   Importer (Ptr param -> Ptr shape -> IO Shape.Size)

foreign import ccall safe "dynamic" callFill ::
   Importer (Ptr param -> Ptr shape -> Ptr am -> IO ())


{-
Attention:
The 'fill' function may alter the shape.
An example is 'mapFilter'.
-}
materialize ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   String ->
   (core -> Exp sh) ->
   (core ->
    LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
    LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
    LLVM.CodeGenFunction () ()) ->
   Sym.Hull p core -> IO (p -> IO (Array sh a))
materialize name shape fill (Sym.Hull core create delete) = do
   (fsh, farr) <-
      compile name $
      liftA2 (,)
         (Code.createFunction callShaper "shape" $
          \paramPtr resultPtr -> do
            param <- Memory.load paramPtr
            sh <- unExp $ shape $ core param
            MultiValueMemory.store sh resultPtr
            Shape.size sh >>= LLVM.ret)
         (Code.createFunction callFill "fill" $
          \paramPtr shapePtr bufferPtr -> do
            param <- Memory.load paramPtr
            fill (core param) shapePtr bufferPtr
            LLVM.ret ())

   return $ \p ->
      bracket (create p) (delete . fst) $ \(_ctx, param) ->
      alloca $ \shptr ->
      with param $ \paramPtr -> do
         let paramMVPtr = MultiValueMemory.castStructPtr paramPtr
         let shapeMVPtr = MultiValueMemory.castStructPtr shptr
         fptr <- mallocArray =<< fsh paramMVPtr shapeMVPtr
         withForeignMemPtr fptr $ farr paramMVPtr shapeMVPtr
         sh <- peek shptr
         return (Array sh fptr)


foreign import ccall safe "dynamic" callFillExpArray ::
   Importer (Ptr param -> Ptr final -> Ptr shape -> Ptr am -> IO ())


materializeExpArray ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a,
    Storable b, MultiValueMemory.C b) =>
   String ->
   (core -> Exp sh) ->
   (core ->
    LLVM.Value (Ptr (MultiValueMemory.Struct b)) ->
    LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
    LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
    LLVM.CodeGenFunction () ()) ->
   Sym.Hull p core -> IO (p -> IO (b, Array sh a))
materializeExpArray name shape fill (Sym.Hull core create delete) = do
   (fsh, farr) <-
      compile name $
      liftA2 (,)
         (Code.createFunction callShaper "shape" $
          \paramPtr resultPtr -> do
            param <- Memory.load paramPtr
            sh <- unExp $ shape $ core param
            MultiValueMemory.store sh resultPtr
            Shape.size sh >>= LLVM.ret)
         (Code.createFunction callFillExpArray "fill" $
          \paramPtr finalPtr shapePtr bufferPtr -> do
            param <- Memory.load paramPtr
            fill (core param) finalPtr shapePtr bufferPtr
            LLVM.ret ())

   return $ \p ->
      bracket (create p) (delete . fst) $ \(_ctx, param) ->
      alloca $ \shptr ->
      alloca $ \finalPtr ->
      with param $ \paramPtr -> do
         let paramMVPtr = MultiValueMemory.castStructPtr paramPtr
         let finalMVPtr = MultiValueMemory.castStructPtr finalPtr
         let shapeMVPtr = MultiValueMemory.castStructPtr shptr
         fptr <- mallocArray =<< fsh paramMVPtr shapeMVPtr
         withForeignMemPtr fptr $ farr paramMVPtr finalMVPtr shapeMVPtr
         sh <- peek shptr
         final <- peek finalPtr
         return (final, Array sh fptr)


foreign import ccall safe "dynamic" callShaper2 ::
   Importer (Ptr param -> Ptr shapeA -> Ptr shapeB -> Ptr Shape.Size -> IO ())

foreign import ccall safe "dynamic" callFill2 ::
   Importer (Ptr param -> Ptr shapeA -> Ptr am -> Ptr shapeB -> Ptr bm -> IO ())


materialize2 ::
   (Shape.C sha, Storable sha, MultiValueMemory.C sha,
    Shape.C shb, Storable shb, MultiValueMemory.C shb,
    Storable a, MultiValueMemory.C a,
    Storable b, MultiValueMemory.C b) =>
   String ->
   (core -> Exp (sha,shb)) ->
   (core ->
    (LLVM.Value (Ptr (MultiValueMemory.Struct sha)),
     LLVM.Value (Ptr (MultiValueMemory.Struct a))) ->
    (LLVM.Value (Ptr (MultiValueMemory.Struct shb)),
     LLVM.Value (Ptr (MultiValueMemory.Struct b))) ->
    LLVM.CodeGenFunction () ()) ->
   Sym.Hull p core -> IO (p -> IO (Array sha a, Array shb b))
materialize2 name shape fill (Sym.Hull core create delete) = do
   (fsh, farr) <-
      compile name $
      liftA2 (,)
         (Code.createFunction callShaper2 "shape" $
          \paramPtr shapeAPtr shapeBPtr sizesPtr -> do
            param <- Memory.load paramPtr
            (sha,shb) <- fmap MultiValue.unzip $ unExp $ shape $ core param
            MultiValueMemory.store sha shapeAPtr
            MultiValueMemory.store shb shapeBPtr
            sizeAPtr <- LLVM.bitcast sizesPtr
            flip LLVM.store sizeAPtr =<< Shape.size sha
            sizeBPtr <- A.advanceArrayElementPtr sizeAPtr
            flip LLVM.store sizeBPtr =<< Shape.size shb
            LLVM.ret ())
         (Code.createFunction callFill2 "fill" $
          \paramPtr shapeAPtr bufferAPtr shapeBPtr bufferBPtr -> do
            param <- Memory.load paramPtr
            fill (core param) (shapeAPtr, bufferAPtr) (shapeBPtr, bufferBPtr)
            LLVM.ret ())

   return $ \p ->
      bracket (create p) (delete . fst) $ \(_ctx, param) ->
      alloca $ \shaPtr ->
      alloca $ \shbPtr ->
      allocaArray 2 $ \sizesPtr ->
      with param $ \paramPtr -> do
         let paramMVPtr = MultiValueMemory.castStructPtr paramPtr
         let shapeAMVPtr = MultiValueMemory.castStructPtr shaPtr
         let shapeBMVPtr = MultiValueMemory.castStructPtr shbPtr
         fsh paramMVPtr shapeAMVPtr shapeBMVPtr sizesPtr
         afptr <- mallocArray =<< peekElemOff sizesPtr 0
         bfptr <- mallocArray =<< peekElemOff sizesPtr 1
         withForeignMemPtr afptr $ \aptr ->
            withForeignMemPtr bfptr $ \bptr ->
            farr paramMVPtr shapeAMVPtr aptr shapeBMVPtr bptr
         sha <- peek shaPtr
         shb <- peek shbPtr
         return (Array sha afptr, Array shb bfptr)


render ::
   (Shape.C sh, Shape.Index sh ~ ix,
    Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   Sym.Hull p (Core.Array sh a) -> IO (p -> IO (Array sh a))
render =
   materialize "render" Core.shape
      (\(Core.Array esh code) shapePtr bufferPtr -> do
         let step ix p = do
                flip Memory.store p =<< code ix
                A.advanceArrayElementPtr p
         sh <- Shape.load esh shapePtr
         void $ Shape.loop step sh bufferPtr)


data Scatter sh0 sh1 a =
   Scatter {
      scatterAccum :: Exp a -> Exp a -> Exp a,
      scatterInit :: Core.Array sh1 a,
      scatterMap :: Core.Array sh0 (Shape.Index sh1, a)
   }

scatter ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   Sym.Hull p (Scatter sh0 sh1 a) -> IO (p -> IO (Array sh1 a))
scatter =
   materialize "scatter"
      (Core.shape . scatterInit)
      (\(Scatter accum arrInit arrMap) ->
         Priv.scatter accum arrInit arrMap)



data ScatterMaybe sh0 sh1 a =
   ScatterMaybe {
      scatterMaybeAccum :: Exp a -> Exp a -> Exp a,
      scatterMaybeInit :: Core.Array sh1 a,
      scatterMaybeMap :: Core.Array sh0 (Maybe (Shape.Index sh1, a))
   }

scatterMaybe ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   Sym.Hull p (ScatterMaybe sh0 sh1 a) -> IO (p -> IO (Array sh1 a))
scatterMaybe =
   materialize "scatterMaybe"
      (Core.shape . scatterMaybeInit)
      (\(ScatterMaybe accum arrInit arrMap) ->
         Priv.scatterMaybe accum arrInit arrMap)


data MapAccumLSimple sh n acc a b =
   MapAccumLSimple {
      mapAccumLSimpleAccum :: Exp acc -> Exp a -> Exp (acc,b),
      mapAccumLSimpleInit :: Core.Array sh acc,
      mapAccumLSimpleArray :: Core.Array (sh, n) a
   }

mapAccumLSimple ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    MultiValue.C acc,
    Storable a, MultiValueMemory.C a,
    Storable b, MultiValueMemory.C b) =>
   Sym.Hull p (MapAccumLSimple sh n acc a b) -> IO (p -> IO (Array (sh,n) b))
mapAccumLSimple =
   materialize "mapAccumLSimple"
      (Core.shape . mapAccumLSimpleArray)
      (\(MapAccumLSimple f arrInit arrData) ->
         Priv.mapAccumLSimple f arrInit arrData)


data MapAccumLSequence n acc final a b =
   MapAccumLSequence {
      mapAccumLSequenceAccum :: Exp acc -> Exp a -> Exp (acc,b),
      mapAccumLSequenceFinal :: Exp acc -> Exp final,
      mapAccumLSequenceInit :: Exp acc,
      mapAccumLSequenceArray :: Core.Array n a
   }

-- FIXME: check correct size of array of initial values
mapAccumLSequence ::
   (Shape.C n, Storable n, MultiValueMemory.C n,
    MultiValue.C acc,
    Storable final, MultiValueMemory.C final,
    Storable a, MultiValueMemory.C a,
    Storable b, MultiValueMemory.C b) =>
   Sym.Hull p (MapAccumLSequence n acc final a b) ->
   IO (p -> IO (final, Array n b))
mapAccumLSequence =
   materializeExpArray "mapAccumLSequence"
      (Core.shape . mapAccumLSequenceArray)
      (\(MapAccumLSequence f final expInit arr) ->
         Priv.mapAccumLSequence f final expInit arr)


data MapAccumL sh n acc final a b =
   MapAccumL {
      mapAccumLAccum :: Exp acc -> Exp a -> Exp (acc,b),
      mapAccumLFinal :: Exp acc -> Exp final,
      mapAccumLInit :: Core.Array sh acc,
      mapAccumLArray :: Core.Array (sh, n) a
   }

-- FIXME: check correct size of array of initial values
mapAccumL ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    MultiValue.C acc,
    Storable final, MultiValueMemory.C final,
    Storable a, MultiValueMemory.C a,
    Storable b, MultiValueMemory.C b) =>
   Sym.Hull p (MapAccumL sh n acc final a b) ->
   IO (p -> IO (Array sh final, Array (sh,n) b))
mapAccumL =
   materialize2 "mapAccumL"
      (\core ->
         Expr.zip
            (Core.shape $ mapAccumLInit core)
            (Core.shape $ mapAccumLArray core))
      (\(MapAccumL f final arrInit arrData) ->
         Priv.mapAccumL f final arrInit arrData)


data FoldOuterL n sh a b =
   FoldOuterL {
      foldOuterLAccum :: Exp a -> Exp b -> Exp a,
      foldOuterLInit :: Core.Array sh a,
      foldOuterLArray :: Core.Array (n,sh) b
   }

-- FIXME: check correct size of array of initial values
foldOuterL ::
   (Shape.C n, Storable n, MultiValueMemory.C n,
    Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   Sym.Hull p (FoldOuterL n sh a b) -> IO (p -> IO (Array sh a))
foldOuterL =
   materialize "foldOuterL"
      (Core.shape . foldOuterLInit)
      (\(FoldOuterL f arrInit arrData) -> Priv.foldOuterL f arrInit arrData)


data MapFilter n a b =
   MapFilter {
      mapFilterMap :: Exp a -> Exp b,
      mapFilterPredicate :: Exp a -> Exp Bool,
      mapFilterArray :: Core.Array n a
   }

mapFilter ::
   (Shape.Sequence n,
    Storable n, MultiValueMemory.C n,
    Storable b, MultiValueMemory.C b) =>
   Sym.Hull p (MapFilter n a b) -> IO (p -> IO (Array n b))
mapFilter =
   materialize "mapFilter"
      (Core.shape . mapFilterArray)
      (\(MapFilter f p arr) shapePtr bufferPtr ->
         flip MultiValueMemory.store shapePtr
            =<< Priv.mapFilter f p arr shapePtr bufferPtr)


data FilterOuter n sh a =
   FilterOuter {
      filterOuterPredicate :: Core.Array n Bool,
      filterOuterArray :: Core.Array (n,sh) a
   }

-- FIXME: check correct size of row selection array
filterOuter ::
   (Shape.Sequence n, Storable n, MultiValueMemory.C n,
    Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   Sym.Hull p (FilterOuter n sh a) -> IO (p -> IO (Array (n,sh) a))
filterOuter =
   materialize "filterOuter"
      (Core.shape . filterOuterArray)
      (\(FilterOuter p arr) shapePtr bufferPtr ->
         flip MultiValueMemory.store shapePtr
            =<< Priv.filterOuter p arr shapePtr bufferPtr)


data AddDimension sh n a b =
   AddDimension {
      addDimensionSize :: Exp n,
      addDimensionSelect :: Exp (Shape.Index n) -> Exp a -> Exp b,
      addDimensionArray :: Core.Array sh a
   }

addDimension ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    Storable b, MultiValueMemory.C b) =>
   Sym.Hull p (AddDimension sh n a b) -> IO (p -> IO (Array (sh,n) b))
addDimension =
   materialize "addDimension"
      (\r -> Expr.zip (Core.shape (addDimensionArray r)) (addDimensionSize r))
      (\(AddDimension n select arr) -> Priv.addDimension n select arr)