{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TupleSections #-}
-- | Code generation for CUDA.
module Futhark.CodeGen.Backends.CCUDA
  ( compileProg
  , GC.CParts(..)
  , GC.asLibrary
  , GC.asExecutable
  ) where

import Control.Monad
import Data.List (intercalate)
import Data.Maybe (catMaybes)
import qualified Language.C.Quote.OpenCL as C

import qualified Futhark.CodeGen.Backends.GenericC as GC
import qualified Futhark.CodeGen.ImpGen.CUDA as ImpGen
import Futhark.IR.KernelsMem
  hiding (GetSize, CmpSizeLe, GetSizeMax)
import Futhark.MonadFreshNames
import Futhark.CodeGen.ImpCode.OpenCL
import Futhark.CodeGen.Backends.COpenCL.Boilerplate (commonOptions)
import Futhark.CodeGen.Backends.CCUDA.Boilerplate
import Futhark.CodeGen.Backends.GenericC.Options

-- | Compile the program to C with calls to CUDA.
compileProg :: MonadFreshNames m => Prog KernelsMem -> m GC.CParts
compileProg :: Prog KernelsMem -> m CParts
compileProg Prog KernelsMem
prog = do
  (Program String
cuda_code String
cuda_prelude Map String Safety
kernel_names [PrimType]
_ Map Name SizeClass
sizes [FailureMsg]
failures Definitions OpenCL
prog') <-
    Prog KernelsMem -> m Program
forall (m :: * -> *).
MonadFreshNames m =>
Prog KernelsMem -> m Program
ImpGen.compileProg Prog KernelsMem
prog
  let extra :: CompilerM OpenCL () ()
extra = String
-> String
-> Map String Safety
-> Map Name SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateBoilerplate String
cuda_code String
cuda_prelude
              Map String Safety
kernel_names Map Name SizeClass
sizes [FailureMsg]
failures
  Operations OpenCL ()
-> CompilerM OpenCL () ()
-> String
-> [Space]
-> [Option]
-> Definitions OpenCL
-> m CParts
forall (m :: * -> *) op.
MonadFreshNames m =>
Operations op ()
-> CompilerM op () ()
-> String
-> [Space]
-> [Option]
-> Definitions op
-> m CParts
GC.compileProg Operations OpenCL ()
operations CompilerM OpenCL () ()
extra String
cuda_includes
    [String -> Space
Space String
"device", Space
DefaultSpace] [Option]
cliOptions Definitions OpenCL
prog'
  where
    operations :: GC.Operations OpenCL ()
    operations :: Operations OpenCL ()
operations = Operations OpenCL ()
forall op s. Operations op s
GC.defaultOperations
                 { opsWriteScalar :: WriteScalar OpenCL ()
GC.opsWriteScalar = WriteScalar OpenCL ()
writeCUDAScalar
                 , opsReadScalar :: ReadScalar OpenCL ()
GC.opsReadScalar  = ReadScalar OpenCL ()
readCUDAScalar
                 , opsAllocate :: Allocate OpenCL ()
GC.opsAllocate    = Allocate OpenCL ()
allocateCUDABuffer
                 , opsDeallocate :: Deallocate OpenCL ()
GC.opsDeallocate  = Deallocate OpenCL ()
deallocateCUDABuffer
                 , opsCopy :: Copy OpenCL ()
GC.opsCopy        = Copy OpenCL ()
copyCUDAMemory
                 , opsStaticArray :: StaticArray OpenCL ()
GC.opsStaticArray = StaticArray OpenCL ()
staticCUDAArray
                 , opsMemoryType :: MemoryType OpenCL ()
GC.opsMemoryType  = MemoryType OpenCL ()
cudaMemoryType
                 , opsCompiler :: OpCompiler OpenCL ()
GC.opsCompiler    = OpCompiler OpenCL ()
callKernel
                 , opsFatMemory :: Bool
GC.opsFatMemory   = Bool
True
                 }
    cuda_includes :: String
cuda_includes = [String] -> String
unlines [ String
"#include <cuda.h>"
                            , String
"#include <nvrtc.h>"
                            ]

cliOptions :: [Option]
cliOptions :: [Option]
cliOptions =
  [Option]
commonOptions [Option] -> [Option] -> [Option]
forall a. [a] -> [a] -> [a]
++
  [ Option :: String -> Maybe Char -> OptionArgument -> Stm -> Option
Option { optionLongName :: String
optionLongName = String
"dump-cuda"
           , optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing
           , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"FILE"
           , optionAction :: Stm
optionAction = [C.cstm|{futhark_context_config_dump_program_to(cfg, optarg);
                                     entry_point = NULL;}|]
           }
  , Option :: String -> Maybe Char -> OptionArgument -> Stm -> Option
Option { optionLongName :: String
optionLongName = String
"load-cuda"
           , optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing
           , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"FILE"
           , optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_program_from(cfg, optarg);|]
           }
  , Option :: String -> Maybe Char -> OptionArgument -> Stm -> Option
Option { optionLongName :: String
optionLongName = String
"dump-ptx"
           , optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing
           , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"FILE"
           , optionAction :: Stm
optionAction = [C.cstm|{futhark_context_config_dump_ptx_to(cfg, optarg);
                                     entry_point = NULL;}|]
           }
  , Option :: String -> Maybe Char -> OptionArgument -> Stm -> Option
Option { optionLongName :: String
optionLongName = String
"load-ptx"
           , optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing
           , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"FILE"
           , optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_load_ptx_from(cfg, optarg);|]
           }
  , Option :: String -> Maybe Char -> OptionArgument -> Stm -> Option
Option { optionLongName :: String
optionLongName = String
"nvrtc-option"
           , optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing
           , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"OPT"
           , optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_add_nvrtc_option(cfg, optarg);|]
           }
  ]

