{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TupleSections #-}

module Futhark.CodeGen.Backends.PyOpenCL
  ( compileProg,
  )
where

import Control.Monad
import qualified Data.Map as M
import qualified Futhark.CodeGen.Backends.GenericPython as Py
import Futhark.CodeGen.Backends.GenericPython.AST
import Futhark.CodeGen.Backends.GenericPython.Definitions
import Futhark.CodeGen.Backends.GenericPython.Options
import Futhark.CodeGen.Backends.PyOpenCL.Boilerplate
import qualified Futhark.CodeGen.ImpCode.OpenCL as Imp
import qualified Futhark.CodeGen.ImpGen.OpenCL as ImpGen
import Futhark.IR.KernelsMem (KernelsMem, Prog)
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeString)

--maybe pass the config file rather than multiple arguments
compileProg ::
  MonadFreshNames m =>
  Maybe String ->
  Prog KernelsMem ->
  m (ImpGen.Warnings, String)
compileProg :: Maybe String -> Prog KernelsMem -> m (Warnings, String)
compileProg Maybe String
module_name Prog KernelsMem
prog = do
  ( Warnings
ws,
    Imp.Program
      String
opencl_code
      String
opencl_prelude
      Map KernelName KernelSafety
kernels
      [PrimType]
types
      Map KernelName SizeClass
sizes
      [FailureMsg]
failures
      Definitions OpenCL
prog'
    ) <-
    Prog KernelsMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
Prog KernelsMem -> m (Warnings, Program)
ImpGen.compileProg Prog KernelsMem
prog
  --prepare the strings for assigning the kernels and set them as global
  let assign :: String
assign =
        [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$
          (KernelName -> String) -> [KernelName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map
            ( \KernelName
x ->
                PyStmt -> String
forall a. Pretty a => a -> String
pretty (PyStmt -> String) -> PyStmt -> String
forall a b. (a -> b) -> a -> b
$
                  PyExp -> PyExp -> PyStmt
Assign
                    (String -> PyExp
Var (String
"self." String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
zEncodeString (KernelName -> String
nameToString KernelName
x) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_var"))
                    (String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ String
"program." String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
zEncodeString (KernelName -> String
nameToString KernelName
x))
            )
            ([KernelName] -> [String]) -> [KernelName] -> [String]
forall a b. (a -> b) -> a -> b
$ Map KernelName KernelSafety -> [KernelName]
forall k a. Map k a -> [k]
M.keys Map KernelName KernelSafety
kernels

  let defines :: [PyStmt]
defines =
        [ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"synchronous") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ Bool -> PyExp
Bool Bool
False,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_platform") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_device") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_threshold") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_group_size") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_num_groups") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_tile_size") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"fut_opencl_src") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
RawStringLiteral (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ String
opencl_prelude String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
opencl_code,
          String -> PyStmt
Escape String
pyValues,
          String -> PyStmt
Escape String
pyFunctions,
          String -> PyStmt
Escape String
pyPanic,
          String -> PyStmt
Escape String
pyTuning
        ]

  let imports :: [PyStmt]
imports =
        [ String -> Maybe String -> PyStmt
Import String
"sys" Maybe String
forall a. Maybe a
Nothing,
          String -> Maybe String -> PyStmt
Import String
"numpy" (Maybe String -> PyStmt) -> Maybe String -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> Maybe String
forall a. a -> Maybe a
Just String
"np",
          String -> Maybe String -> PyStmt
Import String
"ctypes" (Maybe String -> PyStmt) -> Maybe String -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> Maybe String
forall a. a -> Maybe a
Just String
"ct",
          String -> PyStmt
Escape String
openClPrelude,
          String -> Maybe String -> PyStmt
Import String
"pyopencl.array" Maybe String
forall a. Maybe a
Nothing,
          String -> Maybe String -> PyStmt
Import String
"time" Maybe String
forall a. Maybe a
Nothing
        ]

  let constructor :: Constructor
constructor =
        [String] -> [PyStmt] -> Constructor
