{-# LANGUAGE QuasiQuotes #-}

-- | Code generation for CUDA.
module Futhark.CodeGen.Backends.CCUDA
  ( compileProg,
    GC.CParts (..),
    GC.asLibrary,
    GC.asExecutable,
    GC.asServer,
  )
where

import Control.Monad
import Data.Maybe (catMaybes)
import Data.Text qualified as T
import Futhark.CodeGen.Backends.CCUDA.Boilerplate
import Futhark.CodeGen.Backends.COpenCL.Boilerplate (commonOptions, sizeLoggingCode)
import Futhark.CodeGen.Backends.GenericC qualified as GC
import Futhark.CodeGen.Backends.GenericC.Options
import Futhark.CodeGen.Backends.SimpleRep (primStorageType, toStorage)
import Futhark.CodeGen.ImpCode.OpenCL
import Futhark.CodeGen.ImpGen.CUDA qualified as ImpGen
import Futhark.IR.GPUMem hiding
  ( CmpSizeLe,
    GetSize,
    GetSizeMax,
  )
import Futhark.MonadFreshNames
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C
import NeatInterpolation (untrimming)

-- | Compile the program to C with calls to CUDA.
compileProg :: MonadFreshNames m => T.Text -> Prog GPUMem -> m (ImpGen.Warnings, GC.CParts)
compileProg :: forall (m :: * -> *).
MonadFreshNames m =>
Text -> Prog GPUMem -> m (Warnings, CParts)
compileProg Text
version Prog GPUMem
prog = do
  (Warnings
ws, Program Text
cuda_code Text
cuda_prelude Map Name KernelSafety
kernels [PrimType]
_ Map Name SizeClass
sizes [FailureMsg]
failures Definitions OpenCL
prog') <-
    forall (m :: * -> *).
MonadFreshNames m =>
Prog GPUMem -> m (Warnings, Program)
ImpGen.compileProg Prog GPUMem
prog
  let cost_centres :: [Name]
cost_centres =
        [ Name
copyDevToDev,
          Name
copyDevToHost,
          Name
copyHostToDev,
          Name
copyScalarToDev,
          Name
copyScalarFromDev
        ]
      extra :: CompilerM OpenCL () ()
extra =
        Text
-> Text
-> [Name]
-> Map Name KernelSafety
-> Map Name SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateBoilerplate
          Text
cuda_code
          Text
cuda_prelude
          [Name]
cost_centres
          Map Name KernelSafety
kernels
          Map Name SizeClass
sizes
          [FailureMsg]
failures
  (Warnings
ws,)
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) op.
MonadFreshNames m =>
Text
-> Text
-> Operations op ()
-> CompilerM op () ()
-> Text
-> (Space, [Space])
-> [Option]
-> Definitions op
-> m CParts
GC.compileProg
      Text
"cuda"
      Text
version
      Operations OpenCL ()
operations
      CompilerM OpenCL () ()
extra
      Text
cuda_includes
      ([Char] -> Space
Space [Char]
"device", [[Char] -> Space
Space [Char]
"device", Space
DefaultSpace])
      [Option]
cliOptions
      Definitions OpenCL
prog'
  where
    operations :: GC.Operations OpenCL ()
    operations :: Operations OpenCL ()
operations =
      forall op s. Operations op s
GC.defaultOperations
        { opsWriteScalar :: WriteScalar OpenCL ()
GC.opsWriteScalar = WriteScalar OpenCL ()
writeCUDAScalar,
          opsReadScalar :: ReadScalar OpenCL ()
GC.opsReadScalar = ReadScalar OpenCL ()
readCUDAScalar,
          opsAllocate :: Allocate OpenCL ()
GC.opsAllocate = Allocate OpenCL ()
allocateCUDABuffer,
          opsDeallocate :: Deallocate OpenCL ()
GC.opsDeallocate = Deallocate OpenCL ()
deallocateCUDABuffer,
          opsCopy :: Copy OpenCL ()
GC.opsCopy = Copy OpenCL ()
copyCUDAMemory,
          opsStaticArray :: StaticArray OpenCL ()
GC.opsStaticArray = StaticArray OpenCL ()
staticCUDAArray,
          opsMemoryType :: MemoryType OpenCL ()
GC.opsMemoryType = MemoryType OpenCL ()
cudaMemoryType,
          opsCompiler :: OpCompiler OpenCL ()
GC.opsCompiler = OpCompiler OpenCL ()
callKernel,
          opsFatMemory :: Bool
GC.opsFatMemory = Bool
True,
          opsCritical :: ([BlockItem], [BlockItem])
GC.opsCritical =
            ( [C.citems|CUDA_SUCCEED_FATAL(cuCtxPushCurrent(ctx->cuda.cu_ctx));|],
              [C.citems|CUDA_SUCCEED_FATAL(cuCtxPopCurrent(&ctx->cuda.cu_ctx));|]
            )
        }
    cuda_includes :: Text
cuda_includes =
      [untrimming|
       #include <cuda.h>
       #include <cuda_runtime.h>
       #include <nvrtc.h>
      |]

cliOptions :: [Option]
cliOptions :: [Option]
cliOptions =
  [Option]
commonOptions
    forall a. [a] -> [a] -> [a]
++ [ Option
           { optionLongName :: [Char]
optionLongName = [Char]
"dump-cuda",
             optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Dump the embedded CUDA kernels to the indicated file.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_dump_program_to(cfg, optarg);
                                     entry_point = NULL;}|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"load-cuda",
             optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Instead of using the embedded CUDA kernels, load them from the indicated file.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_program_from(cfg, optarg);|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"dump-ptx",
             optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Dump the PTX-compiled version of the embedded kernels to the indicated file.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_dump_ptx_to(cfg, optarg);
                                     entry_point = NULL;}|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"load-ptx",
             optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Load PTX code from the indicated file.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_ptx_from(cfg, optarg);|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"nvrtc-option",
             optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"OPT",
             optionDescription :: [Char]
optionDescription = [Char]
"Add an additional build option to the string passed to NVRTC.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_add_nvrtc_option(cfg, optarg);|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"profile",
             optionShortName :: Maybe Char
optionShortName = forall a. a -> Maybe a
Just Char
'P',
             optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
             optionDescription :: [Char]
optionDescription = [Char]
"Gather profiling data while executing and print out a summary at the end.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_profiling(cfg, 1);|]
           }
       ]

