{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TupleSections #-}

-- | Code generation for CUDA.
module Futhark.CodeGen.Backends.CCUDA
  ( compileProg,
    GC.CParts (..),
    GC.asLibrary,
    GC.asExecutable,
    GC.asServer,
  )
where

import Control.Monad
import Data.List (intercalate)
import Data.Maybe (catMaybes)
import Futhark.CodeGen.Backends.CCUDA.Boilerplate
import Futhark.CodeGen.Backends.COpenCL.Boilerplate (commonOptions, sizeLoggingCode)
import qualified Futhark.CodeGen.Backends.GenericC as GC
import Futhark.CodeGen.Backends.GenericC.Options
import Futhark.CodeGen.ImpCode.OpenCL
import qualified Futhark.CodeGen.ImpGen.CUDA as ImpGen
import Futhark.IR.GPUMem hiding
  ( CmpSizeLe,
    GetSize,
    GetSizeMax,
  )
import Futhark.MonadFreshNames
import qualified Language.C.Quote.OpenCL as C

-- | Compile the program to C with calls to CUDA.
compileProg :: MonadFreshNames m => Prog GPUMem -> m (ImpGen.Warnings, GC.CParts)
compileProg :: forall (m :: * -> *).
MonadFreshNames m =>
Prog GPUMem -> m (Warnings, CParts)
compileProg Prog GPUMem
prog = do
  (Warnings
ws, Program [Char]
cuda_code [Char]
cuda_prelude Map Name KernelSafety
kernels [PrimType]
_ Map Name SizeClass
sizes [FailureMsg]
failures Definitions OpenCL
prog') <-
    Prog GPUMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
Prog GPUMem -> m (Warnings, Program)
ImpGen.compileProg Prog GPUMem
prog
  let cost_centres :: [Name]
cost_centres =
        [ Name
copyDevToDev,
          Name
copyDevToHost,
          Name
copyHostToDev,
          Name
copyScalarToDev,
          Name
copyScalarFromDev
        ]
      extra :: CompilerM OpenCL () ()
extra =
        [Char]
-> [Char]
-> [Name]
-> Map Name KernelSafety
-> Map Name SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateBoilerplate
          [Char]
cuda_code
          [Char]
cuda_prelude
          [Name]
cost_centres
          Map Name KernelSafety
kernels
          Map Name SizeClass
sizes
          [FailureMsg]
failures
  (Warnings
ws,)
    (CParts -> (Warnings, CParts)) -> m CParts -> m (Warnings, CParts)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char]
-> Operations OpenCL ()
-> CompilerM OpenCL () ()
-> [Char]
-> [Space]
-> [Option]
-> Definitions OpenCL
-> m CParts
forall (m :: * -> *) op.
MonadFreshNames m =>
[Char]
-> Operations op ()
-> CompilerM op () ()
-> [Char]
-> [Space]
-> [Option]
-> Definitions op
-> m CParts
GC.compileProg
      [Char]
"cuda"
      Operations OpenCL ()
operations
      CompilerM OpenCL () ()
extra
      [Char]
cuda_includes
      [[Char] -> Space
Space [Char]
"device", Space
DefaultSpace]
      [Option]
cliOptions
      Definitions OpenCL
prog'
  where
    operations :: GC.Operations OpenCL ()
    operations :: Operations OpenCL ()
operations =
      Operations OpenCL ()
forall op s. Operations op s
GC.defaultOperations
        { opsWriteScalar :: WriteScalar OpenCL ()
GC.opsWriteScalar = WriteScalar OpenCL ()
writeCUDAScalar,
          opsReadScalar :: ReadScalar OpenCL ()
GC.opsReadScalar = ReadScalar OpenCL ()
readCUDAScalar,
          opsAllocate :: Allocate OpenCL ()
GC.opsAllocate = Allocate OpenCL ()
allocateCUDABuffer,
          opsDeallocate :: Deallocate OpenCL ()
GC.opsDeallocate = Deallocate OpenCL ()
deallocateCUDABuffer,
          opsCopy :: Copy OpenCL ()
GC.opsCopy = Copy OpenCL ()
copyCUDAMemory,
          opsStaticArray :: StaticArray OpenCL ()
GC.opsStaticArray = StaticArray OpenCL ()
staticCUDAArray,
          opsMemoryType :: MemoryType OpenCL ()
GC.opsMemoryType = MemoryType OpenCL ()
cudaMemoryType,
          opsCompiler :: OpCompiler OpenCL ()
GC.opsCompiler = OpCompiler OpenCL ()
callKernel,
          opsFatMemory :: Bool
GC.opsFatMemory = Bool
True,
          opsCritical :: ([BlockItem], [BlockItem])
GC.opsCritical =
            ( [C.citems|CUDA_SUCCEED_FATAL(cuCtxPushCurrent(ctx->cuda.cu_ctx));|],
              [C.citems|CUDA_SUCCEED_FATAL(cuCtxPopCurrent(&ctx->cuda.cu_ctx));|]
            )
        }
    cuda_includes :: [Char]
cuda_includes =
      [[Char]] -> [Char]
unlines
        [ [Char]
"#include <cuda.h>",
          [Char]
"#include <cuda_runtime.h>",
          [Char]
"#include <nvrtc.h>"
        ]

cliOptions :: [Option]
cliOptions :: [Option]
cliOptions =
  [Option]
commonOptions
    [Option] -> [Option] -> [Option]
forall a. [a] -> [a] -> [a]
++ [ Option :: [Char] -> Maybe Char -> OptionArgument -> [Char] -> Stm -> Option
Option
           { optionLongName :: [Char]
optionLongName = [Char]
"dump-cuda",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Dump the embedded CUDA kernels to the indicated file.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_dump_program_to(cfg, optarg);
                                     entry_point = NULL;}|]
           },
         Option :: [Char] -> Maybe Char -> OptionArgument -> [Char] -> Stm -> Option
