-- | Code generation for Python with OpenCL.
module Futhark.CodeGen.Backends.PyOpenCL
  ( compileProg,
  )
where

import Control.Monad
import Data.Map qualified as M
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericPython hiding (compileProg)
import Futhark.CodeGen.Backends.GenericPython qualified as GP
import Futhark.CodeGen.Backends.GenericPython.AST
import Futhark.CodeGen.Backends.GenericPython.Options
import Futhark.CodeGen.Backends.PyOpenCL.Boilerplate
import Futhark.CodeGen.ImpCode (Count (..))
import Futhark.CodeGen.ImpCode.OpenCL qualified as Imp
import Futhark.CodeGen.ImpGen.OpenCL qualified as ImpGen
import Futhark.CodeGen.RTS.Python (openclPy)
import Futhark.IR.GPUMem (GPUMem, Prog)
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeText)
import Futhark.Util.Pretty (prettyString, prettyText)

-- | Compile the program to Python with calls to OpenCL.
compileProg ::
  (MonadFreshNames m) =>
  CompilerMode ->
  String ->
  Prog GPUMem ->
  m (ImpGen.Warnings, T.Text)
compileProg :: forall (m :: * -> *).
MonadFreshNames m =>
CompilerMode -> SpaceId -> Prog GPUMem -> m (Warnings, Text)
compileProg CompilerMode
mode SpaceId
class_name Prog GPUMem
prog = do
  ( Warnings
ws,
    Imp.Program
      Text
opencl_code
      Text
opencl_prelude
      [(Name, KernelConstExp)]
macros
      Map Name KernelSafety
kernels
      [PrimType]
types
      ParamMap
sizes
      [FailureMsg]
failures
      Definitions OpenCL
prog'
    ) <-
    forall (m :: * -> *).
MonadFreshNames m =>
Prog GPUMem -> m (Warnings, Program)
ImpGen.compileProg Prog GPUMem
prog
  -- prepare the strings for assigning the kernels and set them as global
  let assign :: SpaceId
assign =
        [SpaceId] -> SpaceId
unlines
          forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map
            ( \Name
x ->
                forall a. Pretty a => a -> SpaceId
prettyString forall a b. (a -> b) -> a -> b
$
                  PyExp -> PyExp -> PyStmt
Assign
                    (SpaceId -> PyExp
Var (Text -> SpaceId
T.unpack (Text
"self." forall a. Semigroup a => a -> a -> a
<> Text -> Text
zEncodeText (Name -> Text
nameToText Name
x) forall a. Semigroup a => a -> a -> a
<> Text
"_var")))
                    (SpaceId -> PyExp
Var forall a b. (a -> b) -> a -> b
$ Text -> SpaceId
T.unpack forall a b. (a -> b) -> a -> b
$ Text
"program." forall a. Semigroup a => a -> a -> a
<> Text -> Text
zEncodeText (Name -> Text
nameToText Name
x))
            )
          forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys Map Name KernelSafety
kernels

  let defines :: [PyStmt]
defines =
        [ PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"synchronous") forall a b. (a -> b) -> a -> b
$ Bool -> PyExp
Bool Bool
False,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"preferred_platform") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"build_options") forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
List [],
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"preferred_device") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_threshold") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_group_size") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_num_groups") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_tile_size") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_reg_tile_size") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"fut_opencl_src") forall a b. (a -> b) -> a -> b
$ Text -> PyExp
RawStringLiteral forall a b. (a -> b) -> a -> b
$ Text
opencl_prelude forall a. Semigroup a => a -> a -> a
<> Text
opencl_code
        ]

  let imports :: [PyStmt]
imports =
        [ SpaceId -> Maybe SpaceId -> PyStmt
Import SpaceId
"sys" forall a. Maybe a
Nothing,
          SpaceId -> Maybe SpaceId -> PyStmt
Import SpaceId
"numpy" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just SpaceId
"np",
          SpaceId -> Maybe SpaceId -> PyStmt
Import SpaceId
"ctypes" forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just SpaceId
"ct",
          Text -> PyStmt
Escape Text
openclPy,
          SpaceId -> Maybe SpaceId -> PyStmt
Import SpaceId
"pyopencl.array" forall a. Maybe a
Nothing,
          SpaceId -> Maybe SpaceId -> PyStmt
Import SpaceId
"time" forall a. Maybe a
Nothing
        ]

  let constructor :: Constructor