-- We detect the special case of writing a constant and turn it into a
-- non-blocking write.  This may be slightly faster, as it prevents
-- unnecessary synchronisation of the context, and writing a constant
-- is fairly common.  This is only possible because we can give the
-- constant infinite lifetime (with 'static'), which is not the case
-- for ordinary variables.
writeCUDAScalar :: GC.WriteScalar OpenCL ()
writeCUDAScalar :: WriteScalar OpenCL ()
writeCUDAScalar Exp
mem Exp
idx Type
t [Char]
"device" Volatility
_ val :: Exp
val@C.Const {} = do
  VName
val' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_static"
  let ([BlockItem]
bef, [BlockItem]
aft) = Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
copyScalarToDev
  forall op s. BlockItem -> CompilerM op s ()
GC.item
    [C.citem|{static $ty:t $id:val' = $exp:val;
              $items:bef
              CUDA_SUCCEED_OR_RETURN(
                cuMemcpyHtoDAsync($exp:mem + $exp:idx * sizeof($ty:t),
                                  &$id:val',
                                  sizeof($ty:t),
                                  0));
              $items:aft
             }|]
writeCUDAScalar Exp
mem Exp
idx Type
t [Char]
"device" Volatility
_ Exp
val = do
  VName
val' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_tmp"
  let ([BlockItem]
bef, [BlockItem]
aft) = Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
copyScalarToDev
  forall op s. BlockItem -> CompilerM op s ()
GC.item
    [C.citem|{$ty:t $id:val' = $exp:val;
                  $items:bef
                  CUDA_SUCCEED_OR_RETURN(
                    cuMemcpyHtoD($exp:mem + $exp:idx * sizeof($ty:t),
                                 &$id:val',
                                 sizeof($ty:t)));
                  $items:aft
                 }|]
writeCUDAScalar Exp
_ Exp
_ Type
_ [Char]
space Volatility
_ Exp
_ =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot write to '" forall a. [a] -> [a] -> [a]
++ [Char]
space forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

