{-# LANGUAGE CPP                 #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE ViewPatterns        #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.CodeGen.Base (

  -- Types
  DeviceProperties, KernelMetadata(..),

  -- Thread identifiers
  blockDim, gridDim, threadIdx, blockIdx, warpSize,
  gridSize, globalThreadIdx,

  -- Other intrinsics
  laneId, warpId,
  laneMask_eq, laneMask_lt, laneMask_le, laneMask_gt, laneMask_ge,
  atomicAdd_f,
  nanosleep,

  -- Barriers and synchronisation
  __syncthreads, __syncthreads_count, __syncthreads_and, __syncthreads_or,
  __syncwarp, __syncwarp_mask,
  __threadfence_block, __threadfence_grid,

  -- Shared memory
  staticSharedMem,
  dynamicSharedMem,
  sharedMemAddrSpace,

  -- Kernel definitions
  (+++),
  makeOpenAcc, makeOpenAccWith,

) where

import Data.Array.Accelerate.Error
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Module
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Ptr
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
import Data.Array.Accelerate.LLVM.PTX.Target
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type

import Foreign.CUDA.Analysis                                        ( Compute(..), computeCapability )
import qualified Foreign.CUDA.Analysis                              as CUDA

import LLVM.AST.Type.AddrSpace
import LLVM.AST.Type.Constant
import LLVM.AST.Type.Downcast
import LLVM.AST.Type.Function
import LLVM.AST.Type.InlineAssembly
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Metadata
import LLVM.AST.Type.Name
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation
import qualified LLVM.AST.Constant                                  as LLVM hiding ( type' )
import qualified LLVM.AST.Global                                    as LLVM
import qualified LLVM.AST.Linkage                                   as LLVM
import qualified LLVM.AST.Name                                      as LLVM
import qualified LLVM.AST.Type                                      as LLVM

import Control.Applicative
import Control.Monad                                                ( void )
import Control.Monad.State                                          ( gets )
import Prelude                                                      as P

#if MIN_VERSION_llvm_hs(10,0,0)
import qualified LLVM.AST.Type.Instruction.RMW                      as RMW
import LLVM.AST.Type.Instruction.Atomic
#elif !MIN_VERSION_llvm_hs(9,0,0)
import Data.String
import Text.Printf
#endif


-- Thread identifiers
-- ------------------

-- | Read the builtin registers that store CUDA thread and grid identifiers
--
-- <https://github.com/llvm-mirror/llvm/blob/master/include/llvm/IR/IntrinsicsNVVM.td>
--
specialPTXReg :: Label -> CodeGen PTX (Operands Int32)
specialPTXReg :: Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
f =
  GlobalFunction '[] Int32
-> [FunctionAttribute] -> CodeGen PTX (Operands Int32)
forall (args :: [*]) t arch.
GlobalFunction args t
-> [FunctionAttribute] -> CodeGen arch (Operands t)
call (Type Int32 -> Maybe TailCall -> Label -> GlobalFunction '[] Int32
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type Int32
forall a. IsType a => Type a
type' (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) Label
f) [FunctionAttribute
NoUnwind, FunctionAttribute
ReadNone]

blockDim, gridDim, threadIdx, blockIdx, warpSize :: CodeGen PTX (Operands Int32)
blockDim :: CodeGen PTX (Operands Int32)
blockDim    = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.ntid.x"
gridDim :: CodeGen PTX (Operands Int32)
gridDim     = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.nctaid.x"
threadIdx :: CodeGen PTX (Operands Int32)
threadIdx   = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.tid.x"
blockIdx :: CodeGen PTX (Operands Int32)
blockIdx    = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.ctaid.x"
warpSize :: CodeGen PTX (Operands Int32)
warpSize    = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.warpsize"

laneId :: CodeGen PTX (Operands Int32)
laneId :: CodeGen PTX (Operands Int32)
laneId      = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.laneid"

laneMask_eq, laneMask_lt, laneMask_le, laneMask_gt, laneMask_ge :: CodeGen PTX (Operands Int32)
laneMask_eq :: CodeGen PTX (Operands Int32)
laneMask_eq = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.lanemask.eq"
laneMask_lt :: CodeGen PTX (Operands Int32)
laneMask_lt = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.lanemask.lt"
laneMask_le :: CodeGen PTX (Operands Int32)
laneMask_le = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.lanemask.le"
laneMask_gt :: CodeGen PTX (Operands Int32)
laneMask_gt = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.lanemask.gt"
laneMask_ge :: CodeGen PTX (Operands Int32)
laneMask_ge = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.nvvm.read.ptx.sreg.lanemask.ge"


-- | NOTE: The special register %warpid as volatile value and is not guaranteed
--         to be constant over the lifetime of a thread or thread block.
--
-- http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#sm-id-and-warp-id
--
-- http://docs.nvidia.com/cuda/parallel-thread-execution/index.html#special-registers-warpid
--
warpId :: CodeGen PTX (Operands Int32)
warpId :: CodeGen PTX (Operands Int32)
warpId = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  Operands Int32
tid <- CodeGen PTX (Operands Int32)
threadIdx
  IntegralType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.quot IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
tid (Int32 -> Operands Int32
A.liftInt32 (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (DeviceProperties -> Int
CUDA.warpSize DeviceProperties
dev)))

_warpId :: CodeGen PTX (Operands Int32)
_warpId :: CodeGen PTX (Operands Int32)
_warpId = Label -> CodeGen PTX (Operands Int32)
specialPTXReg Label
"llvm.ptx.read.warpid"


-- | The size of the thread grid
--
-- > gridDim.x * blockDim.x
--
gridSize :: CodeGen PTX (Operands Int32)
gridSize :: CodeGen PTX (Operands Int32)
gridSize = do
  Operands Int32
ncta  <- CodeGen PTX (Operands Int32)
gridDim
  Operands Int32
nt    <- CodeGen PTX (Operands Int32)
blockDim
  NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
ncta Operands Int32
nt


-- | The global thread index
--
-- > blockDim.x * blockIdx.x + threadIdx.x
--
globalThreadIdx :: CodeGen PTX (Operands Int32)
globalThreadIdx :: CodeGen PTX (Operands Int32)
globalThreadIdx = do
  Operands Int32
ntid  <- CodeGen PTX (Operands Int32)
blockDim
  Operands Int32
ctaid <- CodeGen PTX (Operands Int32)
blockIdx
  Operands Int32
tid   <- CodeGen PTX (Operands Int32)
threadIdx
  --
  Operands Int32
u     <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
ntid Operands Int32
ctaid
  Operands Int32
v     <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
add NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
tid Operands Int32
u
  Operands Int32 -> CodeGen PTX (Operands Int32)
forall (m :: * -> *) a. Monad m => a -> m a
return Operands Int32
v


{--
-- | Generate function parameters that will specify the first and last (linear)
-- index of the array this kernel should evaluate.
--
gangParam :: (Operands Int, Operands Int, [LLVM.Parameter])
gangParam =
  let start = "ix.start"
      end   = "ix.end"
  in
  (local start, local end, parameter start ++ parameter end )
--}


-- Barriers and synchronisation
-- ----------------------------

-- | Call a built-in CUDA synchronisation intrinsic
--
barrier :: Label -> CodeGen PTX ()
barrier :: Label -> CodeGen PTX ()
barrier Label
f = CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands ()) -> CodeGen PTX ())
-> CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ GlobalFunction '[] ()
-> [FunctionAttribute] -> CodeGen PTX (Operands ())
forall (args :: [*]) t arch.
GlobalFunction args t
-> [FunctionAttribute] -> CodeGen arch (Operands t)
call (Type () -> Maybe TailCall -> Label -> GlobalFunction '[] ()
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type ()
VoidType (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) Label
f) [FunctionAttribute
NoUnwind, FunctionAttribute
NoDuplicate, FunctionAttribute
Convergent]

