{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Knead.Symbolic.PhysicalPrivate where

import qualified Data.Array.Knead.Symbolic.Private as Sym
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Code (getElementPtr)

import LLVM.DSL.Expression (Exp, unExp)

import qualified LLVM.Extra.Multi.Value.Storable as Storable
import qualified LLVM.Extra.Multi.Value.Marshal as Marshal
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Control as C

import qualified LLVM.Core as LLVM

import Foreign.Ptr (Ptr, )

import qualified Control.Applicative.HT as App
import Control.Monad.HT (void, )
import Control.Applicative ((<$>), )

import Data.Tuple.HT (mapSnd, )

import Prelude2010
import Prelude ()



type MarshalPtr a = LLVM.Ptr (Marshal.Struct a)

writeArray ::
   (Shape.C sh, Shape.Index sh ~ ix, Storable.C a) =>
   MultiValue.T sh ->
   (MultiValue.T ix -> LLVM.CodeGenFunction r (MultiValue.T a)) ->
   LLVM.Value (Ptr a) ->
   LLVM.CodeGenFunction r (LLVM.Value (Ptr a))
writeArray sh code ptr = do
   let clear ix p = flip Storable.storeNext p =<< code ix
   Shape.loop clear sh ptr


mapAccumLLoop ::
   (MultiValue.C acc, Storable.C b,
    Shape.C sh, Shape.Index sh ~ ix) =>
   (MultiValue.T ix -> LLVM.CodeGenFunction r (MultiValue.T a)) ->
   (Exp acc -> Exp a -> Exp (acc, b)) ->
   MultiValue.T sh ->
   LLVM.Value (Ptr b) -> MultiValue.T acc ->
   LLVM.CodeGenFunction r (LLVM.Value (Ptr b), MultiValue.T acc)
mapAccumLLoop code f n yPtr accInit = do
   let step k0 (ptr0, acc0) = do
         x <- code k0
         (acc1,y) <- MultiValue.unzip <$> Expr.unliftM2 f acc0 x
         ptr1 <- Storable.storeNext y ptr0
         return (ptr1, acc1)
   Shape.loop step n (yPtr, accInit)

mapAccumLSimple ::
   (Shape.C sh, Marshal.C sh,
    Shape.C n, Marshal.C n,
    MultiValue.C acc,
    Storable.C x,
    Storable.C y) =>
   (Exp acc -> Exp x -> Exp (acc,y)) ->
   Sym.Array sh acc -> Sym.Array (sh, n) x ->
   LLVM.Value (MarshalPtr (sh,n)) ->
   LLVM.Value (Ptr y) ->
   LLVM.CodeGenFunction r ()
mapAccumLSimple f (Sym.Array _ initCode) (Sym.Array esh code) sptr ptr = do
   (sh, n) <- MultiValue.unzip <$> Shape.load esh sptr
   let step ix ptrStart = do
         accInit <- initCode ix
         fst <$> mapAccumLLoop (code . MultiValue.zip ix) f n ptrStart accInit
   void $ Shape.loop step sh ptr

mapAccumLSequence ::
   (Shape.C n, Marshal.C n,
    MultiValue.C acc, Storable.C final,
    Storable.C x,
    Storable.C y) =>
   (Exp acc -> Exp x -> Exp (acc,y)) ->
   (Exp acc -> Exp final) ->
   Exp acc -> Sym.Array n x ->
   LLVM.Value (Ptr final) ->
   LLVM.Value (MarshalPtr n) ->
   LLVM.Value (Ptr y) ->
   LLVM.CodeGenFunction r ()
mapAccumLSequence f final initExp (Sym.Array esh code) accPtr sptr yPtr = do
   n <- Shape.load esh sptr
   accInit <- unExp initExp
   accExit <- snd <$> mapAccumLLoop code f n yPtr accInit
   flip Storable.store accPtr =<< Expr.unliftM1 final accExit

mapAccumL ::
   (Shape.C sh, Marshal.C sh,
    Shape.C n, Marshal.C n,
    MultiValue.C acc, Storable.C final,
    Storable.C x,
    Storable.C y) =>
   (Exp acc -> Exp x -> Exp (acc,y)) ->
   (Exp acc -> Exp final) ->
   Sym.Array sh acc -> Sym.Array (sh, n) x ->
   (LLVM.Value (MarshalPtr sh), LLVM.Value (Ptr final)) ->
   (LLVM.Value (MarshalPtr (sh,n)), LLVM.Value (Ptr y)) ->
   LLVM.CodeGenFunction r ()
mapAccumL f final (Sym.Array _ initCode) (Sym.Array esh code)
      (_, accPtr) (sptr, yPtr) = do
   (sh, n) <- MultiValue.unzip <$> Shape.load esh sptr
   let step ix (accPtr0, yPtrStart) = do
         accInit <- initCode ix
         (ptrStop, accExit) <-
            mapAccumLLoop (code . MultiValue.zip ix) f n yPtrStart accInit
         accPtr1 <-
            flip Storable.storeNext accPtr0
               =<< Expr.unliftM1 final accExit
         return (accPtr1, ptrStop)
   void $ Shape.loop step sh (accPtr,yPtr)

foldOuterL ::
   (Shape.C sh, Marshal.C sh,
    Shape.C n, Marshal.C n,
    Storable.C a) =>
   (Exp a -> Exp b -> Exp a) ->
   Sym.Array sh a -> Sym.Array (n,sh) b ->
   LLVM.Value (MarshalPtr sh) ->
   LLVM.Value (Ptr a) ->
   LLVM.CodeGenFunction r ()
foldOuterL f (Sym.Array _ initCode) (Sym.Array esh code) sptr ptr = do
   sh <- Shape.load (Expr.snd esh) sptr
   n <- MultiValue.fst <$> unExp esh
   void $ writeArray sh initCode ptr

   let step k ix ptr0 = do
         b <- code $ MultiValue.zip k ix
         a0 <- Storable.load ptr0
         a1 <- Expr.unliftM2 f a0 b
         Storable.storeNext a1 ptr0
   void $ Shape.loop (\k () -> void $ Shape.loop (step k) sh ptr) n ()

{- |
We need a scalar Shape type @n@.
Scalar Shape types could be distinguished from other Shape types
by the fact that you can convert any Index into a Shape.
-}
mapFilter ::
   (Shape.Sequence n, Marshal.C n,
    Storable.C b) =>
   (Exp a -> Exp b) ->
   (Exp a -> Exp Bool) ->
   Sym.Array n a ->
   LLVM.Value (MarshalPtr n) ->
   LLVM.Value (Ptr b) ->
   LLVM.CodeGenFunction r (MultiValue.T n)
mapFilter f p (Sym.Array esh code) sptr ptr = do
   n <- Shape.load esh sptr
   let step ix (dstPtr,dstIx) = do
         a <- code ix
         MultiValue.Cons c <- Expr.unliftM1 p a
         C.ifThen c (dstPtr,dstIx)
            (App.lift2 (,)
               (flip Storable.storeNext dstPtr =<< Expr.unliftM1 f a)
               (MultiValue.inc dstIx))
   Shape.sequenceShapeFromIndex . snd
      =<< Shape.loop step n (ptr, MultiValue.zero)

filterOuter ::
   (Shape.Sequence n, Marshal.C n,
    Shape.C sh, Marshal.C sh,
    Storable.C a) =>
   Sym.Array n Bool ->
   Sym.Array (n,sh) a ->
   LLVM.Value (MarshalPtr (n,sh)) ->
   LLVM.Value (Ptr a) ->
   LLVM.CodeGenFunction r (MultiValue.T (n,sh))
filterOuter (Sym.Array _eish selectCode) (Sym.Array esh code) sptr ptr = do
   (n,sh) <- MultiValue.unzip <$> Shape.load esh sptr
   let step k (dstPtr0,dstK) = do
         MultiValue.Cons c <- selectCode k
         C.ifThen c (dstPtr0,dstK)
            (do
               dstPtr1 <- writeArray sh (code . MultiValue.zip k) dstPtr0
               (,) dstPtr1 <$> MultiValue.inc dstK)
   finalN <-
      Shape.sequenceShapeFromIndex . snd
         =<< Shape.loop step n (ptr, MultiValue.zero)
   return $ MultiValue.zip finalN sh


scatterMaybe ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Marshal.C sh1,
    Storable.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a -> Sym.Array sh0 (Maybe (ix1, a)) ->
   LLVM.Value (MarshalPtr sh1) ->
   LLVM.Value (Ptr a) ->
   LLVM.CodeGenFunction r ()
scatterMaybe accum (Sym.Array esh codeInit) (Sym.Array eish codeMap)
      sptr ptr = do

   sh <- Shape.load esh sptr
   void $ writeArray sh codeInit ptr

   ish <- unExp eish
   let fill ix () = do
         (MultiValue.Cons c, (jx, a)) <-
            mapSnd MultiValue.unzip . MultiValue.splitMaybe <$> codeMap ix
         C.ifThen c () $ do
            p <- getElementPtr sh ptr jx
            flip Storable.store p
               =<< Expr.unliftM2 (flip accum) a
               =<< Storable.load p
   Shape.loop fill ish ()

scatter ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Marshal.C sh1,
    Storable.C a) =>
   (Exp a -> Exp a -> Exp a) ->
   Sym.Array sh1 a ->
   Sym.Array sh0 (Shape.Index sh1, a) ->
   LLVM.Value (MarshalPtr sh1) ->
   LLVM.Value (Ptr a) ->
   LLVM.CodeGenFunction r ()
scatter accum (Sym.Array esh codeInit) (Sym.Array eish codeMap) sptr ptr = do
   sh <- Shape.load esh sptr
   void $ writeArray sh codeInit ptr

   ish <- unExp eish
   let fill ix () = do
         (jx, a) <- MultiValue.unzip <$> codeMap ix
         p <- getElementPtr sh ptr jx
         flip Storable.store p
            =<< Expr.unliftM2 (flip accum) a
            =<< Storable.load p
   Shape.loop fill ish ()

addDimension ::
   (Shape.C n, Marshal.C n, Shape.Index n ~ k,
    Shape.C sh, Marshal.C sh,
    Storable.C b) =>
   Exp n ->
   (Exp k -> Exp a -> Exp b) ->
   Sym.Array sh a ->
   LLVM.Value (MarshalPtr (sh,n)) ->
   LLVM.Value (Ptr b) ->
   LLVM.CodeGenFunction r ()
addDimension en select (Sym.Array esh code) sptr ptr = do
   (sh,n) <- MultiValue.unzip <$> Shape.load (Expr.zip esh en) sptr

   let fill ix ptr0 = do
         a <- code ix
         writeArray n (\k -> Expr.unliftM2 select k a) ptr0
   void $ Shape.loop fill sh ptr