readCUDAScalar :: GC.ReadScalar OpenCL ()
readCUDAScalar :: ReadScalar OpenCL ()
readCUDAScalar Exp
mem Exp
idx Type
t [Char]
"device" Volatility
_ = do
  VName
val <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"read_res"
  let ([BlockItem]
bef, [BlockItem]
aft) = Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
copyScalarFromDev
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_
    forall op s. BlockItem -> CompilerM op s ()
GC.item
    [C.citems|
       $ty:t $id:val;
       {
       $items:bef
       CUDA_SUCCEED_OR_RETURN(
          cuMemcpyDtoH(&$id:val,
                       $exp:mem + $exp:idx * sizeof($ty:t),
                       sizeof($ty:t)));
       $items:aft
       }
       |]
  forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|if (ctx->failure_is_an_option && futhark_context_sync(ctx) != 0)
            { return 1; }|]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:val|]
readCUDAScalar Exp
_ Exp
_ Type
_ [Char]
space Volatility
_ =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot write to '" forall a. [a] -> [a] -> [a]
++ [Char]
space forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

allocateCUDABuffer :: GC.Allocate OpenCL ()
allocateCUDABuffer :: Allocate OpenCL ()
allocateCUDABuffer Exp
mem Exp
size Exp
tag [Char]
"device" =
  forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|ctx->error =
     CUDA_SUCCEED_NONFATAL(cuda_alloc(&ctx->cuda, ctx->log,
                                      (size_t)$exp:size, $exp:tag, &$exp:mem));|]
allocateCUDABuffer Exp
_ Exp
_ Exp
_ [Char]
space =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot allocate in '" forall a. [a] -> [a] -> [a]
++ [Char]
space forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

deallocateCUDABuffer :: GC.Deallocate OpenCL ()
deallocateCUDABuffer :: Deallocate OpenCL ()
deallocateCUDABuffer Exp
mem Exp
tag [Char]
"device" =
  forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|CUDA_SUCCEED_OR_RETURN(cuda_free(&ctx->cuda, $exp:mem, $exp:tag));|]
deallocateCUDABuffer Exp
_ Exp
_ [Char]
space =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot deallocate in '" forall a. [a] -> [a] -> [a]
++ [Char]
space forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

copyCUDAMemory :: GC.Copy OpenCL ()
copyCUDAMemory :: Copy OpenCL ()
copyCUDAMemory CopyBarrier
b Exp
dstmem Exp
dstidx Space
dstSpace Exp
srcmem Exp
srcidx Space
srcSpace Exp
nbytes = do
  let (Exp
copy, Name
prof) = CopyBarrier -> Space -> Space -> (Exp, Name)
memcpyFun CopyBarrier
b Space
dstSpace Space
srcSpace
      ([BlockItem]
bef, [BlockItem]
aft) = Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
prof
  forall op s. BlockItem -> CompilerM op s ()
GC.item
    [C.citem|{$items:bef CUDA_SUCCEED_OR_RETURN($exp:copy); $items:aft}|]
  where
    dst :: Exp
dst = [C.cexp|$exp:dstmem + $exp:dstidx|]
    src :: Exp
src = [C.cexp|$exp:srcmem + $exp:srcidx|]
    memcpyFun :: CopyBarrier -> Space -> Space -> (Exp, Name)
memcpyFun CopyBarrier
GC.CopyBarrier Space
DefaultSpace (Space [Char]
"device") =
      ([C.cexp|cuMemcpyDtoH($exp:dst, $exp:src, $exp:nbytes)|], Name
copyDevToHost)
    memcpyFun CopyBarrier
GC.CopyBarrier (Space [Char]
"device") Space
DefaultSpace =
      ([C.cexp|cuMemcpyHtoD($exp:dst, $exp:src, $exp:nbytes)|], Name
copyHostToDev)
    memcpyFun CopyBarrier
_ (Space [Char]
"device") (Space [Char]
"device") =
      ([C.cexp|cuMemcpy($exp:dst, $exp:src, $exp:nbytes)|], Name
copyDevToDev)
    memcpyFun CopyBarrier