Py.Constructor
          [ String
"self",
            String
"command_queue=None",
            String
"interactive=False",
            String
"platform_pref=preferred_platform",
            String
"device_pref=preferred_device",
            String
"default_group_size=default_group_size",
            String
"default_num_groups=default_num_groups",
            String
"default_tile_size=default_tile_size",
            String
"default_threshold=default_threshold",
            String
"sizes=sizes"
          ]
          [String -> PyStmt
Escape (String -> PyStmt) -> String -> PyStmt
forall a b. (a -> b) -> a -> b
$ [PrimType]
-> String -> Map KernelName SizeClass -> [FailureMsg] -> String
openClInit [PrimType]
types String
assign Map KernelName SizeClass
sizes [FailureMsg]
failures]
      options :: [Option]
options =
        [ Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"platform",
              optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'p',
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"str",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_platform") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"device",
              optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'd',
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"str",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_device") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"default-threshold",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_threshold") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"default-group-size",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_group_size") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"default-num-groups",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_num_groups") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"default-tile-size",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_tile_size") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"size",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"size_assignment",
              optionAction :: [PyStmt]
optionAction =
                [ PyExp -> PyExp -> PyStmt
Assign
                    ( PyExp -> PyIdx -> PyExp
Index
                        (String -> PyExp
Var String
"sizes")
                        ( PyExp -> PyIdx
IdxExp
                            ( PyExp -> PyIdx -> PyExp
Index
                                (String -> PyExp
Var String
"optarg")
                                (PyExp -> PyIdx
IdxExp (Integer -> PyExp
Integer Integer
0))
                            )
                        )
                    )
                    (PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"optarg") (PyExp -> PyIdx
IdxExp (Integer -> PyExp
Integer Integer
1)))
                ]
            }
        ]

  (Warnings
ws,)
    (String -> (Warnings, String)) -> m String -> m (Warnings, String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe String
-> Constructor
-> [PyStmt]
-> [PyStmt]
-> Operations OpenCL ()
-> ()
-> [PyStmt]
-> [Option]
-> Definitions OpenCL
-> m String
forall (m :: * -> *) op s.
MonadFreshNames m =>
Maybe String
-> Constructor
-> [PyStmt]
-> [PyStmt]
-> Operations op s
-> s
-> [PyStmt]
-> [Option]
-> Definitions op
-> m String
Py.compileProg
      Maybe String
module_name
      Constructor
constructor
      [PyStmt]
imports
      [PyStmt]
defines
      Operations OpenCL ()
operations
      ()
      [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"sync" [String -> PyExp
Var String
"self"]]
      [Option]
options
      Definitions OpenCL
prog'
  where
    operations :: Py.Operations Imp.OpenCL ()
    operations :: Operations OpenCL ()
operations =
      Operations :: forall op s.
WriteScalar op s
-> ReadScalar op s
-> Allocate op s
-> Copy op s
-> StaticArray op s
-> OpCompiler op s
-> EntryOutput op s
-> EntryInput op s
-> Operations op s
Py.Operations
        { opsCompiler :: OpCompiler OpenCL ()
Py.opsCompiler = OpCompiler OpenCL ()
callKernel,
          opsWriteScalar :: WriteScalar OpenCL ()
Py.opsWriteScalar = WriteScalar OpenCL ()
writeOpenCLScalar,
          opsReadScalar :: ReadScalar OpenCL ()
Py.opsReadScalar = ReadScalar OpenCL ()
readOpenCLScalar,
          opsAllocate :: Allocate OpenCL ()
Py.opsAllocate = Allocate OpenCL ()
allocateOpenCLBuffer,
          opsCopy :: Copy OpenCL ()
Py.opsCopy = Copy OpenCL ()
copyOpenCLMemory,
          opsStaticArray :: StaticArray OpenCL ()
Py.opsStaticArray = StaticArray OpenCL ()
staticOpenCLArray,
          opsEntryOutput :: EntryOutput OpenCL ()
Py.opsEntryOutput = EntryOutput OpenCL ()
packArrayOutput,
          opsEntryInput :: EntryInput OpenCL ()
Py.opsEntryInput = EntryInput OpenCL ()
unpackArrayInput
        }

-- We have many casts to 'long', because PyOpenCL may get confused at
-- the 32-bit numbers that ImpCode uses for offsets and the like.
asLong :: PyExp -> PyExp
asLong :: PyExp -> PyExp
asLong PyExp
x = String -> [PyExp] -> PyExp
Py.simpleCall String
"np.long" [PyExp
x]

callKernel :: Py.OpCompiler Imp.OpenCL ()
callKernel :: OpCompiler OpenCL ()
callKernel (Imp.GetSize VName
v KernelName
key) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"self.sizes") (PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ String -> PyExp
String (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ KernelName -> String
forall a. Pretty a => a -> String
pretty KernelName
key)
callKernel (Imp.CmpSizeLe VName
v KernelName
key Exp
x) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyExp
x' <- Exp -> CompilerM OpenCL () PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
x
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      String -> PyExp -> PyExp -> PyExp
BinOp String
"<=" (PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"self.sizes") (PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ String -> PyExp
String (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ KernelName -> String
forall a. Pretty a => a -> String
pretty KernelName
key)) PyExp
x'
callKernel (Imp.GetSizeMax VName
v SizeClass
size_class) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ String
"self.max_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ SizeClass -> String
forall a. Pretty a => a -> String
pretty SizeClass
size_class
callKernel (Imp.LaunchKernel KernelSafety
safety KernelName
name [KernelArg]
args [Exp]
num_workgroups [Exp]
workgroup_size) = do
  [PyExp]
num_workgroups' <- (Exp -> CompilerM OpenCL () PyExp)
-> [Exp] -> CompilerM OpenCL () [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((PyExp -> PyExp)
-> CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PyExp -> PyExp
asLong (CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp)
-> (Exp -> CompilerM OpenCL () PyExp)
-> Exp
-> CompilerM OpenCL () PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> CompilerM OpenCL () PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp) [Exp]
num_workgroups
  [PyExp]
workgroup_size' <- (Exp -> CompilerM OpenCL () PyExp)
-> [Exp] -> CompilerM OpenCL () [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((PyExp -> PyExp)
-> CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PyExp -> PyExp
asLong (CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp)
-> (Exp -> CompilerM OpenCL () PyExp)
-> Exp
-> CompilerM OpenCL () PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> CompilerM OpenCL () PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp) [Exp]
workgroup_size
  let kernel_size :: [PyExp]
kernel_size = (PyExp -> PyExp -> PyExp) -> [PyExp] -> [PyExp] -> [PyExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PyExp -> PyExp -> PyExp
mult_exp [PyExp]
num_workgroups' [PyExp]
workgroup_size'
      total_elements :: PyExp
total_elements = (PyExp -> PyExp -> PyExp) -> PyExp -> [PyExp] -> PyExp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PyExp -> PyExp -> PyExp
mult_exp (Integer -> PyExp
Integer Integer
1) [PyExp]
kernel_size
      cond :: PyExp
cond = String -> PyExp -> PyExp -> PyExp
BinOp String
"!=" PyExp
total_elements (Integer -> PyExp
Integer Integer
0)

  [PyStmt]
body <- CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
Py.collect (CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt])
-> CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall a b. (a -> b) -> a -> b
$ KernelName
-> KernelSafety
-> [PyExp]
-> [PyExp]
-> [KernelArg]
-> CompilerM OpenCL () ()
forall op s.
KernelName
-> KernelSafety
-> [PyExp]
-> [PyExp]
-> [KernelArg]
-> CompilerM op s ()
launchKernel KernelName
name KernelSafety
safety [PyExp]
kernel_size [PyExp]
workgroup_size' [KernelArg]
args
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If PyExp
cond [PyStmt]
body []

  Bool -> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (KernelSafety
safety KernelSafety -> KernelSafety -> Bool
forall a. Ord a => a -> a -> Bool
>= KernelSafety
Imp.SafetyFull) (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"self.failure_is_an_option") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PrimValue -> PyExp
Py.compilePrimValue (IntValue -> PrimValue
Imp.IntValue (Int32 -> IntValue
Imp.Int32Value Int32
1))
  where
    mult_exp :: PyExp -> PyExp -> PyExp
mult_exp = String -> PyExp -> PyExp -> PyExp
BinOp String
"*"

launchKernel ::
  Imp.KernelName ->
  Imp.KernelSafety ->
  [PyExp] ->
  [PyExp] ->
  [Imp.KernelArg] ->
  Py.CompilerM op s ()
launchKernel :: KernelName
-> KernelSafety
-> [PyExp]
-> [PyExp]
-> [KernelArg]
-> CompilerM op s ()
launchKernel KernelName
kernel_name KernelSafety
safety [PyExp]
kernel_dims [PyExp]
workgroup_dims [KernelArg]
args = do
  let kernel_dims' :: PyExp
kernel_dims' = [PyExp] -> PyExp
Tuple [PyExp]
kernel_dims
      workgroup_dims' :: PyExp
workgroup_dims' = [PyExp] -> PyExp
Tuple [PyExp]
workgroup_dims
      kernel_name' :: String
kernel_name' = String
"self." String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
zEncodeString (KernelName -> String
nameToString KernelName
kernel_name) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_var"
  [PyExp]
args' <- (KernelArg -> CompilerM op s PyExp)
-> [KernelArg] -> CompilerM op s [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelArg -> CompilerM op s PyExp
forall op s. KernelArg -> CompilerM op s PyExp
processKernelArg [KernelArg]
args
  let failure_args :: [PyExp]
failure_args =
        Int -> [PyExp] -> [PyExp]
forall a. Int -> [a] -> [a]
take
          (KernelSafety -> Int
Imp.numFailureParams KernelSafety
safety)
          [ String -> PyExp
Var String
"self.global_failure",
            String -> PyExp
Var String
"self.failure_is_an_option",
            String -> PyExp
Var String
"self.global_failure_args"
          ]
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      String -> [PyExp] -> PyExp
Py.simpleCall (String
kernel_name' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
".set_args") ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$
        [PyExp]
failure_args [PyExp] -> [PyExp] -> [PyExp]
forall a. [a] -> [a] -> [a]
++ [PyExp]
args'
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      String -> [PyExp] -> PyExp
Py.simpleCall
        String
"cl.enqueue_nd_range_kernel"
        [String -> PyExp
Var String
"self.queue", String -> PyExp
Var String
kernel_name', PyExp
kernel_dims', PyExp
workgroup_dims']
  CompilerM op s ()
forall op s. CompilerM op s ()
finishIfSynchronous
  where
    processKernelArg :: Imp.KernelArg -> Py.CompilerM op s PyExp
    processKernelArg :: KernelArg -> CompilerM op s PyExp
processKernelArg (Imp.ValueKArg Exp
e PrimType
bt) = do
      PyExp
e' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
e
      PyExp -> CompilerM op s PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall (PrimType -> String
Py.compilePrimToNp PrimType
bt) [PyExp
e']
    processKernelArg (Imp.MemKArg VName
v) = VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
    processKernelArg (Imp.SharedMemoryKArg (Imp.Count Exp
num_bytes)) = do
      PyExp
num_bytes' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
num_bytes
      PyExp -> CompilerM op s PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"cl.LocalMemory" [PyExp -> PyExp
asLong PyExp
num_bytes']

writeOpenCLScalar :: Py.WriteScalar Imp.OpenCL ()
writeOpenCLScalar :: WriteScalar OpenCL ()
writeOpenCLScalar PyExp
mem PyExp
i PrimType
bt String
"device" PyExp
val = do
  let nparr :: PyExp
nparr =
        PyExp -> [PyArg] -> PyExp
Call
          (String -> PyExp
Var String
"np.array")
          [PyExp -> PyArg
Arg PyExp
val, String -> PyExp -> PyArg
ArgKeyword String
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimType PrimType
bt]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> [PyArg] -> PyExp
Call
        (String -> PyExp
Var String
"cl.enqueue_copy")
        [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
          PyExp -> PyArg
Arg PyExp
mem,
          PyExp -> PyArg
Arg PyExp
nparr,
          String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp -> PyExp -> PyExp
BinOp String
"*" (PyExp -> PyExp
asLong PyExp
i) (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt),
          String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
        ]
writeOpenCLScalar PyExp
_ PyExp
_ PrimType
_ String
space PyExp
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot write to '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

readOpenCLScalar :: Py.ReadScalar Imp.OpenCL ()
readOpenCLScalar :: ReadScalar OpenCL ()
readOpenCLScalar PyExp
mem PyExp
i PrimType
bt String
"device" = do
  VName
val <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"read_res"
  let val' :: PyExp
val' = String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
val
  let nparr :: PyExp
nparr =
        PyExp -> [PyArg] -> PyExp
Call
          (String -> PyExp
Var String
"np.empty")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer Integer
1,
            String -> PyExp -> PyArg
ArgKeyword String
"dtype" (String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimType PrimType
bt)
          ]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
val' PyExp
nparr
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> [PyArg] -> PyExp
Call
        (String -> PyExp
Var String
"cl.enqueue_copy")
        [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
          PyExp -> PyArg
Arg PyExp
val',
          PyExp -> PyArg
Arg PyExp
mem,
          String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp -> PyExp -> PyExp
BinOp String
"*" (PyExp -> PyExp
asLong PyExp
i) (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt),
          String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
        ]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"sync" [String -> PyExp
Var String
"self"]
  PyExp -> CompilerM OpenCL () PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM OpenCL () PyExp)
-> PyExp -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx -> PyExp
Index PyExp
val' (PyIdx -> PyExp) -> PyIdx -> PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer Integer
0
readOpenCLScalar PyExp
_ PyExp
_ PrimType
_ String
space =
  String -> CompilerM OpenCL () PyExp
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () PyExp)
-> String -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ String
"Cannot read from '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

allocateOpenCLBuffer :: Py.Allocate Imp.OpenCL ()
allocateOpenCLBuffer :: Allocate OpenCL ()
allocateOpenCLBuffer PyExp
mem PyExp
size String
"device" =
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assign PyExp
mem (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      String -> [PyExp] -> PyExp
Py.simpleCall String
"opencl_alloc" [String -> PyExp
Var String
"self", PyExp
size, String -> PyExp
String (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> String
forall a. Pretty a => a -> String
pretty PyExp
mem]
allocateOpenCLBuffer PyExp
_ PyExp
_ String
space =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot allocate in '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' space"

copyOpenCLMemory :: Py.Copy Imp.OpenCL ()
copyOpenCLMemory :: Copy OpenCL ()
copyOpenCLMemory PyExp
destmem PyExp
destidx Space
Imp.DefaultSpace PyExp
srcmem PyExp
srcidx (Imp.Space String
"device") PyExp
nbytes PrimType
bt = do
  let divide :: PyExp
divide = String -> PyExp -> PyExp -> PyExp
BinOp String
"//" PyExp
nbytes (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt)
      end :: PyExp
end = String -> PyExp -> PyExp -> PyExp
BinOp String
"+" PyExp
destidx PyExp
divide
      dest :: PyExp
dest = PyExp -> PyIdx -> PyExp
Index PyExp
destmem (PyExp -> PyExp -> PyIdx
IdxRange PyExp
destidx PyExp
end)
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> [PyArg] -> PyExp
Call
          (String -> PyExp
Var String
"cl.enqueue_copy")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
            PyExp -> PyArg
Arg PyExp
dest,
            PyExp -> PyArg
Arg PyExp
srcmem,
            String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
srcidx,
            String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
          ]
copyOpenCLMemory PyExp
destmem PyExp
destidx (Imp.Space String
"device") PyExp
srcmem PyExp
srcidx Space
Imp.DefaultSpace PyExp
nbytes PrimType
bt = do
  let divide :: PyExp
divide = String -> PyExp -> PyExp -> PyExp
BinOp String
"//" PyExp
nbytes (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt)
      end :: PyExp
end = String -> PyExp -> PyExp -> PyExp
BinOp String
"+" PyExp
srcidx PyExp
divide
      src :: PyExp
src = PyExp -> PyIdx -> PyExp
Index PyExp
srcmem (PyExp -> PyExp -> PyIdx
IdxRange PyExp
srcidx PyExp
end)
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> [PyArg] -> PyExp
Call
          (String -> PyExp
Var String
"cl.enqueue_copy")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
            PyExp -> PyArg
Arg PyExp
destmem,
            PyExp -> PyArg
Arg PyExp
src,
            String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
destidx,
            String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
          ]
copyOpenCLMemory PyExp
destmem PyExp
destidx (Imp.Space String
"device") PyExp
srcmem PyExp
srcidx (Imp.Space String
"device") PyExp
nbytes PrimType
_ = do
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> [PyArg] -> PyExp
Call
          (String -> PyExp
Var String
"cl.enqueue_copy")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
            PyExp -> PyArg
Arg PyExp
destmem,
            PyExp -> PyArg
Arg PyExp
srcmem,
            String -> PyExp -> PyArg
ArgKeyword String
"dest_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
destidx,
            String -> PyExp -> PyArg
ArgKeyword String
"src_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
srcidx,
            String -> PyExp -> PyArg
ArgKeyword String
"byte_count" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
nbytes
          ]
  CompilerM OpenCL () ()
forall op s. CompilerM op s ()
finishIfSynchronous
copyOpenCLMemory PyExp
destmem PyExp
destidx Space
Imp.DefaultSpace PyExp
srcmem PyExp
srcidx Space
Imp.DefaultSpace PyExp
nbytes PrimType
_ =
  PyExp -> PyExp -> PyExp -> PyExp -> PyExp -> CompilerM OpenCL () ()
forall op s.
PyExp -> PyExp -> PyExp -> PyExp -> PyExp -> CompilerM op s ()
Py.copyMemoryDefaultSpace PyExp
destmem PyExp
destidx PyExp
srcmem PyExp
srcidx PyExp
nbytes
copyOpenCLMemory PyExp
_ PyExp
_ Space
destspace PyExp
_ PyExp
_ Space
srcspace PyExp
_ PrimType
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot copy to " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
forall a. Show a => a -> String
show Space
destspace String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" from " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
forall a. Show a => a -> String
show Space
srcspace

staticOpenCLArray :: Py.StaticArray Imp.OpenCL ()
staticOpenCLArray :: StaticArray OpenCL ()
staticOpenCLArray VName
name String
"device" PrimType
t ArrayContents
vs = do
  (PyStmt -> CompilerM OpenCL () ())
-> [PyStmt] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.atInit ([PyStmt] -> CompilerM OpenCL () ())
-> (CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt])
-> CompilerM OpenCL () ()
-> CompilerM OpenCL () ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
Py.collect (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ do
    -- Create host-side Numpy array with intended values.
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
name') (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ case ArrayContents
vs of
        Imp.ArrayValues [PrimValue]
vs' ->
          PyExp -> [PyArg] -> PyExp
Call
            (String -> PyExp
Var String
"np.array")
            [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
List ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ (PrimValue -> PyExp) -> [PrimValue] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> PyExp
Py.compilePrimValue [PrimValue]
vs',
              String -> PyExp -> PyArg
ArgKeyword String
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimToNp PrimType
t
            ]
        Imp.ArrayZeros Int
n ->
          PyExp -> [PyArg] -> PyExp
Call
            (String -> PyExp
Var String
"np.zeros")
            [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n,
              String -> PyExp -> PyArg
ArgKeyword String
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimToNp PrimType
t
            ]

    let num_elems :: Int
num_elems = case ArrayContents
vs of
          Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
          Imp.ArrayZeros Int
n -> Int
n

    -- Create memory block on the device.
    VName
static_mem <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"static_mem"
    let size :: PyExp
size = Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
t
    Allocate OpenCL ()
allocateOpenCLBuffer (String -> PyExp
Var (VName -> String
Py.compileName VName
static_mem)) PyExp
size String
"device"

    -- Copy Numpy array to the device memory block.
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
size (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
          PyExp -> [PyArg] -> PyExp
Call
            (String -> PyExp
Var String
"cl.enqueue_copy")
            [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
              PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ VName -> String
Py.compileName VName
static_mem,
              PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"normaliseArray") [PyExp -> PyArg
Arg (String -> PyExp
Var String
name')],
              String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
            ]

    -- Store the memory block for later reference.
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyExp -> PyStmt
Assign (PyExp -> String -> PyExp
Field (String -> PyExp
Var String
"self") String
name') (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ VName -> String
Py.compileName VName
static_mem

  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
name') (PyExp -> String -> PyExp
Field (String -> PyExp
Var String
"self") String
name')
  where
    name' :: String
name' = VName -> String
Py.compileName VName
name
staticOpenCLArray VName
_ String
space PrimType
_ ArrayContents
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"PyOpenCL backend cannot create static array in memory space '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'"

packArrayOutput :: Py.EntryOutput Imp.OpenCL ()
packArrayOutput :: EntryOutput OpenCL ()
packArrayOutput VName
mem String
"device" PrimType
bt Signedness
ept [DimSize]
dims = do
  PyExp
mem' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
mem
  PyExp -> CompilerM OpenCL () PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM OpenCL () PyExp)
-> PyExp -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyArg] -> PyExp
Call
      (String -> PyExp
Var String
"cl.array.Array")
      [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
        PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
Tuple ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ (DimSize -> PyExp) -> [DimSize] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> PyExp
Py.compileDim [DimSize]
dims,
        PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> String
Py.compilePrimTypeExt PrimType
bt Signedness
ept,
        String -> PyExp -> PyArg
ArgKeyword String
"data" PyExp
mem'
      ]
packArrayOutput VName
_ String
sid PrimType
_ Signedness
_ [DimSize]
_ =
  String -> CompilerM OpenCL () PyExp
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () PyExp)
-> String -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ String
"Cannot return array from " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" space."

unpackArrayInput :: Py.EntryInput Imp.OpenCL ()
unpackArrayInput :: EntryInput OpenCL ()
unpackArrayInput PyExp
mem String
"device" PrimType
t Signedness
s [DimSize]
dims PyExp
e = do
  let type_is_ok :: PyExp
type_is_ok =
        String -> PyExp -> PyExp -> PyExp
BinOp
          String
"and"
          (String -> PyExp -> PyExp -> PyExp
BinOp String
"in" (String -> [PyExp] -> PyExp
Py.simpleCall String
"type" [PyExp
e]) ([PyExp] -> PyExp
List [String -> PyExp
Var String
"np.ndarray", String -> PyExp
Var String
"cl.array.Array"]))
          (String -> PyExp -> PyExp -> PyExp
BinOp String
"==" (PyExp -> String -> PyExp
Field PyExp
e String
"dtype") (String -> PyExp
Var (PrimType -> Signedness -> String
Py.compilePrimToExtNp PrimType
t Signedness
s)))
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assert PyExp
type_is_ok (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
String String
"Parameter has unexpected type"

  (DimSize -> Int32 -> CompilerM OpenCL () ())
-> [DimSize] -> [Int32] -> CompilerM OpenCL () ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (PyExp -> DimSize -> Int32 -> CompilerM OpenCL () ()
forall op s. PyExp -> DimSize -> Int32 -> CompilerM op s ()
Py.unpackDim PyExp
e) [DimSize]
dims [Int32
0 ..]

  let memsize' :: PyExp
memsize' = String -> [PyExp] -> PyExp
Py.simpleCall String
"np.int64" [PyExp -> String -> PyExp
Field PyExp
e String
"nbytes"]
      pyOpenCLArrayCase :: [PyStmt]
pyOpenCLArrayCase =
        [PyExp -> PyExp -> PyStmt
Assign PyExp
mem (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ PyExp -> String -> PyExp
Field PyExp
e String
"data"]
  [PyStmt]
numpyArrayCase <- CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
Py.collect (CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt])
-> CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall a b. (a -> b) -> a -> b
$ do
    Allocate OpenCL ()
allocateOpenCLBuffer PyExp
mem PyExp
memsize' String
"device"
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
memsize' (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
          PyExp -> [PyArg] -> PyExp
Call
            (String -> PyExp
Var String
"cl.enqueue_copy")
            [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
              PyExp -> PyArg
Arg PyExp
mem,
              PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"normaliseArray") [PyExp -> PyArg
Arg PyExp
e],
              String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
            ]

  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If
      (String -> PyExp -> PyExp -> PyExp
BinOp String
"==" (String -> [PyExp] -> PyExp
Py.simpleCall String
"type" [PyExp
e]) (String -> PyExp
Var String
"cl.array.Array"))
      [PyStmt]
pyOpenCLArrayCase
      [PyStmt]
numpyArrayCase
unpackArrayInput PyExp
_ String
sid PrimType
_ Signedness
_ [DimSize]
_ PyExp
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot accept array from " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" space."

ifNotZeroSize :: PyExp -> PyStmt -> PyStmt
ifNotZeroSize :: PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
e PyStmt
s =
  PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If (String -> PyExp -> PyExp -> PyExp
BinOp String
"!=" PyExp
e (Integer -> PyExp
Integer Integer
0)) [PyStmt
s] []

finishIfSynchronous :: Py.CompilerM op s ()
finishIfSynchronous :: CompilerM op s ()
finishIfSynchronous =
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If (String -> PyExp
Var String
"synchronous") [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"sync" [String -> PyExp
Var String
"self"]] []