{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Data.Array.Knead.Simple.Physical (
   Array(Array, shape, buffer), -- data constructor intended for PhysicalParameterized
   toList,
   fromList,
   vectorFromList,
   with,
   render,
   scanl1,
   mapAccumLSimple,
   scatter,
   scatterMaybe,
   permute,
   ) where

import qualified Data.Array.Knead.Simple.PhysicalPrivate as Priv
import qualified Data.Array.Knead.Simple.Private as Sym
import qualified Data.Array.Knead.Shape.Nested as Shape
import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Code as Code
import Data.Array.Knead.Expression (Exp, unExp, )
import Data.Array.Knead.Code (getElementPtr, compile, )

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Maybe as Maybe

import qualified LLVM.Core as LLVM

import Foreign.Marshal.Array (pokeArray, peekArray, )
import Foreign.Marshal.Alloc (alloca, )
import Foreign.Storable (Storable, peek, )
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, mallocForeignPtrArray, )
import Foreign.Ptr (FunPtr, Ptr, )

import Control.Monad.HT (void, )
import Control.Applicative (liftA2, (<$>), )

import Prelude hiding (scanl1, )


data Array sh a =
   Array {
      shape :: sh,
      buffer :: ForeignPtr a
   }


toList ::
   (Shape.C sh, Storable a) =>
   Array sh a -> IO [a]
toList (Array sh fptr) =
   withForeignPtr fptr $ peekArray (Shape.size sh)

fromList ::
   (Shape.C sh, Storable a) =>
   sh -> [a] -> IO (Array sh a)
fromList sh xs = do
   let size = Shape.size sh
   fptr <- mallocForeignPtrArray size
   withForeignPtr fptr $
      \ptr ->
         pokeArray ptr $
         take size $
         xs ++ repeat (error "Array.Knead.Physical.fromList: list too short for shape")
   return (Array sh fptr)

vectorFromList ::
   (Shape.C sh, Num sh, Storable a) =>
   [a] -> IO (Array sh a)
vectorFromList xs = do
   let size = length xs
   fptr <- mallocForeignPtrArray size
   withForeignPtr fptr $ flip pokeArray xs
   return (Array (fromIntegral size) fptr)


{- |
The symbolic array is only valid inside the enclosed action.
-}
with ::
   (Shape.C sh, MultiValueMemory.C a) =>
   (Sym.Array sh a -> IO b) ->
   Array sh a -> IO b
with f (Array sh fptr) =
   withForeignPtr fptr $ \ptr ->
      f $
      Sym.Array
         (Shape.value sh)
         (\ix ->
            Memory.load =<<
               getElementPtr (Shape.value sh)
                  (LLVM.valueOf (MultiValueMemory.castStructPtr ptr)) ix)


type Importer f = FunPtr f -> f

foreign import ccall safe "dynamic" callShaper ::
   Importer (Ptr sh -> IO Shape.Size)

foreign import ccall safe "dynamic" callRenderer ::
   Importer (Ptr sh -> Ptr am -> IO ())


materialize ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a, MultiValueMemory.Struct a ~ am) =>
   String ->
   Exp sh ->
   (LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
    LLVM.Value (Ptr am) -> LLVM.CodeGenFunction () ()) ->
   IO (Array sh a)
materialize name esh code =
   alloca $ \shptr -> do
      (fsh, farr) <-
         compile name $
         liftA2 (,)
            (Code.createFunction callShaper "shape" $ \ptr -> do
               sh <- unExp esh
               MultiValueMemory.store sh ptr
               Shape.sizeCode sh >>= LLVM.ret)
            (Code.createFunction callRenderer "fill"
               (\paramPtr arrayPtr -> code paramPtr arrayPtr >> LLVM.ret ()))
      let lshptr = MultiValueMemory.castStructPtr shptr
      n <- fsh lshptr
      fptr <- mallocForeignPtrArray (fromIntegral n)
      withForeignPtr fptr $ farr lshptr . MultiValueMemory.castStructPtr
      sh <- peek shptr
      return (Array sh fptr)

render ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   Sym.Array sh a -> IO (Array sh a)
render (Sym.Array esh code) =
   materialize "render" esh $ \sptr ptr -> do
      let step ix p = do
             flip Memory.store p =<< code ix
             A.advanceArrayElementPtr p
      sh <- Shape.load esh sptr
      void $ Shape.loop step sh ptr

scanl1 ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array (sh, n) a -> IO (Array (sh, n) a)
scanl1 f (Sym.Array esh code) =
   materialize "scanl1" esh $ \sptr ptr -> do
      (sh, n) <- MultiValue.unzip <$> Shape.load esh sptr
      let step ix ptrStart =
             fmap fst $
             (\body -> Shape.loop body n (ptrStart, Maybe.nothing)) $
                   \k0 (ptr0, macc0) -> do
                a <- code $ MultiValue.zip ix k0
                acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a)
                Memory.store acc1 ptr0
                ptr1 <- A.advanceArrayElementPtr ptr0
                return (ptr1, Maybe.just acc1)
      void $ Shape.loop step sh ptr

mapAccumLSimple ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    MultiValue.C acc,
    Storable x, MultiValueMemory.C x,
    Storable y, MultiValueMemory.C y) =>
   (Exp acc -> Exp x -> Exp (acc,y)) ->
   Sym.Array sh acc -> Sym.Array (sh, n) x -> IO (Array (sh, n) y)
mapAccumLSimple f arrInit arrData =
   materialize "mapAccumLSimple" (Sym.shape arrData) $
      Priv.mapAccumLSimple f arrInit arrData

scatterMaybe ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   Sym.Array sh0 (Maybe (ix1, a)) -> IO (Array sh1 a)
scatterMaybe accum arrInit arrMap =
   materialize "scatterMaybe" (Sym.shape arrInit) $
      Priv.scatterMaybe accum arrInit arrMap

scatter ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   Sym.Array sh0 (ix1, a) -> IO (Array sh1 a)
scatter accum arrInit arrMap =
   materialize "scatter" (Sym.shape arrInit) $
      Priv.scatter accum arrInit arrMap

permute ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   (Exp ix0 -> Exp ix1) ->
   Sym.Array sh0 a ->
   IO (Array sh1 a)
permute accum deflt ixmap input =
   scatter accum deflt
      (Sym.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)