GC.CopyNoBarrier Space
DefaultSpace (Space [Char]
"device") =
      ([C.cexp|cuMemcpyDtoHAsync($exp:dst, $exp:src, $exp:nbytes, 0)|], Name
copyDevToHost)
    memcpyFun CopyBarrier
GC.CopyNoBarrier (Space [Char]
"device") Space
DefaultSpace =
      ([C.cexp|cuMemcpyHtoDAsync($exp:dst, $exp:src, $exp:nbytes, 0)|], Name
copyHostToDev)
    memcpyFun CopyBarrier
_ Space
_ Space
_ =
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
        [Char]
"Cannot copy to '"
          forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Space
dstSpace
          forall a. [a] -> [a] -> [a]
++ [Char]
"' from '"
          forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Space
srcSpace
          forall a. [a] -> [a] -> [a]
++ [Char]
"'."

staticCUDAArray :: GC.StaticArray OpenCL ()
staticCUDAArray :: StaticArray OpenCL ()
staticCUDAArray VName
name [Char]
"device" PrimType
t ArrayContents
vs = do
  let ct :: Type
ct = PrimType -> Type
GC.primTypeToCType PrimType
t
  VName
name_realtype <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
name forall a. [a] -> [a] -> [a]
++ [Char]
"_realtype"
  Int
num_elems <- case ArrayContents
vs of
    ArrayValues [PrimValue]
vs' -> do
      let vs'' :: [Initializer]
vs'' = [[C.cinit|$exp:v|] | PrimValue
v <- [PrimValue]
vs']
      forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:(length vs'')] = {$inits:vs''};|]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Initializer]
vs''
    ArrayZeros Int
n -> do
      forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
n
  -- Fake a memory block.
  forall op s. Id -> Type -> Maybe Exp -> Stm -> CompilerM op s ()
GC.contextFieldDyn
    (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
name forall a. Monoid a => a
mempty)
    [C.cty|struct memblock_device|]
    forall a. Maybe a
Nothing
    [C.cstm|cuMemFree(ctx->$id:name.mem);|]
  -- During startup, copy the data to where we need it.
  forall op s. Stm -> CompilerM op s ()
GC.atInit
    [C.cstm|{
    ctx->$id:name.references = NULL;
    ctx->$id:name.size = 0;
    CUDA_SUCCEED_FATAL(cuMemAlloc(&ctx->$id:name.mem,
                            ($int:num_elems > 0 ? $int:num_elems : 1)*sizeof($ty:ct)));
    if ($int:num_elems > 0) {
      CUDA_SUCCEED_FATAL(cuMemcpyHtoD(ctx->$id:name.mem, $id:name_realtype,
                                $int:num_elems*sizeof($ty:ct)));
    }
  }|]
  forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|struct memblock_device $id:name = ctx->$id:name;|]
staticCUDAArray VName
_ [Char]
space PrimType
_ ArrayContents
_ =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
    [Char]
"CUDA backend cannot create static array in '"
      forall a. [a] -> [a] -> [a]
++ [Char]
space
      forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space"

cudaMemoryType :: GC.MemoryType OpenCL ()
cudaMemoryType :: MemoryType OpenCL ()
cudaMemoryType [Char]
"device" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|typename CUdeviceptr|]
cudaMemoryType [Char]
space =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"CUDA backend does not support '" forall a. [a] -> [a] -> [a]
++ [Char]
space forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

callKernel :: GC.OpCompiler OpenCL ()
callKernel :: OpCompiler OpenCL ()
callKernel (GetSize VName
v Name
key) =
  forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = *ctx->tuning_params.$id:key;|]
callKernel (CmpSizeLe VName
v Name
key Exp
x) = do
  Exp
x' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
  forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = *ctx->tuning_params.$id:key <= $exp:x';|]
  forall op. VName -> Name -> Exp -> CompilerM op () ()
sizeLoggingCode VName
v Name
key Exp
x'
callKernel (GetSizeMax VName
v SizeClass
size_class) =
  let field :: [Char]
field = [Char]
"max_" forall a. [a] -> [a] -> [a]
++ SizeClass -> [Char]
cudaSizeClass SizeClass
size_class
   in forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->cuda.$id:field;|]
  where
    cudaSizeClass :: SizeClass -> [Char]
