{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Data.Array.Knead.Simple.Physical (
   Array,
   shape,
   toList,
   fromList,
   vectorFromList,
   with,
   render,
   scanl1,
   mapAccumLSimple,
   scatter,
   scatterMaybe,
   permute,
   ) where

import qualified Data.Array.Knead.Simple.PhysicalPrivate as Priv
import qualified Data.Array.Knead.Simple.Private as Sym
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Code as Code
import Data.Array.Knead.Expression (Exp, unExp, )
import Data.Array.Knead.Code (getElementPtr, compile, )

import qualified Data.Array.Comfort.Storable.Mutable.Unchecked as MutArray
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as ComfortShape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Maybe as Maybe

import qualified LLVM.Core as LLVM

import Foreign.Marshal.Alloc (alloca, )
import Foreign.Storable (Storable, peek, )
import Foreign.ForeignPtr (withForeignPtr, mallocForeignPtrArray, )
import Foreign.Ptr (FunPtr, Ptr, )

import Control.Monad.HT (void, (<=<), )
import Control.Applicative (liftA2, (<$>), )

import Prelude2010 hiding (scanl1)
import Prelude ()


shape :: Array sh a -> sh
shape = Array.shape

toList ::
   (Shape.C sh, Storable a) =>
   Array sh a -> IO [a]
toList = MutArray.toList <=< MutArray.unsafeThaw

fromList ::
   (Shape.C sh, Storable a) =>
   sh -> [a] -> IO (Array sh a)
fromList sh = MutArray.unsafeFreeze <=< MutArray.fromList sh

vectorFromList ::
   (Num n, Storable a) =>
   [a] -> IO (Array (ComfortShape.ZeroBased n) a)
vectorFromList xs =
   Array.mapShape (\(Shape.ZeroBased n) -> Shape.ZeroBased $ fromIntegral n) <$>
   (MutArray.unsafeFreeze =<< MutArray.vectorFromList xs)


{- |
The symbolic array is only valid inside the enclosed action.
-}
with ::
   (Shape.C sh, MultiValueMemory.C a) =>
   (Sym.Array sh a -> IO b) ->
   Array sh a -> IO b
with f (Array sh fptr) =
   withForeignPtr fptr $ \ptr ->
      f $
      Sym.Array
         (Shape.value sh)
         (\ix ->
            Memory.load =<<
               getElementPtr (Shape.value sh)
                  (LLVM.valueOf (MultiValueMemory.castStructPtr ptr)) ix)


type Importer f = FunPtr f -> f

foreign import ccall safe "dynamic" callShaper ::
   Importer (Ptr sh -> IO Shape.Size)

foreign import ccall safe "dynamic" callRenderer ::
   Importer (Ptr sh -> Ptr am -> IO ())


materialize ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a, MultiValueMemory.Struct a ~ am) =>
   String ->
   Exp sh ->
   (LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
    LLVM.Value (Ptr am) -> LLVM.CodeGenFunction () ()) ->
   IO (Array sh a)
materialize name esh code =
   alloca $ \shptr -> do
      (fsh, farr) <-
         compile name $
         liftA2 (,)
            (Code.createFunction callShaper "shape" $ \ptr -> do
               sh <- unExp esh
               MultiValueMemory.store sh ptr
               Shape.size sh >>= LLVM.ret)
            (Code.createFunction callRenderer "fill"
               (\paramPtr arrayPtr -> code paramPtr arrayPtr >> LLVM.ret ()))
      let lshptr = MultiValueMemory.castStructPtr shptr
      n <- fsh lshptr
      fptr <- mallocForeignPtrArray (fromIntegral n)
      withForeignPtr fptr $ farr lshptr . MultiValueMemory.castStructPtr
      sh <- peek shptr
      return (Array sh fptr)

render ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Storable a, MultiValueMemory.C a) =>
   Sym.Array sh a -> IO (Array sh a)
render (Sym.Array esh code) =
   materialize "render" esh $ \sptr ptr -> do
      let step ix p = do
             flip Memory.store p =<< code ix
             A.advanceArrayElementPtr p
      sh <- Shape.load esh sptr
      void $ Shape.loop step sh ptr

scanl1 ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array (sh, n) a -> IO (Array (sh, n) a)
scanl1 f (Sym.Array esh code) =
   materialize "scanl1" esh $ \sptr ptr -> do
      (sh, n) <- MultiValue.unzip <$> Shape.load esh sptr
      let step ix ptrStart =
             fmap fst $
             (\body -> Shape.loop body n (ptrStart, Maybe.nothing)) $
                   \k0 (ptr0, macc0) -> do
                a <- code $ MultiValue.zip ix k0
                acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a)
                Memory.store acc1 ptr0
                ptr1 <- A.advanceArrayElementPtr ptr0
                return (ptr1, Maybe.just acc1)
      void $ Shape.loop step sh ptr

mapAccumLSimple ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    MultiValue.C acc,
    Storable x, MultiValueMemory.C x,
    Storable y, MultiValueMemory.C y) =>
   (Exp acc -> Exp x -> Exp (acc,y)) ->
   Sym.Array sh acc -> Sym.Array (sh, n) x -> IO (Array (sh, n) y)
mapAccumLSimple f arrInit arrData =
   materialize "mapAccumLSimple" (Sym.shape arrData) $
      Priv.mapAccumLSimple f arrInit arrData

scatterMaybe ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   Sym.Array sh0 (Maybe (ix1, a)) -> IO (Array sh1 a)
scatterMaybe accum arrInit arrMap =
   materialize "scatterMaybe" (Sym.shape arrInit) $
      Priv.scatterMaybe accum arrInit arrMap

scatter ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   Sym.Array sh0 (ix1, a) -> IO (Array sh1 a)
scatter accum arrInit arrMap =
   materialize "scatter" (Sym.shape arrInit) $
      Priv.scatter accum arrInit arrMap

permute ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   (Exp ix0 -> Exp ix1) ->
   Sym.Array sh0 a ->
   IO (Array sh1 a)
permute accum deflt ixmap input =
   scatter accum deflt
      (Sym.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)