{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.PTX.CodeGen.Base (
DeviceProperties, KernelMetadata(..),
blockDim, gridDim, threadIdx, blockIdx, warpSize,
gridSize, globalThreadIdx,
laneId, warpId,
laneMask_eq, laneMask_lt, laneMask_le, laneMask_gt, laneMask_ge,
atomicAdd_f,
nanosleep,
__syncthreads, __syncthreads_count, __syncthreads_and, __syncthreads_or,
__syncwarp, __syncwarp_mask,
__threadfence_block, __threadfence_grid,
staticSharedMem,
dynamicSharedMem,
sharedMemAddrSpace,
(+++),
makeOpenAcc, makeOpenAccWith,
) where
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Module
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Ptr
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.Target
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type
import Foreign.CUDA.Analysis ( Compute(..), computeCapability )
import qualified Foreign.CUDA.Analysis as CUDA
import LLVM.AST.Type.AddrSpace
import LLVM.AST.Type.Constant
import LLVM.AST.Type.Downcast
import LLVM.AST.Type.Function
import LLVM.AST.Type.InlineAssembly
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Metadata
import LLVM.AST.Type.Name
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation
import qualified LLVM.AST.Constant as LLVM hiding ( type' )
import qualified LLVM.AST.Global as LLVM
import qualified LLVM.AST.Linkage as LLVM
import qualified LLVM.AST.Name as LLVM
import qualified LLVM.AST.Type as LLVM
import Control.Applicative
import Control.Monad ( void )
import Control.Monad.State ( gets )
import Prelude as P
#if MIN_VERSION_llvm_hs(10,0,0)
import qualified LLVM.AST.Type.Instruction.RMW as RMW
import LLVM.AST.Type.Instruction.Atomic
#elif !MIN_VERSION_llvm_hs(9,0,0)
import Data.String
import Text.Printf
#endif
specialPTXReg :: Label -> CodeGen PTX (Operands Int32)
specialPTXReg :: Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
f =
GlobalFunction '[] Int32
-> [FunctionAttribute] -> CodeGen PTX (Operands Int32)
forall (args :: [*]) t arch.
GlobalFunction args t
-> [FunctionAttribute] -> CodeGen arch (Operands t)
call (Type Int32 -> Maybe TailCall -> Label -> GlobalFunction '[] Int32
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type Int32
forall a. IsType a => Type a
type' (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) Label
f) [FunctionAttribute
NoUnwind, FunctionAttribute
ReadNone]
blockDim, gridDim, threadIdx, blockIdx, warpSize :: CodeGen PTX (Operands Int32)
blockDim :: CodeGen PTX (Operands Int32)
blockDim = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.ntid.x"
gridDim :: CodeGen PTX (Operands Int32)
gridDim = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.nctaid.x"
threadIdx :: CodeGen PTX (Operands Int32)
threadIdx = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.tid.x"
blockIdx :: CodeGen PTX (Operands Int32)
blockIdx = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.ctaid.x"
warpSize :: CodeGen PTX (Operands Int32)
warpSize = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.warpsize"
laneId :: CodeGen PTX (Operands Int32)
laneId :: CodeGen PTX (Operands Int32)
laneId = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.laneid"
laneMask_eq, laneMask_lt, laneMask_le, laneMask_gt, laneMask_ge :: CodeGen PTX (Operands Int32)
laneMask_eq :: CodeGen PTX (Operands Int32)
laneMask_eq = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.lanemask.eq"
laneMask_lt :: CodeGen PTX (Operands Int32)
laneMask_lt = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.lanemask.lt"
laneMask_le :: CodeGen PTX (Operands Int32)
laneMask_le = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.lanemask.le"
laneMask_gt :: CodeGen PTX (Operands Int32)
laneMask_gt = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.lanemask.gt"
laneMask_ge :: CodeGen PTX (Operands Int32)
laneMask_ge = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.lanemask.ge"
warpId :: CodeGen PTX (Operands Int32)
warpId :: CodeGen PTX (Operands Int32)
warpId = do
DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
Operands Int32
tid <- CodeGen PTX (Operands Int32)
threadIdx
IntegralType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
tid (Int32 -> Operands Int32
A.liftInt32 (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev)))
_warpId :: CodeGen PTX (Operands Int32)
_warpId :: CodeGen PTX (Operands Int32)
_warpId = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.ptx.read.warpid"
gridSize :: CodeGen PTX (Operands Int32)
gridSize :: CodeGen PTX (Operands Int32)
gridSize = do
Operands Int32
ncta <- CodeGen PTX (Operands Int32)
gridDim
Operands Int32
nt <- CodeGen PTX (Operands Int32)
blockDim
NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
ncta Operands Int32
nt
globalThreadIdx :: CodeGen PTX (Operands Int32)
globalThreadIdx :: CodeGen PTX (Operands Int32)
globalThreadIdx = do
Operands Int32
ntid <- CodeGen PTX (Operands Int32)
blockDim
Operands Int32
ctaid <- CodeGen PTX (Operands Int32)
blockIdx
Operands Int32
tid <- CodeGen PTX (Operands Int32)
threadIdx
Operands Int32
u <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
ntid Operands Int32
ctaid
Operands Int32
v <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
add NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
tid Operands Int32
u
Operands Int32 -> CodeGen PTX (Operands Int32)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int32
v
barrier :: Label -> CodeGen PTX ()
barrier :: Label -> CodeGen PTX ()
barrier Label
f = CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands ()) -> CodeGen PTX ())
-> CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ GlobalFunction '[] ()
-> [FunctionAttribute] -> CodeGen PTX (Operands ())
forall (args :: [*]) t arch.
GlobalFunction args t
-> [FunctionAttribute] -> CodeGen arch (Operands t)
call (Type () -> Maybe TailCall -> Label -> GlobalFunction '[] ()
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type ()
VoidType (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) Label
f) [FunctionAttribute
NoUnwind, FunctionAttribute
NoDuplicate, FunctionAttribute
Convergent]
barrier_op :: Label -> Operands Int32 -> CodeGen PTX (Operands Int32)
barrier_op :: Label -> Operands Int32 -> CodeGen PTX (Operands Int32)
barrier_op Label
f Operands Int32
x = GlobalFunction '[Int32] Int32
-> [FunctionAttribute] -> CodeGen PTX (Operands Int32)
forall (args :: [*]) t arch.
GlobalFunction args t
-> [FunctionAttribute] -> CodeGen arch (Operands t)
call (PrimType Int32
-> Operand Int32
-> GlobalFunction '[] Int32
-> GlobalFunction '[Int32] Int32
forall a kind (args1 :: [*]) t.
PrimType a
-> Operand a
-> Function kind args1 t
-> Function kind (a : args1) t
Lam PrimType Int32
forall a. IsPrim a => PrimType a
primType (IntegralType Int32 -> Operands Int32 -> Operand Int32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
x) (Type Int32 -> Maybe TailCall -> Label -> GlobalFunction '[] Int32
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type Int32
forall a. IsType a => Type a
type' (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) Label
f)) [FunctionAttribute
NoUnwind, FunctionAttribute
NoDuplicate, FunctionAttribute
Convergent]
__syncthreads :: CodeGen PTX ()
__syncthreads :: CodeGen PTX ()
__syncthreads = Label -> CodeGen PTX ()
barrier Label
"llvm.nvvm.barrier0"
__syncthreads_count :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_count :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_count = Label -> Operands Int32 -> CodeGen PTX (Operands Int32)
barrier_op Label
"llvm.nvvm.barrier0.popc"
__syncthreads_and :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_and :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_and = Label -> Operands Int32 -> CodeGen PTX (Operands Int32)
barrier_op Label
"llvm.nvvm.barrier0.and"
__syncthreads_or :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_or :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_or = Label -> Operands Int32 -> CodeGen PTX (Operands Int32)
barrier_op Label
"llvm.nvvm.barrier0.or"
__syncwarp :: HasCallStack => CodeGen PTX ()
__syncwarp :: CodeGen PTX ()
__syncwarp = HasCallStack => Operands Word32 -> CodeGen PTX ()
Operands Word32 -> CodeGen PTX ()
__syncwarp_mask (Word32 -> Operands Word32
liftWord32 Word32
0xffffffff)
__syncwarp_mask :: HasCallStack => Operands Word32 -> CodeGen PTX ()
__syncwarp_mask :: Operands Word32 -> CodeGen PTX ()
__syncwarp_mask Operands Word32
mask = do
DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
if DeviceProperties -> Compute
computeCapability DeviceProperties
dev Compute -> Compute -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Int -> Compute
Compute Int
7 Int
0
then () -> CodeGen PTX ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
else
#if !MIN_VERSION_llvm_hs(6,0,0)
internalError "LLVM-6.0 or above is required for Volta devices and later"
#else
CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands ()) -> CodeGen PTX ())
-> CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ GlobalFunction '[Word32] ()
-> [FunctionAttribute] -> CodeGen PTX (Operands ())
forall (args :: [*]) t arch.
GlobalFunction args t
-> [FunctionAttribute] -> CodeGen arch (Operands t)
call (PrimType Word32
-> Operand Word32
-> GlobalFunction '[] ()
-> GlobalFunction '[Word32] ()
forall a kind (args1 :: [*]) t.
PrimType a
-> Operand a
-> Function kind args1 t
-> Function kind (a : args1) t
Lam PrimType Word32
forall a. IsPrim a => PrimType a
primType (PrimType Word32 -> Operands Word32 -> Operand Word32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op PrimType Word32
forall a. IsPrim a => PrimType a
primType Operands Word32
mask) (Type () -> Maybe TailCall -> Label -> GlobalFunction '[] ()
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type ()
VoidType (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) Label
"llvm.nvvm.bar.warp.sync")) [FunctionAttribute
NoUnwind, FunctionAttribute
NoDuplicate, FunctionAttribute
Convergent]
#endif
__threadfence_block :: CodeGen PTX ()
__threadfence_block :: CodeGen PTX ()
__threadfence_block = Label -> CodeGen PTX ()
barrier Label
"llvm.nvvm.membar.cta"
__threadfence_grid :: CodeGen PTX ()
__threadfence_grid :: CodeGen PTX ()
__threadfence_grid = Label -> CodeGen PTX ()
barrier Label
"llvm.nvvm.membar.gl"
atomicAdd_f :: HasCallStack => FloatingType a -> Operand (Ptr a) -> Operand a -> CodeGen PTX ()
atomicAdd_f :: FloatingType a -> Operand (Ptr a) -> Operand a -> CodeGen PTX ()
atomicAdd_f FloatingType a
t Operand (Ptr a)
addr Operand a
val =
#if MIN_VERSION_llvm_hs(10,0,0)
void . instr' $ AtomicRMW (FloatingNumType t) NonVolatile RMW.FAdd addr val (CrossThread, AcquireRelease)
#else
let
_width :: Int
_width :: Int
_width =
case FloatingType a
t of
FloatingType a
TypeHalf -> Int
16
FloatingType a
TypeFloat -> Int
32
FloatingType a
TypeDouble -> Int
64
(PrimType (Ptr a)
t_addr, ScalarType a
t_val, Word32
_addrspace) =
case Operand (Ptr a) -> Type (Ptr a)
forall (f :: * -> *) a. TypeOf f => f a -> Type a
typeOf Operand (Ptr a)
addr of
PrimType ta :: PrimType (Ptr a)
ta@(PtrPrimType (ScalarPrimType ScalarType a1
tv) (AddrSpace Word32
as))
-> (PrimType (Ptr a)
ta, ScalarType a
ScalarType a1
tv, Word32
as)
Type (Ptr a)
_ -> String -> (PrimType (Ptr a), ScalarType a, Word32)
forall a. HasCallStack => String -> a
internalError String
"unexpected operand type"
t_ret :: Type a
t_ret = PrimType a -> Type a
forall a. PrimType a -> Type a
PrimType (ScalarType a -> PrimType a
forall a. ScalarType a -> PrimType a
ScalarPrimType ScalarType a
t_val)
#if MIN_VERSION_llvm_hs(9,0,0) || !MIN_VERSION_llvm_hs(6,0,0)
asm :: InlineAssembly
asm =
case FloatingType a
t of
FloatingType a
TypeHalf -> ByteString
-> ShortByteString -> Bool -> Bool -> Dialect -> InlineAssembly
InlineAssembly ByteString
"atom.add.noftz.f16 $0, [$1], $2;" ShortByteString
"=c,l,c" Bool
True Bool
False Dialect
ATTDialect
FloatingType a
TypeFloat -> ByteString
-> ShortByteString -> Bool -> Bool -> Dialect -> InlineAssembly
InlineAssembly ByteString
"atom.global.add.f32 $0, [$1], $2;" ShortByteString
"=f,l,f" Bool
True Bool
False Dialect
ATTDialect
FloatingType a
TypeDouble -> ByteString
-> ShortByteString -> Bool -> Bool -> Dialect -> InlineAssembly
InlineAssembly ByteString
"atom.global.add.f64 $0, [$1], $2;" ShortByteString
"=d,l,d" Bool
True Bool
False Dialect
ATTDialect
in
CodeGen PTX (Operands a) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands a) -> CodeGen PTX ())
-> CodeGen PTX (Operands a) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ Instruction a -> CodeGen PTX (Operands a)
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operands a)
instr (Function (Either InlineAssembly Label) '[Ptr a, a] a
-> [Either GroupID FunctionAttribute] -> Instruction a
forall (args :: [*]) a.
Function (Either InlineAssembly Label) args a
-> [Either GroupID FunctionAttribute] -> Instruction a
Call (PrimType (Ptr a)
-> Operand (Ptr a)
-> Function (Either InlineAssembly Label) '[a] a
-> Function (Either InlineAssembly Label) '[Ptr a, a] a
forall a kind (args1 :: [*]) t.
PrimType a
-> Operand a
-> Function kind args1 t
-> Function kind (a : args1) t
Lam PrimType (Ptr a)
t_addr Operand (Ptr a)
addr (PrimType a
-> Operand a
-> Function (Either InlineAssembly Label) '[] a
-> Function (Either InlineAssembly Label) '[a] a
forall a kind (args1 :: [*]) t.
PrimType a
-> Operand a
-> Function kind args1 t
-> Function kind (a : args1) t
Lam (ScalarType a -> PrimType a
forall a. ScalarType a -> PrimType a
ScalarPrimType ScalarType a
t_val) Operand a
val (Type a
-> Maybe TailCall
-> Either InlineAssembly Label
-> Function (Either InlineAssembly Label) '[] a
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type a
t_ret (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) (InlineAssembly -> Either InlineAssembly Label
forall a b. a -> Either a b
Left InlineAssembly
asm)))) [FunctionAttribute -> Either GroupID FunctionAttribute
forall a b. b -> Either a b
Right FunctionAttribute
NoUnwind])
#else
fun = fromString $ printf "llvm.nvvm.atomic.load.add.f%d.p%df%d" _width (_addrspace :: Word32) _width
in
void $ call (Lam t_addr addr (Lam (ScalarPrimType t_val) val (Body t_ret (Just Tail) fun))) [NoUnwind]
#endif
#endif
sharedMemAddrSpace :: AddrSpace
sharedMemAddrSpace :: AddrSpace
sharedMemAddrSpace = Word32 -> AddrSpace
AddrSpace Word32
3
sharedMemVolatility :: Volatility
sharedMemVolatility :: Volatility
sharedMemVolatility = Volatility
Volatile
staticSharedMem
:: TypeR e
-> Word64
-> CodeGen PTX (IRArray (Vector e))
staticSharedMem :: TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
staticSharedMem TypeR e
tp Word64
n = do
Operands e
ad <- TypeR e -> CodeGen PTX (Operands e)
forall s. TypeR s -> CodeGen PTX (Operands s)
go TypeR e
tp
IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return (IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e)))
-> IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e))
forall a b. (a -> b) -> a -> b
$ IRArray :: forall sh e.
ArrayR (Array sh e)
-> Operands sh
-> Operands e
-> AddrSpace
-> Volatility
-> IRArray (Array sh e)
IRArray { irArrayRepr :: ArrayR (Vector e)
irArrayRepr = ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp
, irArrayShape :: Operands DIM1
irArrayShape = Operands () -> Operands Int -> Operands DIM1
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair Operands ()
OP_Unit (Operands Int -> Operands DIM1) -> Operands Int -> Operands DIM1
forall a b. (a -> b) -> a -> b
$ Operand Int -> Operands Int
OP_Int (Operand Int -> Operands Int) -> Operand Int -> Operands Int
forall a b. (a -> b) -> a -> b
$ IntegralType Int -> Int -> Operand Int
forall a. IntegralType a -> a -> Operand a
integral IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType (Int -> Operand Int) -> Int -> Operand Int
forall a b. (a -> b) -> a -> b
$ Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral Word64
n
, irArrayData :: Operands e
irArrayData = Operands e
ad
, irArrayAddrSpace :: AddrSpace
irArrayAddrSpace = AddrSpace
sharedMemAddrSpace
, irArrayVolatility :: Volatility
irArrayVolatility = Volatility
sharedMemVolatility
}
where
go :: TypeR s -> CodeGen PTX (Operands s)
go :: TypeR s -> CodeGen PTX (Operands s)
go TypeR s
TupRunit = Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return Operands ()
OP_Unit
go (TupRpair TupR ScalarType a1
t1 TupR ScalarType b
t2) = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (Operands a1 -> Operands b -> Operands (a1, b))
-> CodeGen PTX (Operands a1)
-> CodeGen PTX (Operands b -> Operands (a1, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TupR ScalarType a1 -> CodeGen PTX (Operands a1)
forall s. TypeR s -> CodeGen PTX (Operands s)
go TupR ScalarType a1
t1 CodeGen PTX (Operands b -> Operands (a1, b))
-> CodeGen PTX (Operands b) -> CodeGen PTX (Operands (a1, b))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TupR ScalarType b -> CodeGen PTX (Operands b)
forall s. TypeR s -> CodeGen PTX (Operands s)
go TupR ScalarType b
t2
go tt :: TypeR s
tt@(TupRsingle ScalarType s
t) = do
Name (Ptr s)
nm <- CodeGen PTX (Name (Ptr s))
forall arch a. CodeGen arch (Name a)
freshName
Operand (Ptr s)
sm <- Operand (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operand (Ptr s) -> CodeGen PTX (Operand (Ptr s)))
-> Operand (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a b. (a -> b) -> a -> b
$ Constant (Ptr s) -> Operand (Ptr s)
forall a. Constant a -> Operand a
ConstantOperand (Constant (Ptr s) -> Operand (Ptr s))
-> Constant (Ptr s) -> Operand (Ptr s)
forall a b. (a -> b) -> a -> b
$ Type (Ptr s) -> Name (Ptr s) -> Constant (Ptr s)
forall a. Type a -> Name a -> Constant a
GlobalReference (PrimType (Ptr s) -> Type (Ptr s)
forall a. PrimType a -> Type a
PrimType (PrimType s -> AddrSpace -> PrimType (Ptr s)
forall a1. PrimType a1 -> AddrSpace -> PrimType (Ptr a1)
PtrPrimType (Word64 -> ScalarType s -> PrimType s
forall a. Word64 -> ScalarType a -> PrimType a
ArrayPrimType Word64
n ScalarType s
t) AddrSpace
sharedMemAddrSpace)) Name (Ptr s)
nm
Global -> CodeGen PTX ()
forall arch. HasCallStack => Global -> CodeGen arch ()
declare (Global -> CodeGen PTX ()) -> Global -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ Global
LLVM.globalVariableDefaults
{ addrSpace :: AddrSpace
LLVM.addrSpace = AddrSpace
sharedMemAddrSpace
, type' :: Type
LLVM.type' = Word64 -> Type -> Type
LLVM.ArrayType Word64
n (ScalarType s -> Type
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast ScalarType s
t)
, linkage :: Linkage
LLVM.linkage = Linkage
LLVM.External
, name :: Name
LLVM.name = Name (Ptr s) -> Name
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast Name (Ptr s)
nm
, alignment :: Word32
LLVM.alignment = Word32
4 Word32 -> Word32 -> Word32
forall a. Ord a => a -> a -> a
`P.max` Int -> Word32
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (TypeR s -> Int
forall e. TypeR e -> Int
bytesElt TypeR s
tt)
}
Operand (Ptr s)
p <- Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s)))
-> Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a b. (a -> b) -> a -> b
$ Operand (Ptr s) -> [Operand Int32] -> Instruction (Ptr s)
forall a1 i.
Operand (Ptr a1) -> [Operand i] -> Instruction (Ptr a1)
GetElementPtr Operand (Ptr s)
sm [NumType Int32 -> Int32 -> Operand Int32
forall a. NumType a -> a -> Operand a
num NumType Int32
forall a. IsNum a => NumType a
numType Int32
0, NumType Int32 -> Int32 -> Operand Int32
forall a. NumType a -> a -> Operand a
num NumType Int32
forall a. IsNum a => NumType a
numType Int32
0 :: Operand Int32]
Operand (Ptr s)
q <- Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s)))
-> Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a b. (a -> b) -> a -> b
$ PrimType (Ptr s) -> Operand (Ptr s) -> Instruction (Ptr s)
forall b a1.
PrimType (Ptr b) -> Operand (Ptr a1) -> Instruction (Ptr b)
PtrCast (PrimType s -> AddrSpace -> PrimType (Ptr s)
forall a1. PrimType a1 -> AddrSpace -> PrimType (Ptr a1)
PtrPrimType (ScalarType s -> PrimType s
forall a. ScalarType a -> PrimType a
ScalarPrimType ScalarType s
t) AddrSpace
sharedMemAddrSpace) Operand (Ptr s)
p
Operands s -> CodeGen PTX (Operands s)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands s -> CodeGen PTX (Operands s))
-> Operands s -> CodeGen PTX (Operands s)
forall a b. (a -> b) -> a -> b
$ ScalarType s -> Operand s -> Operands s
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType s
t (Operand (Ptr s) -> Operand s
forall t. HasCallStack => Operand (Ptr t) -> Operand t
unPtr Operand (Ptr s)
q)
initialiseDynamicSharedMemory :: CodeGen PTX (Operand (Ptr Word8))
initialiseDynamicSharedMemory :: CodeGen PTX (Operand (Ptr Word8))
initialiseDynamicSharedMemory = do
Global -> CodeGen PTX ()
forall arch. HasCallStack => Global -> CodeGen arch ()
declare (Global -> CodeGen PTX ()) -> Global -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ Global
LLVM.globalVariableDefaults
{ addrSpace :: AddrSpace
LLVM.addrSpace = AddrSpace
sharedMemAddrSpace
, type' :: Type
LLVM.type' = Word64 -> Type -> Type
LLVM.ArrayType Word64
0 (Word32 -> Type
LLVM.IntegerType Word32
8)
, linkage :: Linkage
LLVM.linkage = Linkage
LLVM.External
, name :: Name
LLVM.name = ShortByteString -> Name
LLVM.Name ShortByteString
"__shared__"
, alignment :: Word32
LLVM.alignment = Word32
4
}
Operand (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operand (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8)))
-> Operand (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8))
forall a b. (a -> b) -> a -> b
$ Constant (Ptr Word8) -> Operand (Ptr Word8)
forall a. Constant a -> Operand a
ConstantOperand (Constant (Ptr Word8) -> Operand (Ptr Word8))
-> Constant (Ptr Word8) -> Operand (Ptr Word8)
forall a b. (a -> b) -> a -> b
$ Type (Ptr Word8) -> Name (Ptr Word8) -> Constant (Ptr Word8)
forall a. Type a -> Name a -> Constant a
GlobalReference (PrimType (Ptr Word8) -> Type (Ptr Word8)
forall a. PrimType a -> Type a
PrimType (PrimType Word8 -> AddrSpace -> PrimType (Ptr Word8)
forall a1. PrimType a1 -> AddrSpace -> PrimType (Ptr a1)
PtrPrimType (Word64 -> ScalarType Word8 -> PrimType Word8
forall a. Word64 -> ScalarType a -> PrimType a
ArrayPrimType Word64
0 ScalarType Word8
forall a. IsScalar a => ScalarType a
scalarType) AddrSpace
sharedMemAddrSpace)) Name (Ptr Word8)
"__shared__"
dynamicSharedMem
:: forall e int.
TypeR e
-> IntegralType int
-> Operands int
-> Operands int
-> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem :: TypeR e
-> IntegralType int
-> Operands int
-> Operands int
-> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem TypeR e
tp IntegralType int
int n :: Operands int
n@(IntegralType int -> Operands int -> Operand int
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType int
int -> Operand int
m) (IntegralType int -> Operands int -> Operand int
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType int
int -> Operand int
offset)
| IntegralDict int
IntegralDict <- IntegralType int -> IntegralDict int
forall a. IntegralType a -> IntegralDict a
integralDict IntegralType int
int = do
Operand (Ptr Word8)
smem <- CodeGen PTX (Operand (Ptr Word8))
initialiseDynamicSharedMemory
let
numTp :: NumType int
numTp = IntegralType int -> NumType int
forall a. IntegralType a -> NumType a
IntegralNumType IntegralType int
int
go :: TypeR s -> Operand int -> CodeGen PTX (Operand int, Operands s)
go :: TypeR s -> Operand int -> CodeGen PTX (Operand int, Operands s)
go TypeR s
TupRunit Operand int
i = (Operand int, Operands ())
-> CodeGen PTX (Operand int, Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Operand int
i, Operands ()
OP_Unit)
go (TupRpair TupR ScalarType a1
t2 TupR ScalarType b
t1) Operand int
i0 = do
(Operand int
i1, Operands b
p1) <- TupR ScalarType b
-> Operand int -> CodeGen PTX (Operand int, Operands b)
forall s.
TypeR s -> Operand int -> CodeGen PTX (Operand int, Operands s)
go TupR ScalarType b
t1 Operand int
i0
(Operand int
i2, Operands a1
p2) <- TupR ScalarType a1
-> Operand int -> CodeGen PTX (Operand int, Operands a1)
forall s.
TypeR s -> Operand int -> CodeGen PTX (Operand int, Operands s)
go TupR ScalarType a1
t2 Operand int
i1
(Operand int, Operands (a1, b))
-> CodeGen PTX (Operand int, Operands (a1, b))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Operand int, Operands (a1, b))
-> CodeGen PTX (Operand int, Operands (a1, b)))
-> (Operand int, Operands (a1, b))
-> CodeGen PTX (Operand int, Operands (a1, b))
forall a b. (a -> b) -> a -> b
$ (Operand int
i2, Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair Operands a1
p2 Operands b
p1)
go (TupRsingle ScalarType s
t) Operand int
i = do
Operand (Ptr Word8)
p <- Instruction (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8))
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8)))
-> Instruction (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8))
forall a b. (a -> b) -> a -> b
$ Operand (Ptr Word8) -> [Operand int] -> Instruction (Ptr Word8)
forall a1 i.
Operand (Ptr a1) -> [Operand i] -> Instruction (Ptr a1)
GetElementPtr Operand (Ptr Word8)
smem [NumType int -> int -> Operand int
forall a. NumType a -> a -> Operand a
num NumType int
numTp int
0, Operand int
i]
Operand (Ptr s)
q <- Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s)))
-> Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a b. (a -> b) -> a -> b
$ PrimType (Ptr s) -> Operand (Ptr Word8) -> Instruction (Ptr s)
forall b a1.
PrimType (Ptr b) -> Operand (Ptr a1) -> Instruction (Ptr b)
PtrCast (PrimType s -> AddrSpace -> PrimType (Ptr s)
forall a1. PrimType a1 -> AddrSpace -> PrimType (Ptr a1)
PtrPrimType (ScalarType s -> PrimType s
forall a. ScalarType a -> PrimType a
ScalarPrimType ScalarType s
t) AddrSpace
sharedMemAddrSpace) Operand (Ptr Word8)
p
Operand int
a <- Instruction int -> CodeGen PTX (Operand int)
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction int -> CodeGen PTX (Operand int))
-> Instruction int -> CodeGen PTX (Operand int)
forall a b. (a -> b) -> a -> b
$ NumType int -> Operand int -> Operand int -> Instruction int
forall a. NumType a -> Operand a -> Operand a -> Instruction a
Mul NumType int
numTp Operand int
m (IntegralType int -> int -> Operand int
forall a. IntegralType a -> a -> Operand a
integral IntegralType int
int (Int -> int
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (TypeR s -> Int
forall e. TypeR e -> Int
bytesElt (ScalarType s -> TypeR s
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType s
t))))
Operand int
b <- Instruction int -> CodeGen PTX (Operand int)
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction int -> CodeGen PTX (Operand int))
-> Instruction int -> CodeGen PTX (Operand int)
forall a b. (a -> b) -> a -> b
$ NumType int -> Operand int -> Operand int -> Instruction int
forall a. NumType a -> Operand a -> Operand a -> Instruction a
Add NumType int
numTp Operand int
i Operand int
a
(Operand int, Operands s) -> CodeGen PTX (Operand int, Operands s)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operand int
b, ScalarType s -> Operand s -> Operands s
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType s
t (Operand (Ptr s) -> Operand s
forall t. HasCallStack => Operand (Ptr t) -> Operand t
unPtr Operand (Ptr s)
q))
(Operand int
_, Operands e
ad) <- TypeR e -> Operand int -> CodeGen PTX (Operand int, Operands e)
forall s.
TypeR s -> Operand int -> CodeGen PTX (Operand int, Operands s)
go TypeR e
tp Operand int
offset
Operands Int
sz <- IntegralType int
-> NumType Int -> Operands int -> CodeGen PTX (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType int
int (NumType Int
forall a. IsNum a => NumType a
numType :: NumType Int) Operands int
n
IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return (IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e)))
-> IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e))
forall a b. (a -> b) -> a -> b
$ IRArray :: forall sh e.
ArrayR (Array sh e)
-> Operands sh
-> Operands e
-> AddrSpace
-> Volatility
-> IRArray (Array sh e)
IRArray { irArrayRepr :: ArrayR (Vector e)
irArrayRepr = ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp
, irArrayShape :: Operands DIM1
irArrayShape = Operands () -> Operands Int -> Operands DIM1
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair Operands ()
OP_Unit Operands Int
sz
, irArrayData :: Operands e
irArrayData = Operands e
ad
, irArrayAddrSpace :: AddrSpace
irArrayAddrSpace = AddrSpace
sharedMemAddrSpace
, irArrayVolatility :: Volatility
irArrayVolatility = Volatility
sharedMemVolatility
}
nanosleep :: Operands Int32 -> CodeGen PTX ()
nanosleep :: Operands Int32 -> CodeGen PTX ()
nanosleep Operands Int32
ns =
let
attrs :: [FunctionAttribute]
attrs = [FunctionAttribute
NoUnwind, FunctionAttribute
Convergent]
asm :: InlineAssembly
asm = ByteString
-> ShortByteString -> Bool -> Bool -> Dialect -> InlineAssembly
InlineAssembly ByteString
"nanosleep.u32 $0;" ShortByteString
"r" Bool
True Bool
False Dialect
ATTDialect
in
CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands ()) -> CodeGen PTX ())
-> CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ Instruction () -> CodeGen PTX (Operands ())
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operands a)
instr (Function (Either InlineAssembly Label) '[Int32] ()
-> [Either GroupID FunctionAttribute] -> Instruction ()
forall (args :: [*]) a.
Function (Either InlineAssembly Label) args a
-> [Either GroupID FunctionAttribute] -> Instruction a
Call (PrimType Int32
-> Operand Int32
-> Function (Either InlineAssembly Label) '[] ()
-> Function (Either InlineAssembly Label) '[Int32] ()
forall a kind (args1 :: [*]) t.
PrimType a
-> Operand a
-> Function kind args1 t
-> Function kind (a : args1) t
Lam PrimType Int32
forall a. IsPrim a => PrimType a
primType (IntegralType Int32 -> Operands Int32 -> Operand Int32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
ns) (Type ()
-> Maybe TailCall
-> Either InlineAssembly Label
-> Function (Either InlineAssembly Label) '[] ()
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type ()
VoidType (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) (InlineAssembly -> Either InlineAssembly Label
forall a b. a -> Either a b
Left InlineAssembly
asm))) ((FunctionAttribute -> Either GroupID FunctionAttribute)
-> [FunctionAttribute] -> [Either GroupID FunctionAttribute]
forall a b. (a -> b) -> [a] -> [b]
map FunctionAttribute -> Either GroupID FunctionAttribute
forall a b. b -> Either a b
Right [FunctionAttribute]
attrs))
data instance KernelMetadata PTX = KM_PTX LaunchConfig
(+++) :: IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
IROpenAcc [Kernel PTX aenv a]
k1 +++ :: IROpenAcc PTX aenv a
-> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
+++ IROpenAcc [Kernel PTX aenv a]
k2 = [Kernel PTX aenv a] -> IROpenAcc PTX aenv a
forall arch aenv arrs.
[Kernel arch aenv arrs] -> IROpenAcc arch aenv arrs
IROpenAcc ([Kernel PTX aenv a]
k1 [Kernel PTX aenv a] -> [Kernel PTX aenv a] -> [Kernel PTX aenv a]
forall a. [a] -> [a] -> [a]
++ [Kernel PTX aenv a]
k2)
makeOpenAcc
:: Label
-> [LLVM.Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAcc :: Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAcc Label
name [Parameter]
param CodeGen PTX ()
kernel = do
DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith (DeviceProperties -> LaunchConfig
simpleLaunchConfig DeviceProperties
dev) Label
name [Parameter]
param CodeGen PTX ()
kernel
makeOpenAccWith
:: LaunchConfig
-> Label
-> [LLVM.Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith :: LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
name [Parameter]
param CodeGen PTX ()
kernel = do
Kernel PTX aenv a
body <- LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (Kernel PTX aenv a)
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (Kernel PTX aenv a)
makeKernel LaunchConfig
config Label
name [Parameter]
param CodeGen PTX ()
kernel
IROpenAcc PTX aenv a -> CodeGen PTX (IROpenAcc PTX aenv a)
forall (m :: * -> *) a. Monad m => a -> m a
return (IROpenAcc PTX aenv a -> CodeGen PTX (IROpenAcc PTX aenv a))
-> IROpenAcc PTX aenv a -> CodeGen PTX (IROpenAcc PTX aenv a)
forall a b. (a -> b) -> a -> b
$ [Kernel PTX aenv a] -> IROpenAcc PTX aenv a
forall arch aenv arrs.
[Kernel arch aenv arrs] -> IROpenAcc arch aenv arrs
IROpenAcc [Kernel PTX aenv a
body]
makeKernel
:: LaunchConfig
-> Label
-> [LLVM.Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (Kernel PTX aenv a)
makeKernel :: LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (Kernel PTX aenv a)
makeKernel LaunchConfig
config name :: Label
name@(Label ShortByteString
l) [Parameter]
param CodeGen PTX ()
kernel = do
()
_ <- CodeGen PTX ()
kernel
[BasicBlock]
code <- CodeGen PTX [BasicBlock]
forall arch. HasCallStack => CodeGen arch [BasicBlock]
createBlocks
ShortByteString -> [Maybe Metadata] -> CodeGen PTX ()
forall arch. ShortByteString -> [Maybe Metadata] -> CodeGen arch ()
addMetadata ShortByteString
"nvvm.annotations"
[ Metadata -> Maybe Metadata
forall a. a -> Maybe a
Just (Metadata -> Maybe Metadata)
-> (Constant -> Metadata) -> Constant -> Maybe Metadata
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constant -> Metadata
MetadataConstantOperand (Constant -> Maybe Metadata) -> Constant -> Maybe Metadata
forall a b. (a -> b) -> a -> b
$ Type -> Name -> Constant
LLVM.GlobalReference (Type -> AddrSpace -> Type
LLVM.PointerType (Type -> [Type] -> Bool -> Type
LLVM.FunctionType Type
LLVM.VoidType [ Type
t | LLVM.Parameter Type
t Name
_ [ParameterAttribute]
_ <- [Parameter]
param ] Bool
False) (Word32 -> AddrSpace
AddrSpace Word32
0)) (ShortByteString -> Name
LLVM.Name ShortByteString
l)
, Metadata -> Maybe Metadata
forall a. a -> Maybe a
Just (Metadata -> Maybe Metadata)
-> (ShortByteString -> Metadata)
-> ShortByteString
-> Maybe Metadata
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShortByteString -> Metadata
MetadataStringOperand (ShortByteString -> Maybe Metadata)
-> ShortByteString -> Maybe Metadata
forall a b. (a -> b) -> a -> b
$ ShortByteString
"kernel"
, Metadata -> Maybe Metadata
forall a. a -> Maybe a
Just (Metadata -> Maybe Metadata)
-> (Constant -> Metadata) -> Constant -> Maybe Metadata
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constant -> Metadata
MetadataConstantOperand (Constant -> Maybe Metadata) -> Constant -> Maybe Metadata
forall a b. (a -> b) -> a -> b
$ Word32 -> Integer -> Constant
LLVM.Int Word32
32 Integer
1
]
Kernel PTX aenv a -> CodeGen PTX (Kernel PTX aenv a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Kernel PTX aenv a -> CodeGen PTX (Kernel PTX aenv a))
-> Kernel PTX aenv a -> CodeGen PTX (Kernel PTX aenv a)
forall a b. (a -> b) -> a -> b
$ Kernel :: forall arch aenv a.
Global -> KernelMetadata arch -> Kernel arch aenv a
Kernel
{ kernelMetadata :: KernelMetadata PTX
kernelMetadata = LaunchConfig -> KernelMetadata PTX
KM_PTX LaunchConfig
config
, unKernel :: Global
unKernel = Global
LLVM.functionDefaults
{ returnType :: Type
LLVM.returnType = Type
LLVM.VoidType
, name :: Name
LLVM.name = Label -> Name
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast Label
name
, parameters :: ([Parameter], Bool)
LLVM.parameters = ([Parameter]
param, Bool
False)
, basicBlocks :: [BasicBlock]
LLVM.basicBlocks = [BasicBlock]
code
}
}