{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TupleSections #-}

-- | Code generation for CUDA.
module Futhark.CodeGen.Backends.CCUDA
  ( compileProg,
    GC.CParts (..),
    GC.asLibrary,
    GC.asExecutable,
    GC.asServer,
  )
where

import Control.Monad
import Data.List (intercalate)
import Data.Maybe (catMaybes)
import Futhark.CodeGen.Backends.CCUDA.Boilerplate
import Futhark.CodeGen.Backends.COpenCL.Boilerplate (commonOptions)
import qualified Futhark.CodeGen.Backends.GenericC as GC
import Futhark.CodeGen.Backends.GenericC.Options
import Futhark.CodeGen.ImpCode.OpenCL
import qualified Futhark.CodeGen.ImpGen.CUDA as ImpGen
import Futhark.IR.KernelsMem hiding
  ( CmpSizeLe,
    GetSize,
    GetSizeMax,
  )
import Futhark.MonadFreshNames
import qualified Language.C.Quote.OpenCL as C

-- | Compile the program to C with calls to CUDA.
compileProg :: MonadFreshNames m => Prog KernelsMem -> m (ImpGen.Warnings, GC.CParts)
compileProg :: Prog KernelsMem -> m (Warnings, CParts)
compileProg Prog KernelsMem
prog = do
  (Warnings
ws, Program String
cuda_code String
cuda_prelude Map KernelName KernelSafety
kernels [PrimType]
_ Map KernelName SizeClass
sizes [FailureMsg]
failures Definitions OpenCL
prog') <-
    Prog KernelsMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
Prog KernelsMem -> m (Warnings, Program)
ImpGen.compileProg Prog KernelsMem
prog
  let cost_centres :: [KernelName]
cost_centres =
        [ KernelName
copyDevToDev,
          KernelName
copyDevToHost,
          KernelName
copyHostToDev,
          KernelName
copyScalarToDev,
          KernelName
copyScalarFromDev
        ]
      extra :: CompilerM OpenCL () ()
extra =
        String
-> String
-> [KernelName]
-> Map KernelName KernelSafety
-> Map KernelName SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateBoilerplate
          String
cuda_code
          String
cuda_prelude
          [KernelName]
cost_centres
          Map KernelName KernelSafety
kernels
          Map KernelName SizeClass
sizes
          [FailureMsg]
failures
  (Warnings
ws,)
    (CParts -> (Warnings, CParts)) -> m CParts -> m (Warnings, CParts)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> Operations OpenCL ()
-> CompilerM OpenCL () ()
-> String
-> [Space]
-> [Option]
-> Definitions OpenCL
-> m CParts
forall (m :: * -> *) op.
MonadFreshNames m =>
String
-> Operations op ()
-> CompilerM op () ()
-> String
-> [Space]
-> [Option]
-> Definitions op
-> m CParts
GC.compileProg
      String
"cuda"
      Operations OpenCL ()
operations
      CompilerM OpenCL () ()
extra
      String
cuda_includes
      [String -> Space
Space String
"device", Space
DefaultSpace]
      [Option]
cliOptions
      Definitions OpenCL
prog'
  where
    operations :: GC.Operations OpenCL ()
    operations :: Operations OpenCL ()
operations =
      Operations OpenCL ()
forall op s. Operations op s
GC.defaultOperations
        { opsWriteScalar :: WriteScalar OpenCL ()
GC.opsWriteScalar = WriteScalar OpenCL ()
writeCUDAScalar,
          opsReadScalar :: ReadScalar OpenCL ()
GC.opsReadScalar = ReadScalar OpenCL ()
readCUDAScalar,
          opsAllocate :: Allocate OpenCL ()
GC.opsAllocate = Allocate OpenCL ()
allocateCUDABuffer,
          opsDeallocate :: Deallocate OpenCL ()
GC.opsDeallocate = Deallocate OpenCL ()
deallocateCUDABuffer,
          opsCopy :: Copy OpenCL ()
GC.opsCopy = Copy OpenCL ()
copyCUDAMemory,
          opsStaticArray :: StaticArray OpenCL ()
GC.opsStaticArray = StaticArray OpenCL ()
staticCUDAArray,
          opsMemoryType :: MemoryType OpenCL ()
GC.opsMemoryType = MemoryType OpenCL ()
cudaMemoryType,
          opsCompiler :: OpCompiler OpenCL ()
GC.opsCompiler = OpCompiler OpenCL ()
callKernel,
          opsFatMemory :: Bool
GC.opsFatMemory = Bool
True,
          opsCritical :: ([BlockItem], [BlockItem])
GC.opsCritical =
            ( [C.citems|CUDA_SUCCEED(cuCtxPushCurrent(ctx->cuda.cu_ctx));|],
              [C.citems|CUDA_SUCCEED(cuCtxPopCurrent(&ctx->cuda.cu_ctx));|]
            )
        }
    cuda_includes :: String
cuda_includes =
      [String] -> String
unlines
        [ String
"#include <cuda.h>",
          String
"#include <cuda_runtime.h>",
          String
"#include <nvrtc.h>"
        ]

cliOptions :: [Option]
cliOptions :: [Option]
cliOptions =
  [Option]
commonOptions
    [Option] -> [Option] -> [Option]
forall a. [a] -> [a] -> [a]
++ [ Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
           { optionLongName :: String
optionLongName = String
"dump-cuda",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"FILE",
             optionDescription :: String
optionDescription = String
"Dump the embedded CUDA kernels to the indicated file.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_dump_program_to(cfg, optarg);
                                     entry_point = NULL;}|]
           },
         Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
           { optionLongName :: String
optionLongName = String
"load-cuda",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"FILE",
             optionDescription :: String
optionDescription = String
"Instead of using the embedded CUDA kernels, load them from the indicated file.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_program_from(cfg, optarg);|]
           },
         Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
           { optionLongName :: String
optionLongName = String
"dump-ptx",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"FILE",
             optionDescription :: String
optionDescription = String
"Dump the PTX-compiled version of the embedded kernels to the indicated file.",
             optionAction :: Stm
optionAction =
               [C.cstm|{futhark_context_config_dump_ptx_to(cfg, optarg);
                                     entry_point = NULL;}|]
           },
         Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
           { optionLongName :: String
optionLongName = String
"load-ptx",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"FILE",
             optionDescription :: String
optionDescription = String
"Load PTX code from the indicated file.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_ptx_from(cfg, optarg);|]
           },
         Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
           { optionLongName :: String
optionLongName = String
"nvrtc-option",
             optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
             optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"OPT",
             optionDescription :: String
optionDescription = String
"Add an additional build option to the string passed to NVRTC.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_add_nvrtc_option(cfg, optarg);|]
           },
         Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
           { optionLongName :: String
optionLongName = String
"profile",
             optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'P',
             optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
             optionDescription :: String
optionDescription = String
"Gather profiling data while executing and print out a summary at the end.",
             optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_profiling(cfg, 1);|]
           }
       ]

