{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.PTX.CodeGen.Queue
where
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Downcast
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.Target
import LLVM.AST.Type.Constant
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Atomic
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation
import qualified LLVM.AST.Global as LLVM
import qualified LLVM.AST.Linkage as LLVM
import qualified LLVM.AST.Name as LLVM
import qualified LLVM.AST.Type as LLVM
import qualified LLVM.AST.Type.Instruction.RMW as RMW
type WorkQueue = (Operand (Ptr Int32), Operand (Ptr Int32))
globalWorkQueue :: CodeGen WorkQueue
globalWorkQueue = do
sn <- freshName
declare $ LLVM.globalVariableDefaults
{ LLVM.name = LLVM.Name "__queue__"
, LLVM.type' = LLVM.IntegerType 32
, LLVM.alignment = 4
}
declare $ LLVM.globalVariableDefaults
{ LLVM.name = downcast sn
, LLVM.addrSpace = sharedMemAddrSpace
, LLVM.type' = LLVM.IntegerType 32
, LLVM.linkage = LLVM.Internal
, LLVM.alignment = 4
}
return ( ConstantOperand (GlobalReference type' "__queue__")
, ConstantOperand (GlobalReference type' sn) )
dequeue :: WorkQueue -> IR Int32 -> CodeGen (IR Int32)
dequeue (queue, smem) n = do
tid <- threadIdx
when (A.eq scalarType tid (lift 0)) $ do
v <- instr' $ AtomicRMW integralType NonVolatile RMW.Add queue (op integralType n) (CrossThread, AcquireRelease)
_ <- instr' $ Store Volatile smem v
return ()
__syncthreads
v <- instr' $ Load scalarType Volatile smem
return (ir integralType v)
mkQueueInit
:: DeviceProperties
-> CodeGen (IROpenAcc PTX aenv a)
mkQueueInit dev =
let
(start, _end, paramGang) = gangParam
config = launchConfig dev [1] (\_ -> 0) (\_ _ -> 1)
in
makeOpenAccWith config "qinit" paramGang $ do
(queue,_) <- globalWorkQueue
_ <- instr' $ Store Volatile queue (op integralType start)
return_