barrier_op :: Label -> Operands Int32 -> CodeGen PTX (Operands Int32)
barrier_op :: Label -> Operands Int32 -> CodeGen PTX (Operands Int32)
barrier_op Label
f Operands Int32
x = GlobalFunction '[Int32] Int32
-> [FunctionAttribute] -> CodeGen PTX (Operands Int32)
forall (args :: [*]) t arch.
GlobalFunction args t
-> [FunctionAttribute] -> CodeGen arch (Operands t)
call (PrimType Int32
-> Operand Int32
-> GlobalFunction '[] Int32
-> GlobalFunction '[Int32] Int32
forall a kind (args1 :: [*]) t.
PrimType a
-> Operand a
-> Function kind args1 t
-> Function kind (a : args1) t
Lam PrimType Int32
forall a. IsPrim a => PrimType a
primType (IntegralType Int32 -> Operands Int32 -> Operand Int32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
x) (Type Int32 -> Maybe TailCall -> Label -> GlobalFunction '[] Int32
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type Int32
forall a. IsType a => Type a
type' (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) Label
f)) [FunctionAttribute
NoUnwind, FunctionAttribute
NoDuplicate, FunctionAttribute
Convergent]


-- | Wait until all threads in the thread block have reached this point, and all
-- global and shared memory accesses made by these threads prior to the
-- __syncthreads() are visible to all threads in the block.
--
-- <http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#synchronization-functions>
--
__syncthreads :: CodeGen PTX ()
__syncthreads :: CodeGen PTX ()
__syncthreads = Label -> CodeGen PTX ()
barrier Label
"llvm.nvvm.barrier0"

-- | Identical to __syncthreads() with the additional feature that it returns
-- the number of threads in the block for which the predicate evaluates to
-- non-zero.
--
__syncthreads_count :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_count :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_count = Label -> Operands Int32 -> CodeGen PTX (Operands Int32)
barrier_op Label
"llvm.nvvm.barrier0.popc"

-- | Identical to __syncthreads() with the additional feature that it returns
-- non-zero iff the predicate evaluates to non-zero for all threads in the
-- block.
--
__syncthreads_and :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_and :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_and = Label -> Operands Int32 -> CodeGen PTX (Operands Int32)
barrier_op Label
"llvm.nvvm.barrier0.and"

-- | Identical to __syncthreads() with the additional feature that it returns
-- non-zero iff the predicate evaluates to non-zero for any thread in the block.
--
__syncthreads_or :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_or :: Operands Int32 -> CodeGen PTX (Operands Int32)
__syncthreads_or = Label -> Operands Int32 -> CodeGen PTX (Operands Int32)
barrier_op Label
"llvm.nvvm.barrier0.or"


-- | Wait until all warp lanes have reached this point.
--
__syncwarp :: HasCallStack => CodeGen PTX ()
__syncwarp :: CodeGen PTX ()
__syncwarp = HasCallStack => Operands Word32 -> CodeGen PTX ()
Operands Word32 -> CodeGen PTX ()
__syncwarp_mask (Word32 -> Operands Word32
liftWord32 Word32
0xffffffff)

-- | Wait until all warp lanes named in the mask have executed a __syncwarp()
-- with the same mask. All non-exited threads named in the mask must execute
-- a corresponding __syncwarp with the same mask, or the result is undefined.
--
-- This guarantees memory ordering among threads participating in the barrier.
--
-- Requires LLVM-6.0 or higher.
-- Only required for devices of SM7 and later.
--
__syncwarp_mask :: HasCallStack => Operands Word32 -> CodeGen PTX ()
__syncwarp_mask :: Operands Word32 -> CodeGen PTX ()
__syncwarp_mask Operands Word32
mask = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  if DeviceProperties -> Compute
computeCapability DeviceProperties
dev Compute -> Compute -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Int -> Compute
Compute Int
7 Int
0
    then () -> CodeGen PTX ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    else