Option
           { optionLongName :: [Char]
optionLongName = [Char]
"load-cuda",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Instead of using the embedded CUDA kernels, load them from the indicated file.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_program_from(cfg, optarg);|]
           },
         Option :: [Char] -> Maybe Char -> OptionArgument -> [Char] -> Stm -> Option
Option
           { optionLongName :: [Char]
optionLongName = [Char]
"dump-ptx",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Dump the PTX-compiled version of the embedded kernels to the indicated file.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_dump_ptx_to(cfg, optarg);
                                     entry_point = NULL;}|]
           },
         Option :: [Char] -> Maybe Char -> OptionArgument -> [Char] -> Stm -> Option
Option
           { optionLongName :: [Char]
optionLongName = [Char]
"load-ptx",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"FILE",
             optionDescription :: [Char]
optionDescription = [Char]
"Load PTX code from the indicated file.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_ptx_from(cfg, optarg);|]
           },
         Option :: [Char] -> Maybe Char -> OptionArgument -> [Char] -> Stm -> Option
Option
           { optionLongName :: [Char]
optionLongName = [Char]
"nvrtc-option",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"OPT",
             optionDescription :: [Char]
optionDescription = [Char]
"Add an additional build option to the string passed to NVRTC.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_add_nvrtc_option(cfg, optarg);|]
           },
         Option :: [Char] -> Maybe Char -> OptionArgument -> [Char] -> Stm -> Option
Option
           { optionLongName :: [Char]
optionLongName = [Char]
"profile",
             optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'P',
             optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
             optionDescription :: [Char]
optionDescription = [Char]
"Gather profiling data while executing and print out a summary at the end.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_profiling(cfg, 1);|]
           }
       ]

writeCUDAScalar :: GC.WriteScalar OpenCL ()
writeCUDAScalar :: WriteScalar OpenCL ()
writeCUDAScalar Exp
mem Exp
idx Type
t [Char]
"device" Volatility
_ Exp
val = do
  VName