writeCUDAScalar :: GC.WriteScalar OpenCL ()
writeCUDAScalar :: WriteScalar OpenCL ()
writeCUDAScalar Exp
mem Exp
idx Type
t String
"device" Volatility
_ Exp
val = do
  VName
val' <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_tmp"
  let ([BlockItem]
bef, [BlockItem]
aft) = KernelName -> ([BlockItem], [BlockItem])
profilingEnclosure KernelName
copyScalarToDev
  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.item
    [C.citem|{$ty:t $id:val' = $exp:val;
                  $items:bef
                  CUDA_SUCCEED(
                    cuMemcpyHtoD($exp:mem + $exp:idx * sizeof($ty:t),
                                 &$id:val',
                                 sizeof($ty:t)));
                  $items:aft
                 }|]
writeCUDAScalar Exp
_ Exp
_ Type
_ String
space Volatility
_ Exp
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot write to '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

readCUDAScalar :: GC.ReadScalar OpenCL ()
readCUDAScalar :: ReadScalar OpenCL ()
readCUDAScalar Exp
mem Exp
idx Type
t String
"device" Volatility
_ = do
  VName
val <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"read_res"
  let ([BlockItem]
bef, [BlockItem]
aft) = KernelName -> ([BlockItem], [BlockItem])
profilingEnclosure KernelName
copyScalarFromDev
  (BlockItem -> CompilerM OpenCL () ())
-> [BlockItem] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_
    BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.item
    [C.citems|
       $ty:t $id:val;
       {
       $items:bef
       CUDA_SUCCEED(
          cuMemcpyDtoH(&$id:val,
                       $exp:mem + $exp:idx * sizeof($ty:t),
                       sizeof($ty:t)));
        $items:aft
       }
       |]
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|if (futhark_context_sync(ctx) != 0) { return 1; }|]
  Exp -> CompilerM OpenCL () Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:val|]