#if !MIN_VERSION_llvm_hs(6,0,0)
         internalError "LLVM-6.0 or above is required for Volta devices and later"
#else
         CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands ()) -> CodeGen PTX ())
-> CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ GlobalFunction '[Word32] ()
-> [FunctionAttribute] -> CodeGen PTX (Operands ())
forall (args :: [*]) t arch.
GlobalFunction args t
-> [FunctionAttribute] -> CodeGen arch (Operands t)
call (PrimType Word32
-> Operand Word32
-> GlobalFunction '[] ()
-> GlobalFunction '[Word32] ()
forall a kind (args1 :: [*]) t.
PrimType a
-> Operand a
-> Function kind args1 t
-> Function kind (a : args1) t
Lam PrimType Word32
forall a. IsPrim a => PrimType a
primType (PrimType Word32 -> Operands Word32 -> Operand Word32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op PrimType Word32
forall a. IsPrim a => PrimType a
primType Operands Word32
mask) (Type () -> Maybe TailCall -> Label -> GlobalFunction '[] ()
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type ()
VoidType (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) Label
"llvm.nvvm.bar.warp.sync")) [FunctionAttribute
NoUnwind, FunctionAttribute
NoDuplicate, FunctionAttribute
Convergent]
#endif


-- | Ensure that all writes to shared and global memory before the call to
-- __threadfence_block() are observed by all threads in the *block* of the
-- calling thread as occurring before all writes to shared and global memory
-- made by the calling thread after the call.
--
-- <http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-fence-functions>
--
__threadfence_block :: CodeGen PTX ()
__threadfence_block :: CodeGen PTX ()
__threadfence_block = Label -> CodeGen PTX ()
barrier Label
"llvm.nvvm.membar.cta"


-- | As __threadfence_block(), but the synchronisation is for *all* thread blocks.
-- In CUDA this is known simply as __threadfence().
--
__threadfence_grid :: CodeGen PTX ()
__threadfence_grid :: CodeGen PTX ()
__threadfence_grid = Label -> CodeGen PTX ()
barrier Label
"llvm.nvvm.membar.gl"


-- Atomic functions
-- ----------------

-- LLVM provides atomic instructions for integer arguments only. CUDA provides
-- additional support for atomic add on floating point types, which can be
-- accessed through the following intrinsics.
--
-- Double precision is supported on Compute 6.0 devices and later. Half
-- precision is supported on Compute 7.0 devices and later.
--
-- LLVM-4.0 currently lacks support for this intrinsic, however it is
-- accessible via inline assembly.
--
-- LLVM-9 integrated floating-point atomic operations into the AtomicRMW
-- instruction, but this functionality is missing from llvm-hs-9. We access
-- it via inline assembly..
--
-- <https://github.com/AccelerateHS/accelerate/issues/363>
--
atomicAdd_f :: HasCallStack => FloatingType a -> Operand (Ptr a) -> Operand a -> CodeGen PTX ()
atomicAdd_f :: FloatingType a -> Operand (Ptr a) -> Operand a -> CodeGen PTX ()
atomicAdd_f FloatingType a
t Operand (Ptr a)
addr Operand a
val =
#if MIN_VERSION_llvm_hs(10,0,0)
  void . instr' $ AtomicRMW (FloatingNumType t) NonVolatile RMW.FAdd addr val (CrossThread, AcquireRelease)
#else
  let
      _width :: Int
      _width :: Int
_width =
        case FloatingType a
t of
          FloatingType a
TypeHalf    -> Int
16
          FloatingType a
TypeFloat   -> Int
32
          FloatingType a
TypeDouble  -> Int
64

      (PrimType (Ptr a)
t_addr, ScalarType a
t_val, Word32
_addrspace) =
        case Operand (Ptr a) -> Type (Ptr a)
forall (f :: * -> *) a. TypeOf f => f a -> Type a
typeOf Operand (Ptr a)
addr of
          PrimType ta :: PrimType (Ptr a)
ta@(PtrPrimType (ScalarPrimType ScalarType a1
tv) (AddrSpace Word32
as))
            -> (PrimType (Ptr a)
ta, ScalarType a
ScalarType a1
tv, Word32
as)
          Type (Ptr a)
_ -> String -> (PrimType (Ptr a), ScalarType a, Word32)
forall a. HasCallStack => String -> a
internalError String
"unexpected operand type"

      t_ret :: Type a
t_ret = PrimType a -> Type a
forall a. PrimType a -> Type a
PrimType (ScalarType a -> PrimType a
forall a. ScalarType a -> PrimType a
ScalarPrimType ScalarType a
t_val)
#if MIN_VERSION_llvm_hs(9,0,0) || !MIN_VERSION_llvm_hs(6,0,0)
      asm :: InlineAssembly
asm   =
        case FloatingType a
t of
          -- assuming .address_size 64
          FloatingType a
TypeHalf   -> ByteString
-> ShortByteString -> Bool -> Bool -> Dialect -> InlineAssembly
InlineAssembly ByteString
"atom.add.noftz.f16  $0, [$1], $2;" ShortByteString
"=c,l,c" Bool
True Bool
False Dialect
ATTDialect
          FloatingType a
TypeFloat  -> ByteString
-> ShortByteString -> Bool -> Bool -> Dialect -> InlineAssembly
InlineAssembly ByteString
"atom.global.add.f32 $0, [$1], $2;" ShortByteString
"=f,l,f" Bool
True Bool
False Dialect
ATTDialect
          FloatingType a
