{-# LANGUAGE TypeFamilies #-}
module Data.Array.Knead.Simple.PhysicalPrivate where

import qualified Data.Array.Knead.Simple.Private as Sym
import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, unExp)
import Data.Array.Knead.Code (getElementPtr)

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Memory as Memory

import qualified LLVM.Core as LLVM

import Foreign.Storable (Storable, )
import Foreign.Ptr (Ptr, )

import Control.Monad.HT (void, )
import Control.Applicative ((<$>), )

import Data.Tuple.HT (mapSnd, )



mapAccumL ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    MultiValue.C acc,
    Storable x, MultiValueMemory.C x,
    Storable y, MultiValueMemory.C y) =>
   (Exp acc -> Exp x -> Exp (acc,y)) ->
   Sym.Array sh acc -> Sym.Array (sh, n) x ->
   LLVM.Value (Ptr (MultiValueMemory.Struct (sh,n))) ->
   LLVM.Value (Ptr (MultiValueMemory.Struct y)) ->
   LLVM.CodeGenFunction r ()
mapAccumL f (Sym.Array _ initCode) (Sym.Array esh code) sptr ptr = do
   (sh, n) <- MultiValue.unzip <$> Shape.load esh sptr
   let step ix ptrStart = do
         accInit <- initCode ix
         fmap fst $
          (\body -> Shape.loop body n (ptrStart, accInit)) $
                \k0 (ptr0, acc0) -> do
             x <- code $ MultiValue.zip ix k0
             (acc1,y) <- MultiValue.unzip <$> Expr.unliftM2 f acc0 x
             Memory.store y ptr0
             ptr1 <- A.advanceArrayElementPtr ptr0
             return (ptr1, acc1)
   void $ Shape.loop step sh ptr

foldOuterL ::
   (Shape.C sh, Storable sh, MultiValueMemory.C sh,
    Shape.C n, Storable n, MultiValueMemory.C n,
    MultiValueMemory.C a) =>
   (Exp a -> Exp b -> Exp a) ->
   Sym.Array sh a -> Sym.Array (n,sh) b ->
   LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
   LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
   LLVM.CodeGenFunction r ()
foldOuterL f (Sym.Array _ initCode) (Sym.Array esh code) _sptr ptr = do
   -- (n,sh) <- MultiValue.unzip <$> Shape.load esh sptr
   (n,sh) <- MultiValue.unzip <$> unExp esh
   let fillInit ix ptr0 = do
         a <- initCode ix
         Memory.store a ptr0
         A.advanceArrayElementPtr ptr0
   void $ Shape.loop fillInit sh ptr

   let step k ix ptr0 = do
       b <- code $ MultiValue.zip k ix
       a0 <- Memory.load ptr0
       a1 <- Expr.unliftM2 f a0 b
       Memory.store a1 ptr0
       A.advanceArrayElementPtr ptr0
   void $ Shape.loop (\k () -> void $ Shape.loop (step k) sh ptr) n ()

scatterMaybe ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a -> Sym.Array sh0 (Maybe (ix1, a)) ->
   LLVM.Value (Ptr (MultiValueMemory.Struct sh1)) ->
   LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
   LLVM.CodeGenFunction r ()
scatterMaybe accum (Sym.Array esh codeInit) (Sym.Array eish codeMap)
      sptr ptr = do

   let clear ix p = do
         flip Memory.store p =<< codeInit ix
         A.advanceArrayElementPtr p
   sh <- Shape.load esh sptr
   void $ Shape.loop clear sh ptr

   ish <- unExp eish
   let fill ix () = do
         (MultiValue.Cons c, (jx, a)) <-
            mapSnd MultiValue.unzip . MultiValue.splitMaybe <$> codeMap ix
         C.ifThen c () $ do
            p <- getElementPtr sh ptr jx
            flip Memory.store p
               =<< Expr.unliftM2 (flip accum) a
               =<< Memory.load p
   Shape.loop fill ish ()

scatter ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    Storable a, MultiValueMemory.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   Sym.Array sh0 (Shape.Index sh1, a) ->
   LLVM.Value (Ptr (MultiValueMemory.Struct sh1)) ->
   LLVM.Value (Ptr (MultiValueMemory.Struct a)) ->
   LLVM.CodeGenFunction r ()
scatter accum (Sym.Array esh codeInit) (Sym.Array eish codeMap) sptr ptr = do
   let clear ix p = do
         flip Memory.store p =<< codeInit ix
         A.advanceArrayElementPtr p
   sh <- Shape.load esh sptr
   void $ Shape.loop clear sh ptr

   ish <- unExp eish
   let fill ix () = do
         (jx, a) <- fmap MultiValue.unzip $ codeMap ix
         p <- getElementPtr sh ptr jx
         flip Memory.store p
            =<< Expr.unliftM2 (flip accum) a
            =<< Memory.load p
   Shape.loop fill ish ()