constructor =
        [SpaceId] -> [PyStmt] -> Constructor
Constructor
          [ SpaceId
"self",
            SpaceId
"build_options=build_options",
            SpaceId
"command_queue=None",
            SpaceId
"interactive=False",
            SpaceId
"platform_pref=preferred_platform",
            SpaceId
"device_pref=preferred_device",
            SpaceId
"default_group_size=default_group_size",
            SpaceId
"default_num_groups=default_num_groups",
            SpaceId
"default_tile_size=default_tile_size",
            SpaceId
"default_reg_tile_size=default_reg_tile_size",
            SpaceId
"default_threshold=default_threshold",
            SpaceId
"sizes=sizes"
          ]
          [Text -> PyStmt
Escape forall a b. (a -> b) -> a -> b
$ [(Name, KernelConstExp)]
-> [PrimType] -> SpaceId -> ParamMap -> [FailureMsg] -> Text
openClInit [(Name, KernelConstExp)]
macros [PrimType]
types SpaceId
assign ParamMap
sizes [FailureMsg]
failures]
      options :: [Option]
options =
        [ Option
            { optionLongName :: Text
optionLongName = Text
"platform",
              optionShortName :: Maybe Char
optionShortName = forall a. a -> Maybe a
Just Char
'p',
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"str",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"preferred_platform") forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"device",
              optionShortName :: Maybe Char
optionShortName = forall a. a -> Maybe a
Just Char
'd',
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"str",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"preferred_device") forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"build-option",
              optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"str",
              optionAction :: [PyStmt]
optionAction =
                [ PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"build_options") forall a b. (a -> b) -> a -> b
$
                    SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"+" (SpaceId -> PyExp
Var SpaceId
"build_options") forall a b. (a -> b) -> a -> b
$
                      [PyExp] -> PyExp
List [SpaceId -> PyExp
Var SpaceId
"optarg"]
                ]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"default-threshold",
              optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_threshold") forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"default-group-size",
              optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_group_size") forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"default-num-groups",
              optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_num_groups") forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"default-tile-size",
              optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_tile_size") forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"default-reg-tile-size",
              optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_reg_tile_size") forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"param",
              optionShortName :: Maybe Char
optionShortName = forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"param_assignment",
              optionAction :: [PyStmt]
optionAction =
                [ PyExp -> PyExp -> PyStmt
Assign
                    ( PyExp -> PyIdx -> PyExp
Index
                        (SpaceId -> PyExp
Var SpaceId
"params")
                        ( PyExp -> PyIdx
IdxExp
                            ( PyExp -> PyIdx -> PyExp
Index
                                (SpaceId -> PyExp
Var SpaceId
"optarg")
                                (PyExp -> PyIdx
IdxExp (Integer -> PyExp
Integer Integer
0))
                            )
                        )
                    )
                    (PyExp -> PyIdx -> PyExp
Index (SpaceId -> PyExp
Var SpaceId
"optarg") (PyExp -> PyIdx
IdxExp (Integer -> PyExp
Integer Integer
1)))
                ]
            }
        ]

  (Warnings
ws,)
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) op s.
MonadFreshNames m =>
CompilerMode
-> SpaceId
-> Constructor
-> [PyStmt]
-> [PyStmt]
-> Operations op s
-> s
-> [PyStmt]
-> [Option]
-> Definitions op
-> m Text
GP.compileProg
      CompilerMode
mode
      SpaceId
class_name
      Constructor
constructor
      [PyStmt]
imports
      [PyStmt]
defines
      Operations OpenCL ()
operations
      ()
      [PyExp -> PyStmt
Exp forall a b. (a -> b) -> a -> b
$ SpaceId -> [PyExp] -> PyExp
simpleCall SpaceId
"sync" [SpaceId -> PyExp
Var SpaceId
"self"]]
      [Option]
options
      Definitions OpenCL
prog'
  where
    operations :: Operations Imp.OpenCL ()
    operations :: Operations OpenCL ()
