{-# LANGUAGE QuasiQuotes #-}

-- | Code generation for C with OpenCL.
module Futhark.CodeGen.Backends.COpenCL
  ( compileProg,
    GC.CParts (..),
    GC.asLibrary,
    GC.asExecutable,
    GC.asServer,
  )
where

import Control.Monad hiding (mapM)
import Data.List (intercalate)
import Data.Text qualified as T
import Futhark.CodeGen.Backends.COpenCL.Boilerplate
import Futhark.CodeGen.Backends.GenericC qualified as GC
import Futhark.CodeGen.Backends.GenericC.Options
import Futhark.CodeGen.Backends.SimpleRep (primStorageType, toStorage)
import Futhark.CodeGen.ImpCode.OpenCL
import Futhark.CodeGen.ImpGen.OpenCL qualified as ImpGen
import Futhark.IR.GPUMem hiding
  ( CmpSizeLe,
    GetSize,
    GetSizeMax,
  )
import Futhark.MonadFreshNames
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C
import NeatInterpolation (untrimming)

-- | Compile the program to C with calls to OpenCL.
compileProg :: MonadFreshNames m => T.Text -> Prog GPUMem -> m (ImpGen.Warnings, GC.CParts)
compileProg :: forall (m :: * -> *).
MonadFreshNames m =>
Text -> Prog GPUMem -> m (Warnings, CParts)
compileProg Text
version Prog GPUMem
prog = do
  ( Warnings
ws,
    Program
      Text
opencl_code
      Text
opencl_prelude
      Map Name KernelSafety
kernels
      [PrimType]
types
      Map Name SizeClass
sizes
      [FailureMsg]
failures
      Definitions OpenCL
prog'
    ) <-
    forall (m :: * -> *).
MonadFreshNames m =>
Prog GPUMem -> m (Warnings, Program)
ImpGen.compileProg Prog GPUMem
prog
  let cost_centres :: [Name]
cost_centres =
        [ Name
copyDevToDev,
          Name
copyDevToHost,
          Name
copyHostToDev,
          Name
copyScalarToDev,
          Name
copyScalarFromDev
        ]
  (Warnings
ws,)
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) op.
MonadFreshNames m =>
Text
-> Text
-> Operations op ()
-> CompilerM op () ()
-> Text
-> (Space, [Space])
-> [Option]
-> Definitions op
-> m CParts
GC.compileProg
      Text
"opencl"
      Text
version
      Operations OpenCL ()
operations
      ( Text
-> Text
-> [Name]
-> Map Name KernelSafety
-> [PrimType]
-> Map Name SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateBoilerplate
          Text
opencl_code
          Text
opencl_prelude
          [Name]
cost_centres
          Map Name KernelSafety
kernels
          [PrimType]
types
          Map Name SizeClass
sizes
          [FailureMsg]
failures
      )
      Text
include_opencl_h
      ([Char] -> Space
Space [Char]
"device", [[Char] -> Space
Space [Char]
"device", Space
DefaultSpace])
      [Option]
cliOptions
      Definitions OpenCL
prog'
  where
    operations :: GC.Operations OpenCL ()
    operations :: Operations OpenCL ()
operations =
      forall op s. Operations op s
GC.defaultOperations
        { opsCompiler :: OpCompiler OpenCL ()
GC.opsCompiler = OpCompiler OpenCL ()
callKernel,
          opsWriteScalar :: WriteScalar OpenCL ()
GC.opsWriteScalar = WriteScalar OpenCL ()
writeOpenCLScalar,
          opsReadScalar :: ReadScalar OpenCL ()
GC.opsReadScalar = ReadScalar OpenCL ()
readOpenCLScalar,
          opsAllocate :: Allocate OpenCL ()
GC.opsAllocate = Allocate OpenCL ()
allocateOpenCLBuffer,
          opsDeallocate :: Deallocate OpenCL ()
GC.opsDeallocate = Deallocate OpenCL ()
deallocateOpenCLBuffer,
          opsCopy :: Copy OpenCL ()
GC.opsCopy = Copy OpenCL ()
copyOpenCLMemory,
          opsStaticArray :: StaticArray OpenCL ()
GC.opsStaticArray = StaticArray OpenCL ()
staticOpenCLArray,
          opsMemoryType :: MemoryType OpenCL ()
GC.opsMemoryType = MemoryType OpenCL ()
openclMemoryType,
          opsFatMemory :: Bool
GC.opsFatMemory = Bool
True
        }
    include_opencl_h :: Text
