{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TupleSections #-}

-- | Code generation for C with OpenCL.
module Futhark.CodeGen.Backends.COpenCL
  ( compileProg,
    GC.CParts (..),
    GC.asLibrary,
    GC.asExecutable,
    GC.asServer,
  )
where

import Control.Monad hiding (mapM)
import Data.List (intercalate)
import Futhark.CodeGen.Backends.COpenCL.Boilerplate
import qualified Futhark.CodeGen.Backends.GenericC as GC
import Futhark.CodeGen.Backends.GenericC.Options
import Futhark.CodeGen.Backends.SimpleRep (primStorageType, toStorage)
import Futhark.CodeGen.ImpCode.OpenCL
import qualified Futhark.CodeGen.ImpGen.OpenCL as ImpGen
import Futhark.IR.GPUMem hiding
  ( CmpSizeLe,
    GetSize,
    GetSizeMax,
  )
import Futhark.MonadFreshNames
import qualified Language.C.Quote.OpenCL as C
import qualified Language.C.Syntax as C
import NeatInterpolation (untrimming)

-- | Compile the program to C with calls to OpenCL.
compileProg :: MonadFreshNames m => Prog GPUMem -> m (ImpGen.Warnings, GC.CParts)
compileProg :: Prog GPUMem -> m (Warnings, CParts)
compileProg Prog GPUMem
prog = do
  ( Warnings
ws,
    Program
      Text
opencl_code
      Text
opencl_prelude
      Map KernelName KernelSafety
kernels
      [PrimType]
types
      Map KernelName SizeClass
sizes
      [FailureMsg]
failures
      Definitions OpenCL
prog'
    ) <-
    Prog GPUMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
Prog GPUMem -> m (Warnings, Program)
ImpGen.compileProg Prog GPUMem
prog
  let cost_centres :: [KernelName]
cost_centres =
        [ KernelName
copyDevToDev,
          KernelName
copyDevToHost,
          KernelName
copyHostToDev,
          KernelName
copyScalarToDev,
          KernelName
copyScalarFromDev
        ]
  (Warnings
ws,)
    (CParts -> (Warnings, CParts)) -> m CParts -> m (Warnings, CParts)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Text
-> Operations OpenCL ()
-> CompilerM OpenCL () ()
-> Text
-> [Space]
-> [Option]
-> Definitions OpenCL
-> m CParts
forall (m :: * -> *) op.
MonadFreshNames m =>
Text
-> Operations op ()
-> CompilerM op () ()
-> Text
-> [Space]
-> [Option]
-> Definitions op
-> m CParts
GC.compileProg
      Text
"opencl"
      Operations OpenCL ()
operations
      ( Text
-> Text
-> [KernelName]
-> Map KernelName KernelSafety
-> [PrimType]
-> Map KernelName SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateBoilerplate
          Text
opencl_code
          Text
opencl_prelude
          [KernelName]
cost_centres
          Map KernelName KernelSafety
kernels
          [PrimType]
types
          Map KernelName SizeClass
sizes
          [FailureMsg]
failures
      )
      Text
include_opencl_h
      [SpaceId -> Space
Space SpaceId
"device", Space
DefaultSpace]
      [Option]
cliOptions
      Definitions OpenCL
prog'
  where
    operations :: GC.Operations OpenCL ()
    operations :: Operations OpenCL ()
operations =
      Operations OpenCL ()
forall op s. Operations op s
GC.defaultOperations
        { opsCompiler :: OpCompiler OpenCL ()
GC.opsCompiler = OpCompiler OpenCL ()
callKernel,
          opsWriteScalar :: WriteScalar OpenCL ()
GC.opsWriteScalar = WriteScalar OpenCL ()
writeOpenCLScalar,
          opsReadScalar :: ReadScalar OpenCL ()
GC.opsReadScalar = ReadScalar OpenCL ()
readOpenCLScalar,
          opsAllocate :: Allocate OpenCL ()
GC.opsAllocate = Allocate OpenCL ()
allocateOpenCLBuffer,
          opsDeallocate :: Deallocate OpenCL ()
GC.opsDeallocate = Deallocate OpenCL ()
deallocateOpenCLBuffer,
          opsCopy :: Copy OpenCL ()
GC.opsCopy = Copy OpenCL ()
copyOpenCLMemory,
          opsStaticArray :: StaticArray OpenCL ()
GC.opsStaticArray = StaticArray OpenCL ()
staticOpenCLArray,
          opsMemoryType :: MemoryType OpenCL ()
GC.opsMemoryType = MemoryType OpenCL ()
openclMemoryType,
          opsFatMemory :: Bool
GC.opsFatMemory = Bool
True
        }
    include_opencl_h :: Text
include_opencl_h =
      [untrimming|
       #define CL_TARGET_OPENCL_VERSION 120
       #define CL_USE_DEPRECATED_OPENCL_1_2_APIS
       #ifdef __APPLE__
       #define CL_SILENCE_DEPRECATION
       #include <OpenCL/cl.h>
       #else
       #include <CL/cl.h>
       #endif
       |]

cliOptions :: [Option]
cliOptions :: [Option]
cliOptions =
  [Option]
commonOptions
    [Option] -> [Option] -> [Option]
forall a. [a] -> [a] -> [a]
++ [ Option :: SpaceId -> Maybe Char -> OptionArgument -> SpaceId -> Stm -> Option
Option
           { optionLongName :: SpaceId
optionLongName = SpaceId
"platform",
             optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'p',
             optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"NAME",
             optionDescription :: SpaceId
optionDescription = SpaceId
"Use the first OpenCL platform whose name contains the given string.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_platform(cfg, optarg);|]
           },
         Option :: SpaceId -> Maybe Char -> OptionArgument -> SpaceId -> Stm -> Option
Option
           { optionLongName :: SpaceId
optionLongName = SpaceId
"dump-opencl",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"FILE",
             optionDescription :: SpaceId
optionDescription = SpaceId
"Dump the embedded OpenCL program to the indicated file.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_dump_program_to(cfg, optarg);
                                     entry_point = NULL;}|]
           },
         Option :: SpaceId -> Maybe Char -> OptionArgument -> SpaceId -> Stm -> Option