TypeDouble -> ByteString
-> ShortByteString -> Bool -> Bool -> Dialect -> InlineAssembly
InlineAssembly ByteString
"atom.global.add.f64 $0, [$1], $2;" ShortByteString
"=d,l,d" Bool
True Bool
False Dialect
ATTDialect
  in
  CodeGen PTX (Operands a) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands a) -> CodeGen PTX ())
-> CodeGen PTX (Operands a) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ Instruction a -> CodeGen PTX (Operands a)
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operands a)
instr (Function (Either InlineAssembly Label) '[Ptr a, a] a
-> [Either GroupID FunctionAttribute] -> Instruction a
forall (args :: [*]) a.
Function (Either InlineAssembly Label) args a
-> [Either GroupID FunctionAttribute] -> Instruction a
Call (PrimType (Ptr a)
-> Operand (Ptr a)
-> Function (Either InlineAssembly Label) '[a] a
-> Function (Either InlineAssembly Label) '[Ptr a, a] a
forall a kind (args1 :: [*]) t.
PrimType a
-> Operand a
-> Function kind args1 t
-> Function kind (a : args1) t
Lam PrimType (Ptr a)
t_addr Operand (Ptr a)
addr (PrimType a
-> Operand a
-> Function (Either InlineAssembly Label) '[] a
-> Function (Either InlineAssembly Label) '[a] a
forall a kind (args1 :: [*]) t.
PrimType a
-> Operand a
-> Function kind args1 t
-> Function kind (a : args1) t
Lam (ScalarType a -> PrimType a
forall a. ScalarType a -> PrimType a
ScalarPrimType ScalarType a
t_val) Operand a
val (Type a
-> Maybe TailCall
-> Either InlineAssembly Label
-> Function (Either InlineAssembly Label) '[] a
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type a
t_ret (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) (InlineAssembly -> Either InlineAssembly Label
forall a b. a -> Either a b
Left InlineAssembly
asm)))) [FunctionAttribute -> Either GroupID FunctionAttribute
forall a b. b -> Either a b
Right FunctionAttribute
NoUnwind])
#else
      fun   = fromString $ printf "llvm.nvvm.atomic.load.add.f%d.p%df%d" _width (_addrspace :: Word32) _width
  in
  void $ call (Lam t_addr addr (Lam (ScalarPrimType t_val) val (Body t_ret (Just Tail) fun))) [NoUnwind]
#endif
#endif


-- Shared memory
-- -------------

sharedMemAddrSpace :: AddrSpace
sharedMemAddrSpace :: AddrSpace
sharedMemAddrSpace = Word32 -> AddrSpace
AddrSpace Word32
3

sharedMemVolatility :: Volatility
sharedMemVolatility :: Volatility
sharedMemVolatility = Volatility
Volatile


-- Declare a new statically allocated array in the __shared__ memory address
-- space, with enough storage to contain the given number of elements.
--
staticSharedMem
    :: TypeR e
    -> Word64
    -> CodeGen PTX (IRArray (Vector e))
staticSharedMem :: TypeR e -> Word64 -> CodeGen PTX (IRArray (Vector e))
staticSharedMem TypeR e
tp Word64
n = do
  Operands e
ad    <- TypeR e -> CodeGen PTX (Operands e)
forall s. TypeR s -> CodeGen PTX (Operands s)
go TypeR e
tp
  IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return (IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e)))
-> IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e))
forall a b. (a -> b) -> a -> b
$ IRArray :: forall sh e.
ArrayR (Array sh e)
-> Operands sh
-> Operands e
-> AddrSpace
-> Volatility
-> IRArray (Array sh e)
IRArray { irArrayRepr :: ArrayR (Vector e)
irArrayRepr       = ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp
                   , irArrayShape :: Operands DIM1
irArrayShape      = Operands () -> Operands Int -> Operands DIM1
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair Operands ()
OP_Unit (Operands Int -> Operands DIM1) -> Operands Int -> Operands DIM1
forall a b. (a -> b) -> a -> b
$ Operand Int -> Operands Int
OP_Int (Operand Int -> Operands Int) -> Operand Int -> Operands Int
forall a b. (a -> b) -> a -> b
$ IntegralType Int -> Int -> Operand Int
forall a. IntegralType a -> a -> Operand a
integral IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType (Int -> Operand Int) -> Int -> Operand Int
forall a b. (a -> b) -> a -> b
$ Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral Word64
n
                   , irArrayData :: Operands e
irArrayData       = Operands e
ad
                   , irArrayAddrSpace :: AddrSpace
irArrayAddrSpace  = AddrSpace
sharedMemAddrSpace
                   , irArrayVolatility :: Volatility
irArrayVolatility = Volatility
sharedMemVolatility
                   }
  where
    go :: TypeR s -> CodeGen PTX (Operands s)
    go :: TypeR s -> CodeGen PTX (Operands s)