include_opencl_h =
      [untrimming|
       #define CL_TARGET_OPENCL_VERSION 120
       #define CL_USE_DEPRECATED_OPENCL_1_2_APIS
       #ifdef __APPLE__
       #define CL_SILENCE_DEPRECATION
       #include <OpenCL/cl.h>
       #else
       #include <CL/cl.h>
       #endif
       |]

cliOptions :: [Option]
cliOptions :: [Option]
cliOptions =
  [Option]
commonOptions
    forall a. [a] -> [a] -> [a]
++ [ Option
           { optionLongName :: [Char]
optionLongName = [Char]
"platform",
             optionShortName :: Maybe Char
optionShortName = forall a. a -> Maybe a
Just Char
'p',
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"NAME",
             optionDescription :: [Char]
optionDescription = [Char]
"Use the first OpenCL platform whose name contains the given string.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_platform(cfg, optarg);|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"dump-opencl",
             optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Dump the embedded OpenCL program to the indicated file.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_dump_program_to(cfg, optarg);
                                     entry_point = NULL;}|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"load-opencl",
             optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Instead of using the embedded OpenCL program, load it from the indicated file.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_program_from(cfg, optarg);|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"dump-opencl-binary",
             optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Dump the compiled version of the embedded OpenCL program to the indicated file.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_dump_binary_to(cfg, optarg);
                                     entry_point = NULL;}|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"load-opencl-binary",
             optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Load an OpenCL binary from the indicated file.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_binary_from(cfg, optarg);|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"build-option",
             optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"OPT",
             optionDescription :: [Char]
optionDescription = [Char]
"Add an additional build option to the string passed to clBuildProgram().",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_add_build_option(cfg, optarg);|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"profile",
             optionShortName :: Maybe Char
optionShortName = forall a. a -> Maybe a
Just Char
'P',
             optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
             optionDescription :: [Char]
optionDescription = [Char]
"Gather profiling data while executing and print out a summary at the end.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_profiling(cfg, 1);|]
           },
         Option
           { optionLongName :: [Char]
optionLongName = [Char]
"list-devices",
             optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
             optionDescription :: [Char]
optionDescription = [Char]
"List all OpenCL devices and platforms available on the system.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_list_devices(cfg);
                        entry_point = NULL;}|]
           }
       ]

-- We detect the special case of writing a constant and turn it into a
-- non-blocking write.  This may be slightly faster, as it prevents
-- unnecessary synchronisation of the OpenCL command queue, and
-- writing a constant is fairly common.  This is only possible because
-- we can give the constant infinite lifetime (with 'static'), which
-- is not the case for ordinary variables.
writeOpenCLScalar :: GC.WriteScalar OpenCL ()
writeOpenCLScalar :: WriteScalar OpenCL ()
writeOpenCLScalar Exp
mem Exp
i Type
t [Char]
"device" Volatility
_ Exp
val = do
  VName
val' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_tmp"
  let (BlockItem
decl, Exp
blocking) =
        case Exp
val of
          C.Const {} -> ([C.citem|static $ty:t $id:val' = $exp:val;|], [C.cexp|CL_FALSE|])
          Exp
_ -> ([C.citem|$ty:t $id:val' = $exp:val;|], [C.cexp|CL_TRUE|])
  forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|{$item:decl
                  OPENCL_SUCCEED_OR_RETURN(
                    clEnqueueWriteBuffer(ctx->opencl.queue, $exp:mem, $exp:blocking,
                                         $exp:i * sizeof($ty:t), sizeof($ty:t),
                                         &$id:val',
                                         0, NULL, $exp:(profilingEvent copyScalarToDev)));
                }|]
writeOpenCLScalar Exp
_ Exp
_ Type
_ [Char]
space Volatility
_ Exp
_ =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot write to '" forall a. [a] -> [a] -> [a]
++ [Char]
space forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

-- It is often faster to do a blocking clEnqueueReadBuffer() than to
-- do an async clEnqueueReadBuffer() followed by a clFinish(), even
-- with an in-order command queue.  This is safe if and only if there
-- are no possible outstanding failures.
readOpenCLScalar :: GC.ReadScalar OpenCL ()
readOpenCLScalar :: ReadScalar OpenCL ()
readOpenCLScalar Exp
mem Exp
i Type
t [Char]
"device" Volatility
_ = do
  VName
val <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"read_res"
  forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|$ty:t $id:val;|]
  forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|OPENCL_SUCCEED_OR_RETURN(
                   clEnqueueReadBuffer(ctx->opencl.queue, $exp:mem,
                                       ctx->failure_is_an_option ? CL_FALSE : CL_TRUE,
                                       $exp:i * sizeof($ty:t), sizeof($ty:t),
                                       &$id:val,
                                       0, NULL, $exp:(profilingEvent copyScalarFromDev)));
              |]
  forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|if (ctx->failure_is_an_option && futhark_context_sync(ctx) != 0)
            { return 1; }|]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:val|]
