{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
module Data.Array.Accelerate.LLVM.Native.CodeGen.Permute
  where
import Data.Array.Accelerate.Array.Sugar                            ( Array, Vector, Shape, Elt, eltType )
import Data.Array.Accelerate.Error
import qualified Data.Array.Accelerate.Array.Sugar                  as S
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Permute
import Data.Array.Accelerate.LLVM.CodeGen.Ptr
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Native.Target                     ( Native )
import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import LLVM.AST.Type.AddrSpace
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Atomic
import LLVM.AST.Type.Instruction.RMW                                as RMW
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Representation
import Control.Applicative
import Control.Monad                                                ( void )
import Data.Typeable
import Prelude
mkPermute
    :: (Shape sh, Shape sh', Elt e)
    => Gamma aenv
    -> IRPermuteFun Native aenv (e -> e -> e)
    -> IRFun1       Native aenv (sh -> sh')
    -> IRDelayed    Native aenv (Array sh e)
    -> CodeGen (IROpenAcc Native aenv (Array sh' e))
mkPermute aenv combine project arr =
  (+++) <$> mkPermuteS aenv combine project arr
        <*> mkPermuteP aenv combine project arr
mkPermuteS
    :: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
    => Gamma aenv
    -> IRPermuteFun Native aenv (e -> e -> e)
    -> IRFun1       Native aenv (sh -> sh')
    -> IRDelayed    Native aenv (Array sh e)
    -> CodeGen (IROpenAcc Native aenv (Array sh' e))
mkPermuteS aenv IRPermuteFun{..} project IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Array sh' e))
      paramEnv                  = envParam aenv
  in
  makeOpenAcc "permuteS" (paramGang ++ paramOut ++ paramEnv) $ do
    sh <- delayedExtent
    imapFromTo start end $ \i -> do
      ix  <- indexOfInt sh i
      ix' <- app1 project ix
      unless (ignore ix') $ do
        j <- intOfIndex (irArrayShape arrOut) ix'
        
        x <- app1 delayedLinearIndex i
        y <- readArray arrOut j
        r <- app2 combine x y
        writeArray arrOut j r
    return_
mkPermuteP
    :: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
    => Gamma aenv
    -> IRPermuteFun Native aenv (e -> e -> e)
    -> IRFun1       Native aenv (sh -> sh')
    -> IRDelayed    Native aenv (Array sh e)
    -> CodeGen (IROpenAcc Native aenv (Array sh' e))
mkPermuteP aenv IRPermuteFun{..} project arr =
  case atomicRMW of
    Nothing       -> mkPermuteP_mutex aenv combine project arr
    Just (rmw, f) -> mkPermuteP_rmw   aenv rmw f   project arr
mkPermuteP_rmw
    :: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
    => Gamma aenv
    -> RMWOperation
    -> IRFun1    Native aenv (e -> e)
    -> IRFun1    Native aenv (sh -> sh')
    -> IRDelayed Native aenv (Array sh e)
    -> CodeGen (IROpenAcc Native aenv (Array sh' e))
mkPermuteP_rmw aenv rmw update project IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Array sh' e))
      paramEnv                  = envParam aenv
  in
  makeOpenAcc "permuteP_rmw" (paramGang ++ paramOut ++ paramEnv) $ do
    sh <- delayedExtent
    imapFromTo start end $ \i -> do
      ix  <- indexOfInt sh i
      ix' <- app1 project ix
      unless (ignore ix') $ do
        j <- intOfIndex (irArrayShape arrOut) ix'
        x <- app1 delayedLinearIndex i
        r <- app1 update x
        case rmw of
          Exchange
            -> writeArray arrOut j r
          
          _ | SingleTuple s <- eltType (undefined::e)
            , Just adata    <- gcast (irArrayData arrOut)
            , Just r'       <- gcast r
            -> do
                  addr <- instr' $ GetElementPtr (asPtr defaultAddrSpace (op s adata)) [op integralType j]
                  
                  case s of
                    NumScalarType (IntegralNumType t) -> void . instr' $ AtomicRMW t NonVolatile rmw addr (op t r') (CrossThread, AcquireRelease)
                    NumScalarType t | RMW.Add <- rmw  -> atomicCAS_rmw s (A.add t r') addr
                    NumScalarType t | RMW.Sub <- rmw  -> atomicCAS_rmw s (A.sub t r') addr
                    _ -> case rmw of
                           RMW.Min                    -> atomicCAS_cmp s A.lt addr (op s r')
                           RMW.Max                    -> atomicCAS_cmp s A.gt addr (op s r')
                           _                          -> $internalError "mkPermute_rmw" "unexpected transition"
          
          _ -> $internalError "mkPermute_rmw" "unexpected transition"
    return_
mkPermuteP_mutex
    :: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
    => Gamma aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRFun1    Native aenv (sh -> sh')
    -> IRDelayed Native aenv (Array sh e)
    -> CodeGen (IROpenAcc Native aenv (Array sh' e))
mkPermuteP_mutex aenv combine project IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out"  :: Name (Array sh' e))
      (arrLock, paramLock)      = mutableArray ("lock" :: Name (Vector Word8))
      paramEnv                  = envParam aenv
  in
  makeOpenAcc "permuteP_mutex" (paramGang ++ paramOut ++ paramLock ++ paramEnv) $ do
    sh <- delayedExtent
    imapFromTo start end $ \i -> do
      ix  <- indexOfInt sh i
      ix' <- app1 project ix
      
      unless (ignore ix') $ do
        j <- intOfIndex (irArrayShape arrOut) ix'
        x <- app1 delayedLinearIndex i
        atomically arrLock j $ do
          y <- readArray arrOut j
          r <- app2 combine x y
          writeArray arrOut j r
    return_
atomically
    :: IRArray (Vector Word8)
    -> IR Int
    -> CodeGen a
    -> CodeGen a
atomically barriers i action = do
  let
      lock      = integral integralType 1
      unlock    = integral integralType 0
      unlocked  = lift 0
  
  spin <- newBlock "spinlock.entry"
  crit <- newBlock "spinlock.critical-section"
  exit <- newBlock "spinlock.exit"
  addr <- instr' $ GetElementPtr (asPtr defaultAddrSpace (op integralType (irArrayData barriers))) [op integralType i]
  _    <- br spin
  
  
  
  setBlock spin
  old  <- instr $ AtomicRMW integralType NonVolatile Exchange addr lock   (CrossThread, Acquire)
  ok   <- A.eq scalarType old unlocked
  _    <- cbr ok crit spin
  
  
  
  
  setBlock crit
  r    <- action
  _    <- instr $ AtomicRMW integralType NonVolatile Exchange addr unlock (CrossThread, Release)
  _    <- br exit
  setBlock exit
  return r
ignore :: forall ix. Shape ix => IR ix -> CodeGen (IR Bool)
ignore (IR ix) = go (S.eltType (undefined::ix)) (S.fromElt (S.ignore::ix)) ix
  where
    go :: TupleType t -> t -> Operands t -> CodeGen (IR Bool)
    go UnitTuple           ()          OP_Unit        = return (lift True)
    go (PairTuple tsh tsz) (ish, isz) (OP_Pair sh sz) = do x <- go tsh ish sh
                                                           y <- go tsz isz sz
                                                           land' x y
    go (SingleTuple t)     ig         sz              = A.eq t (ir t (scalar t ig)) (ir t (op' t sz))