go TypeR s
TupRunit          = Operands () -> CodeGen PTX (Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return Operands ()
OP_Unit
    go (TupRpair TupR ScalarType a1
t1 TupR ScalarType b
t2)  = Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (Operands a1 -> Operands b -> Operands (a1, b))
-> CodeGen PTX (Operands a1)
-> CodeGen PTX (Operands b -> Operands (a1, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TupR ScalarType a1 -> CodeGen PTX (Operands a1)
forall s. TypeR s -> CodeGen PTX (Operands s)
go TupR ScalarType a1
t1 CodeGen PTX (Operands b -> Operands (a1, b))
-> CodeGen PTX (Operands b) -> CodeGen PTX (Operands (a1, b))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TupR ScalarType b -> CodeGen PTX (Operands b)
forall s. TypeR s -> CodeGen PTX (Operands s)
go TupR ScalarType b
t2
    go tt :: TypeR s
tt@(TupRsingle ScalarType s
t) = do
      -- Declare a new global reference for the statically allocated array
      -- located in the __shared__ memory space.
      Name (Ptr s)
nm <- CodeGen PTX (Name (Ptr s))
forall arch a. CodeGen arch (Name a)
freshName
      Operand (Ptr s)
sm <- Operand (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operand (Ptr s) -> CodeGen PTX (Operand (Ptr s)))
-> Operand (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a b. (a -> b) -> a -> b
$ Constant (Ptr s) -> Operand (Ptr s)
forall a. Constant a -> Operand a
ConstantOperand (Constant (Ptr s) -> Operand (Ptr s))
-> Constant (Ptr s) -> Operand (Ptr s)
forall a b. (a -> b) -> a -> b
$ Type (Ptr s) -> Name (Ptr s) -> Constant (Ptr s)
forall a. Type a -> Name a -> Constant a
GlobalReference (PrimType (Ptr s) -> Type (Ptr s)
forall a. PrimType a -> Type a
PrimType (PrimType s -> AddrSpace -> PrimType (Ptr s)
forall a1. PrimType a1 -> AddrSpace -> PrimType (Ptr a1)
PtrPrimType (Word64 -> ScalarType s -> PrimType s
forall a. Word64 -> ScalarType a -> PrimType a
ArrayPrimType Word64
n ScalarType s
t) AddrSpace
sharedMemAddrSpace)) Name (Ptr s)
nm
      Global -> CodeGen PTX ()
forall arch. HasCallStack => Global -> CodeGen arch ()
declare (Global -> CodeGen PTX ()) -> Global -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ Global
LLVM.globalVariableDefaults
        { addrSpace :: AddrSpace
LLVM.addrSpace = AddrSpace
sharedMemAddrSpace
        , type' :: Type
LLVM.type'     = Word64 -> Type -> Type
LLVM.ArrayType Word64
n (ScalarType s -> Type
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast ScalarType s
t)
        , linkage :: Linkage
LLVM.linkage   = Linkage
LLVM.External
        , name :: Name
LLVM.name      = Name (Ptr s) -> Name
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast Name (Ptr s)
nm
        , alignment :: Word32
LLVM.alignment = Word32
4 Word32 -> Word32 -> Word32
forall a. Ord a => a -> a -> a
`P.max` Int -> Word32
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (TypeR s -> Int
forall e. TypeR e -> Int
bytesElt TypeR s
tt)
        }

      -- Return a pointer to the first element of the __shared__ memory array.
      -- We do this rather than just returning the global reference directly due
      -- to how __shared__ memory needs to be indexed with the GEP instruction.
      Operand (Ptr s)
p <- Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s)))
-> Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a b. (a -> b) -> a -> b
$ Operand (Ptr s) -> [Operand Int32] -> Instruction (Ptr s)
forall a1 i.
Operand (Ptr a1) -> [Operand i] -> Instruction (Ptr a1)
GetElementPtr Operand (Ptr s)
sm [NumType Int32 -> Int32 -> Operand Int32
forall a. NumType a -> a -> Operand a
num NumType Int32
forall a. IsNum a => NumType a
numType Int32
0, NumType Int32 -> Int32 -> Operand Int32
forall a. NumType a -> a -> Operand a
num NumType Int32
forall a. IsNum a => NumType a
numType Int32
0 :: Operand Int32]
      Operand (Ptr s)
q <- Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s)))
-> Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a b. (a -> b) -> a -> b
$ PrimType (Ptr s) -> Operand (Ptr s) -> Instruction (Ptr s)
forall b a1.
PrimType (Ptr b) -> Operand (Ptr a1) -> Instruction (Ptr b)
PtrCast (PrimType s -> AddrSpace -> PrimType (Ptr s)
forall a1. PrimType a1 -> AddrSpace -> PrimType (Ptr a1)
PtrPrimType (ScalarType s -> PrimType s
forall a. ScalarType a -> PrimType a
ScalarPrimType ScalarType s
t) AddrSpace
sharedMemAddrSpace) Operand (Ptr s)
p

      Operands s -> CodeGen PTX (Operands s)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands s -> CodeGen PTX (Operands s))
-> Operands s -> CodeGen PTX (Operands s)
forall a b. (a -> b) -> a -> b
$ ScalarType s -> Operand s -> Operands s
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType s
t (Operand (Ptr s) -> Operand s
forall t. HasCallStack => Operand (Ptr t) -> Operand t
unPtr Operand (Ptr s)
q)


-- External declaration in shared memory address space. This must be declared in
-- order to access memory allocated dynamically by the CUDA driver. This results
-- in the following global declaration:
--
-- > @__shared__ = external addrspace(3) global [0 x i8]
--
initialiseDynamicSharedMemory :: CodeGen PTX (Operand (Ptr Word8))
initialiseDynamicSharedMemory :: CodeGen PTX (Operand (Ptr Word8))
initialiseDynamicSharedMemory = do
  Global -> CodeGen PTX ()
forall arch. HasCallStack => Global -> CodeGen arch ()
declare (Global -> CodeGen PTX ()) -> Global -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ Global
LLVM.globalVariableDefaults
    { addrSpace :: AddrSpace
LLVM.addrSpace = AddrSpace
sharedMemAddrSpace
    , type' :: Type
LLVM.type'     = Word64 -> Type -> Type
LLVM.ArrayType Word64
0 (Word32 -> Type
LLVM.IntegerType Word32
8)
    , linkage :: Linkage
LLVM.linkage   = Linkage
LLVM.External
    , name :: Name
LLVM.name      = ShortByteString -> Name
LLVM.Name ShortByteString
"__shared__"
    , alignment :: Word32
LLVM.alignment = Word32
4
    }
  Operand (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operand (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8)))
