{-# LANGUAGE GADTs #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Data.Array.Knead.Parameterized.Physical (
   Phys.Array,
   Array.shape,
   Phys.fromList,
   feed,
   the,
   theMarshal,
   render,
   renderShape,
   mapAccumLSimple,
   foldOuterL,
   scatter,
   scatterMaybe,
   permute,
   ) where

import qualified Data.Array.Knead.Parameterized.PhysicalHull as PhysHull
import qualified Data.Array.Knead.Parameterized.Private as Sym
import qualified Data.Array.Knead.Simple.Physical as Phys
import qualified Data.Array.Knead.Simple.Private as Core
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Code (getElementPtr)

import qualified Data.Array.Comfort.Storable.Unchecked as Array

import qualified LLVM.DSL.Parameter as Param
import qualified LLVM.DSL.Execution as Code
import LLVM.DSL.Expression (Exp, unExp)

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Storable as Storable
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Memory as Memory

import qualified LLVM.Core as LLVM

import Foreign.Marshal.Alloc (alloca, )
import Foreign.Storable (peek, )
import Foreign.ForeignPtr (withForeignPtr, touchForeignPtr, )
import Foreign.Ptr (FunPtr, Ptr, )

import Control.Exception (bracket, )
import Control.Monad.HT ((<=<), )
import Control.Applicative (liftA2, )
import Data.Tuple.HT (mapFst, )


{-# INLINE feed #-}
feed ::
   (Shape.C sh, Marshal.MV sh, Storable.C a) =>
   Param.T p (Phys.Array sh a) -> Sym.Array p sh a
feed arr =
   Param.withMulti (fmap Array.shape arr) $ \getShape valueShape ->
   Sym.Array
      (\p ->
         case mapFst valueShape $ MultiValue.unzip p of
            (sh, MultiValue.Cons ptr) ->
               Core.Array (Expr.lift0 sh) $
                  Storable.loadMultiValue <=< getElementPtr sh ptr)
      (\p ->
         case Array.buffer $ Param.get arr p of
            fptr ->
               withForeignPtr fptr $ \ptr ->
                  return (fptr, (getShape p, ptr)))
      touchForeignPtr


type Importer f = FunPtr f -> f

foreign import ccall safe "dynamic" callThe ::
   Importer (LLVM.Ptr param -> Ptr a -> IO ())


the ::
   (Shape.Scalar z, Storable.C a, MultiValue.C a) =>
   Sym.Array p z a -> IO (p -> IO a)
the (Sym.Array arr create delete) = do
   func <-
      Code.compile "the" $
      Code.createFunction callThe "eval" $
      \paramPtr resultPtr -> do
         param <- Memory.load paramPtr
         case arr param of
            Core.Array z code ->
               code (Shape.zeroIndex z) >>=
               flip Storable.storeMultiValue resultPtr
         LLVM.ret ()
   return $ \p ->
      bracket (create p) (delete . fst) $ \(_ctx, param) ->
      Marshal.with param $ \pptr ->
      alloca $ \aptr -> func pptr aptr >> peek aptr

foreign import ccall safe "dynamic" callTheMarshal ::
   Importer (LLVM.Ptr param -> LLVM.Ptr a -> IO ())

theMarshal ::
   (Shape.Scalar z, Marshal.C a, MultiValue.C a) =>
   Sym.Array p z a -> IO (p -> IO a)
theMarshal (Sym.Array arr create delete) = do
   func <-
      Code.compile "the-marshal" $
      Code.createFunction callTheMarshal "eval" $
      \paramPtr resultPtr -> do
         param <- Memory.load paramPtr
         case arr param of
            Core.Array z code ->
               code (Shape.zeroIndex z) >>=
               flip Memory.store resultPtr
         LLVM.ret ()
   return $ \p ->
      bracket (create p) (delete . fst) $ \(_ctx, param) ->
      Marshal.with param $ \pptr ->
      Marshal.alloca $ \aptr ->
         func pptr aptr >>
         Marshal.peek aptr


foreign import ccall safe "dynamic" callShaper ::
   Importer (LLVM.Ptr param -> LLVM.Ptr shape -> IO Shape.Size)


renderShape ::
   (Shape.C sh, Marshal.MV sh,
    Storable.C a, MultiValue.C a) =>
   Sym.Array p sh a -> IO (p -> IO (sh, Shape.Size))
renderShape (Sym.Array arr create delete) = do
   fsh <-
      Code.compile "renderShape" $
      Code.createFunction callShaper "shape" $
      \paramPtr resultPtr -> do
        param <- Memory.load paramPtr
        case arr param of
           Core.Array esh _code -> do
              sh <- unExp esh
              Memory.store sh resultPtr
              Shape.size sh >>= LLVM.ret
   return $ \p ->
      bracket (create p) (delete . fst) $ \(_ctx, param) ->
      Marshal.alloca $ \shptr ->
      Marshal.with param $ \pptr -> do
         n <- fsh pptr shptr
         sh <- Marshal.peek shptr
         return (sh, n)


render ::
   (Shape.C sh, Marshal.MV sh, Storable.C a) =>
   Sym.Array p sh a -> IO (p -> IO (Phys.Array sh a))
render = PhysHull.render . Sym.arrayHull


mapAccumLSimple ::
   (Shape.C sh, Marshal.MV sh,
    Shape.C n, Marshal.MV n,
    MultiValue.C acc,
    Storable.C a, MultiValue.C a,
    Storable.C b, MultiValue.C b) =>
   (Exp acc -> Exp a -> Exp (acc,b)) ->
   Sym.Array p sh acc ->
   Sym.Array p (sh, n) a ->
   IO (p -> IO (Phys.Array (sh,n) b))
mapAccumLSimple f arrInit arrMap =
   PhysHull.mapAccumLSimple $
      liftA2 (PhysHull.MapAccumLSimple f)
         (Sym.arrayHull arrInit)
         (Sym.arrayHull arrMap)

foldOuterL ::
   (Shape.C sh, Marshal.MV sh,
    Shape.C n, Marshal.MV n,
    Storable.C a, MultiValue.C a) =>
   (Exp a -> Exp b -> Exp a) ->
   Sym.Array p sh a ->
   Sym.Array p (n,sh) b ->
   IO (p -> IO (Phys.Array sh a))
foldOuterL f arrInit arrMap =
   PhysHull.foldOuterL $
      liftA2 (PhysHull.FoldOuterL f)
         (Sym.arrayHull arrInit)
         (Sym.arrayHull arrMap)

scatter ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1,
    Storable.C a, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array p sh1 a ->
   Sym.Array p sh0 (ix1, a) -> IO (p -> IO (Phys.Array sh1 a))
scatter accum arrBase arrMap =
   PhysHull.scatter $
      liftA2 (PhysHull.Scatter accum)
         (Sym.arrayHull arrBase)
         (Sym.arrayHull arrMap)

scatterMaybe ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1,
    Storable.C a, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array p sh1 a ->
   Sym.Array p sh0 (Maybe (ix1, a)) -> IO (p -> IO (Phys.Array sh1 a))
scatterMaybe accum arrBase arrMap =
   PhysHull.scatterMaybe $
      liftA2 (PhysHull.ScatterMaybe accum)
         (Sym.arrayHull arrBase)
         (Sym.arrayHull arrMap)

permute ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1,
    Storable.C a, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array p sh1 a ->
   (Exp ix0 -> Exp ix1) ->
   Sym.Array p sh0 a ->
   IO (p -> IO (Phys.Array sh1 a))
permute accum deflt ixmap input =
   scatter accum deflt
      (Core.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)