{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Data.Array.Knead.Simple.Physical (
   Array,
   shape,
   toList,
   fromList,
   vectorFromList,
   with,
   render,
   scanl1,
   mapAccumLSimple,
   scatter,
   scatterMaybe,
   permute,
   ) where

import qualified Data.Array.Knead.Simple.PhysicalPrivate as Priv
import qualified Data.Array.Knead.Simple.Private as Sym
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Simple.PhysicalPrivate (MarshalPtr)
import Data.Array.Knead.Code (getElementPtr)

import qualified LLVM.DSL.Execution as Code
import LLVM.DSL.Expression (Exp, unExp)

import qualified Data.Array.Comfort.Storable.Mutable.Unchecked as MutArray
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as ComfortShape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Storable as Storable
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Maybe as Maybe

import qualified LLVM.Core as LLVM

import Foreign.Storable (Storable, )
import Foreign.ForeignPtr (withForeignPtr, mallocForeignPtrArray, )
import Foreign.Ptr (FunPtr, Ptr, )

import Control.Monad.HT (void, (<=<), )
import Control.Applicative (liftA2, (<$>), )

import Prelude2010 hiding (scanl1)
import Prelude ()


shape :: Array sh a -> sh
shape = Array.shape

toList ::
   (Shape.C sh, Storable a) =>
   Array sh a -> IO [a]
toList = MutArray.toList <=< MutArray.unsafeThaw

fromList ::
   (Shape.C sh, Storable a) =>
   sh -> [a] -> IO (Array sh a)
fromList sh = MutArray.unsafeFreeze <=< MutArray.fromList sh

vectorFromList ::
   (Num n, Storable a) =>
   [a] -> IO (Array (ComfortShape.ZeroBased n) a)
vectorFromList xs =
   Array.mapShape (\(Shape.ZeroBased n) -> Shape.ZeroBased $ fromIntegral n) <$>
   (MutArray.unsafeFreeze =<< MutArray.vectorFromList xs)


{- |
The symbolic array is only valid inside the enclosed action.
-}
with ::
   (Shape.C sh, Storable.C a) =>
   (Sym.Array sh a -> IO b) ->
   Array sh a -> IO b
with f (Array sh fptr) =
   withForeignPtr fptr $ \ptr ->
      f $
      Sym.Array
         (Shape.value sh)
         (\ix ->
            Storable.loadMultiValue =<<
               getElementPtr (Shape.value sh) (LLVM.valueOf ptr) ix)


type Importer f = FunPtr f -> f

foreign import ccall safe "dynamic" callShaper ::
   Importer (LLVM.Ptr sh -> IO Shape.Size)

foreign import ccall safe "dynamic" callRenderer ::
   Importer (LLVM.Ptr sh -> Ptr a -> IO ())


materialize ::
   (Shape.C sh, Marshal.MV sh, Storable.C a) =>
   String ->
   Exp sh ->
   (LLVM.Value (MarshalPtr sh) ->
    LLVM.Value (Ptr a) -> LLVM.CodeGenFunction () ()) ->
   IO (Array sh a)
materialize name esh code =
   Marshal.alloca $ \lshptr -> do
      (fsh, farr) <-
         Code.compile name $
         liftA2 (,)
            (Code.createFunction callShaper "shape" $ \ptr -> do
               sh <- unExp esh
               Memory.store sh ptr
               Shape.size sh >>= LLVM.ret)
            (Code.createFunction callRenderer "fill"
               (\paramPtr arrayPtr -> code paramPtr arrayPtr >> LLVM.ret ()))
      n <- fsh lshptr
      fptr <- mallocForeignPtrArray (fromIntegral n)
      withForeignPtr fptr $ farr lshptr
      sh <- Marshal.peek lshptr
      return (Array sh fptr)

render ::
   (Shape.C sh, Marshal.MV sh, Storable.C a) =>
   Sym.Array sh a -> IO (Array sh a)
render (Sym.Array esh code) =
   materialize "render" esh $ \sptr ptr -> do
      let step ix p = flip Storable.storeNextMultiValue p =<< code ix
      sh <- Shape.load esh sptr
      void $ Shape.loop step sh ptr

scanl1 ::
   (Shape.C sh, Marshal.MV sh,
    Shape.C n, Marshal.MV n,
    Storable.C a, MultiValue.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array (sh, n) a -> IO (Array (sh, n) a)
scanl1 f (Sym.Array esh code) =
   materialize "scanl1" esh $ \sptr ptr -> do
      (sh, n) <- MultiValue.unzip <$> Shape.load esh sptr
      let step ix ptrStart =
             fmap fst $
             (\body -> Shape.loop body n (ptrStart, Maybe.nothing)) $
                   \k0 (ptr0, macc0) -> do
                a <- code $ MultiValue.zip ix k0
                acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a)
                ptr1 <- Storable.storeNextMultiValue acc1 ptr0
                return (ptr1, Maybe.just acc1)
      void $ Shape.loop step sh ptr

mapAccumLSimple ::
   (Shape.C sh, Marshal.MV sh,
    Shape.C n, Marshal.MV n,
    MultiValue.C acc, Storable.C x, Storable.C y) =>
   (Exp acc -> Exp x -> Exp (acc,y)) ->
   Sym.Array sh acc -> Sym.Array (sh, n) x -> IO (Array (sh, n) y)
mapAccumLSimple f arrInit arrData =
   materialize "mapAccumLSimple" (Sym.shape arrData) $
      Priv.mapAccumLSimple f arrInit arrData

scatterMaybe ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1,
    Storable.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   Sym.Array sh0 (Maybe (ix1, a)) -> IO (Array sh1 a)
scatterMaybe accum arrInit arrMap =
   materialize "scatterMaybe" (Sym.shape arrInit) $
      Priv.scatterMaybe accum arrInit arrMap

scatter ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1,
    Storable.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   Sym.Array sh0 (ix1, a) -> IO (Array sh1 a)
scatter accum arrInit arrMap =
   materialize "scatter" (Sym.shape arrInit) $
      Priv.scatter accum arrInit arrMap

permute ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1,
    Storable.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   (Exp ix0 -> Exp ix1) ->
   Sym.Array sh0 a ->
   IO (Array sh1 a)
permute accum deflt ixmap input =
   scatter accum deflt
      (Sym.mapWithIndex (Expr.lift2 MultiValue.zip . ixmap) input)