{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.PTX.CodeGen.Base (
DeviceProperties, KernelMetadata(..),
blockDim, gridDim, threadIdx, blockIdx, warpSize,
gridSize, globalThreadIdx,
gangParam,
laneId, warpId,
laneMask_eq, laneMask_lt, laneMask_le, laneMask_gt, laneMask_ge,
atomicAdd_f,
__syncthreads,
__threadfence_block, __threadfence_grid,
staticSharedMem,
dynamicSharedMem,
sharedMemAddrSpace,
(+++),
makeOpenAcc, makeOpenAccWith,
) where
import LLVM.AST.Type.AddrSpace
import LLVM.AST.Type.Constant
import LLVM.AST.Type.Global
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Metadata
import LLVM.AST.Type.Name
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation
import qualified LLVM.AST.Global as LLVM
import qualified LLVM.AST.Linkage as LLVM
import qualified LLVM.AST.Name as LLVM
import qualified LLVM.AST.Type as LLVM
import Data.Array.Accelerate.Analysis.Type
import Data.Array.Accelerate.Array.Sugar ( Elt, Vector, eltType )
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Downcast
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Module
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Ptr
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.CodeGen.Type
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.Context
import Data.Array.Accelerate.LLVM.PTX.Target
import Control.Applicative
import Control.Monad ( void )
import Text.Printf
import Prelude as P
specialPTXReg :: Label -> CodeGen (IR Int32)
specialPTXReg f =
call (Body type' f) [NoUnwind, ReadNone]
blockDim, gridDim, threadIdx, blockIdx, warpSize :: CodeGen (IR Int32)
blockDim = specialPTXReg "llvm.nvvm.read.ptx.sreg.ntid.x"
gridDim = specialPTXReg "llvm.nvvm.read.ptx.sreg.nctaid.x"
threadIdx = specialPTXReg "llvm.nvvm.read.ptx.sreg.tid.x"
blockIdx = specialPTXReg "llvm.nvvm.read.ptx.sreg.ctaid.x"
warpSize = specialPTXReg "llvm.nvvm.read.ptx.sreg.warpsize"
laneId :: CodeGen (IR Int32)
laneId = specialPTXReg "llvm.nvvm.read.ptx.sreg.laneid"
laneMask_eq, laneMask_lt, laneMask_le, laneMask_gt, laneMask_ge :: CodeGen (IR Int32)
laneMask_eq = specialPTXReg "llvm.nvvm.read.ptx.sreg.lanemask.eq"
laneMask_lt = specialPTXReg "llvm.nvvm.read.ptx.sreg.lanemask.lt"
laneMask_le = specialPTXReg "llvm.nvvm.read.ptx.sreg.lanemask.le"
laneMask_gt = specialPTXReg "llvm.nvvm.read.ptx.sreg.lanemask.gt"
laneMask_ge = specialPTXReg "llvm.nvvm.read.ptx.sreg.lanemask.ge"
warpId :: CodeGen (IR Int32)
warpId = do
tid <- threadIdx
ws <- warpSize
A.quot integralType tid ws
_warpId :: CodeGen (IR Int32)
_warpId = specialPTXReg "llvm.ptx.read.warpid"
gridSize :: CodeGen (IR Int32)
gridSize = do
ncta <- gridDim
nt <- blockDim
mul numType ncta nt
globalThreadIdx :: CodeGen (IR Int32)
globalThreadIdx = do
ntid <- blockDim
ctaid <- blockIdx
tid <- threadIdx
u <- mul numType ntid ctaid
v <- add numType tid u
return v
gangParam :: (IR Int32, IR Int32, [LLVM.Parameter])
gangParam =
let t = scalarType
start = "ix.start"
end = "ix.end"
in
(local t start, local t end, [ scalarParameter t start, scalarParameter t end ] )
barrier :: Label -> CodeGen ()
barrier f = void $ call (Body VoidType f) [NoUnwind, NoDuplicate, Convergent]
__syncthreads :: CodeGen ()
__syncthreads = barrier "llvm.nvvm.barrier0"
__threadfence_block :: CodeGen ()
__threadfence_block = barrier "llvm.nvvm.membar.cta"
__threadfence_grid :: CodeGen ()
__threadfence_grid = barrier "llvm.nvvm.membar.gl"
atomicAdd_f :: FloatingType a -> Operand (Ptr a) -> Operand a -> CodeGen ()
atomicAdd_f t addr val =
let
width :: Int
width =
case t of
TypeFloat{} -> 32
TypeDouble{} -> 64
TypeCFloat{} -> 32
TypeCDouble{} -> 64
addrspace :: Word32
(t_addr, t_val, addrspace) =
case typeOf addr of
PrimType ta@(PtrPrimType (ScalarPrimType tv) (AddrSpace as))
-> (ta, tv, as)
_ -> $internalError "atomicAdd" "unexpected operand type"
t_ret = PrimType (ScalarPrimType t_val)
fun = Label $ printf "llvm.nvvm.atomic.load.add.f%d.p%df%d" width addrspace width
in
void $ call (Lam t_addr addr (Lam (ScalarPrimType t_val) val (Body t_ret fun))) [NoUnwind]
sharedMemAddrSpace :: AddrSpace
sharedMemAddrSpace = AddrSpace 3
sharedMemVolatility :: Volatility
sharedMemVolatility = Volatile
staticSharedMem
:: forall e. Elt e
=> Word64
-> CodeGen (IRArray (Vector e))
staticSharedMem n = do
ad <- go (eltType (undefined::e))
return $ IRArray { irArrayShape = IR (OP_Pair OP_Unit (OP_Int (integral integralType (P.fromIntegral n))))
, irArrayData = IR ad
, irArrayAddrSpace = sharedMemAddrSpace
, irArrayVolatility = sharedMemVolatility
}
where
go :: TupleType s -> CodeGen (Operands s)
go UnitTuple = return OP_Unit
go (PairTuple t1 t2) = OP_Pair <$> go t1 <*> go t2
go tt@(SingleTuple t) = do
nm <- freshName
sm <- return $ ConstantOperand $ GlobalReference (PrimType (PtrPrimType (ArrayType n t) sharedMemAddrSpace)) nm
declare $ LLVM.globalVariableDefaults
{ LLVM.addrSpace = sharedMemAddrSpace
, LLVM.type' = LLVM.ArrayType n (downcast t)
, LLVM.linkage = LLVM.Internal
, LLVM.name = downcast nm
, LLVM.alignment = 4 `P.max` P.fromIntegral (sizeOf tt)
}
p <- instr' $ GetElementPtr sm [num numType 0, num numType 0 :: Operand Int32]
q <- instr' $ PtrCast (PtrPrimType (ScalarPrimType t) sharedMemAddrSpace) p
return $ ir' t (unPtr q)
initialiseDynamicSharedMemory :: CodeGen (Operand (Ptr Word8))
initialiseDynamicSharedMemory = do
declare $ LLVM.globalVariableDefaults
{ LLVM.addrSpace = sharedMemAddrSpace
, LLVM.type' = LLVM.ArrayType 0 (LLVM.IntegerType 8)
, LLVM.linkage = LLVM.External
, LLVM.name = LLVM.Name "__shared__"
, LLVM.alignment = 4
}
return $ ConstantOperand $ GlobalReference type' "__shared__"
dynamicSharedMem
:: forall e int. (Elt e, IsIntegral int)
=> IR int
-> IR int
-> CodeGen (IRArray (Vector e))
dynamicSharedMem n@(op integralType -> m) (op integralType -> offset) = do
smem <- initialiseDynamicSharedMemory
let
go :: TupleType s -> Operand int -> CodeGen (Operand int, Operands s)
go UnitTuple i = return (i, OP_Unit)
go (PairTuple t2 t1) i0 = do
(i1, p1) <- go t1 i0
(i2, p2) <- go t2 i1
return $ (i2, OP_Pair p2 p1)
go (SingleTuple t) i = do
p <- instr' $ GetElementPtr smem [num numType 0, i]
q <- instr' $ PtrCast (PtrPrimType (ScalarPrimType t) sharedMemAddrSpace) p
a <- instr' $ Mul numType m (integral integralType (P.fromIntegral (sizeOf (SingleTuple t))))
b <- instr' $ Add numType i a
return (b, ir' t (unPtr q))
(_, ad) <- go (eltType (undefined::e)) offset
IR sz <- A.fromIntegral integralType (numType :: NumType Int) n
return $ IRArray { irArrayShape = IR $ OP_Pair OP_Unit sz
, irArrayData = IR ad
, irArrayAddrSpace = sharedMemAddrSpace
, irArrayVolatility = sharedMemVolatility
}
data instance KernelMetadata PTX = KM_PTX LaunchConfig
(+++) :: IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
IROpenAcc k1 +++ IROpenAcc k2 = IROpenAcc (k1 ++ k2)
makeOpenAcc
:: PTX
-> Label
-> [LLVM.Parameter]
-> CodeGen ()
-> CodeGen (IROpenAcc PTX aenv a)
makeOpenAcc (deviceProperties . ptxContext -> dev) =
makeOpenAccWith (simpleLaunchConfig dev)
makeOpenAccWith
:: LaunchConfig
-> Label
-> [LLVM.Parameter]
-> CodeGen ()
-> CodeGen (IROpenAcc PTX aenv a)
makeOpenAccWith config name param kernel = do
body <- makeKernel config name param kernel
return $ IROpenAcc [body]
makeKernel :: LaunchConfig -> Label -> [LLVM.Parameter] -> CodeGen () -> CodeGen (Kernel PTX aenv a)
makeKernel config name@(Label l) param kernel = do
_ <- kernel
code <- createBlocks
addMetadata "nvvm.annotations"
[ Just . MetadataOperand $ ConstantOperand (GlobalReference VoidType (Name l))
, Just . MetadataStringOperand $ "kernel"
, Just . MetadataOperand $ scalar scalarType (1::Int)
]
return $ Kernel
{ kernelMetadata = KM_PTX config
, unKernel = LLVM.functionDefaults
{ LLVM.returnType = LLVM.VoidType
, LLVM.name = downcast name
, LLVM.parameters = (param, False)
, LLVM.basicBlocks = code
}
}