readCUDAScalar Exp
_ Exp
_ Type
_ String
space Volatility
_ =
  String -> CompilerM OpenCL () Exp
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () Exp)
-> String -> CompilerM OpenCL () Exp
forall a b. (a -> b) -> a -> b
$ String
"Cannot write to '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

allocateCUDABuffer :: GC.Allocate OpenCL ()
allocateCUDABuffer :: Allocate OpenCL ()
allocateCUDABuffer Exp
mem Exp
size Exp
tag String
"device" =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|CUDA_SUCCEED(cuda_alloc(&ctx->cuda, $exp:size, $exp:tag, &$exp:mem));|]
allocateCUDABuffer Exp
_ Exp
_ Exp
_ String
space =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot allocate in '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

deallocateCUDABuffer :: GC.Deallocate OpenCL ()
deallocateCUDABuffer :: Deallocate OpenCL ()
deallocateCUDABuffer Exp
mem Exp
tag String
"device" =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|CUDA_SUCCEED(cuda_free(&ctx->cuda, $exp:mem, $exp:tag));|]
deallocateCUDABuffer Exp
_ Exp
_ String
space =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot deallocate in '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

copyCUDAMemory :: GC.Copy OpenCL ()
copyCUDAMemory :: Copy OpenCL ()
copyCUDAMemory Exp
dstmem Exp
dstidx Space
dstSpace Exp
srcmem Exp
srcidx Space
srcSpace Exp
nbytes = do
  let (String
fn, KernelName
prof) = Space -> Space -> (String, KernelName)
memcpyFun Space
dstSpace Space
srcSpace
      ([BlockItem]
bef, [BlockItem]
aft) = KernelName -> ([BlockItem], [BlockItem])
profilingEnclosure KernelName
prof
  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.item
    [C.citem|{
                $items:bef
                CUDA_SUCCEED(
                  $id:fn($exp:dstmem + $exp:dstidx,
                         $exp:srcmem + $exp:srcidx,
                         $exp:nbytes));
                $items:aft
                }
                |]
  where
    memcpyFun :: Space -> Space -> (String, KernelName)
memcpyFun Space
DefaultSpace (Space String
"device") = (String
"cuMemcpyDtoH", KernelName
copyDevToHost)
    memcpyFun (Space String
"device") Space
DefaultSpace = (String
"cuMemcpyHtoD", KernelName
copyHostToDev)
    memcpyFun (Space String
"device") (Space String
"device") = (String
"cuMemcpy", KernelName
copyDevToDev)
    memcpyFun Space
_ Space
_ =
      String -> (String, KernelName)
forall a. HasCallStack => String -> a
error (String -> (String, KernelName)) -> String -> (String, KernelName)
forall a b. (a -> b) -> a -> b
$
        String
"Cannot copy to '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
forall a. Show a => a -> String
show Space
dstSpace
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' from '"
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
forall a. Show a => a -> String
show Space
srcSpace
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'."

staticCUDAArray :: GC.StaticArray OpenCL ()
staticCUDAArray :: StaticArray OpenCL ()
staticCUDAArray VName
name String
"device" PrimType
t ArrayContents
vs = do
  let ct :: Type
ct = PrimType -> Type
GC.primTypeToCType PrimType
t
  VName
name_realtype <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM OpenCL () VName)
-> String -> CompilerM OpenCL () VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_realtype"
  Int
num_elems <- case ArrayContents
vs of
    ArrayValues [PrimValue]
vs' -> do
      let vs'' :: [Initializer]
vs'' = [[C.cinit|$exp:v|] | Exp
v <- (PrimValue -> Exp) -> [PrimValue] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> Exp
GC.compilePrimValue [PrimValue]
vs']
      Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:(length vs'')] = {$inits:vs''};|]
      Int -> CompilerM OpenCL () Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> CompilerM OpenCL () Int) -> Int -> CompilerM OpenCL () Int
forall a b. (a -> b) -> a -> b
$ [Initializer] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Initializer]
vs''
    ArrayZeros Int
n -> do
      Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|]
      Int -> CompilerM OpenCL () Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n
  -- Fake a memory block.
  Id -> Type -> Maybe Exp -> CompilerM OpenCL () ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