val' <- [Char] -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_tmp"
  let ([BlockItem]
bef, [BlockItem]
aft) = Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
copyScalarToDev
  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.item
    [C.citem|{$ty:t $id:val' = $exp:val;
                  $items:bef
                  CUDA_SUCCEED_OR_RETURN(
                    cuMemcpyHtoD($exp:mem + $exp:idx * sizeof($ty:t),
                                 &$id:val',
                                 sizeof($ty:t)));
                  $items:aft
                 }|]
writeCUDAScalar Exp
_ Exp
_ Type
_ [Char]
space Volatility
_ Exp
_ =
  [Char] -> CompilerM OpenCL () ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> CompilerM OpenCL () ())
-> [Char] -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot write to '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
space [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

readCUDAScalar :: GC.ReadScalar OpenCL ()
readCUDAScalar :: ReadScalar OpenCL ()
readCUDAScalar Exp
mem Exp
idx Type
t [Char]
"device" Volatility
_ = do
  VName
val <- [Char] -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"read_res"
  let ([BlockItem]
bef, [BlockItem]
aft) = Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
copyScalarFromDev
  (BlockItem -> CompilerM OpenCL () ())
-> [BlockItem] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_
    BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.item
    [C.citems|
       $ty:t $id:val;
       {
       $items:bef
       CUDA_SUCCEED_OR_RETURN(
          cuMemcpyDtoH(&$id:val,
                       $exp:mem + $exp:idx * sizeof($ty:t),
                       sizeof($ty:t)));
        $items:aft
       }
       |]
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|if (futhark_context_sync(ctx) != 0) { return 1; }|]
  Exp -> CompilerM OpenCL () Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:val|]
readCUDAScalar Exp
_ Exp
_ Type
_ [Char]
space Volatility
_ =
  [Char] -> CompilerM OpenCL () Exp
forall a. HasCallStack => [Char] -> a
error ([Char] -> CompilerM OpenCL () Exp)
-> [Char] -> CompilerM OpenCL () Exp
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot write to '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
space [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

allocateCUDABuffer :: GC.Allocate OpenCL ()
allocateCUDABuffer :: Allocate OpenCL ()
allocateCUDABuffer Exp
mem Exp
size Exp
tag [Char]
"device" =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|CUDA_SUCCEED_OR_RETURN(cuda_alloc(&ctx->cuda, $exp:size, $exp:tag, &$exp:mem));|]
allocateCUDABuffer Exp
_ Exp
_ Exp
_ [Char]
space =
  [Char] -> CompilerM OpenCL () ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> CompilerM OpenCL () ())
-> [Char] -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot allocate in '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
space [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

deallocateCUDABuffer :: GC.Deallocate OpenCL ()
deallocateCUDABuffer :: Deallocate OpenCL ()
deallocateCUDABuffer Exp
mem Exp
tag [Char]
"device" =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|CUDA_SUCCEED_OR_RETURN(cuda_free(&ctx->cuda, $exp:mem, $exp:tag));|]
deallocateCUDABuffer Exp
_ Exp
_ [Char]
space =
  [Char] -> CompilerM OpenCL () ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> CompilerM OpenCL () ())
-> [Char] -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot deallocate in '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
space [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

copyCUDAMemory :: GC.Copy OpenCL ()
copyCUDAMemory :: Copy OpenCL ()
copyCUDAMemory Exp
dstmem Exp
dstidx Space
dstSpace Exp
srcmem Exp
srcidx Space
srcSpace Exp
nbytes = do
  let ([Char]
fn, Name
prof) = Space -> Space -> ([Char], Name)
memcpyFun Space
dstSpace Space
srcSpace
      ([BlockItem]
bef, [BlockItem]
aft) = Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
prof
  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.item
    [C.citem|{
                $items:bef
                CUDA_SUCCEED_OR_RETURN(
                  $id:fn($exp:dstmem + $exp:dstidx,
                         $exp:srcmem + $exp:srcidx,
                         $exp:nbytes));
                $items:aft
                }
                |]
  where
    memcpyFun :: Space -> Space -> ([Char], Name)
memcpyFun Space
DefaultSpace (Space [Char]
"device") = ([Char]
"cuMemcpyDtoH", Name
copyDevToHost)
    memcpyFun (Space [Char]
"device") Space
DefaultSpace = ([Char]
"cuMemcpyHtoD", Name
copyHostToDev)
    memcpyFun (Space [Char]
"device") (Space [Char]
"device") = ([Char]
"cuMemcpy", Name
copyDevToDev)
    memcpyFun Space
_ Space
_ =
      [Char] -> ([Char], Name)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ([Char], Name)) -> [Char] -> ([Char], Name)