-> Operand (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8))
forall a b. (a -> b) -> a -> b
$ Constant (Ptr Word8) -> Operand (Ptr Word8)
forall a. Constant a -> Operand a
ConstantOperand (Constant (Ptr Word8) -> Operand (Ptr Word8))
-> Constant (Ptr Word8) -> Operand (Ptr Word8)
forall a b. (a -> b) -> a -> b
$ Type (Ptr Word8) -> Name (Ptr Word8) -> Constant (Ptr Word8)
forall a. Type a -> Name a -> Constant a
GlobalReference (PrimType (Ptr Word8) -> Type (Ptr Word8)
forall a. PrimType a -> Type a
PrimType (PrimType Word8 -> AddrSpace -> PrimType (Ptr Word8)
forall a1. PrimType a1 -> AddrSpace -> PrimType (Ptr a1)
PtrPrimType (Word64 -> ScalarType Word8 -> PrimType Word8
forall a. Word64 -> ScalarType a -> PrimType a
ArrayPrimType Word64
0 ScalarType Word8
forall a. IsScalar a => ScalarType a
scalarType) AddrSpace
sharedMemAddrSpace)) Name (Ptr Word8)
"__shared__"


-- Declared a new dynamically allocated array in the __shared__ memory space
-- with enough space to contain the given number of elements.
--
dynamicSharedMem
    :: forall e int.
       TypeR e
    -> IntegralType int
    -> Operands int                                 -- number of array elements
    -> Operands int                                 -- #bytes of shared memory the have already been allocated
    -> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem :: TypeR e
-> IntegralType int
-> Operands int
-> Operands int
-> CodeGen PTX (IRArray (Vector e))
dynamicSharedMem TypeR e
tp IntegralType int
int n :: Operands int
n@(IntegralType int -> Operands int -> Operand int
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType int
int -> Operand int
m) (IntegralType int -> Operands int -> Operand int
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType int
int -> Operand int
offset)
  | IntegralDict int
IntegralDict <- IntegralType int -> IntegralDict int
forall a. IntegralType a -> IntegralDict a
integralDict IntegralType int
int = do
    Operand (Ptr Word8)
smem         <- CodeGen PTX (Operand (Ptr Word8))
initialiseDynamicSharedMemory
    let
        numTp :: NumType int
numTp = IntegralType int -> NumType int
forall a. IntegralType a -> NumType a
IntegralNumType IntegralType int
int

        go :: TypeR s -> Operand int -> CodeGen PTX (Operand int, Operands s)
        go :: TypeR s -> Operand int -> CodeGen PTX (Operand int, Operands s)
go TypeR s
TupRunit         Operand int
i  = (Operand int, Operands ())
-> CodeGen PTX (Operand int, Operands ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Operand int
i, Operands ()
OP_Unit)
        go (TupRpair TupR ScalarType a1
t2 TupR ScalarType b
t1) Operand int
i0 = do
          (Operand int
i1, Operands b
p1) <- TupR ScalarType b
-> Operand int -> CodeGen PTX (Operand int, Operands b)
forall s.
TypeR s -> Operand int -> CodeGen PTX (Operand int, Operands s)
go TupR ScalarType b
t1 Operand int
i0
          (Operand int
i2, Operands a1
p2) <- TupR ScalarType a1
-> Operand int -> CodeGen PTX (Operand int, Operands a1)
forall s.
TypeR s -> Operand int -> CodeGen PTX (Operand int, Operands s)
go TupR ScalarType a1
t2 Operand int
i1
          (Operand int, Operands (a1, b))
-> CodeGen PTX (Operand int, Operands (a1, b))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Operand int, Operands (a1, b))
 -> CodeGen PTX (Operand int, Operands (a1, b)))
-> (Operand int, Operands (a1, b))
-> CodeGen PTX (Operand int, Operands (a1, b))
forall a b. (a -> b) -> a -> b
$ (Operand int
i2, Operands a1 -> Operands b -> Operands (a1, b)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair Operands a1
p2 Operands b
p1)
        go (TupRsingle ScalarType s
t)   Operand int
i  = do
          Operand (Ptr Word8)
p <- Instruction (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8))
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8)))
-> Instruction (Ptr Word8) -> CodeGen PTX (Operand (Ptr Word8))
forall a b. (a -> b) -> a -> b
$ Operand (Ptr Word8) -> [Operand int] -> Instruction (Ptr Word8)
forall a1 i.
Operand (Ptr a1) -> [Operand i] -> Instruction (Ptr a1)
GetElementPtr Operand (Ptr Word8)
smem [NumType int -> int -> Operand int
forall a. NumType a -> a -> Operand a
num NumType int
numTp int
0, Operand int
i] -- TLM: note initial zero index!!
          Operand (Ptr s)
q <- Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s)))
-> Instruction (Ptr s) -> CodeGen PTX (Operand (Ptr s))
forall a b. (a -> b) -> a -> b
$ PrimType (Ptr s) -> Operand (Ptr Word8) -> Instruction (Ptr s)
forall b a1.
PrimType (Ptr b) -> Operand (Ptr a1) -> Instruction (Ptr b)
PtrCast (PrimType s -> AddrSpace -> PrimType (Ptr s)
forall a1. PrimType a1 -> AddrSpace -> PrimType (Ptr a1)
PtrPrimType (ScalarType s -> PrimType s
forall a. ScalarType a -> PrimType a
ScalarPrimType ScalarType s
t) AddrSpace
sharedMemAddrSpace) Operand (Ptr Word8)
p
          Operand int