Option
           { optionLongName :: SpaceId
optionLongName = SpaceId
"load-opencl",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"FILE",
             optionDescription :: SpaceId
optionDescription = SpaceId
"Instead of using the embedded OpenCL program, load it from the indicated file.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_program_from(cfg, optarg);|]
           },
         Option :: SpaceId -> Maybe Char -> OptionArgument -> SpaceId -> Stm -> Option
Option
           { optionLongName :: SpaceId
optionLongName = SpaceId
"dump-opencl-binary",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"FILE",
             optionDescription :: SpaceId
optionDescription = SpaceId
"Dump the compiled version of the embedded OpenCL program to the indicated file.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_dump_binary_to(cfg, optarg);
                                     entry_point = NULL;}|]
           },
         Option :: SpaceId -> Maybe Char -> OptionArgument -> SpaceId -> Stm -> Option
Option
           { optionLongName :: SpaceId
optionLongName = SpaceId
"load-opencl-binary",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"FILE",
             optionDescription :: SpaceId
optionDescription = SpaceId
"Load an OpenCL binary from the indicated file.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_binary_from(cfg, optarg);|]
           },
         Option :: SpaceId -> Maybe Char -> OptionArgument -> SpaceId -> Stm -> Option
Option
           { optionLongName :: SpaceId
optionLongName = SpaceId
"build-option",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"OPT",
             optionDescription :: SpaceId
optionDescription = SpaceId
"Add an additional build option to the string passed to clBuildProgram().",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_add_build_option(cfg, optarg);|]
           },
         Option :: SpaceId -> Maybe Char -> OptionArgument -> SpaceId -> Stm -> Option
Option
           { optionLongName :: SpaceId
optionLongName = SpaceId
"profile",
             optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'P',
             optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
             optionDescription :: SpaceId
optionDescription = SpaceId
"Gather profiling data while executing and print out a summary at the end.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_profiling(cfg, 1);|]
           },
         Option :: SpaceId -> Maybe Char -> OptionArgument -> SpaceId -> Stm -> Option
Option
           { optionLongName :: SpaceId
optionLongName = SpaceId
"list-devices",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
             optionDescription :: SpaceId
optionDescription = SpaceId
"List all OpenCL devices and platforms available on the system.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_list_devices(cfg);
                        entry_point = NULL;}|]
           }
       ]