GC.contextField (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
name SrcLoc
forall a. Monoid a => a
mempty) [C.cty|struct memblock_device|] Maybe Exp
forall a. Maybe a
Nothing
  -- During startup, copy the data to where we need it.
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.atInit
    [C.cstm|{
    ctx->$id:name.references = NULL;
    ctx->$id:name.size = 0;
    CUDA_SUCCEED(cuMemAlloc(&ctx->$id:name.mem,
                            ($int:num_elems > 0 ? $int:num_elems : 1)*sizeof($ty:ct)));
    if ($int:num_elems > 0) {
      CUDA_SUCCEED(cuMemcpyHtoD(ctx->$id:name.mem, $id:name_realtype,
                                $int:num_elems*sizeof($ty:ct)));
    }
  }|]
  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|struct memblock_device $id:name = ctx->$id:name;|]
staticCUDAArray VName
_ String
space PrimType
_ ArrayContents
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    String
"CUDA backend cannot create static array in '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space
      String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space"

cudaMemoryType :: GC.MemoryType OpenCL ()
cudaMemoryType :: MemoryType OpenCL ()
cudaMemoryType String
"device" = Type -> CompilerM OpenCL () Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|typename CUdeviceptr|]
cudaMemoryType String
space =
  MemoryType OpenCL ()
forall a. HasCallStack => String -> a
error MemoryType OpenCL () -> MemoryType OpenCL ()
forall a b. (a -> b) -> a -> b
$ String
"CUDA backend does not support '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

callKernel :: GC.OpCompiler OpenCL ()
callKernel :: OpCompiler OpenCL ()
callKernel (GetSize VName
v KernelName
key) =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->sizes.$id:key;|]
callKernel (CmpSizeLe VName
v KernelName
key Exp
x) = do
  Exp
x' <- Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->sizes.$id:key <= $exp:x';|]
callKernel (GetSizeMax VName
v SizeClass
size_class) =
  let field :: String
field = String
"max_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ SizeClass -> String
cudaSizeClass SizeClass
size_class
   in Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->cuda.$id:field;|]
  where
    cudaSizeClass :: SizeClass -> String
cudaSizeClass SizeThreshold {} = String
"threshold"
    cudaSizeClass SizeClass
SizeGroup = String
"block_size"
    cudaSizeClass SizeClass
SizeNumGroups = String
"grid_size"
    cudaSizeClass SizeClass
SizeTile = String
"tile_size"
    cudaSizeClass SizeClass
SizeRegTile = String
"reg_tile_size"
    cudaSizeClass SizeClass
SizeLocalMemory = String
"shared_memory"
    cudaSizeClass (SizeBespoke KernelName
x Int64
_) = KernelName -> String
forall a. Pretty a => a -> String
pretty KernelName
x
callKernel (LaunchKernel KernelSafety
safety KernelName
kernel_name [KernelArg]
args [Exp]
num_blocks [Exp]
block_size) = do
  VName
args_arr <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"kernel_args"
  VName
time_start <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"time_start"
  VName
time_end <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"time_end"
  ([VName]
args', [Maybe (VName, VName)]
shared_vars) <- [(VName, Maybe (VName, VName))]
-> ([VName], [Maybe (VName, VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Maybe (VName, VName))]
 -> ([VName], [Maybe (VName, VName)]))