readOpenCLScalar Exp
_ Exp
_ Type
_ [Char]
space Volatility
_ =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot read from '" forall a. [a] -> [a] -> [a]
++ [Char]
space forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

allocateOpenCLBuffer :: GC.Allocate OpenCL ()
allocateOpenCLBuffer :: Allocate OpenCL ()
allocateOpenCLBuffer Exp
mem Exp
size Exp
tag [Char]
"device" =
  forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|ctx->error =
     OPENCL_SUCCEED_NONFATAL(opencl_alloc(&ctx->opencl, ctx->log,
                                          (size_t)$exp:size, $exp:tag, &$exp:mem));|]
allocateOpenCLBuffer Exp
_ Exp
_ Exp
_ [Char]
space =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot allocate in '" forall a. [a] -> [a] -> [a]
++ [Char]
space forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

deallocateOpenCLBuffer :: GC.Deallocate OpenCL ()
deallocateOpenCLBuffer :: Deallocate OpenCL ()
deallocateOpenCLBuffer Exp
mem Exp
tag [Char]
"device" =
  forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|OPENCL_SUCCEED_OR_RETURN(opencl_free(&ctx->opencl, $exp:mem, $exp:tag));|]
deallocateOpenCLBuffer Exp
_ Exp
_ [Char]
space =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot deallocate in '" forall a. [a] -> [a] -> [a]
++ [Char]
space forall a. [a] -> [a] -> [a]
++ [Char]
"' space"

syncArg :: GC.CopyBarrier -> C.Exp
syncArg :: CopyBarrier -> Exp
syncArg CopyBarrier
GC.CopyBarrier = [C.cexp|CL_TRUE|]
syncArg CopyBarrier
GC.CopyNoBarrier = [C.cexp|CL_FALSE|]

copyOpenCLMemory :: GC.Copy OpenCL ()
-- The read/write/copy-buffer functions fail if the given offset is
-- out of bounds, even if asked to read zero bytes.  We protect with a
-- branch to avoid this.
copyOpenCLMemory :: Copy OpenCL ()
copyOpenCLMemory CopyBarrier
b Exp
destmem Exp
destidx Space
DefaultSpace Exp
srcmem Exp
srcidx (Space [Char]
"device") Exp
nbytes =
  forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|
    if ($exp:nbytes > 0) {
      typename cl_bool sync_call = $exp:(syncArg b);
      OPENCL_SUCCEED_OR_RETURN(
        clEnqueueReadBuffer(ctx->opencl.queue, $exp:srcmem,
                            ctx->failure_is_an_option ? CL_FALSE : sync_call,
                            (size_t)$exp:srcidx, (size_t)$exp:nbytes,
                            $exp:destmem + $exp:destidx,
                            0, NULL, $exp:(profilingEvent copyHostToDev)));
      if (sync_call &&
          ctx->failure_is_an_option &&
          futhark_context_sync(ctx) != 0) { return 1; }
   }
  |]
copyOpenCLMemory CopyBarrier
b Exp
destmem Exp
destidx (Space [Char]
"device") Exp
srcmem Exp
srcidx Space
DefaultSpace Exp
nbytes =
  forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|
    if ($exp:nbytes > 0) {
      OPENCL_SUCCEED_OR_RETURN(
        clEnqueueWriteBuffer(ctx->opencl.queue, $exp:destmem, $exp:(syncArg b),
                             (size_t)$exp:destidx, (size_t)$exp:nbytes,
                             $exp:srcmem + $exp:srcidx,
                             0, NULL, $exp:(profilingEvent copyDevToHost)));
    }
  |]
copyOpenCLMemory CopyBarrier
_ Exp
destmem Exp
destidx (Space [Char]
"device") Exp
srcmem Exp
srcidx (Space [Char]
"device") Exp
nbytes =
  -- Be aware that OpenCL swaps the usual order of operands for
  -- memcpy()-like functions.  The order below is not a typo.
  forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|{
    if ($exp:nbytes > 0) {
      OPENCL_SUCCEED_OR_RETURN(
        clEnqueueCopyBuffer(ctx->opencl.queue,
                            $exp:srcmem, $exp:destmem,
                            (size_t)$exp:srcidx, (size_t)$exp:destidx,
                            (size_t)$exp:nbytes,
                            0, NULL, $exp:(profilingEvent copyDevToDev)));
      if (ctx->debugging) {
        OPENCL_SUCCEED_FATAL(clFinish(ctx->opencl.queue));
      }
    }
  }|]