-- We detect the special case of writing a constant and turn it into a
-- non-blocking write.  This may be slightly faster, as it prevents
-- unnecessary synchronisation of the OpenCL command queue, and
-- writing a constant is fairly common.  This is only possible because
-- we can give the constant infinite lifetime (with 'static'), which
-- is not the case for ordinary variables.
writeOpenCLScalar :: GC.WriteScalar OpenCL ()
writeOpenCLScalar :: WriteScalar OpenCL ()
writeOpenCLScalar Exp
mem Exp
i Type
t SpaceId
"device" Volatility
_ Exp
val = do
  VName
val' <- SpaceId -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"write_tmp"
  let (BlockItem
decl, Exp
blocking) =
        case Exp
val of
          C.Const {} -> ([C.citem|static $ty:t $id:val' = $exp:val;|], [C.cexp|CL_FALSE|])
          Exp
_ -> ([C.citem|$ty:t $id:val' = $exp:val;|], [C.cexp|CL_TRUE|])
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|{$item:decl
                  OPENCL_SUCCEED_OR_RETURN(
                    clEnqueueWriteBuffer(ctx->opencl.queue, $exp:mem, $exp:blocking,
                                         $exp:i * sizeof($ty:t), sizeof($ty:t),
                                         &$id:val',
                                         0, NULL, $exp:(profilingEvent copyScalarToDev)));
                }|]
writeOpenCLScalar Exp
_ Exp
_ Type
_ SpaceId
space Volatility
_ Exp
_ =
  SpaceId -> CompilerM OpenCL () ()
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () ())
-> SpaceId -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot write to '" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
space SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"' memory space."

-- It is often faster to do a blocking clEnqueueReadBuffer() than to
-- do an async clEnqueueReadBuffer() followed by a clFinish(), even
-- with an in-order command queue.  This is safe if and only if there
-- are no possible outstanding failures.
readOpenCLScalar :: GC.ReadScalar OpenCL ()
readOpenCLScalar :: ReadScalar OpenCL ()
readOpenCLScalar Exp
mem Exp
i Type
t SpaceId
"device" Volatility
_ = do
  VName
val <- SpaceId -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"read_res"
  InitGroup -> CompilerM OpenCL () ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|$ty:t $id:val;|]
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|OPENCL_SUCCEED_OR_RETURN(
                   clEnqueueReadBuffer(ctx->opencl.queue, $exp:mem,
                                       ctx->failure_is_an_option ? CL_FALSE : CL_TRUE,
                                       $exp:i * sizeof($ty:t), sizeof($ty:t),
                                       &$id:val,
                                       0, NULL, $exp:(profilingEvent copyScalarFromDev)));
              |]
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|if (ctx->failure_is_an_option &&
                     futhark_context_sync(ctx) != 0) { return 1; }|]
  Exp -> CompilerM OpenCL () Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:val|]
readOpenCLScalar Exp
_ Exp
_ Type
_ SpaceId
space Volatility
_ =
  SpaceId -> CompilerM OpenCL () Exp
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () Exp)
-> SpaceId -> CompilerM OpenCL () Exp
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot read from '" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
space SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"' memory space."

allocateOpenCLBuffer :: GC.Allocate OpenCL ()
allocateOpenCLBuffer :: Allocate OpenCL ()
allocateOpenCLBuffer Exp
mem Exp
size Exp
tag SpaceId
"device" =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|OPENCL_SUCCEED_OR_RETURN(opencl_alloc(&ctx->opencl, (size_t)$exp:size, $exp:tag, &$exp:mem));|]
allocateOpenCLBuffer Exp
_ Exp
_ Exp
_ SpaceId
space =
  SpaceId -> CompilerM OpenCL () ()
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () ())
-> SpaceId -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot allocate in '" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
space SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"' space."

deallocateOpenCLBuffer :: GC.Deallocate OpenCL ()
deallocateOpenCLBuffer :: Deallocate OpenCL ()
deallocateOpenCLBuffer Exp
mem Exp
tag SpaceId
"device" =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|OPENCL_SUCCEED_OR_RETURN(opencl_free(&ctx->opencl, $exp:mem, $exp:tag));|]
deallocateOpenCLBuffer Exp
_ Exp
_ SpaceId
space =
  SpaceId -> CompilerM OpenCL () ()
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () ())
-> SpaceId -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot deallocate in '" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
space SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"' space"