writeCUDAScalar :: GC.WriteScalar OpenCL ()
writeCUDAScalar :: WriteScalar OpenCL ()
writeCUDAScalar Exp
mem Exp
idx Type
t String
"device" Volatility
_ Exp
val = do
  VName
val' <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_tmp"
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|{$ty:t $id:val' = $exp:val;
                  CUDA_SUCCEED(
                    cuMemcpyHtoD($exp:mem + $exp:idx * sizeof($ty:t),
                                 &$id:val',
                                 sizeof($ty:t)));
                 }|]
writeCUDAScalar Exp
_ Exp
_ Type
_ String
space Volatility
_ Exp
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot write to '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

readCUDAScalar :: GC.ReadScalar OpenCL ()
readCUDAScalar :: ReadScalar OpenCL ()
readCUDAScalar Exp
mem Exp
idx Type
t String
"device" Volatility
_ = do
  VName
val <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"read_res"
  InitGroup -> CompilerM OpenCL () ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|$ty:t $id:val;|]
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|CUDA_SUCCEED(
                   cuMemcpyDtoH(&$id:val,
                                $exp:mem + $exp:idx * sizeof($ty:t),
                                sizeof($ty:t)));
                |]
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|if (futhark_context_sync(ctx) != 0) { return 1; }|]
  Exp -> CompilerM OpenCL () Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:val|]
readCUDAScalar Exp
_ Exp
_ Type
_ String
space Volatility
_ =
  String -> CompilerM OpenCL () Exp
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () Exp)
-> String -> CompilerM OpenCL () Exp
forall a b. (a -> b) -> a -> b
$ String
"Cannot write to '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

allocateCUDABuffer :: GC.Allocate OpenCL ()
allocateCUDABuffer :: Allocate OpenCL ()
allocateCUDABuffer Exp
mem Exp
size Exp
tag String
"device" =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|CUDA_SUCCEED(cuda_alloc(&ctx->cuda, $exp:size, $exp:tag, &$exp:mem));|]
allocateCUDABuffer Exp
_ Exp
_ Exp
_ String
space =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot allocate in '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