forall a b. (a -> b) -> a -> b
$
        [Char]
"Cannot copy to '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Space -> [Char]
forall a. Show a => a -> [Char]
show Space
dstSpace
          [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"' from '"
          [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Space -> [Char]
forall a. Show a => a -> [Char]
show Space
srcSpace
          [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"'."

staticCUDAArray :: GC.StaticArray OpenCL ()
staticCUDAArray :: StaticArray OpenCL ()
staticCUDAArray VName
name [Char]
"device" PrimType
t ArrayContents
vs = do
  let ct :: Type
ct = PrimType -> Type
GC.primTypeToCType PrimType
t
  VName
name_realtype <- [Char] -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> CompilerM OpenCL () VName)
-> [Char] -> CompilerM OpenCL () VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
name [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_realtype"
  Int
num_elems <- case ArrayContents
vs of
    ArrayValues [PrimValue]
vs' -> do
      let vs'' :: [Initializer]
vs'' = [[C.cinit|$exp:v|] | Exp
v <- (PrimValue -> Exp) -> [PrimValue] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> Exp
GC.compilePrimValue [PrimValue]
vs']
      Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:(length vs'')] = {$inits:vs''};|]
      Int -> CompilerM OpenCL () Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> CompilerM OpenCL () Int) -> Int -> CompilerM OpenCL () Int
forall a b. (a -> b) -> a -> b
$ [Initializer] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Initializer]
vs''
    ArrayZeros Int
n -> do
      Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|]
      Int -> CompilerM OpenCL () Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n
  -- Fake a memory block.
  Id -> Type -> Maybe Exp -> CompilerM OpenCL () ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
GC.contextField (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
name SrcLoc
forall a. Monoid a => a
mempty) [C.cty|struct memblock_device|] Maybe Exp
forall a. Maybe a
Nothing
  -- During startup, copy the data to where we need it.
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.atInit
    [C.cstm|{
    ctx->$id:name.references = NULL;
    ctx->$id:name.size = 0;
    CUDA_SUCCEED_FATAL(cuMemAlloc(&ctx->$id:name.mem,
                            ($int:num_elems > 0 ? $int:num_elems : 1)*sizeof($ty:ct)));
    if ($int:num_elems > 0) {
      CUDA_SUCCEED_FATAL(cuMemcpyHtoD(ctx->$id:name.mem, $id:name_realtype,
                                $int:num_elems*sizeof($ty:ct)));
    }
  }|]
  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|struct memblock_device $id:name = ctx->$id:name;|]
staticCUDAArray VName
_ [Char]
space PrimType
_ ArrayContents
_ =
  [Char] -> CompilerM OpenCL () ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> CompilerM OpenCL () ())
-> [Char] -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    [Char]
"CUDA backend cannot create static array in '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
space
      [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space"

cudaMemoryType :: GC.MemoryType OpenCL ()
cudaMemoryType :: MemoryType OpenCL ()
cudaMemoryType [Char]
"device" = Type -> CompilerM OpenCL () Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|typename CUdeviceptr|]
cudaMemoryType [Char]
space =
  MemoryType OpenCL ()