operations =
      Operations
        { opsCompiler :: OpCompiler OpenCL ()
opsCompiler = OpCompiler OpenCL ()
callKernel,
          opsWriteScalar :: WriteScalar OpenCL ()
opsWriteScalar = WriteScalar OpenCL ()
writeOpenCLScalar,
          opsReadScalar :: ReadScalar OpenCL ()
opsReadScalar = ReadScalar OpenCL ()
readOpenCLScalar,
          opsAllocate :: Allocate OpenCL ()
opsAllocate = Allocate OpenCL ()
allocateOpenCLBuffer,
          opsCopies :: Map (Space, Space) (DoLMADCopy OpenCL ())
opsCopies =
            forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (SpaceId -> Space
Imp.Space SpaceId
"device", SpaceId -> Space
Imp.Space SpaceId
"device") forall op s. DoLMADCopy op s
copygpu2gpu forall a b. (a -> b) -> a -> b
$
              forall op s.
Operations op s -> Map (Space, Space) (DoLMADCopy op s)
opsCopies forall op s. Operations op s
defaultOperations,
          opsEntryOutput :: EntryOutput OpenCL ()
opsEntryOutput = EntryOutput OpenCL ()
packArrayOutput,
          opsEntryInput :: EntryInput OpenCL ()
opsEntryInput = EntryInput OpenCL ()
unpackArrayInput
        }

-- We have many casts to 'long', because PyOpenCL may get confused at
-- the 32-bit numbers that ImpCode uses for offsets and the like.
asLong :: PyExp -> PyExp
asLong :: PyExp -> PyExp
asLong PyExp
x = SpaceId -> [PyExp] -> PyExp
simpleCall SpaceId
"np.int64" [PyExp
x]

getParamByKey :: Name -> PyExp
getParamByKey :: Name -> PyExp
getParamByKey Name
key = PyExp -> PyIdx -> PyExp
Index (SpaceId -> PyExp
Var SpaceId
"self.sizes") (PyExp -> PyIdx
IdxExp forall a b. (a -> b) -> a -> b
$ Text -> PyExp
String forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> Text
prettyText Name
key)

kernelConstToExp :: Imp.KernelConst -> PyExp
kernelConstToExp :: KernelConst -> PyExp
kernelConstToExp (Imp.SizeConst Name
key SizeClass
_) =
  Name -> PyExp
getParamByKey Name
key
kernelConstToExp (Imp.SizeMaxConst SizeClass
size_class) =
  SpaceId -> PyExp
Var forall a b. (a -> b) -> a -> b
$ SpaceId
"self.max_" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> SpaceId
prettyString SizeClass
size_class

compileGroupDim :: Imp.GroupDim -> CompilerM op s PyExp
compileGroupDim :: forall op s. GroupDim -> CompilerM op s PyExp
compileGroupDim (Left Exp
e) = PyExp -> PyExp
asLong forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
e
compileGroupDim (Right KernelConst
kc) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ KernelConst -> PyExp
kernelConstToExp KernelConst
kc

callKernel :: OpCompiler Imp.OpenCL ()
callKernel :: OpCompiler OpenCL ()
callKernel (Imp.GetSize VName
v Name
key) = do
  PyExp
v' <- forall op s. VName -> CompilerM op s PyExp
compileVar VName
v
  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
v' forall a b. (a -> b) -> a -> b
$ Name -> PyExp
getParamByKey Name
key
callKernel (Imp.CmpSizeLe VName
v Name
key Exp
x) = do
  PyExp
v' <- forall op s. VName -> CompilerM op s PyExp
compileVar VName
v
  PyExp
x' <- forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
x
  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
v' forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"<=" (Name -> PyExp
getParamByKey Name
key) PyExp
x'
callKernel (Imp.GetSizeMax VName
v SizeClass
size_class) = do
  PyExp
v' <- forall op s. VName -> CompilerM op s PyExp
compileVar VName
v
  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
v' forall a b. (a -> b) -> a -> b
$ KernelConst -> PyExp
kernelConstToExp forall a b. (a -> b) -> a -> b
$ SizeClass -> KernelConst
Imp.SizeMaxConst SizeClass
size_class
callKernel (Imp.LaunchKernel KernelSafety
safety Name
name Count Bytes (TExp Int64)
local_memory [KernelArg]
args [Exp]
num_workgroups [GroupDim]
workgroup_size) = do
  [PyExp]
num_workgroups' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PyExp -> PyExp
asLong forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. Exp -> CompilerM op s PyExp
compileExp) [Exp]
num_workgroups
  [PyExp]
workgroup_size' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall op s. GroupDim -> CompilerM op s PyExp
compileGroupDim [GroupDim]
workgroup_size
  let kernel_size :: [PyExp]