deallocateCUDABuffer :: GC.Deallocate OpenCL ()
deallocateCUDABuffer :: Deallocate OpenCL ()
deallocateCUDABuffer Exp
mem Exp
tag String
"device" =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|CUDA_SUCCEED(cuda_free(&ctx->cuda, $exp:mem, $exp:tag));|]
deallocateCUDABuffer Exp
_ Exp
_ String
space =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot deallocate in '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

copyCUDAMemory :: GC.Copy OpenCL ()
copyCUDAMemory :: Copy OpenCL ()
copyCUDAMemory Exp
dstmem Exp
dstidx Space
dstSpace Exp
srcmem Exp
srcidx Space
srcSpace Exp
nbytes = do
  String
fn <- Space -> Space -> CompilerM OpenCL () String
forall (m :: * -> *). Monad m => Space -> Space -> m String
memcpyFun Space
dstSpace Space
srcSpace
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|CUDA_SUCCEED(
                  $id:fn($exp:dstmem + $exp:dstidx,
                         $exp:srcmem + $exp:srcidx,
                         $exp:nbytes));
                |]
  where
    memcpyFun :: Space -> Space -> m String
memcpyFun Space
DefaultSpace (Space String
"device")     = String -> m String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"cuMemcpyDtoH"
    memcpyFun (Space String
"device") Space
DefaultSpace     = String -> m String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"cuMemcpyHtoD"
    memcpyFun (Space String
"device") (Space String
"device") = String -> m String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"cuMemcpy"
    memcpyFun Space
_ Space
_ = String -> m String
forall a. HasCallStack => String -> a
error (String -> m String) -> String -> m String
forall a b. (a -> b) -> a -> b
$ String
"Cannot copy to '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
forall a. Show a => a -> String
show Space
dstSpace
                           String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' from '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
forall a. Show a => a -> String
show Space
srcSpace String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'."

staticCUDAArray :: GC.StaticArray OpenCL ()
staticCUDAArray :: StaticArray OpenCL ()
staticCUDAArray VName
name String
"device" PrimType
t ArrayContents
vs = do
  let ct :: Type
ct = PrimType -> Type
GC.primTypeToCType PrimType
t
  VName
name_realtype <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM OpenCL () VName)
-> String -> CompilerM OpenCL () VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_realtype"
  Int
num_elems <- case ArrayContents
vs of
    ArrayValues [PrimValue]
vs' -> do
      let vs'' :: [Initializer]
vs'' = [[C.cinit|$exp:v|] | Exp
v <- (PrimValue -> Exp) -> [PrimValue] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> Exp
GC.compilePrimValue [PrimValue]
vs']
      Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:(length vs'')] = {$inits:vs''};|]
      Int -> CompilerM OpenCL () Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> CompilerM OpenCL () Int) -> Int -> CompilerM OpenCL () Int
forall a b. (a -> b) -> a -> b
$ [Initializer] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Initializer]
vs''
    ArrayZeros Int
n -> do
      Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|]
      Int -> CompilerM OpenCL () Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n
  -- Fake a memory block.
  String -> Type -> Maybe Exp -> CompilerM OpenCL () ()
forall op s. String -> Type -> Maybe Exp -> CompilerM op s ()
GC.contextField (VName -> String
forall a. Pretty a => a -> String
pretty VName
name) [C.cty|struct memblock_device|] Maybe Exp
forall a. Maybe a
Nothing
  -- During startup, copy the data to where we need it.
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.atInit [C.cstm|{
    ctx->$id:name.references = NULL;
    ctx->$id:name.size = 0;
    CUDA_SUCCEED(cuMemAlloc(&ctx->$id:name.mem,
                            ($int:num_elems > 0 ? $int:num_elems : 1)*sizeof($ty:ct)));
    if ($int:num_elems > 0) {
      CUDA_SUCCEED(cuMemcpyHtoD(ctx->$id:name.mem, $id:name_realtype,
                                $int:num_elems*sizeof($ty:ct)));
    }
  }|]
  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|struct memblock_device $id:name = ctx->$id:name;|]