copyOpenCLMemory CopyBarrier
_ Exp
destmem Exp
destidx Space
DefaultSpace Exp
srcmem Exp
srcidx Space
DefaultSpace Exp
nbytes =
  forall op s. Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
GC.copyMemoryDefaultSpace Exp
destmem Exp
destidx Exp
srcmem Exp
srcidx Exp
nbytes
copyOpenCLMemory CopyBarrier
_ Exp
_ Exp
_ Space
destspace Exp
_ Exp
_ Space
srcspace Exp
_ =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot copy to " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Space
destspace forall a. [a] -> [a] -> [a]
++ [Char]
" from " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Space
srcspace

openclMemoryType :: GC.MemoryType OpenCL ()
openclMemoryType :: MemoryType OpenCL ()
openclMemoryType [Char]
"device" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|typename cl_mem|]
openclMemoryType [Char]
space =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"OpenCL backend does not support '" forall a. [a] -> [a] -> [a]
++ [Char]
space forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

staticOpenCLArray :: GC.StaticArray OpenCL ()
staticOpenCLArray :: StaticArray OpenCL ()
staticOpenCLArray VName
name [Char]
"device" PrimType
t ArrayContents
vs = do
  let ct :: Type
ct = PrimType -> Type
GC.primTypeToCType PrimType
t
  VName
name_realtype <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
name forall a. [a] -> [a] -> [a]
++ [Char]
"_realtype"
  Int
num_elems <- case ArrayContents
vs of
    ArrayValues [PrimValue]
vs' -> do
      let vs'' :: [Initializer]
vs'' = [[C.cinit|$exp:v|] | PrimValue
v <- [PrimValue]
vs']
      forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:(length vs'')] = {$inits:vs''};|]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Initializer]
vs''
    ArrayZeros Int
n -> do
      forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
n
  -- Fake a memory block.
  forall op s. Id -> Type -> Maybe Exp -> Stm -> CompilerM op s ()
GC.contextFieldDyn
    (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
name forall a. Monoid a => a
mempty)
    [C.cty|struct memblock_device|]
    forall a. Maybe a
Nothing
    [C.cstm|OPENCL_SUCCEED_FATAL(clReleaseMemObject(ctx->$id:name.mem));|]
  -- During startup, copy the data to where we need it.
  forall op s. Stm -> CompilerM op s ()
GC.atInit
    [C.cstm|{
    typename cl_int success;
    ctx->$id:name.references = NULL;
    ctx->$id:name.size = 0;
    ctx->$id:name.mem =
      clCreateBuffer(ctx->opencl.ctx, CL_MEM_READ_WRITE,
                     ($int:num_elems > 0 ? $int:num_elems : 1)*sizeof($ty:ct), NULL,
                     &success);
    OPENCL_SUCCEED_OR_RETURN(success);
    if ($int:num_elems > 0) {
      OPENCL_SUCCEED_OR_RETURN(
        clEnqueueWriteBuffer(ctx->opencl.queue, ctx->$id:name.mem, CL_TRUE,
                             0, $int:num_elems*sizeof($ty:ct),
                             $id:name_realtype,
                             0, NULL, NULL));
    }
  }|]
  forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|struct memblock_device $id:name = ctx->$id:name;|]
staticOpenCLArray VName
_ [Char]
space PrimType
_ ArrayContents
_ =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"OpenCL backend cannot create static array in memory space '" forall a. [a] -> [a] -> [a]
++ [Char]
space forall a. [a] -> [a] -> [a]
++ [Char]
"'"

callKernel :: GC.OpCompiler OpenCL ()
callKernel :: OpCompiler OpenCL ()
callKernel (GetSize VName
v Name
key) =
  forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = *ctx->tuning_params.$id:key;|]
callKernel (CmpSizeLe VName
v Name
key Exp
x) = do
  Exp
x' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
  forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = *ctx->tuning_params.$id:key <= $exp:x';|]
  forall op. VName -> Name -> Exp -> CompilerM op () ()
sizeLoggingCode VName
v Name
key Exp
x'
callKernel (GetSizeMax VName
v SizeClass
size_class) =
  let field :: [Char]