cudaSizeClass SizeThreshold {} = [Char]
"threshold"
    cudaSizeClass SizeClass
SizeGroup = [Char]
"block_size"
    cudaSizeClass SizeClass
SizeNumGroups = [Char]
"grid_size"
    cudaSizeClass SizeClass
SizeTile = [Char]
"tile_size"
    cudaSizeClass SizeClass
SizeRegTile = [Char]
"reg_tile_size"
    cudaSizeClass SizeClass
SizeLocalMemory = [Char]
"shared_memory"
    cudaSizeClass (SizeBespoke Name
x Int64
_) = forall a. Pretty a => a -> [Char]
prettyString Name
x
callKernel (LaunchKernel KernelSafety
safety Name
kernel_name [KernelArg]
args [Exp]
num_blocks [Exp]
block_size) = do
  VName
args_arr <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"kernel_args"
  VName
time_start <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"time_start"
  VName
time_end <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"time_end"
  ([VName]
args', [Maybe (VName, VName)]
shared_vars) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {op} {s}.
KernelArg -> CompilerM op s (VName, Maybe (VName, VName))
mkArgs [KernelArg]
args
  let ([VName]
shared_sizes, [VName]
shared_offsets) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. [Maybe a] -> [a]
catMaybes [Maybe (VName, VName)]
shared_vars
      shared_offsets_sc :: [Exp]
shared_offsets_sc = [VName] -> [Exp]
mkOffsets [VName]
shared_sizes
      shared_args :: [(VName, Exp)]
shared_args = forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
shared_offsets [Exp]
shared_offsets_sc
      shared_tot :: Exp
shared_tot = forall a. [a] -> a
last [Exp]
shared_offsets_sc
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(VName, Exp)]
shared_args forall a b. (a -> b) -> a -> b
$ \(VName
arg, Exp
offset) ->
    forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|unsigned int $id:arg = $exp:offset;|]

  (Exp
grid_x, Exp
grid_y, Exp
grid_z) <- [Exp] -> (Exp, Exp, Exp)
mkDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
num_blocks
  (Exp
block_x, Exp
block_y, Exp
block_z) <- [Exp] -> (Exp, Exp, Exp)
mkDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
block_size
  let perm_args :: [Initializer]
perm_args
        | forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
num_blocks forall a. Eq a => a -> a -> Bool
== Int
3 = [[C.cinit|&perm[0]|], [C.cinit|&perm[1]|], [C.cinit|&perm[2]|]]
        | Bool
otherwise = []
      failure_args :: [Initializer]
failure_args =
        forall a. Int -> [a] -> [a]
take
          (KernelSafety -> Int
numFailureParams KernelSafety
safety)
          [ [C.cinit|&ctx->global_failure|],
            [C.cinit|&ctx->failure_is_an_option|],
            [C.cinit|&ctx->global_failure_args|]
          ]
      args'' :: [Initializer]