staticCUDAArray VName
_ String
space PrimType
_ ArrayContents
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"CUDA backend cannot create static array in '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space"

cudaMemoryType :: GC.MemoryType OpenCL ()
cudaMemoryType :: MemoryType OpenCL ()
cudaMemoryType String
"device" = Type -> CompilerM OpenCL () Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|typename CUdeviceptr|]
cudaMemoryType String
space =
  MemoryType OpenCL ()
forall a. HasCallStack => String -> a
error MemoryType OpenCL () -> MemoryType OpenCL ()
forall a b. (a -> b) -> a -> b
$ String
"CUDA backend does not support '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

callKernel :: GC.OpCompiler OpenCL ()
callKernel :: OpCompiler OpenCL ()
callKernel (GetSize VName
v Name
key) =
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->sizes.$id:key;|]
callKernel (CmpSizeLe VName
v Name
key Exp
x) = do
  Exp
x' <- Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->sizes.$id:key <= $exp:x';|]
callKernel (GetSizeMax VName
v SizeClass
size_class) =
  let field :: String
field = String
"max_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ SizeClass -> String
cudaSizeClass SizeClass
size_class
  in Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = ctx->cuda.$id:field;|]
  where
    cudaSizeClass :: SizeClass -> String
cudaSizeClass (SizeThreshold KernelPath
_) = String
"threshold"
    cudaSizeClass SizeClass
SizeGroup = String
"block_size"
    cudaSizeClass SizeClass
SizeNumGroups = String
"grid_size"
    cudaSizeClass SizeClass
SizeTile = String
"tile_size"
    cudaSizeClass SizeClass
SizeLocalMemory = String
"shared_memory"
    cudaSizeClass (SizeBespoke Name
x Int32
_) = Name -> String
forall a. Pretty a => a -> String
pretty Name
x
callKernel (LaunchKernel Safety
safety String
name [KernelArg]
args [Exp]
num_blocks [Exp]
block_size) = do
  VName
args_arr <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"kernel_args"
  VName
time_start <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"time_start"
  VName
time_end <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"time_end"
  ([VName]
args', [Maybe (VName, VName)]
shared_vars) <- [(VName, Maybe (VName, VName))]
-> ([VName], [Maybe (VName, VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Maybe (VName, VName))]
 -> ([VName], [Maybe (VName, VName)]))