-> CompilerM OpenCL () [(VName, Maybe (VName, VName))]
-> CompilerM OpenCL () ([VName], [Maybe (VName, VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (KernelArg -> CompilerM OpenCL () (VName, Maybe (VName, VName)))
-> [KernelArg]
-> CompilerM OpenCL () [(VName, Maybe (VName, VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelArg -> CompilerM OpenCL () (VName, Maybe (VName, VName))
forall op s.
KernelArg -> CompilerM op s (VName, Maybe (VName, VName))
mkArgs [KernelArg]
args
  let ([VName]
shared_sizes, [VName]
shared_offsets) = [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, VName)] -> ([VName], [VName]))
-> [(VName, VName)] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ [Maybe (VName, VName)] -> [(VName, VName)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (VName, VName)]
shared_vars
      shared_offsets_sc :: [Exp]
shared_offsets_sc = [VName] -> [Exp]
mkOffsets [VName]
shared_sizes
      shared_args :: [(VName, Exp)]
shared_args = [VName] -> [Exp] -> [(VName, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
shared_offsets [Exp]
shared_offsets_sc
      shared_tot :: Exp
shared_tot = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
shared_offsets_sc
  [(VName, Exp)]
-> ((VName, Exp) -> CompilerM OpenCL () ())
-> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(VName, Exp)]
shared_args (((VName, Exp) -> CompilerM OpenCL () ())
 -> CompilerM OpenCL () ())
-> ((VName, Exp) -> CompilerM OpenCL () ())
-> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \(VName
arg, Exp
offset) ->
    InitGroup -> CompilerM OpenCL () ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|unsigned int $id:arg = $exp:offset;|]

  (Exp
grid_x, Exp
grid_y, Exp
grid_z) <- [Exp] -> (Exp, Exp, Exp)
mkDims ([Exp] -> (Exp, Exp, Exp))
-> CompilerM OpenCL () [Exp] -> CompilerM OpenCL () (Exp, Exp, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> CompilerM OpenCL () Exp)
-> [Exp] -> CompilerM OpenCL () [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
num_blocks
  (Exp
block_x, Exp
block_y, Exp
block_z) <- [Exp] -> (Exp, Exp, Exp)
mkDims ([Exp] -> (Exp, Exp, Exp))
-> CompilerM OpenCL () [Exp] -> CompilerM OpenCL () (Exp, Exp, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> CompilerM OpenCL () Exp)
-> [Exp] -> CompilerM OpenCL () [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
block_size
  let perm_args :: [Initializer]
perm_args
        | [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
num_blocks Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 = [[C.cinit|&perm[0]|], [C.cinit|&perm[1]|], [C.cinit|&perm[2]|]]
        | Bool
otherwise = []
      failure_args :: [Initializer]
failure_args =
        Int -> [Initializer] -> [Initializer]
forall a. Int -> [a] -> [a]
take
          (KernelSafety -> Int
numFailureParams KernelSafety
safety)
          [ [C.cinit|&ctx->global_failure|],
            [C.cinit|&ctx->failure_is_an_option|],
            [C.cinit|&ctx->global_failure_args|]
          ]
      args'' :: [Initializer]
args'' = [Initializer]
perm_args [Initializer] -> [Initializer] -> [Initializer]
forall a. [a] -> [a] -> [a]
++ [Initializer]
failure_args [Initializer] -> [Initializer] -> [Initializer]
forall a. [a] -> [a] -> [a]
++ [[C.cinit|&$id:a|] | VName
a <- [VName]
args']
      sizes_nonzero :: Exp
sizes_nonzero =
        [Exp] -> Exp
expsNotZero
          [ Exp
grid_x,
            Exp
grid_y,
            Exp
grid_z,
            Exp
block_x,
            Exp
block_y,
            Exp
block_z
          ]
      ([BlockItem]
bef, [BlockItem]
aft) = KernelName -> ([BlockItem], [BlockItem])
profilingEnclosure KernelName
kernel_name

  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|
    if ($exp:sizes_nonzero) {
      int perm[3] = { 0, 1, 2 };

      if ($exp:grid_y >= (1<<16)) {
        perm[1] = perm[0];
        perm[0] = 1;
      }

      if ($exp:grid_z >= (1<<16)) {
        perm[2] = perm[0];
        perm[0] = 2;
      }

      size_t grid[3];
      grid[perm[0]] = $exp:grid_x;
      grid[perm[1]] = $exp:grid_y;
      grid[perm[2]] = $exp:grid_z;

      void *$id:args_arr[] = { $inits:args'' };
      typename int64_t $id:time_start = 0, $id:time_end = 0;
      if (ctx->debugging) {
        fprintf(ctx->log, "Launching %s with grid size (", $string:(pretty kernel_name));
        $stms:(printSizes [grid_x, grid_y, grid_z])
        fprintf(ctx->log, ") and block size (");
        $stms:(printSizes [block_x, block_y, block_z])
        fprintf(ctx->log, ").\n");
        $id:time_start = get_wall_time();
      }
      $items:bef
      CUDA_SUCCEED(
        cuLaunchKernel(ctx->$id:kernel_name,
                       grid[0], grid[1], grid[2],
                       $exp:block_x, $exp:block_y, $exp:block_z,
                       $exp:shared_tot, NULL,
                       $id:args_arr, NULL));
      $items:aft
      if (ctx->debugging) {
        CUDA_SUCCEED(cuCtxSynchronize());
        $id:time_end = get_wall_time();
        fprintf(ctx->log, "Kernel %s runtime: %ldus\n",
                $string:(pretty kernel_name), $id:time_end - $id:time_start);
      }
    }|]

  Bool -> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (KernelSafety
safety KernelSafety -> KernelSafety -> Bool
forall a. Ord a => a -> a -> Bool
>= KernelSafety
SafetyFull) (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|ctx->failure_is_an_option = 1;|]
  where
    mkDims :: [Exp] -> (Exp, Exp, Exp)
mkDims [] = ([C.cexp|0|], [C.cexp|0|], [C.cexp|0|])
    mkDims [Exp
x] = (Exp
x, [C.cexp|1|], [C.cexp|1|])
    mkDims [Exp
x, Exp
y] = (Exp
x, Exp
y, [C.cexp|1|])
    mkDims (Exp
x : Exp
y : Exp
z : [Exp]
_) = (Exp
x, Exp
y, Exp
z)
    addExp :: a -> a -> Exp
addExp a
x a
y = [C.cexp|$exp:x + $exp:y|]
    alignExp :: a -> Exp
alignExp a
e = [C.cexp|$exp:e + ((8 - ($exp:e % 8)) % 8)|]
    mkOffsets :: [VName] -> [Exp]
mkOffsets = (Exp -> VName -> Exp) -> Exp -> [VName] -> [Exp]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\Exp
a VName
b -> Exp
a Exp -> Exp -> Exp
forall a a. (ToExp a, ToExp a) => a -> a -> Exp
`addExp` VName -> Exp
forall a. ToExp a => a -> Exp
alignExp VName
b) [C.cexp|0|]
    expNotZero :: a -> Exp
expNotZero a
e = [C.cexp|$exp:e != 0|]
    expAnd :: a -> a -> Exp
expAnd a
a a
b = [C.cexp|$exp:a && $exp:b|]
    expsNotZero :: [Exp] -> Exp
expsNotZero = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall a a. (ToExp a, ToExp a) => a -> a -> Exp
expAnd [C.cexp|1|] ([Exp] -> Exp) -> ([Exp] -> [Exp]) -> [Exp] -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> Exp
forall a. ToExp a => a -> Exp
expNotZero
    mkArgs :: KernelArg -> CompilerM op s (VName, Maybe (VName, VName))
mkArgs (ValueKArg Exp
e PrimType
t) =
      (,Maybe (VName, VName)
forall a. Maybe a
Nothing) (VName -> (VName, Maybe (VName, VName)))
-> CompilerM op s VName
-> CompilerM op s (VName, Maybe (VName, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimType -> Exp -> CompilerM op s VName
forall op s. String -> PrimType -> Exp -> CompilerM op s VName
GC.compileExpToName String
"kernel_arg" PrimType
t Exp
e
    mkArgs (MemKArg VName
v) = do
      Exp
v' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
GC.rawMem VName
v
      VName
arg <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"kernel_arg"
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|typename CUdeviceptr $id:arg = $exp:v';|]
      (VName, Maybe (VName, VName))
-> CompilerM op s (VName, Maybe (VName, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arg, Maybe (VName, VName)
forall a. Maybe a
Nothing)
    mkArgs (SharedMemoryKArg (Count Exp
c)) = do
      Exp
num_bytes <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
c
      VName
size <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"shared_size"
      VName
offset <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"shared_offset"
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|unsigned int $id:size = $exp:num_bytes;|]
      (VName, Maybe (VName, VName))
-> CompilerM op s (VName, Maybe (VName, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
offset, (VName, VName) -> Maybe (VName, VName)
forall a. a -> Maybe a
Just (VName
size, VName
offset))

    printSizes :: [Exp] -> [Stm]
printSizes =
      [Stm] -> [[Stm]] -> [Stm]
forall a. [a] -> [[a]] -> [a]
intercalate [[C.cstm|fprintf(ctx->log, ", ");|]] ([[Stm]] -> [Stm]) -> ([Exp] -> [[Stm]]) -> [Exp] -> [Stm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> [Stm]) -> [Exp] -> [[Stm]]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> [Stm]
forall a. ToExp a => a -> [Stm]
printSize
    printSize :: a -> [Stm]
printSize a
e =
      [[C.cstm|fprintf(ctx->log, "%ld", (long int)$exp:e);|]]