copyOpenCLMemory :: GC.Copy OpenCL ()
-- The read/write/copy-buffer functions fail if the given offset is
-- out of bounds, even if asked to read zero bytes.  We protect with a
-- branch to avoid this.
copyOpenCLMemory :: Copy OpenCL ()
copyOpenCLMemory Exp
destmem Exp
destidx Space
DefaultSpace Exp
srcmem Exp
srcidx (Space SpaceId
"device") Exp
nbytes =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|
    if ($exp:nbytes > 0) {
      OPENCL_SUCCEED_OR_RETURN(
        clEnqueueReadBuffer(ctx->opencl.queue, $exp:srcmem,
                            ctx->failure_is_an_option ? CL_FALSE : CL_TRUE,
                            (size_t)$exp:srcidx, (size_t)$exp:nbytes,
                            $exp:destmem + $exp:destidx,
                            0, NULL, $exp:(profilingEvent copyHostToDev)));
      if (ctx->failure_is_an_option &&
          futhark_context_sync(ctx) != 0) { return 1; }
   }
  |]
copyOpenCLMemory Exp
destmem Exp
destidx (Space SpaceId
"device") Exp
srcmem Exp
srcidx Space
DefaultSpace Exp
nbytes =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|
    if ($exp:nbytes > 0) {
      OPENCL_SUCCEED_OR_RETURN(
        clEnqueueWriteBuffer(ctx->opencl.queue, $exp:destmem, CL_TRUE,
                             (size_t)$exp:destidx, (size_t)$exp:nbytes,
                             $exp:srcmem + $exp:srcidx,
                             0, NULL, $exp:(profilingEvent copyDevToHost)));
    }
  |]
copyOpenCLMemory Exp
destmem Exp
destidx (Space SpaceId
"device") Exp
srcmem Exp
srcidx (Space SpaceId
"device") Exp
nbytes =
  -- Be aware that OpenCL swaps the usual order of operands for
  -- memcpy()-like functions.  The order below is not a typo.
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|{
    if ($exp:nbytes > 0) {
      OPENCL_SUCCEED_OR_RETURN(
        clEnqueueCopyBuffer(ctx->opencl.queue,
                            $exp:srcmem, $exp:destmem,
                            (size_t)$exp:srcidx, (size_t)$exp:destidx,
                            (size_t)$exp:nbytes,
                            0, NULL, $exp:(profilingEvent copyDevToDev)));
      if (ctx->debugging) {
        OPENCL_SUCCEED_FATAL(clFinish(ctx->opencl.queue));
      }
    }
  }|]
copyOpenCLMemory Exp
destmem Exp
destidx Space
DefaultSpace Exp
srcmem Exp
srcidx Space
DefaultSpace Exp
nbytes =
  Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM OpenCL () ()
forall op s. Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
GC.copyMemoryDefaultSpace Exp
destmem Exp
destidx Exp
srcmem Exp
srcidx Exp
nbytes
copyOpenCLMemory Exp
_ Exp
_ Space
destspace Exp
_ Exp
_ Space
srcspace Exp
_ =
  SpaceId -> CompilerM OpenCL () ()
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () ())
-> SpaceId -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot copy to " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ Space -> SpaceId
forall a. Show a => a -> SpaceId
show Space
destspace SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
" from " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ Space -> SpaceId
forall a. Show a => a -> SpaceId
show Space
srcspace

openclMemoryType :: GC.MemoryType OpenCL ()
openclMemoryType :: MemoryType OpenCL ()
openclMemoryType SpaceId
"device" = Type -> CompilerM OpenCL () Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|typename cl_mem|]
openclMemoryType SpaceId
space =
  MemoryType OpenCL ()
forall a. HasCallStack => SpaceId -> a
error MemoryType OpenCL () -> MemoryType OpenCL ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"OpenCL backend does not support '" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
space SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"' memory space."

staticOpenCLArray :: GC.StaticArray OpenCL ()
staticOpenCLArray :: StaticArray OpenCL ()
staticOpenCLArray VName
name SpaceId
"device" PrimType
t ArrayContents
vs = do
  let ct :: Type
ct = PrimType -> Type
GC.primTypeToCType PrimType
t
  VName