field = [Char]
"max_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString SizeClass
size_class
   in forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->opencl.$id:field;|]
callKernel (LaunchKernel KernelSafety
safety Name
name [KernelArg]
args [Exp]
num_workgroups [Exp]
workgroup_size) = do
  -- The other failure args are set automatically when the kernel is
  -- first created.
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (KernelSafety
safety forall a. Eq a => a -> a -> Bool
== KernelSafety
SafetyFull) forall a b. (a -> b) -> a -> b
$
    forall op s. Stm -> CompilerM op s ()
GC.stm
      [C.cstm|
      OPENCL_SUCCEED_OR_RETURN(clSetKernelArg(ctx->$id:name, 1,
                                              sizeof(ctx->failure_is_an_option),
                                              &ctx->failure_is_an_option));
    |]

  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {a} {op} {s}.
(Show a, Integral a) =>
a -> KernelArg -> CompilerM op s ()
setKernelArg [KernelSafety -> Int
numFailureParams KernelSafety
safety ..] [KernelArg]
args
  [Exp]
num_workgroups' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
num_workgroups
  [Exp]
workgroup_size' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
workgroup_size
  Exp
local_bytes <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {op} {s}. Exp -> KernelArg -> CompilerM op s Exp
localBytes [C.cexp|0|] [KernelArg]
args

  forall a op s.
ToExp a =>
Name -> [a] -> [a] -> a -> CompilerM op s ()
launchKernel Name
name [Exp]
num_workgroups' [Exp]
workgroup_size' Exp
local_bytes

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (KernelSafety
safety forall a. Ord a => a -> a -> Bool
>= KernelSafety
SafetyFull) forall a b. (a -> b) -> a -> b
$
    forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|ctx->failure_is_an_option = 1;|]
  where
    setKernelArg :: a -> KernelArg -> CompilerM op s ()
setKernelArg a
i (ValueKArg Exp
e PrimType
pt) = do
      VName
v <- case PrimType
pt of
        -- We always transfer f16 values to the kernel as 16 bits, but
        -- the actual host type may be typedef'd to a 32-bit float.
        -- This requires some care.
        FloatType FloatType
Float16 -> do
          VName
v <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"kernel_arg"
          Exp
e' <- PrimType -> Exp -> Exp
toStorage PrimType
pt forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
e
          forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|$ty:(primStorageType pt) $id:v = $e';|]
          forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
        PrimType
_ -> forall op s. [Char] -> PrimType -> Exp -> CompilerM op s VName
GC.compileExpToName [Char]
"kernel_arg" PrimType
pt Exp
e
      forall op s. Stm -> CompilerM op s ()
GC.stm
        [C.cstm|
            OPENCL_SUCCEED_OR_RETURN(clSetKernelArg(ctx->$id:name, $int:i, sizeof($id:v), &$id:v));
          |]
    setKernelArg a
i (MemKArg VName
v) = do
      Exp
v' <- forall op s. VName -> CompilerM op s Exp
GC.rawMem VName
v
      forall op s. Stm -> CompilerM op s ()