kernel_size = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PyExp -> PyExp -> PyExp
mult_exp [PyExp]
num_workgroups' [PyExp]
workgroup_size'
      total_elements :: PyExp
total_elements = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PyExp -> PyExp -> PyExp
mult_exp (Integer -> PyExp
Integer Integer
1) [PyExp]
kernel_size
      cond :: PyExp
cond = SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"!=" PyExp
total_elements (Integer -> PyExp
Integer Integer
0)
  PyExp
local_memory' <- forall op s. Exp -> CompilerM op s PyExp
compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
Imp.untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
Imp.unCount Count Bytes (TExp Int64)
local_memory
  [PyStmt]
body <- forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect forall a b. (a -> b) -> a -> b
$ forall op s.
Name
-> KernelSafety
-> [PyExp]
-> [PyExp]
-> PyExp
-> [KernelArg]
-> CompilerM op s ()
launchKernel Name
name KernelSafety
safety [PyExp]
kernel_size [PyExp]
workgroup_size' PyExp
local_memory' [KernelArg]
args
  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If PyExp
cond [PyStmt]
body []

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (KernelSafety
safety forall a. Ord a => a -> a -> Bool
>= KernelSafety
Imp.SafetyFull) forall a b. (a -> b) -> a -> b
$
    forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$
      PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"self.failure_is_an_option") forall a b. (a -> b) -> a -> b
$
        PrimValue -> PyExp
compilePrimValue (IntValue -> PrimValue
Imp.IntValue (Int32 -> IntValue
Imp.Int32Value Int32
1))
  where
    mult_exp :: PyExp -> PyExp -> PyExp
mult_exp = SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"*"

launchKernel ::
  Imp.KernelName ->
  Imp.KernelSafety ->
  [PyExp] ->
  [PyExp] ->
  PyExp ->
  [Imp.KernelArg] ->
  CompilerM op s ()
launchKernel :: forall op s.
Name
-> KernelSafety
-> [PyExp]
-> [PyExp]
-> PyExp
-> [KernelArg]
-> CompilerM op s ()
launchKernel Name
kernel_name KernelSafety
safety [PyExp]
kernel_dims [PyExp]
workgroup_dims PyExp
local_memory [KernelArg]
args = do
  let kernel_dims' :: PyExp
kernel_dims' = [PyExp] -> PyExp
Tuple [PyExp]
kernel_dims
      workgroup_dims' :: PyExp
workgroup_dims' = [PyExp] -> PyExp
Tuple [PyExp]
workgroup_dims
      kernel_name' :: Text
kernel_name' = Text
"self." forall a. Semigroup a => a -> a -> a
<> Text -> Text
zEncodeText (Name -> Text
nameToText Name
kernel_name) forall a. Semigroup a => a -> a -> a
<> Text
"_var"
  [PyExp]
args' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall op s. KernelArg -> CompilerM op s PyExp
processKernelArg [KernelArg]
args
  let failure_args :: [PyExp]
failure_args =
        forall a. Int -> [a] -> [a]
take
          (KernelSafety -> Int
Imp.numFailureParams KernelSafety
safety)
          [ SpaceId -> PyExp
Var SpaceId
"self.global_failure",
            SpaceId -> PyExp
Var SpaceId
"self.failure_is_an_option",
            SpaceId -> PyExp
Var SpaceId
"self.global_failure_args"
          ]
  forall op s. PyStmt -> CompilerM op s ()
stm forall b c a. (b -> c) -> (a -> b) -> a -> c
. PyExp -> PyStmt
Exp forall a b. (a -> b) -> a -> b
$
    SpaceId -> [PyExp] -> PyExp
simpleCall (Text -> SpaceId
T.unpack forall a b. (a -> b) -> a -> b
$ Text
kernel_name' forall a. Semigroup a => a -> a -> a
<> Text
".set_args") forall a b. (a -> b) -> a -> b
$
      [SpaceId -> [PyExp] -> PyExp
simpleCall SpaceId
"cl.LocalMemory" [SpaceId -> [PyExp] -> PyExp
simpleCall SpaceId
"max" [PyExp
local_memory, Integer -> PyExp
Integer Integer
1]]]
        forall a. [a] -> [a] -> [a]
++ [PyExp]
failure_args
        forall a. [a] -> [a] -> [a]
