{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Data.Array.Knead.Simple.Physical (
   Array(Array, shape, buffer), -- data constructor intended for PhysicalParameterized
   toList,
   fromList,
   vectorFromList,
   with,
   render,
   scanl1,
   scatter,
   permute,
   ) where

import qualified Data.Array.Knead.Simple.Private as Sym
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Code as Code
import Data.Array.Knead.Expression (Exp, unExp, )
import Data.Array.Knead.Code (getElementPtr, compile, )

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Maybe as Maybe

import qualified LLVM.Core as LLVM

import Foreign.Marshal.Array (pokeArray, peekArray, )
import Foreign.Marshal.Alloc (alloca, )
import Foreign.Storable (Storable, peek, )
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, mallocForeignPtrArray, )
import Foreign.Ptr (FunPtr, Ptr, )

import Control.Monad.HT (void, )
import Control.Applicative (liftA2, )
import Data.Word (Word32, )

import Prelude hiding (scanl1, )


data Array sh a =
   Array {
      shape :: sh,
      buffer :: ForeignPtr a
   }


toList ::
   (Shape.C sh, Storable a) =>
   Array sh a -> IO [a]
toList (Array sh fptr) =
   withForeignPtr fptr $ peekArray (Shape.size sh)

fromList ::
   (Shape.C sh, Storable a) =>
   sh -> [a] -> IO (Array sh a)
fromList sh xs = do
   let size = Shape.size sh
   fptr <- mallocForeignPtrArray size
   withForeignPtr fptr $
      \ptr ->
         pokeArray ptr $
         take size $
         xs ++ repeat (error "Array.Knead.Physical.fromList: list too short for shape")
   return (Array sh fptr)

vectorFromList ::
   (Shape.C sh, Num sh, Storable a) =>
   [a] -> IO (Array sh a)
vectorFromList xs = do
   let size = length xs
   fptr <- mallocForeignPtrArray size
   withForeignPtr fptr $ flip pokeArray xs
   return (Array (fromIntegral size) fptr)


{- |
The symbolic array is only valid inside the enclosed action.
-}
with ::
   (Shape.C sh, MultiValueMemory.C a) =>
   (Sym.Array sh a -> IO b) ->
   Array sh a -> IO b
with f (Array sh fptr) =
   withForeignPtr fptr $ \ptr ->
      f $
      Sym.Array
         (Shape.value sh)
         (\ix ->
            Memory.load =<<
               getElementPtr (Shape.value sh)
                  (LLVM.valueOf (MultiValueMemory.castStructPtr ptr)) ix)


type Importer f = FunPtr f -> f

foreign import ccall safe "dynamic" callShaper ::
   Importer (Ptr sh -> IO Word32)

foreign import ccall safe "dynamic" callRenderer ::
   Importer (Ptr sh -> Ptr am -> IO ())


materialize ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a, MultiValueMemory.Struct a ~ am) =>
   String ->
   Exp sh ->
   (LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
    LLVM.Value (Ptr am) -> LLVM.CodeGenFunction () ()) ->
   IO (Array sh a)
materialize name esh code =
   alloca $ \shptr -> do
      (fsh, farr) <-
         compile name $
         liftA2 (,)
            (Code.createFunction callShaper "shape" $ \ptr -> do
               sh <- unExp esh
               MultiValueMemory.store sh ptr
               Shape.sizeCode sh >>= LLVM.ret)
            (Code.createFunction callRenderer "fill"
               (\paramPtr arrayPtr -> code paramPtr arrayPtr >> LLVM.ret ()))
      let lshptr = MultiValueMemory.castStructPtr shptr
      n <- fsh lshptr
      fptr <- mallocForeignPtrArray (fromIntegral n)
      withForeignPtr fptr $ farr lshptr . MultiValueMemory.castStructPtr
      sh <- peek shptr
      return (Array sh fptr)

render ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   Sym.Array sh a -> IO (Array sh a)
render (Sym.Array esh code) =
   materialize "render" esh $ \sptr ptr -> do
      let step ix p = do
             flip Memory.store p =<< code ix
             A.advanceArrayElementPtr p
      sh <- Shape.load esh sptr
      void $ Shape.loop step sh ptr

scanl1 ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array (sh, Word32) a -> IO (Array (sh, Word32) a)
scanl1 f (Sym.Array esh code) =
   materialize "scanl1" esh $ \sptr ptr -> do
      (sh, MultiValue.Cons n) <-
         fmap MultiValue.unzip $ Shape.load esh sptr
      let step ix ptrStart =
             fmap (fst.fst) $
             C.fixedLengthLoop n ((ptrStart, A.zero), Maybe.nothing) $ \((ptr0, k0), macc0) -> do
                a <- code (MultiValue.zip ix $ MultiValue.Cons k0)
                acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a)
                flip Memory.store ptr0 acc1
                ptrK1 <-
                   liftA2 (,)
                      (A.advanceArrayElementPtr ptr0)
                      (A.inc k0)
                return (ptrK1, Maybe.just acc1)
      void $ Shape.loop step sh ptr

scatter ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   Sym.Array sh0 (ix1, a) -> IO (Array sh1 a)
scatter accum (Sym.Array esh defltCode) (Sym.Array eish code) =
   materialize "scatter" esh $ \sptr ptr -> do
      let clear ix p = do
             flip Memory.store p =<< defltCode ix
             A.advanceArrayElementPtr p
      sh <- Shape.load esh sptr
      void $ Shape.loop clear sh ptr

      ish <- unExp eish
      let fill ix () = do
             (jx, a) <- fmap MultiValue.unzip $ code ix
             p <- getElementPtr sh ptr jx
             flip Memory.store p
                =<< Expr.unliftM2 (flip accum) a
                =<< Memory.load p
      void $ Shape.loop fill ish ()

permute ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   (Exp ix0 -> Exp ix1) ->
   Sym.Array sh0 a ->
   IO (Array sh1 a)
permute accum deflt ixmap input =
   scatter accum deflt
      (Sym.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)