{-# LANGUAGE GADTs #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Data.Array.Knead.Parameterized.Physical (
   Phys.Array,
   Phys.shape,
   Phys.fromList,
   feed,
   the,
   render,
   renderShape,
   scatter,
   permute,
   ) where

import qualified Data.Array.Knead.Parameterized.Private as Sym
import qualified Data.Array.Knead.Simple.Physical as Phys
import qualified Data.Array.Knead.Simple.Private as Core
import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Code as Code
import Data.Array.Knead.Expression (Exp, unExp, )
import Data.Array.Knead.Code (getElementPtr, compile, )

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Memory as Memory

import qualified LLVM.Core as LLVM

import Foreign.Marshal.Utils (with, )
import Foreign.Marshal.Alloc (alloca, )
import Foreign.Storable (Storable, peek, )
import Foreign.ForeignPtr (withForeignPtr, mallocForeignPtrArray, touchForeignPtr, )
import Foreign.Ptr (FunPtr, Ptr, )

import Control.Exception (bracket, )
import Control.Monad.HT (void, (<=<), )
import Control.Applicative (liftA2, )
import Data.Tuple.HT (mapFst, )
import Data.Word (Word32, )

import Prelude hiding (scanl1, )


feed ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    MultiValueMemory.C a) =>
   Param.T p (Phys.Array sh a) -> Sym.Array p sh a
feed arr =
   Param.withMulti (fmap Phys.shape arr) $ \getShape valueShape ->
   Sym.Array
      (\p ->
         case mapFst valueShape $ MultiValue.unzip p of
            (sh, MultiValue.Cons ptr) ->
               Core.Array (Expr.lift0 sh) $
                  Memory.load <=< getElementPtr sh ptr)
      (\p ->
         case Phys.buffer $ Param.get arr p of
            fptr ->
               withForeignPtr fptr $ \ptr ->
                  return (fptr, (getShape p, MultiValueMemory.castStructPtr ptr)))
      touchForeignPtr


type Importer f = FunPtr f -> f

foreign import ccall safe "dynamic" callThe ::
   Importer (Ptr param -> Ptr am -> IO ())


the ::
   (Shape.Scalar z, MultiValueMemory.C a, Storable a) =>
   Sym.Array p z a -> IO (p -> IO a)
the (Sym.Array arr create delete) = do
   func <-
      compile "the" $
      Code.createFunction callThe "eval" $
      \paramPtr resultPtr -> do
         param <- Memory.load paramPtr
         case arr param of
            Core.Array z code ->
               code (Shape.zeroIndex z) >>= flip Memory.store resultPtr
         LLVM.ret ()
   return $ \p ->
      bracket (create p) (delete . fst) $ \(_ctx, param) ->
      with param $ \pptr ->
      alloca $ \aptr ->
         func (MultiValueMemory.castStructPtr pptr) (MultiValueMemory.castStructPtr aptr) >>
         peek aptr


foreign import ccall safe "dynamic" callShaper ::
   Importer (Ptr param -> Ptr shape -> IO Word32)

foreign import ccall safe "dynamic" callRenderer ::
   Importer (Ptr param -> Ptr shape -> Ptr am -> IO ())


renderShape ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   Sym.Array p sh a -> IO (p -> IO (sh, Word32))
renderShape (Sym.Array arr create delete) = do
   fsh <-
      compile "renderShape" $
      Code.createFunction callShaper "shape" $
      \paramPtr resultPtr -> do
        param <- Memory.load paramPtr
        case arr param of
           Core.Array esh _code -> do
              sh <- unExp esh
              MultiValueMemory.store sh resultPtr
              Shape.sizeCode sh >>= LLVM.ret
   return $ \p ->
      bracket (create p) (delete . fst) $ \(_ctx, param) ->
      alloca $ \shptr ->
      with param $ \pptr -> do
         let lpptr = MultiValueMemory.castStructPtr pptr
         let lshptr = MultiValueMemory.castStructPtr shptr
         n <- fsh lpptr lshptr
         sh <- peek shptr
         return (sh, n)


render ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   Sym.Array p sh a -> IO (p -> IO (Phys.Array sh a))
render (Sym.Array arr create delete) = do
   (fsh, farr) <-
      compile "render" $
      liftA2 (,)
         (Code.createFunction callShaper "shape" $
          \paramPtr resultPtr -> do
            param <- Memory.load paramPtr
            case arr param of
               Core.Array esh _code -> do
                  sh <- unExp esh
                  MultiValueMemory.store sh resultPtr
                  Shape.sizeCode sh >>= LLVM.ret)
         (Code.createFunction callRenderer "fill" $
          \paramPtr shapePtr bufferPtr -> do
            param <- Memory.load paramPtr
            case arr param of
               Core.Array esh code -> do
                  let step ix p = do
                         flip Memory.store p =<< code ix
                         A.advanceArrayElementPtr p
                  sh <- Shape.load esh shapePtr
                  void $ Shape.loop step sh bufferPtr
                  LLVM.ret ())
   return $ \p ->
      bracket (create p) (delete . fst) $ \(_ctx, param) ->
      alloca $ \shptr ->
      with param $ \pptr -> do
         let lpptr = MultiValueMemory.castStructPtr pptr
         let lshptr = MultiValueMemory.castStructPtr shptr
         n <- fsh lpptr lshptr
         fptr <- mallocForeignPtrArray (fromIntegral n)
         withForeignPtr fptr $
            farr lpptr lshptr . MultiValueMemory.castStructPtr
         sh <- peek shptr
         return (Phys.Array sh fptr)



foreign import ccall safe "dynamic" callScatterer ::
   Importer (Ptr paramBase -> Ptr paramMap -> Ptr shape -> Ptr am -> IO ())

scatter ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array p sh1 a ->
   Sym.Array p sh0 (ix1, a) -> IO (p -> IO (Phys.Array sh1 a))
scatter accum
      (Sym.Array arrBase createBase deleteBase)
      (Sym.Array arrMap createMap deleteMap) = do

   (fsh, farr) <-
      compile "scatter" $
      liftA2 (,)
         (Code.createFunction callShaper "shape" $
          \paramPtr resultPtr -> do
            param <- Memory.load paramPtr
            case arrBase param of
               Core.Array esh _code -> do
                  sh <- unExp esh
                  MultiValueMemory.store sh resultPtr
                  Shape.sizeCode sh >>= LLVM.ret)
         (Code.createFunction callScatterer "fill" $
          \paramBasePtr paramMapPtr shapePtr bufferPtr -> do
            paramBase <- Memory.load paramBasePtr
            paramMap <- Memory.load paramMapPtr
            case (arrBase paramBase, arrMap paramMap) of
               (Core.Array esh codeBase, Core.Array eish codeMap) -> do
                  let clear ix p = do
                         flip Memory.store p =<< codeBase ix
                         A.advanceArrayElementPtr p
                  sh <- Shape.load esh shapePtr
                  void $ Shape.loop clear sh bufferPtr

                  ish <- unExp eish
                  let fill ix () = do
                         (jx, a) <- fmap MultiValue.unzip $ codeMap ix
                         p <- getElementPtr sh bufferPtr jx
                         flip Memory.store p
                            =<< Expr.unliftM2 (flip accum) a
                            =<< Memory.load p
                  Shape.loop fill ish ()
            LLVM.ret ())

   return $ \p ->
      bracket (createBase p) (deleteBase . fst) $ \(_ctxBase, paramBase) ->
      bracket (createMap p) (deleteMap . fst) $ \(_ctxMap, paramMap) ->
      alloca $ \shptr ->
      with paramBase $ \paramBasePtr -> do
      with paramMap $ \paramMapPtr -> do
         let paramBaseMVPtr = MultiValueMemory.castStructPtr paramBasePtr
         let paramMapMVPtr = MultiValueMemory.castStructPtr paramMapPtr
         let shapeMVPtr = MultiValueMemory.castStructPtr shptr
         n <- fsh paramBaseMVPtr shapeMVPtr
         fptr <- mallocForeignPtrArray (fromIntegral n)
         withForeignPtr fptr $
            farr paramBaseMVPtr paramMapMVPtr shapeMVPtr .
            MultiValueMemory.castStructPtr
         sh <- peek shptr
         return (Phys.Array sh fptr)

permute ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array p sh1 a ->
   (Exp ix0 -> Exp ix1) ->
   Sym.Array p sh0 a ->
   IO (p -> IO (Phys.Array sh1 a))
permute accum deflt ixmap input =
   scatter accum deflt
      (Core.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)