GC.stm
        [C.cstm|
            OPENCL_SUCCEED_OR_RETURN(clSetKernelArg(ctx->$id:name, $int:i, sizeof($exp:v'), &$exp:v'));
          |]
    setKernelArg a
i (SharedMemoryKArg Count Bytes Exp
num_bytes) = do
      Exp
num_bytes' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count Bytes Exp
num_bytes
      forall op s. Stm -> CompilerM op s ()
GC.stm
        [C.cstm|
            OPENCL_SUCCEED_OR_RETURN(clSetKernelArg(ctx->$id:name, $int:i, (size_t)$exp:num_bytes', NULL));
            |]

    localBytes :: Exp -> KernelArg -> CompilerM op s Exp
localBytes Exp
cur (SharedMemoryKArg Count Bytes Exp
num_bytes) = do
      Exp
num_bytes' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count Bytes Exp
num_bytes
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:cur + $exp:num_bytes'|]
    localBytes Exp
cur KernelArg
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
cur

launchKernel ::
  C.ToExp a =>
  KernelName ->
  [a] ->
  [a] ->
  a ->
  GC.CompilerM op s ()
launchKernel :: forall a op s.
ToExp a =>
Name -> [a] -> [a] -> a -> CompilerM op s ()
launchKernel Name
kernel_name [a]
num_workgroups [a]
workgroup_dims a
local_bytes = do
  VName
global_work_size <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"global_work_size"
  VName
time_start <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"time_start"
  VName
time_end <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"time_end"
  VName
time_diff <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"time_diff"
  VName
local_work_size <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"local_work_size"

  let ([Char]
debug_str, [Exp]
debug_args) = VName -> VName -> ([Char], [Exp])
debugPrint VName
global_work_size VName
local_work_size

  forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|
    if ($exp:total_elements != 0) {
      const size_t $id:global_work_size[$int:kernel_rank] = {$inits:kernel_dims'};
      const size_t $id:local_work_size[$int:kernel_rank] = {$inits:workgroup_dims'};
      typename int64_t $id:time_start = 0, $id:time_end = 0;
      if (ctx->debugging) {
        fprintf(ctx->log, $string:debug_str, $args:debug_args);
        $id:time_start = get_wall_time();
      }
      OPENCL_SUCCEED_OR_RETURN(
        clEnqueueNDRangeKernel(ctx->opencl.queue, ctx->$id:kernel_name, $int:kernel_rank, NULL,
                               $id:global_work_size, $id:local_work_size,
                               0, NULL, $exp:(profilingEvent kernel_name)));
      if (ctx->debugging) {
        OPENCL_SUCCEED_FATAL(clFinish(ctx->opencl.queue));
        $id:time_end = get_wall_time();
        long int $id:time_diff = $id:time_end - $id:time_start;
        fprintf(ctx->log, "kernel %s runtime: %ldus\n",
                $string:(prettyString kernel_name), $id:time_diff);
      }
    }|]
  where
    kernel_rank :: Int
kernel_rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
kernel_dims
    kernel_dims :: [Exp]
kernel_dims = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
multExp (forall a b. (a -> b) -> [a] -> [b]
map forall {a}. ToExp a => a -> Exp
toSize [a]
num_workgroups) (forall a b. (a -> b) -> [a] -> [b]
map forall {a}. ToExp a => a -> Exp
toSize [a]
workgroup_dims)
    kernel_dims' :: [Initializer]
kernel_dims' = forall a b. (a -> b) -> [a] -> [b]
map forall {a}. ToExp a => a -> Initializer
toInit [Exp]
kernel_dims
    workgroup_dims' :: [Initializer]
workgroup_dims' = forall a b. (a -> b) -> [a] -> [b]
map (forall {a}. ToExp a => a -> Initializer
toInit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a}. ToExp a => a -> Exp
toSize) [a]
workgroup_dims
    total_elements :: Exp
total_elements = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
multExp [C.cexp|1|] [Exp]
kernel_dims

    toInit :: a -> Initializer
toInit a
e = [C.cinit|$exp:e|]
    multExp :: a -> a -> Exp
multExp a
x a
y = [C.cexp|$exp:x * $exp:y|]
    toSize :: a -> Exp
toSize a
e = [C.cexp|(size_t)$exp:e|]

    debugPrint :: VName -> VName -> (String, [C.Exp])
    debugPrint :: VName -> VName -> ([Char], [Exp])
debugPrint VName
global_work_size VName
local_work_size =
      ( [Char]
"Launching %s with global work size "
          forall a. [a] -> [a] -> [a]
++ [Char]
dims
          forall a. [a] -> [a] -> [a]
++ [Char]
" and local work size "
          forall a. [a] -> [a] -> [a]
++ [Char]
dims
          forall a. [a] -> [a] -> [a]
++ [Char]
"; local memory: %d bytes.\n",
        [C.cexp|$string:(prettyString kernel_name)|]
          forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall {a} {a}. (ToIdent a, Show a, Integral a) => a -> a -> Exp
kernelDim VName
global_work_size) [Int
0 .. Int
kernel_rank forall a. Num a => a -> a -> a
- Int
1]
          forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (forall {a} {a}. (ToIdent a, Show a, Integral a) => a -> a -> Exp
kernelDim VName
local_work_size) [Int
0 .. Int
kernel_rank forall a. Num a => a -> a -> a
- Int
1]
          forall a. [a] -> [a] -> [a]
++ [[C.cexp|(int)$exp:local_bytes|]]
      )
      where
        dims :: [Char]
dims = [Char]
"[" forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [[a]] -> [a]
intercalate [Char]
", " (forall a. Int -> a -> [a]
replicate Int
kernel_rank [Char]
"%zu") forall a. [a] -> [a] -> [a]
++ [Char]
"]"
        kernelDim :: a -> a -> Exp
kernelDim a
arr a
i = [C.cexp|$id:arr[$int:i]|]