args'' = [Initializer]
perm_args forall a. [a] -> [a] -> [a]
++ [Initializer]
failure_args forall a. [a] -> [a] -> [a]
++ [[C.cinit|&$id:a|] | VName
a <- [VName]
args']
      sizes_nonzero :: Exp
sizes_nonzero =
        [Exp] -> Exp
expsNotZero
          [ Exp
grid_x,
            Exp
grid_y,
            Exp
grid_z,
            Exp
block_x,
            Exp
block_y,
            Exp
block_z
          ]
      ([BlockItem]
bef, [BlockItem]
aft) = Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
kernel_name

  forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|
    if ($exp:sizes_nonzero) {
      int perm[3] = { 0, 1, 2 };

      if ($exp:grid_y >= (1<<16)) {
        perm[1] = perm[0];
        perm[0] = 1;
      }

      if ($exp:grid_z >= (1<<16)) {
        perm[2] = perm[0];
        perm[0] = 2;
      }

      size_t grid[3];
      grid[perm[0]] = $exp:grid_x;
      grid[perm[1]] = $exp:grid_y;
      grid[perm[2]] = $exp:grid_z;

      void *$id:args_arr[] = { $inits:args'' };
      typename int64_t $id:time_start = 0, $id:time_end = 0;
      if (ctx->debugging) {
        fprintf(ctx->log, "Launching %s with grid size [%ld, %ld, %ld] and block size [%ld, %ld, %ld]; shared memory: %d bytes.\n",
                $string:(prettyString kernel_name),
                (long int)$exp:grid_x, (long int)$exp:grid_y, (long int)$exp:grid_z,
                (long int)$exp:block_x, (long int)$exp:block_y, (long int)$exp:block_z,
                (int)$exp:shared_tot);
        $id:time_start = get_wall_time();
      }
      $items:bef
      CUDA_SUCCEED_OR_RETURN(
        cuLaunchKernel(ctx->$id:kernel_name,
                       grid[0], grid[1], grid[2],
                       $exp:block_x, $exp:block_y, $exp:block_z,
                       $exp:shared_tot, NULL,
                       $id:args_arr, NULL));
      $items:aft
      if (ctx->debugging) {
        CUDA_SUCCEED_FATAL(cuCtxSynchronize());
        $id:time_end = get_wall_time();
        fprintf(ctx->log, "Kernel %s runtime: %ldus\n",
                $string:(prettyString kernel_name), $id:time_end - $id:time_start);
      }
    }|]

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (KernelSafety
safety forall a. Ord a => a -> a -> Bool
>= KernelSafety
SafetyFull) forall a b. (a -> b) -> a -> b
$
    forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|ctx->failure_is_an_option = 1;|]
  where
    mkDims :: [Exp] -> (Exp, Exp, Exp)
mkDims [] = ([C.cexp|0|], [C.cexp|0|], [C.cexp|0|])
    mkDims [Exp
x] = (Exp
x, [C.cexp|1|], [C.cexp|1|])
    mkDims [Exp
x, Exp
y] = (Exp
x, Exp
y, [C.cexp|1|])
    mkDims (Exp
x : Exp
y : Exp
z : [Exp]
_) = (Exp
x, Exp
y, Exp
z)
    addExp :: a -> a -> Exp
addExp a
x a
y = [C.cexp|$exp:x + $exp:y|]
    alignExp :: a -> Exp
alignExp a
e = [C.cexp|$exp:e + ((8 - ($exp:e % 8)) % 8)|]
    mkOffsets :: [VName] -> [Exp]
mkOffsets = forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\Exp
a VName
b -> Exp
a forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
`addExp` forall {a}. ToExp a => a -> Exp
alignExp VName
b) [C.cexp|0|]
    expNotZero :: a -> Exp
expNotZero a
e = [C.cexp|$exp:e != 0|]
    expAnd :: a -> a -> Exp
expAnd a
a a
b = [C.cexp|$exp:a && $exp:b|]
    expsNotZero :: [Exp] -> Exp
expsNotZero = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
expAnd [C.cexp|1|] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall {a}. ToExp a => a -> Exp
expNotZero
    mkArgs :: KernelArg -> CompilerM op s (VName, Maybe (VName, VName))
mkArgs (ValueKArg Exp
e t :: PrimType
t@(FloatType FloatType
Float16)) = do
      VName
arg <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"kernel_arg"
      Exp
e' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
e
      forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|$ty:(primStorageType t) $id:arg = $exp:(toStorage t e');|]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
arg, forall a. Maybe a
Nothing)
    mkArgs (ValueKArg Exp
e PrimType
t) =
      (,forall a. Maybe a
Nothing) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. [Char] -> PrimType -> Exp -> CompilerM op s VName
GC.compileExpToName [Char]
"kernel_arg" PrimType
t Exp
e
    mkArgs (MemKArg VName
v) = do
      Exp
v' <- forall op s. VName -> CompilerM op s Exp
GC.rawMem VName
v
      VName
arg <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"kernel_arg"
      forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|typename CUdeviceptr $id:arg = $exp:v';|]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
arg, forall a. Maybe a
Nothing)
    mkArgs (SharedMemoryKArg (Count Exp
c)) = do
      Exp
num_bytes <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
c
      VName
size <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"shared_size"
      VName
offset <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"shared_offset"
      forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|unsigned int $id:size = $exp:num_bytes;|]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
offset, forall a. a -> Maybe a
Just (VName
size, VName
offset))