name_realtype <- SpaceId -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName (SpaceId -> CompilerM OpenCL () VName)
-> SpaceId -> CompilerM OpenCL () VName
forall a b. (a -> b) -> a -> b
$ VName -> SpaceId
baseString VName
name SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"_realtype"
  Int
num_elems <- case ArrayContents
vs of
    ArrayValues [PrimValue]
vs' -> do
      let vs'' :: [Initializer]
vs'' = [[C.cinit|$exp:v|] | PrimValue
v <- [PrimValue]
vs']
      Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:(length vs'')] = {$inits:vs''};|]
      Int -> CompilerM OpenCL () Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> CompilerM OpenCL () Int) -> Int -> CompilerM OpenCL () Int
forall a b. (a -> b) -> a -> b
$ [Initializer] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Initializer]
vs''
    ArrayZeros Int
n -> do
      Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|]
      Int -> CompilerM OpenCL () Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n
  -- Fake a memory block.
  Id -> Type -> Maybe Exp -> CompilerM OpenCL () ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
GC.contextField (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
name SrcLoc
forall a. Monoid a => a
mempty) [C.cty|struct memblock_device|] Maybe Exp
forall a. Maybe a
Nothing
  -- During startup, copy the data to where we need it.
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.atInit
    [C.cstm|{
    typename cl_int success;
    ctx->$id:name.references = NULL;
    ctx->$id:name.size = 0;
    ctx->$id:name.mem =
      clCreateBuffer(ctx->opencl.ctx, CL_MEM_READ_WRITE,
                     ($int:num_elems > 0 ? $int:num_elems : 1)*sizeof($ty:ct), NULL,
                     &success);
    OPENCL_SUCCEED_OR_RETURN(success);
    if ($int:num_elems > 0) {
      OPENCL_SUCCEED_OR_RETURN(
        clEnqueueWriteBuffer(ctx->opencl.queue, ctx->$id:name.mem, CL_TRUE,
                             0, $int:num_elems*sizeof($ty:ct),
                             $id:name_realtype,
                             0, NULL, NULL));
    }
  }|]
  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|struct memblock_device $id:name = ctx->$id:name;|]
staticOpenCLArray VName
_ SpaceId
space PrimType
_ ArrayContents
_ =
  SpaceId -> CompilerM OpenCL () ()
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () ())
-> SpaceId -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"OpenCL backend cannot create static array in memory space '" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
space SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"'"

callKernel :: GC.OpCompiler OpenCL ()
callKernel :: OpCompiler OpenCL ()
callKernel (GetSize VName
v KernelName
key) =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = *ctx->tuning_params.$id:key;|]
callKernel (CmpSizeLe VName
v KernelName
key Exp
x) = do
  Exp
x' <- Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = *ctx->tuning_params.$id:key <= $exp:x';|]
  VName -> KernelName -> Exp -> CompilerM OpenCL () ()
forall op. VName -> KernelName -> Exp -> CompilerM op () ()
sizeLoggingCode VName
v KernelName
key Exp
x'
callKernel (GetSizeMax VName
v SizeClass
size_class) =
  let field :: SpaceId
field = SpaceId
"max_" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SizeClass -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty SizeClass
size_class
   in Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->opencl.$id:field;|]