a <- Instruction int -> CodeGen PTX (Operand int)
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction int -> CodeGen PTX (Operand int))
-> Instruction int -> CodeGen PTX (Operand int)
forall a b. (a -> b) -> a -> b
$ NumType int -> Operand int -> Operand int -> Instruction int
forall a. NumType a -> Operand a -> Operand a -> Instruction a
Mul NumType int
numTp Operand int
m (IntegralType int -> int -> Operand int
forall a. IntegralType a -> a -> Operand a
integral IntegralType int
int (Int -> int
forall a b. (Integral a, Num b) => a -> b
P.fromIntegral (TypeR s -> Int
forall e. TypeR e -> Int
bytesElt (ScalarType s -> TypeR s
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType s
t))))
          Operand int
b <- Instruction int -> CodeGen PTX (Operand int)
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction int -> CodeGen PTX (Operand int))
-> Instruction int -> CodeGen PTX (Operand int)
forall a b. (a -> b) -> a -> b
$ NumType int -> Operand int -> Operand int -> Instruction int
forall a. NumType a -> Operand a -> Operand a -> Instruction a
Add NumType int
numTp Operand int
i Operand int
a
          (Operand int, Operands s) -> CodeGen PTX (Operand int, Operands s)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operand int
b, ScalarType s -> Operand s -> Operands s
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir ScalarType s
t (Operand (Ptr s) -> Operand s
forall t. HasCallStack => Operand (Ptr t) -> Operand t
unPtr Operand (Ptr s)
q))
    --
    (Operand int
_, Operands e
ad) <- TypeR e -> Operand int -> CodeGen PTX (Operand int, Operands e)
forall s.
TypeR s -> Operand int -> CodeGen PTX (Operand int, Operands s)
go TypeR e
tp Operand int
offset
    Operands Int
sz      <- IntegralType int
-> NumType Int -> Operands int -> CodeGen PTX (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType int
int (NumType Int
forall a. IsNum a => NumType a
numType :: NumType Int) Operands int
n
    IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return   (IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e)))
-> IRArray (Vector e) -> CodeGen PTX (IRArray (Vector e))
forall a b. (a -> b) -> a -> b
$ IRArray :: forall sh e.
ArrayR (Array sh e)
-> Operands sh
-> Operands e
-> AddrSpace
-> Volatility
-> IRArray (Array sh e)
IRArray { irArrayRepr :: ArrayR (Vector e)
irArrayRepr       = ShapeR DIM1 -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR DIM1
dim1 TypeR e
tp
                       , irArrayShape :: Operands DIM1
irArrayShape      = Operands () -> Operands Int -> Operands DIM1
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair Operands ()
OP_Unit Operands Int
sz
                       , irArrayData :: Operands e
irArrayData       = Operands e
ad
                       , irArrayAddrSpace :: AddrSpace
irArrayAddrSpace  = AddrSpace
sharedMemAddrSpace
                       , irArrayVolatility :: Volatility
irArrayVolatility = Volatility
sharedMemVolatility
                       }


-- Other functions
-- ---------------

-- Sleep the thread for (approximately) the given number of nanoseconds.
-- Requires compute capability >= 7.0
--
nanosleep :: Operands Int32 -> CodeGen PTX ()
nanosleep :: Operands Int32 -> CodeGen PTX ()
nanosleep Operands Int32
ns =
  let
      attrs :: [FunctionAttribute]
attrs = [FunctionAttribute
NoUnwind, FunctionAttribute
Convergent]
      asm :: InlineAssembly
asm   = ByteString
-> ShortByteString -> Bool -> Bool -> Dialect -> InlineAssembly
InlineAssembly ByteString
"nanosleep.u32 $0;" ShortByteString
"r" Bool
True Bool
False Dialect
ATTDialect
  in
  CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operands ()) -> CodeGen PTX ())
-> CodeGen PTX (Operands ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ Instruction () -> CodeGen PTX (Operands ())
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operands a)
instr (Function (Either InlineAssembly Label) '[Int32] ()
-> [Either GroupID FunctionAttribute] -> Instruction ()
forall (args :: [*]) a.
Function (Either InlineAssembly Label) args a
-> [Either GroupID FunctionAttribute] -> Instruction a
Call (PrimType Int32
-> Operand Int32
-> Function (Either InlineAssembly Label) '[] ()
-> Function (Either InlineAssembly Label) '[Int32] ()
forall a kind (args1 :: [*]) t.
PrimType a
-> Operand a
-> Function kind args1 t
-> Function kind (a : args1) t
Lam PrimType Int32
forall a. IsPrim a => PrimType a
primType (IntegralType Int32 -> Operands Int32 -> Operand Int32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Operands Int32
ns) (Type ()
-> Maybe TailCall
-> Either InlineAssembly Label
-> Function (Either InlineAssembly Label) '[] ()
forall t kind.
Type t -> Maybe TailCall -> kind -> Function kind '[] t
Body Type ()
VoidType (TailCall -> Maybe TailCall
forall a. a -> Maybe a
Just TailCall
Tail) (InlineAssembly -> Either InlineAssembly Label
forall a b. a -> Either a b
Left InlineAssembly
asm))) ((FunctionAttribute -> Either GroupID FunctionAttribute)
-> [FunctionAttribute] -> [Either GroupID FunctionAttribute]
forall a b. (a -> b) -> [a] -> [b]
map FunctionAttribute -> Either GroupID FunctionAttribute
forall a b. b -> Either a b
Right [FunctionAttribute]
attrs))


-- Global kernel definitions
-- -------------------------

data instance KernelMetadata PTX = KM_PTX LaunchConfig

-- | Combine kernels into a single program
--
(+++) :: IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
IROpenAcc [Kernel PTX aenv a]
k1 +++ :: IROpenAcc PTX aenv a
-> IROpenAcc PTX aenv a -> IROpenAcc PTX aenv a
+++ IROpenAcc [Kernel PTX aenv a]
k2 = [Kernel PTX aenv a] -> IROpenAcc PTX aenv a
forall arch aenv arrs.
[Kernel arch aenv arrs] -> IROpenAcc arch aenv arrs
IROpenAcc ([Kernel PTX aenv a]
k1 [Kernel PTX aenv a] -> [Kernel PTX aenv a] -> [Kernel PTX aenv a]
forall a. [a] -> [a] -> [a]
++ [Kernel PTX aenv a]
k2)