forall a. HasCallStack => [Char] -> a
error MemoryType OpenCL () -> MemoryType OpenCL ()
forall a b. (a -> b) -> a -> b
$ [Char]
"CUDA backend does not support '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
space [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"' memory space."

callKernel :: GC.OpCompiler OpenCL ()
callKernel :: OpCompiler OpenCL ()
callKernel (GetSize VName
v Name
key) =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->sizes.$id:key;|]
callKernel (CmpSizeLe VName
v Name
key Exp
x) = do
  Exp
x' <- Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->sizes.$id:key <= $exp:x';|]
  VName -> Name -> Exp -> CompilerM OpenCL () ()
forall op. VName -> Name -> Exp -> CompilerM op () ()
sizeLoggingCode VName
v Name
key Exp
x'
callKernel (GetSizeMax VName
v SizeClass
size_class) =
  let field :: [Char]
field = [Char]
"max_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ SizeClass -> [Char]
cudaSizeClass SizeClass
size_class
   in Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->cuda.$id:field;|]
  where
    cudaSizeClass :: SizeClass -> [Char]
cudaSizeClass SizeThreshold {} = [Char]
"threshold"
    cudaSizeClass SizeClass
SizeGroup = [Char]
"block_size"
    cudaSizeClass SizeClass
SizeNumGroups = [Char]
"grid_size"
    cudaSizeClass SizeClass
SizeTile = [Char]
"tile_size"
    cudaSizeClass SizeClass
SizeRegTile = [Char]
"reg_tile_size"
    cudaSizeClass SizeClass
SizeLocalMemory = [Char]
"shared_memory"
    cudaSizeClass (SizeBespoke Name
x Int64
_) = Name -> [Char]
forall a. Pretty a => a -> [Char]
pretty Name
x
callKernel (LaunchKernel KernelSafety
safety Name
kernel_name [KernelArg]
args [Exp]
num_blocks [Exp]
block_size) = do
  VName
args_arr <- [Char] -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"kernel_args"
  VName
time_start <- [Char] -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"time_start"
  VName
time_end <- [Char] -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"time_end"
  ([VName]
args', [Maybe (VName, VName)]
shared_vars) <- [(VName, Maybe (VName, VName))]
-> ([VName], [Maybe (VName, VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Maybe (VName, VName))]
 -> ([VName], [Maybe (VName, VName)]))