++ [PyExp]
args'
  forall op s. PyStmt -> CompilerM op s ()
stm forall b c a. (b -> c) -> (a -> b) -> a -> c
. PyExp -> PyStmt
Exp forall a b. (a -> b) -> a -> b
$
    SpaceId -> [PyExp] -> PyExp
simpleCall
      SpaceId
"cl.enqueue_nd_range_kernel"
      [SpaceId -> PyExp
Var SpaceId
"self.queue", SpaceId -> PyExp
Var (Text -> SpaceId
T.unpack Text
kernel_name'), PyExp
kernel_dims', PyExp
workgroup_dims']
  forall op s. CompilerM op s ()
finishIfSynchronous
  where
    processKernelArg :: Imp.KernelArg -> CompilerM op s PyExp
    processKernelArg :: forall op s. KernelArg -> CompilerM op s PyExp
processKernelArg (Imp.ValueKArg Exp
e PrimType
bt) = PrimType -> PyExp -> PyExp
toStorage PrimType
bt forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
e
    processKernelArg (Imp.MemKArg VName
v) = forall op s. VName -> CompilerM op s PyExp
compileVar VName
v

writeOpenCLScalar :: WriteScalar Imp.OpenCL ()
writeOpenCLScalar :: WriteScalar OpenCL ()
writeOpenCLScalar PyExp
mem PyExp
i PrimType
bt SpaceId
"device" PyExp
val = do
  let nparr :: PyExp
nparr =
        PyExp -> [PyArg] -> PyExp
Call
          (SpaceId -> PyExp
Var SpaceId
"np.array")
          [PyExp -> PyArg
Arg PyExp
val, SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"dtype" forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var forall a b. (a -> b) -> a -> b
$ PrimType -> SpaceId
compilePrimType PrimType
bt]
  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp forall a b. (a -> b) -> a -> b
$
      PyExp -> [PyArg] -> PyExp
Call
        (SpaceId -> PyExp
Var SpaceId
"cl.enqueue_copy")
        [ PyExp -> PyArg
Arg forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"self.queue",
          PyExp -> PyArg
Arg PyExp
mem,
          PyExp -> PyArg
Arg PyExp
nparr,
          SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"dst_offset" forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"*" (PyExp -> PyExp
asLong PyExp
i) (Integer -> PyExp
Integer forall a b. (a -> b) -> a -> b
$ forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt),
          SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"is_blocking" forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"synchronous"
        ]
writeOpenCLScalar PyExp
_ PyExp
_ PrimType
_ SpaceId
space PyExp
_ =
  forall a. HasCallStack => SpaceId -> a
error forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot write to '" forall a. [a] -> [a] -> [a]
++ SpaceId
space forall a. [a] -> [a] -> [a]
++ SpaceId
"' memory space."

readOpenCLScalar :: ReadScalar Imp.OpenCL ()
readOpenCLScalar :: ReadScalar OpenCL ()
readOpenCLScalar PyExp
mem PyExp
i PrimType
bt SpaceId
"device" = do
  VName
val <- forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"read_res"
  let val' :: PyExp
val' = SpaceId -> PyExp
Var forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> SpaceId
prettyString VName
val
  let nparr :: PyExp
nparr =
        PyExp -> [PyArg] -> PyExp
Call
          (SpaceId -> PyExp
Var SpaceId
"np.empty")
          [ PyExp -> PyArg
Arg forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer Integer
1,
            SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"dtype" (SpaceId -> PyExp
Var forall a b. (a -> b) -> a -> b
$ PrimType -> SpaceId
compilePrimType PrimType
bt)
          ]
  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
val' PyExp
nparr
  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp forall a b. (a -> b) -> a -> b
$
      PyExp -> [PyArg] -> PyExp
Call
        (SpaceId -> PyExp
Var SpaceId
"cl.enqueue_copy")
        [ PyExp -> PyArg
Arg forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"self.queue",
          PyExp -> PyArg
Arg PyExp
val',
          PyExp -> PyArg
Arg PyExp
mem,
          SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"src_offset" forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"*" (PyExp -> PyExp
asLong PyExp
i) (Integer -> PyExp
Integer forall a b. (a -> b) -> a -> b
$ forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt),
          SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"is_blocking" forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"synchronous"
        ]
  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt
Exp forall a b. (a -> b) -> a -> b
$ SpaceId -> [PyExp] -> PyExp
simpleCall SpaceId
"sync" [SpaceId -> PyExp
Var SpaceId
"self"]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx -> PyExp
Index PyExp
val' forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx
IdxExp forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer Integer
0
readOpenCLScalar PyExp
_ PyExp
_ PrimType
_ SpaceId
space =
  forall a. HasCallStack => SpaceId -> a
error forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot read from '" forall a. [a] -> [a] -> [a]
++ SpaceId
space forall a. [a] -> [a] -> [a]
++ SpaceId
"' memory space."

allocateOpenCLBuffer :: Allocate Imp.OpenCL ()
allocateOpenCLBuffer :: Allocate OpenCL ()
allocateOpenCLBuffer PyExp
mem PyExp
size SpaceId
"device" =
  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assign PyExp
mem forall a b. (a -> b) -> a -> b
$
      SpaceId -> [PyExp] -> PyExp
simpleCall SpaceId
"opencl_alloc" [SpaceId -> PyExp
Var SpaceId
"self", PyExp
size, Text -> PyExp
String forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> Text
prettyText PyExp
mem]
allocateOpenCLBuffer PyExp
_ PyExp
_ SpaceId
space =
  forall a. HasCallStack => SpaceId -> a
error forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot allocate in '" forall a. [a] -> [a] -> [a]
++ SpaceId
space forall a. [a] -> [a] -> [a]
++ SpaceId
"' space"

packArrayOutput :: EntryOutput Imp.OpenCL ()
packArrayOutput :: EntryOutput OpenCL ()
packArrayOutput VName
mem SpaceId
"device" PrimType
bt Signedness
ept [DimSize]
dims = do
  PyExp
mem' <- forall op s. VName -> CompilerM op s PyExp
compileVar VName
mem
  [PyExp]
dims' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall op s. DimSize -> CompilerM op s PyExp
compileDim [DimSize]
dims
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyArg] -> PyExp
Call
      (SpaceId -> PyExp
Var SpaceId
"cl.array.Array")
      [ PyExp -> PyArg
Arg forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"self.queue",
        PyExp -> PyArg
Arg forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
Tuple forall a b. (a -> b) -> a -> b
$ [PyExp]
dims' forall a. Semigroup a => a -> a -> a
<> [Integer -> PyExp
Integer Integer
0 | PrimType
bt forall a. Eq a => a -> a -> Bool
== PrimType
Imp.Unit],
        PyExp -> PyArg
Arg forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> SpaceId
compilePrimToExtNp PrimType
bt Signedness
ept,
        SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"data" PyExp
mem'
      ]
packArrayOutput VName
_ SpaceId
sid PrimType
_ Signedness
_ [DimSize]
_ =
  forall a. HasCallStack => SpaceId -> a
error forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot return array from " forall a. [a] -> [a] -> [a]
++ SpaceId
sid forall a. [a] -> [a] -> [a]
++ SpaceId
" space."

unpackArrayInput :: EntryInput Imp.OpenCL ()
unpackArrayInput :: EntryInput OpenCL ()
unpackArrayInput PyExp
mem SpaceId
"device" PrimType
t Signedness
s [DimSize]
dims PyExp
e = do
  let type_is_ok :: PyExp
type_is_ok =
        SpaceId -> PyExp -> PyExp -> PyExp
BinOp
          SpaceId
"and"
          (SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"in" (SpaceId -> [PyExp] -> PyExp
simpleCall SpaceId
"type" [PyExp
e]) ([PyExp] -> PyExp
List [SpaceId -> PyExp
Var SpaceId
"np.ndarray", SpaceId -> PyExp
Var SpaceId
"cl.array.Array"]))
          (SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"==" (PyExp -> SpaceId -> PyExp
Field PyExp
e SpaceId
"dtype") (SpaceId -> PyExp
Var (PrimType -> Signedness -> SpaceId
compilePrimToExtNp PrimType
t Signedness
s)))
  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assert PyExp
type_is_ok forall a b. (a -> b) -> a -> b
$ Text -> PyExp
String Text
"Parameter has unexpected type"

  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (forall op s. PyExp -> DimSize -> Int32 -> CompilerM op s ()
unpackDim PyExp
e) [DimSize]
dims [Int32
0 ..]

  let memsize' :: PyExp
memsize' = SpaceId -> [PyExp] -> PyExp
simpleCall SpaceId
"np.int64" [PyExp -> SpaceId -> PyExp
Field PyExp
e SpaceId
"nbytes"]
      pyOpenCLArrayCase :: [PyStmt]
pyOpenCLArrayCase =
        [PyExp -> PyExp -> PyStmt
Assign PyExp
mem forall a b. (a -> b) -> a -> b
$ PyExp -> SpaceId -> PyExp
Field PyExp
e SpaceId
"data"]
  [PyStmt]
numpyArrayCase <- forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect forall a b. (a -> b) -> a -> b
$ do
    Allocate OpenCL ()
allocateOpenCLBuffer PyExp
mem PyExp
memsize' SpaceId
"device"
    forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
memsize' forall a b. (a -> b) -> a -> b
$
        PyExp -> PyStmt
Exp forall a b. (a -> b) -> a -> b
$
          PyExp -> [PyArg] -> PyExp
Call
            (SpaceId -> PyExp
Var SpaceId
"cl.enqueue_copy")
            [ PyExp -> PyArg
Arg forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"self.queue",
              PyExp -> PyArg
Arg PyExp
mem,
              PyExp -> PyArg
Arg forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (SpaceId -> PyExp
Var SpaceId
"normaliseArray") [PyExp -> PyArg
Arg PyExp
e],
              SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"is_blocking" forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"synchronous"
            ]

  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If
      (SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"==" (SpaceId -> [PyExp] -> PyExp
simpleCall SpaceId
"type" [PyExp
e]) (SpaceId -> PyExp
Var SpaceId
"cl.array.Array"))
      [PyStmt]
pyOpenCLArrayCase
      [PyStmt]
numpyArrayCase
unpackArrayInput PyExp
_ SpaceId
sid PrimType
_ Signedness
_ [DimSize]
_ PyExp
_ =
  forall a. HasCallStack => SpaceId -> a
error forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot accept array from " forall a. [a] -> [a] -> [a]
++ SpaceId
sid forall a. [a] -> [a] -> [a]
++ SpaceId
" space."

ifNotZeroSize :: PyExp -> PyStmt -> PyStmt
ifNotZeroSize :: PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
e PyStmt
s =
  PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If (SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"!=" PyExp
e (Integer -> PyExp
Integer Integer
0)) [PyStmt
s] []

finishIfSynchronous :: CompilerM op s ()
finishIfSynchronous :: forall op s. CompilerM op s ()
finishIfSynchronous =
  forall op s. PyStmt -> CompilerM op s ()
stm forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If (SpaceId -> PyExp
Var SpaceId
"synchronous") [PyExp -> PyStmt
Exp forall a b. (a -> b) -> a -> b
$ SpaceId -> [PyExp] -> PyExp
simpleCall SpaceId
"sync" [SpaceId -> PyExp
Var SpaceId
"self"]] []

copygpu2gpu :: DoLMADCopy op s
copygpu2gpu :: forall op s. DoLMADCopy op s
copygpu2gpu PrimType
t [Count Elements PyExp]
shape PyExp
dst (Count Elements PyExp
dstoffset, [Count Elements PyExp]
dststride) PyExp
src (Count Elements PyExp
srcoffset, [Count Elements PyExp]
srcstride) = do
  forall op s. PyStmt -> CompilerM op s ()
stm forall b c a. (b -> c) -> (a -> b) -> a -> c
. PyExp -> PyStmt
Exp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SpaceId -> [PyExp] -> PyExp
simpleCall SpaceId
"lmad_copy_gpu2gpu" forall a b. (a -> b) -> a -> b
$
    [ SpaceId -> PyExp
Var SpaceId
"self",
      SpaceId -> PyExp
Var (PrimType -> SpaceId
compilePrimType PrimType
t),
      PyExp
dst,
      forall {k} (u :: k) e. Count u e -> e
unCount Count Elements PyExp
dstoffset,
      [PyExp] -> PyExp
List (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (u :: k) e. Count u e -> e
unCount [Count Elements PyExp]
dststride),
      PyExp
src,
      forall {k} (u :: k) e. Count u e -> e
unCount Count Elements PyExp
srcoffset,
      [PyExp] -> PyExp
List (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (u :: k) e. Count u e -> e
unCount [Count Elements PyExp]
srcstride),
      [PyExp] -> PyExp
List (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (u :: k) e. Count u e -> e
unCount [Count Elements PyExp]
shape)
    ]