-- | Create a single kernel program with the default launch configuration.
--
makeOpenAcc
    :: Label
    -> [LLVM.Parameter]
    -> CodeGen PTX ()
    -> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAcc :: Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAcc Label
name [Parameter]
param CodeGen PTX ()
kernel = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith (DeviceProperties -> LaunchConfig
simpleLaunchConfig DeviceProperties
dev) Label
name [Parameter]
param CodeGen PTX ()
kernel

-- | Create a single kernel program with the given launch analysis information.
--
makeOpenAccWith
    :: LaunchConfig
    -> Label
    -> [LLVM.Parameter]
    -> CodeGen PTX ()
    -> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith :: LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAccWith LaunchConfig
config Label
name [Parameter]
param CodeGen PTX ()
kernel = do
  Kernel PTX aenv a
body  <- LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (Kernel PTX aenv a)
forall aenv a.
LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (Kernel PTX aenv a)
makeKernel LaunchConfig
config Label
name [Parameter]
param CodeGen PTX ()
kernel
  IROpenAcc PTX aenv a -> CodeGen PTX (IROpenAcc PTX aenv a)
forall (m :: * -> *) a. Monad m => a -> m a
return (IROpenAcc PTX aenv a -> CodeGen PTX (IROpenAcc PTX aenv a))
-> IROpenAcc PTX aenv a -> CodeGen PTX (IROpenAcc PTX aenv a)
forall a b. (a -> b) -> a -> b
$ [Kernel PTX aenv a] -> IROpenAcc PTX aenv a
forall arch aenv arrs.
[Kernel arch aenv arrs] -> IROpenAcc arch aenv arrs
IROpenAcc [Kernel PTX aenv a
body]

-- | Create a complete kernel function by running the code generation process
-- specified in the final parameter.
--
makeKernel
    :: LaunchConfig
    -> Label
    -> [LLVM.Parameter]
    -> CodeGen PTX ()
    -> CodeGen PTX (Kernel PTX aenv a)
makeKernel :: LaunchConfig
-> Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (Kernel PTX aenv a)
makeKernel LaunchConfig
config name :: Label
name@(Label ShortByteString
l) [Parameter]
param CodeGen PTX ()
kernel = do
  ()
_    <- CodeGen PTX ()
kernel
  [BasicBlock]
code <- CodeGen PTX [BasicBlock]
forall arch. HasCallStack => CodeGen arch [BasicBlock]
createBlocks
  ShortByteString -> [Maybe Metadata] -> CodeGen PTX ()
forall arch. ShortByteString -> [Maybe Metadata] -> CodeGen arch ()
addMetadata ShortByteString
"nvvm.annotations"
    [ Metadata -> Maybe Metadata
forall a. a -> Maybe a
Just (Metadata -> Maybe Metadata)
-> (Constant -> Metadata) -> Constant -> Maybe Metadata
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constant -> Metadata
MetadataConstantOperand (Constant -> Maybe Metadata) -> Constant -> Maybe Metadata
forall a b. (a -> b) -> a -> b
$ Type -> Name -> Constant
LLVM.GlobalReference (Type -> AddrSpace -> Type
LLVM.PointerType (Type -> [Type] -> Bool -> Type
LLVM.FunctionType Type
LLVM.VoidType [ Type
t | LLVM.Parameter Type
t Name
_ [ParameterAttribute]
_ <- [Parameter]
param ] Bool
False) (Word32 -> AddrSpace
AddrSpace Word32
0)) (ShortByteString -> Name
LLVM.Name ShortByteString
l)
    , Metadata -> Maybe Metadata
forall a. a -> Maybe a
Just (Metadata -> Maybe Metadata)
-> (ShortByteString -> Metadata)
-> ShortByteString
-> Maybe Metadata
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShortByteString -> Metadata
MetadataStringOperand   (ShortByteString -> Maybe Metadata)
-> ShortByteString -> Maybe Metadata
forall a b. (a -> b) -> a -> b
$ ShortByteString
"kernel"
    , Metadata -> Maybe Metadata
forall a. a -> Maybe a
Just (Metadata -> Maybe Metadata)
-> (Constant -> Metadata) -> Constant -> Maybe Metadata
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constant -> Metadata
MetadataConstantOperand (Constant -> Maybe Metadata) -> Constant -> Maybe Metadata
forall a b. (a -> b) -> a -> b
$ Word32 -> Integer -> Constant
LLVM.Int Word32
32 Integer
1
    ]
  Kernel PTX aenv a -> CodeGen PTX (Kernel PTX aenv a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Kernel PTX aenv a -> CodeGen PTX (Kernel PTX aenv a))
-> Kernel PTX aenv a -> CodeGen PTX (Kernel PTX aenv a)
forall a b. (a -> b) -> a -> b
$ Kernel :: forall arch aenv a.
Global -> KernelMetadata arch -> Kernel arch aenv a
Kernel
    { kernelMetadata :: KernelMetadata PTX
kernelMetadata = LaunchConfig -> KernelMetadata PTX
KM_PTX LaunchConfig
config
    , unKernel :: Global
unKernel       = Global
LLVM.functionDefaults
                     { returnType :: Type
LLVM.returnType  = Type
LLVM.VoidType
                     , name :: Name
LLVM.name        = Label -> Name
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast Label
name
                     , parameters :: ([Parameter], Bool)
LLVM.parameters  = ([Parameter]
param, Bool
False)
                     , basicBlocks :: [BasicBlock]
LLVM.basicBlocks = [BasicBlock]
code
                     }
    }