-> CompilerM OpenCL () [(VName, Maybe (VName, VName))]
-> CompilerM OpenCL () ([VName], [Maybe (VName, VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (KernelArg -> CompilerM OpenCL () (VName, Maybe (VName, VName)))
-> [KernelArg]
-> CompilerM OpenCL () [(VName, Maybe (VName, VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelArg -> CompilerM OpenCL () (VName, Maybe (VName, VName))
forall {op} {s}.
KernelArg -> CompilerM op s (VName, Maybe (VName, VName))
mkArgs [KernelArg]
args
  let ([VName]
shared_sizes, [VName]
shared_offsets) = [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, VName)] -> ([VName], [VName]))
-> [(VName, VName)] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ [Maybe (VName, VName)] -> [(VName, VName)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (VName, VName)]
shared_vars
      shared_offsets_sc :: [Exp]
shared_offsets_sc = [VName] -> [Exp]
mkOffsets [VName]
shared_sizes
      shared_args :: [(VName, Exp)]
shared_args = [VName] -> [Exp] -> [(VName, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
shared_offsets [Exp]
shared_offsets_sc
      shared_tot :: Exp
shared_tot = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
shared_offsets_sc
  [(VName, Exp)]
-> ((VName, Exp) -> CompilerM OpenCL () ())
-> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(VName, Exp)]
shared_args (((VName, Exp) -> CompilerM OpenCL () ())
 -> CompilerM OpenCL () ())
-> ((VName, Exp) -> CompilerM OpenCL () ())
-> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \(VName
arg, Exp
offset) ->
    InitGroup -> CompilerM OpenCL () ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|unsigned int $id:arg = $exp:offset;|]

  (Exp
grid_x, Exp
grid_y, Exp
grid_z) <- [Exp] -> (Exp, Exp, Exp)
mkDims ([Exp] -> (Exp, Exp, Exp))
-> CompilerM OpenCL () [Exp] -> CompilerM OpenCL () (Exp, Exp, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> CompilerM OpenCL () Exp)
-> [Exp] -> CompilerM OpenCL () [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
num_blocks
  (Exp
block_x, Exp
block_y, Exp
block_z) <- [Exp] -> (Exp, Exp, Exp)
mkDims ([Exp] -> (Exp, Exp, Exp))
-> CompilerM OpenCL () [Exp] -> CompilerM OpenCL () (Exp, Exp, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> CompilerM OpenCL () Exp)
-> [Exp] -> CompilerM OpenCL () [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
block_size
  let perm_args :: [Initializer]
perm_args
        | [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
num_blocks Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 = [[C.cinit|&perm[0]|], [C.cinit|&perm[1]|], [C.cinit|&perm[2]|]]
        | Bool
otherwise = []
      failure_args :: [Initializer]
failure_args =
        Int -> [Initializer] -> [Initializer]
forall a. Int -> [a] -> [a]
take
          (KernelSafety -> Int
numFailureParams KernelSafety
safety)
          [ [C.cinit|&ctx->global_failure|],
            [C.cinit|&ctx->failure_is_an_option|],
            [C.cinit|&ctx->global_failure_args|]
          ]
      args'' :: [Initializer]
args'' = [Initializer]
perm_args [Initializer] -> [Initializer] -> [Initializer]
forall a. [a] -> [a] -> [a]
++ [Initializer]
failure_args [Initializer] -> [Initializer] -> [Initializer]
forall a. [a] -> [a] -> [a]
++ [[C.cinit|&$id:a|] | VName
a <- [VName]
args']
      sizes_nonzero :: Exp
sizes_nonzero =
        [Exp] -> Exp
expsNotZero
          [ Exp
grid_x,
            Exp
grid_y,
            Exp
grid_z,
            Exp
block_x,
            Exp
block_y,
            Exp
block_z
          ]
      ([BlockItem]
bef, [BlockItem]
aft) = Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
kernel_name

  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|
    if ($exp:sizes_nonzero) {
      int perm[3] = { 0, 1, 2 };

      if ($exp:grid_y >= (1<<16)) {
        perm[1] = perm[0];
        perm[0] = 1;
      }

      if ($exp:grid_z >= (1<<16)) {
        perm[2] = perm[0];
        perm[0] = 2;
      }

      size_t grid[3];
      grid[perm[0]] = $exp:grid_x;
      grid[perm[1]] = $exp:grid_y;
      grid[perm[2]] = $exp:grid_z;

      void *$id:args_arr[] = { $inits:args'' };
      typename int64_t $id:time_start = 0, $id:time_end = 0;
      if (ctx->debugging) {
        fprintf(ctx->log, "Launching %s with grid size (", $string:(pretty kernel_name));
        $stms:(printSizes [grid_x, grid_y, grid_z])
        fprintf(ctx->log, ") and block size (");
        $stms:(printSizes [block_x, block_y, block_z])
        fprintf(ctx->log, ").\n");
        $id:time_start = get_wall_time();
      }
      $items:bef
      CUDA_SUCCEED_OR_RETURN(
        cuLaunchKernel(ctx->$id:kernel_name,
                       grid[0], grid[1], grid[2],
                       $exp:block_x, $exp:block_y, $exp:block_z,
                       $exp:shared_tot, NULL,
                       $id:args_arr, NULL));
      $items:aft
      if (ctx->debugging) {
        CUDA_SUCCEED_FATAL(cuCtxSynchronize());
        $id:time_end = get_wall_time();
        fprintf(ctx->log, "Kernel %s runtime: %ldus\n",
                $string:(pretty kernel_name), $id:time_end - $id:time_start);
      }
    }|]

  Bool -> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (KernelSafety
safety KernelSafety -> KernelSafety -> Bool
forall a. Ord a => a -> a -> Bool
>= KernelSafety
SafetyFull) (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|ctx->failure_is_an_option = 1;|]
  where
    mkDims :: [Exp] -> (Exp, Exp, Exp)
mkDims [] = ([C.cexp|0|], [C.cexp|0|], [C.cexp|0|])
    mkDims [Exp
x] = (Exp
x, [C.cexp|1|], [C.cexp|1|])
    mkDims [Exp
x, Exp
y] = (Exp
x, Exp
y, [C.cexp|1|])
    mkDims (Exp
x : Exp
y : Exp
z : [Exp]
_) = (Exp
x, Exp
y, Exp
z)
    addExp :: a -> a -> Exp
addExp a
x a
y = [C.cexp|$exp:x + $exp:y|]
    alignExp :: a -> Exp
alignExp a
e = [C.cexp|$exp:e + ((8 - ($exp:e % 8)) % 8)|]
    mkOffsets :: [VName] -> [Exp]
mkOffsets = (Exp -> VName -> Exp) -> Exp -> [VName] -> [Exp]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\Exp
a VName
b -> Exp
a Exp -> Exp -> Exp
forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
`addExp` VName -> Exp
forall {a}. ToExp a => a -> Exp
alignExp VName
b) [C.cexp|0|]
    expNotZero :: a -> Exp
expNotZero a
e = [C.cexp|$exp:e != 0|]
    expAnd :: a -> a -> Exp
expAnd a
a a
b = [C.cexp|$exp:a && $exp:b|]
    expsNotZero :: [Exp] -> Exp
expsNotZero = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
expAnd [C.cexp|1|] ([Exp] -> Exp) -> ([Exp] -> [Exp]) -> [Exp] -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> Exp
forall {a}. ToExp a => a -> Exp
expNotZero
    mkArgs :: KernelArg -> CompilerM op s (VName, Maybe (VName, VName))
mkArgs (ValueKArg Exp
e PrimType
t) =
      (,Maybe (VName, VName)
forall a. Maybe a
Nothing) (VName -> (VName, Maybe (VName, VName)))
-> CompilerM op s VName
-> CompilerM op s (VName, Maybe (VName, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> PrimType -> Exp -> CompilerM op s VName
forall op s. [Char] -> PrimType -> Exp -> CompilerM op s VName
GC.compileExpToName [Char]
"kernel_arg" PrimType
t Exp
e
    mkArgs (MemKArg VName
v) = do
      Exp
v' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
GC.rawMem VName
v
      VName
arg <- [Char] -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"kernel_arg"
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|typename CUdeviceptr $id:arg = $exp:v';|]
      (VName, Maybe (VName, VName))
-> CompilerM op s (VName, Maybe (VName, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arg, Maybe (VName, VName)
forall a. Maybe a
Nothing)
    mkArgs (SharedMemoryKArg (Count Exp
c)) = do
      Exp
num_bytes <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
c
      VName
size <- [Char] -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"shared_size"
      VName
offset <- [Char] -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"shared_offset"
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|unsigned int $id:size = $exp:num_bytes;|]
      (VName, Maybe (VName, VName))
-> CompilerM op s (VName, Maybe (VName, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
offset, (VName, VName) -> Maybe (VName, VName)
forall a. a -> Maybe a
Just (VName
size, VName
offset))

    printSizes :: [Exp] -> [Stm]
printSizes =
      [Stm] -> [[Stm]] -> [Stm]
forall a. [a] -> [[a]] -> [a]
intercalate [[C.cstm|fprintf(ctx->log, ", ");|]] ([[Stm]] -> [Stm]) -> ([Exp] -> [[Stm]]) -> [Exp] -> [Stm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> [Stm]) -> [Exp] -> [[Stm]]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> [Stm]
forall {a}. ToExp a => a -> [Stm]
printSize
    printSize :: a -> [Stm]
printSize a
e =
      [[C.cstm|fprintf(ctx->log, "%ld", (long int)$exp:e);|]]