-> CompilerM OpenCL () [(VName, Maybe (VName, VName))]
-> CompilerM OpenCL () ([VName], [Maybe (VName, VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (KernelArg -> CompilerM OpenCL () (VName, Maybe (VName, VName)))
-> [KernelArg]
-> CompilerM OpenCL () [(VName, Maybe (VName, VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelArg -> CompilerM OpenCL () (VName, Maybe (VName, VName))
forall op s.
KernelArg -> CompilerM op s (VName, Maybe (VName, VName))
mkArgs [KernelArg]
args
  let ([VName]
shared_sizes, [VName]
shared_offsets) = [(VName, VName)] -> ([VName], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, VName)] -> ([VName], [VName]))
-> [(VName, VName)] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ [Maybe (VName, VName)] -> [(VName, VName)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (VName, VName)]
shared_vars
      shared_offsets_sc :: [Exp]
shared_offsets_sc = [VName] -> [Exp]
mkOffsets [VName]
shared_sizes
      shared_args :: [(VName, Exp)]
shared_args = [VName] -> [Exp] -> [(VName, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
shared_offsets [Exp]
shared_offsets_sc
      shared_tot :: Exp
shared_tot = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
shared_offsets_sc
  ((VName, Exp) -> CompilerM OpenCL () ())
-> [(VName, Exp)] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(VName
arg,Exp
offset) ->
           InitGroup -> CompilerM OpenCL () ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|unsigned int $id:arg = $exp:offset;|]
        ) [(VName, Exp)]
shared_args

  (Exp
grid_x, Exp
grid_y, Exp
grid_z) <- [Exp] -> (Exp, Exp, Exp)
mkDims ([Exp] -> (Exp, Exp, Exp))
-> CompilerM OpenCL () [Exp] -> CompilerM OpenCL () (Exp, Exp, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> CompilerM OpenCL () Exp)
-> [Exp] -> CompilerM OpenCL () [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
num_blocks
  (Exp
block_x, Exp
block_y, Exp
block_z) <- [Exp] -> (Exp, Exp, Exp)
mkDims ([Exp] -> (Exp, Exp, Exp))
-> CompilerM OpenCL () [Exp] -> CompilerM OpenCL () (Exp, Exp, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> CompilerM OpenCL () Exp)
-> [Exp] -> CompilerM OpenCL () [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> CompilerM OpenCL () Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp [Exp]
block_size
  let perm_args :: [Initializer]
perm_args
        | [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
num_blocks Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 = [ [C.cinit|&perm[0]|], [C.cinit|&perm[1]|], [C.cinit|&perm[2]|] ]
        | Bool
otherwise = []
      failure_args :: [Initializer]
failure_args = Int -> [Initializer] -> [Initializer]
forall a. Int -> [a] -> [a]
take (Safety -> Int
numFailureParams Safety
safety)
                     [[C.cinit|&ctx->global_failure|],
                      [C.cinit|&ctx->failure_is_an_option|],
                      [C.cinit|&ctx->global_failure_args|]]
      args'' :: [Initializer]
args'' = [Initializer]
perm_args [Initializer] -> [Initializer] -> [Initializer]
forall a. [a] -> [a] -> [a]
++ [Initializer]
failure_args [Initializer] -> [Initializer] -> [Initializer]
forall a. [a] -> [a] -> [a]
++ [ [C.cinit|&$id:a|] | VName
a <- [VName]
args' ]
      sizes_nonzero :: Exp
sizes_nonzero = [Exp] -> Exp
expsNotZero [Exp
grid_x, Exp
grid_y, Exp
grid_z,
                      Exp
block_x, Exp
block_y, Exp
block_z]

  Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|
    if ($exp:sizes_nonzero) {
      int perm[3] = { 0, 1, 2 };

      if ($exp:grid_y > (1<<16)) {
        perm[1] = perm[0];
        perm[0] = 1;
      }

      if ($exp:grid_z > (1<<16)) {
        perm[2] = perm[0];
        perm[0] = 2;
      }

      size_t grid[3];
      grid[perm[0]] = $exp:grid_x;
      grid[perm[1]] = $exp:grid_y;
      grid[perm[2]] = $exp:grid_z;

      void *$id:args_arr[] = { $inits:args'' };
      typename int64_t $id:time_start = 0, $id:time_end = 0;
      if (ctx->debugging) {
        fprintf(stderr, "Launching %s with grid size (", $string:name);
        $stms:(printSizes [grid_x, grid_y, grid_z])
        fprintf(stderr, ") and block size (");
        $stms:(printSizes [block_x, block_y, block_z])
        fprintf(stderr, ").\n");
        $id:time_start = get_wall_time();
      }
      CUDA_SUCCEED(
        cuLaunchKernel(ctx->$id:name,
                       grid[0], grid[1], grid[2],
                       $exp:block_x, $exp:block_y, $exp:block_z,
                       $exp:shared_tot, NULL,
                       $id:args_arr, NULL));
      if (ctx->debugging) {
        CUDA_SUCCEED(cuCtxSynchronize());
        $id:time_end = get_wall_time();
        fprintf(stderr, "Kernel %s runtime: %ldus\n",
                $string:name, $id:time_end - $id:time_start);
      }
    }|]

  Bool -> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Safety
safety Safety -> Safety -> Bool
forall a. Ord a => a -> a -> Bool
>= Safety
SafetyFull) (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    Stm -> CompilerM OpenCL () ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|ctx->failure_is_an_option = 1;|]

  where
    mkDims :: [Exp] -> (Exp, Exp, Exp)
mkDims [] = ([C.cexp|0|] , [C.cexp|0|], [C.cexp|0|])
    mkDims [Exp
x] = (Exp
x, [C.cexp|1|], [C.cexp|1|])
    mkDims [Exp
x,Exp
y] = (Exp
x, Exp
y, [C.cexp|1|])
    mkDims (Exp
x:Exp
y:Exp
z:[Exp]
_) = (Exp
x, Exp
y, Exp
z)
    addExp :: a -> a -> Exp
addExp a
x a
y = [C.cexp|$exp:x + $exp:y|]
    alignExp :: a -> Exp
alignExp a
e = [C.cexp|$exp:e + ((8 - ($exp:e % 8)) % 8)|]
    mkOffsets :: [VName] -> [Exp]
mkOffsets = (Exp -> VName -> Exp) -> Exp -> [VName] -> [Exp]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl (\Exp
a VName
b -> Exp
a Exp -> Exp -> Exp
forall a a. (ToExp a, ToExp a) => a -> a -> Exp
`addExp` VName -> Exp
forall a. ToExp a => a -> Exp
alignExp VName
b) [C.cexp|0|]
    expNotZero :: a -> Exp
expNotZero a
e = [C.cexp|$exp:e != 0|]
    expAnd :: a -> a -> Exp
expAnd a
a a
b = [C.cexp|$exp:a && $exp:b|]
    expsNotZero :: [Exp] -> Exp
expsNotZero = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall a a. (ToExp a, ToExp a) => a -> a -> Exp
expAnd [C.cexp|1|] ([Exp] -> Exp) -> ([Exp] -> [Exp]) -> [Exp] -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> Exp
forall a. ToExp a => a -> Exp
expNotZero
    mkArgs :: KernelArg -> CompilerM op s (VName, Maybe (VName, VName))
mkArgs (ValueKArg Exp
e PrimType
t) =
      (,Maybe (VName, VName)
forall a. Maybe a
Nothing) (VName -> (VName, Maybe (VName, VName)))
-> CompilerM op s VName
-> CompilerM op s (VName, Maybe (VName, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimType -> Exp -> CompilerM op s VName
forall op s. String -> PrimType -> Exp -> CompilerM op s VName
GC.compileExpToName String
"kernel_arg" PrimType
t Exp
e
    mkArgs (MemKArg VName
v) = do
      Exp
v' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
GC.rawMem VName
v
      VName
arg <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"kernel_arg"
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|typename CUdeviceptr $id:arg = $exp:v';|]
      (VName, Maybe (VName, VName))
-> CompilerM op s (VName, Maybe (VName, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
arg, Maybe (VName, VName)
forall a. Maybe a
Nothing)
    mkArgs (SharedMemoryKArg (Count Exp
c)) = do
      Exp
num_bytes <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
c
      VName
size <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"shared_size"
      VName
offset <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"shared_offset"
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
GC.decl [C.cdecl|unsigned int $id:size = $exp:num_bytes;|]
      (VName, Maybe (VName, VName))
-> CompilerM op s (VName, Maybe (VName, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
offset, (VName, VName) -> Maybe (VName, VName)
forall a. a -> Maybe a
Just (VName
size, VName
offset))

    printSizes :: [Exp] -> [Stm]
printSizes =
      [Stm] -> [[Stm]] -> [Stm]
forall a. [a] -> [[a]] -> [a]
intercalate [[C.cstm|fprintf(stderr, ", ");|]] ([[Stm]] -> [Stm]) -> ([Exp] -> [[Stm]]) -> [Exp] -> [Stm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> [Stm]) -> [Exp] -> [[Stm]]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> [Stm]
forall a. ToExp a => a -> [Stm]
printSize
    printSize :: a -> [Stm]
printSize a
e =
      [[C.cstm|fprintf(stderr, "%d", $exp:e);|]]