callKernel (LaunchKernel KernelSafety
safety KernelName
name [KernelArg]
args [Exp]
num_workgroups [Exp]
workgroup_size) = do
  -- The other failure args are set automatically when the kernel is
  -- first created.
  Bool -> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (KernelSafety
safety KernelSafety -> KernelSafety -> Bool
forall a. Eq a => a -> a -> Bool
== KernelSafety
SafetyFull) (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm
      [C.cstm|
      OPENCL_SUCCEED_OR_RETURN(clSetKernelArg(ctx->$id:name, 1,
                                              sizeof(ctx->failure_is_an_option),
                                              &ctx->failure_is_an_option));
    |]

  (Int -> KernelArg -> CompilerM OpenCL () ())
-> [Int] -> [KernelArg] -> CompilerM OpenCL () ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> KernelArg -> CompilerM OpenCL () ()
forall a op s.
(Show a, Integral a) =>
a -> KernelArg -> CompilerM op s ()
setKernelArg [KernelSafety -> Int
numFailureParams KernelSafety
safety ..] [KernelArg]
args
  [Exp]
num_workgroups' <- (Exp -> CompilerM OpenCL () Exp)
-> [Exp] -> CompilerM OpenCL () [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
num_workgroups
  [Exp]
workgroup_size' <- (Exp -> CompilerM OpenCL () Exp)
-> [Exp] -> CompilerM OpenCL () [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
workgroup_size
  Exp
local_bytes <- (Exp -> KernelArg -> CompilerM OpenCL () Exp)
-> Exp -> [KernelArg] -> CompilerM OpenCL () Exp
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Exp -> KernelArg -> CompilerM OpenCL () Exp
forall op s. Exp -> KernelArg -> CompilerM op s Exp
localBytes [C.cexp|0|] [KernelArg]
args

  KernelName -> [Exp] -> [Exp] -> Exp -> CompilerM OpenCL () ()
forall a op s.
ToExp a =>
KernelName -> [a] -> [a] -> a -> CompilerM op s ()
launchKernel KernelName
name [Exp]
num_workgroups' [Exp]
workgroup_size' Exp
local_bytes

  Bool -> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (KernelSafety
safety KernelSafety -> KernelSafety -> Bool
forall a. Ord a => a -> a -> Bool
>= KernelSafety
SafetyFull) (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|ctx->failure_is_an_option = 1;|]
  where
    setKernelArg :: a -> KernelArg -> CompilerM op s ()
setKernelArg a
i (ValueKArg Exp
e PrimType
pt) = do
      VName
v <- case PrimType
pt of
        -- We always transfer f16 values to the kernel as 16 bits, but
        -- the actual host type may be typedef'd to a 32-bit float.
        -- This requires some care.
        FloatType FloatType
Float16 -> do
          VName
v <- SpaceId -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"kernel_arg"
          Exp
e' <- PrimType -> Exp -> Exp
toStorage PrimType
pt (Exp -> Exp) -> CompilerM op s Exp -> CompilerM op s Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
e
          InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|$ty:(primStorageType pt) $id:v = $e';|]
          VName -> CompilerM op s VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
        PrimType
_ -> SpaceId -> PrimType -> Exp -> CompilerM op s VName
forall op s. SpaceId -> PrimType -> Exp -> CompilerM op s VName
GC.compileExpToName SpaceId
"kernel_arg" PrimType
pt Exp
e
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GC.stm
        [C.cstm|
            OPENCL_SUCCEED_OR_RETURN(clSetKernelArg(ctx->$id:name, $int:i, sizeof($id:v), &$id:v));
          |]
    setKernelArg a
i (MemKArg VName
v) = do
      Exp
v' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
GC.rawMem VName
v
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GC.stm
        [C.cstm|
            OPENCL_SUCCEED_OR_RETURN(clSetKernelArg(ctx->$id:name, $int:i, sizeof($exp:v'), &$exp:v'));
          |]
    setKernelArg a
i (SharedMemoryKArg Count Bytes Exp
num_bytes) = do
      Exp
num_bytes' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> Exp
forall u e. Count u e -> e
unCount Count Bytes Exp
num_bytes
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GC.stm
        [C.cstm|
            OPENCL_SUCCEED_OR_RETURN(clSetKernelArg(ctx->$id:name, $int:i, (size_t)$exp:num_bytes', NULL));
            |]

    localBytes :: Exp -> KernelArg -> CompilerM op s Exp
localBytes Exp
cur (SharedMemoryKArg Count Bytes Exp
num_bytes) = do
      Exp
num_bytes' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> Exp
forall u e. Count u e -> e
unCount Count Bytes Exp
num_bytes
      Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$exp:cur + $exp:num_bytes'|]
    localBytes Exp
cur KernelArg
_ = Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
cur

launchKernel ::
  C.ToExp a =>
  KernelName ->
  [a] ->
  [a] ->
  a ->
  GC.CompilerM op s ()
launchKernel :: KernelName -> [a] -> [a] -> a -> CompilerM op s ()
launchKernel KernelName
kernel_name [a]
num_workgroups [a]
workgroup_dims a
local_bytes = do
  VName
global_work_size <- SpaceId -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"global_work_size"
  VName
time_start <- SpaceId -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"time_start"
  VName
time_end <- SpaceId -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"time_end"
  VName
time_diff <- SpaceId -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"time_diff"
  VName
local_work_size <- SpaceId -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"local_work_size"

  let (SpaceId
debug_str, [Exp]
debug_args) = VName -> VName -> (SpaceId, [Exp])
debugPrint VName
global_work_size VName
local_work_size

  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|
    if ($exp:total_elements != 0) {
      const size_t $id:global_work_size[$int:kernel_rank] = {$inits:kernel_dims'};
      const size_t $id:local_work_size[$int:kernel_rank] = {$inits:workgroup_dims'};
      typename int64_t $id:time_start = 0, $id:time_end = 0;
      if (ctx->debugging) {
        fprintf(ctx->log, $string:debug_str, $args:debug_args);
        $id:time_start = get_wall_time();
      }
      OPENCL_SUCCEED_OR_RETURN(
        clEnqueueNDRangeKernel(ctx->opencl.queue, ctx->$id:kernel_name, $int:kernel_rank, NULL,
                               $id:global_work_size, $id:local_work_size,
                               0, NULL, $exp:(profilingEvent kernel_name)));
      if (ctx->debugging) {
        OPENCL_SUCCEED_FATAL(clFinish(ctx->opencl.queue));
        $id:time_end = get_wall_time();
        long int $id:time_diff = $id:time_end - $id:time_start;
        fprintf(ctx->log, "kernel %s runtime: %ldus\n",
                $string:(pretty kernel_name), $id:time_diff);
      }
    }|]
  where
    kernel_rank :: Int
kernel_rank = [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
kernel_dims
    kernel_dims :: [Exp]
kernel_dims = (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall a a. (ToExp a, ToExp a) => a -> a -> Exp
multExp ((a -> Exp) -> [a] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map a -> Exp
forall a. ToExp a => a -> Exp
toSize [a]
num_workgroups) ((a -> Exp) -> [a] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map a -> Exp
forall a. ToExp a => a -> Exp
toSize [a]
workgroup_dims)
    kernel_dims' :: [Initializer]
kernel_dims' = (Exp -> Initializer) -> [Exp] -> [Initializer]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> Initializer
forall a. ToExp a => a -> Initializer
toInit [Exp]
kernel_dims
    workgroup_dims' :: [Initializer]
workgroup_dims' = (a -> Initializer) -> [a] -> [Initializer]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Initializer
forall a. ToExp a => a -> Initializer
toInit (Exp -> Initializer) -> (a -> Exp) -> a -> Initializer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Exp
forall a. ToExp a => a -> Exp
toSize) [a]
workgroup_dims
    total_elements :: Exp
total_elements = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall a a. (ToExp a, ToExp a) => a -> a -> Exp
multExp [C.cexp|1|] [Exp]
kernel_dims

    toInit :: a -> Initializer
toInit a
e = [C.cinit|$exp:e|]
    multExp :: a -> a -> Exp
multExp a
x a
y = [C.cexp|$exp:x * $exp:y|]
    toSize :: a -> Exp
toSize a
e = [C.cexp|(size_t)$exp:e|]

    debugPrint :: VName -> VName -> (String, [C.Exp])
    debugPrint :: VName -> VName -> (SpaceId, [Exp])
debugPrint VName
global_work_size VName
local_work_size =
      ( SpaceId
"Launching %s with global work size "
          SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
dims
          SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
" and local work size "
          SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
dims
          SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"; local memory: %d bytes.\n",
        [C.cexp|$string:(pretty kernel_name)|] Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:
        (Int -> Exp) -> [Int] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Int -> Exp
forall a a. (ToIdent a, Show a, Integral a) => a -> a -> Exp
kernelDim VName
global_work_size) [Int
0 .. Int
kernel_rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
          [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ (Int -> Exp) -> [Int] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Int -> Exp
forall a a. (ToIdent a, Show a, Integral a) => a -> a -> Exp
kernelDim VName
local_work_size) [Int
0 .. Int
kernel_rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
          [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [[C.cexp|(int)$exp:local_bytes|]]
      )
      where
        dims :: SpaceId
dims = SpaceId
"[" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId -> [SpaceId] -> SpaceId
forall a. [a] -> [[a]] -> [a]
intercalate SpaceId
", " (Int -> SpaceId -> [SpaceId]
forall a. Int -> a -> [a]
replicate Int
kernel_rank SpaceId
"%zu") SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"]"
        kernelDim :: a -> a -> Exp
kernelDim a
arr a
i = [C.cexp|$id:arr[$int:i]|]