{-# LANGUAGE FlexibleContexts #-}
module Futhark.CodeGen.Backends.PyOpenCL
  ( compileProg
  ) where

import Control.Monad
import qualified Data.Map as M

import Futhark.IR.KernelsMem (Prog, KernelsMem)
import Futhark.CodeGen.Backends.PyOpenCL.Boilerplate
import qualified Futhark.CodeGen.Backends.GenericPython as Py
import qualified Futhark.CodeGen.ImpCode.OpenCL as Imp
import qualified Futhark.CodeGen.ImpGen.OpenCL as ImpGen
import Futhark.CodeGen.Backends.GenericPython.AST
import Futhark.CodeGen.Backends.GenericPython.Options
import Futhark.CodeGen.Backends.GenericPython.Definitions
import Futhark.MonadFreshNames

--maybe pass the config file rather than multiple arguments
compileProg :: MonadFreshNames m =>
               Maybe String -> Prog KernelsMem -> m String
compileProg :: Maybe String -> Prog KernelsMem -> m String
compileProg Maybe String
module_name Prog KernelsMem
prog = do
  Imp.Program String
opencl_code String
opencl_prelude Map String Safety
kernels [PrimType]
types Map Name SizeClass
sizes [FailureMsg]
failures Definitions OpenCL
prog' <-
    Prog KernelsMem -> m Program
forall (m :: * -> *).
MonadFreshNames m =>
Prog KernelsMem -> m Program
ImpGen.compileProg Prog KernelsMem
prog
  --prepare the strings for assigning the kernels and set them as global
  let assign :: String
assign = [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$
               (String -> String) -> [String] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\String
x -> PyStmt -> String
forall a. Pretty a => a -> String
pretty (PyStmt -> String) -> PyStmt -> String
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var (String
"self."String -> String -> String
forall a. [a] -> [a] -> [a]
++String
xString -> String -> String
forall a. [a] -> [a] -> [a]
++String
"_var"))
                          (String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ String
"program."String -> String -> String
forall a. [a] -> [a] -> [a]
++String
x)) ([String] -> [String]) -> [String] -> [String]
forall a b. (a -> b) -> a -> b
$
        Map String Safety -> [String]
forall k a. Map k a -> [k]
M.keys Map String Safety
kernels

  let defines :: [PyStmt]
defines =
        [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"synchronous") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ Bool -> PyExp
Bool Bool
False,
         PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_platform") PyExp
None,
         PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_device") PyExp
None,
         PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_threshold") PyExp
None,
         PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_group_size") PyExp
None,
         PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_num_groups") PyExp
None,
         PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_tile_size") PyExp
None,
         PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"fut_opencl_src") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
RawStringLiteral (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ String
opencl_prelude String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
opencl_code,
         String -> PyStmt
Escape String
pyValues,
         String -> PyStmt
Escape String
pyFunctions,
         String -> PyStmt
Escape String
pyPanic,
         String -> PyStmt
Escape String
pyTuning]

  let imports :: [PyStmt]
imports = [String -> Maybe String -> PyStmt
Import String
"sys" Maybe String
forall a. Maybe a
Nothing,
                 String -> Maybe String -> PyStmt
Import String
"numpy" (Maybe String -> PyStmt) -> Maybe String -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> Maybe String
forall a. a -> Maybe a
Just String
"np",
                 String -> Maybe String -> PyStmt
Import String
"ctypes" (Maybe String -> PyStmt) -> Maybe String -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> Maybe String
forall a. a -> Maybe a
Just String
"ct",
                 String -> PyStmt
Escape String
openClPrelude,
                 String -> Maybe String -> PyStmt
Import String
"pyopencl.array" Maybe String
forall a. Maybe a
Nothing,
                 String -> Maybe String -> PyStmt
Import String
"time" Maybe String
forall a. Maybe a
Nothing]

  let constructor :: Constructor
constructor = [String] -> [PyStmt] -> Constructor
Py.Constructor [ String
"self"
                                   , String
"command_queue=None"
                                   , String
"interactive=False"
                                   , String
"platform_pref=preferred_platform"
                                   , String
"device_pref=preferred_device"
                                   , String
"default_group_size=default_group_size"
                                   , String
"default_num_groups=default_num_groups"
                                   , String
"default_tile_size=default_tile_size"
                                   , String
"default_threshold=default_threshold"
                                   , String
"sizes=sizes"]
                    [String -> PyStmt
Escape (String -> PyStmt) -> String -> PyStmt
forall a b. (a -> b) -> a -> b
$ [PrimType]
-> String -> Map Name SizeClass -> [FailureMsg] -> String
openClInit [PrimType]
types String
assign Map Name SizeClass
sizes [FailureMsg]
failures]
      options :: [Option]
options = [ Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option { optionLongName :: String
optionLongName = String
"platform"
                         , optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'p'
                         , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"str"
                         , optionAction :: [PyStmt]
optionAction =
                           [ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_platform") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg" ]
                         }
                , Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option { optionLongName :: String
optionLongName = String
"device"
                         , optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'd'
                         , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"str"
                         , optionAction :: [PyStmt]
optionAction =
                           [ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_device") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg" ]
                         }
                , Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option { optionLongName :: String
optionLongName = String
"default-threshold"
                         , optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing
                         , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int"
                         , optionAction :: [PyStmt]
optionAction =
                           [ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_threshold") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg" ]
                         }
                , Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option { optionLongName :: String
optionLongName = String
"default-group-size"
                         , optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing
                         , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int"
                         , optionAction :: [PyStmt]
optionAction =
                           [ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_group_size") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg" ]
                         }
                , Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option { optionLongName :: String
optionLongName = String
"default-num-groups"
                         , optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing
                         , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int"
                         , optionAction :: [PyStmt]
optionAction =
                           [ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_num_groups") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg" ]
                         }
                , Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option { optionLongName :: String
optionLongName = String
"default-tile-size"
                         , optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing
                         , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int"
                         , optionAction :: [PyStmt]
optionAction =
                           [ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_tile_size") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg" ]
                         }
                , Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option { optionLongName :: String
optionLongName = String
"size"
                         , optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing
                         , optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"size_assignment"
                         , optionAction :: [PyStmt]
optionAction =
                             [PyExp -> PyExp -> PyStmt
Assign (PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"sizes")
                                      (PyExp -> PyIdx
IdxExp (PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"optarg")
                                               (PyExp -> PyIdx
IdxExp (Integer -> PyExp
Integer Integer
0)))))
                               (PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"optarg") (PyExp -> PyIdx
IdxExp (Integer -> PyExp
Integer Integer
1)))
                             ]
                         }
                ]

  Maybe String
-> Constructor
-> [PyStmt]
-> [PyStmt]
-> Operations OpenCL ()
-> ()
-> [PyStmt]
-> [Option]
-> Definitions OpenCL
-> m String
forall (m :: * -> *) op s.
MonadFreshNames m =>
Maybe String
-> Constructor
-> [PyStmt]
-> [PyStmt]
-> Operations op s
-> s
-> [PyStmt]
-> [Option]
-> Definitions op
-> m String
Py.compileProg Maybe String
module_name Constructor
constructor [PyStmt]
imports [PyStmt]
defines Operations OpenCL ()
operations ()
    [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"sync" [String -> PyExp
Var String
"self"]] [Option]
options Definitions OpenCL
prog'
  where operations :: Py.Operations Imp.OpenCL ()
        operations :: Operations OpenCL ()
operations = Operations :: forall op s.
WriteScalar op s
-> ReadScalar op s
-> Allocate op s
-> Copy op s
-> StaticArray op s
-> OpCompiler op s
-> EntryOutput op s
-> EntryInput op s
-> Operations op s
Py.Operations
                     { opsCompiler :: OpCompiler OpenCL ()
Py.opsCompiler = OpCompiler OpenCL ()
callKernel
                     , opsWriteScalar :: WriteScalar OpenCL ()
Py.opsWriteScalar = WriteScalar OpenCL ()
writeOpenCLScalar
                     , opsReadScalar :: ReadScalar OpenCL ()
Py.opsReadScalar = ReadScalar OpenCL ()
readOpenCLScalar
                     , opsAllocate :: Allocate OpenCL ()
Py.opsAllocate = Allocate OpenCL ()
allocateOpenCLBuffer
                     , opsCopy :: Copy OpenCL ()
Py.opsCopy = Copy OpenCL ()
copyOpenCLMemory
                     , opsStaticArray :: StaticArray OpenCL ()
Py.opsStaticArray = StaticArray OpenCL ()
staticOpenCLArray
                     , opsEntryOutput :: EntryOutput OpenCL ()
Py.opsEntryOutput = EntryOutput OpenCL ()
packArrayOutput
                     , opsEntryInput :: EntryInput OpenCL ()
Py.opsEntryInput = EntryInput OpenCL ()
unpackArrayInput
                     }

-- We have many casts to 'long', because PyOpenCL may get confused at
-- the 32-bit numbers that ImpCode uses for offsets and the like.
asLong :: PyExp -> PyExp
asLong :: PyExp -> PyExp
asLong PyExp
x = String -> [PyExp] -> PyExp
Py.simpleCall String
"np.long" [PyExp
x]

callKernel :: Py.OpCompiler Imp.OpenCL ()
callKernel :: OpCompiler OpenCL ()
callKernel (Imp.GetSize VName
v Name
key) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"self.sizes") (PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ String -> PyExp
String (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ Name -> String
forall a. Pretty a => a -> String
pretty Name
key)
callKernel (Imp.CmpSizeLe VName
v Name
key Exp
x) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyExp
x' <- Exp -> CompilerM OpenCL () PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
x
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
    String -> PyExp -> PyExp -> PyExp
BinOp String
"<=" (PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"self.sizes") (PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ String -> PyExp
String (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ Name -> String
forall a. Pretty a => a -> String
pretty Name
key)) PyExp
x'
callKernel (Imp.GetSizeMax VName
v SizeClass
size_class) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
    String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ String
"self.max_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ SizeClass -> String
forall a. Pretty a => a -> String
pretty SizeClass
size_class

callKernel (Imp.LaunchKernel Safety
safety String
name [KernelArg]
args [Exp]
num_workgroups [Exp]
workgroup_size) = do
  [PyExp]
num_workgroups' <- (Exp -> CompilerM OpenCL () PyExp)
-> [Exp] -> CompilerM OpenCL () [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((PyExp -> PyExp)
-> CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PyExp -> PyExp
asLong (CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp)
-> (Exp -> CompilerM OpenCL () PyExp)
-> Exp
-> CompilerM OpenCL () PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> CompilerM OpenCL () PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp) [Exp]
num_workgroups
  [PyExp]
workgroup_size' <- (Exp -> CompilerM OpenCL () PyExp)
-> [Exp] -> CompilerM OpenCL () [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((PyExp -> PyExp)
-> CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PyExp -> PyExp
asLong (CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp)
-> (Exp -> CompilerM OpenCL () PyExp)
-> Exp
-> CompilerM OpenCL () PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> CompilerM OpenCL () PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp) [Exp]
workgroup_size
  let kernel_size :: [PyExp]
kernel_size = (PyExp -> PyExp -> PyExp) -> [PyExp] -> [PyExp] -> [PyExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PyExp -> PyExp -> PyExp
mult_exp [PyExp]
num_workgroups' [PyExp]
workgroup_size'
      total_elements :: PyExp
total_elements = (PyExp -> PyExp -> PyExp) -> PyExp -> [PyExp] -> PyExp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PyExp -> PyExp -> PyExp
mult_exp (Integer -> PyExp
Integer Integer
1) [PyExp]
kernel_size
      cond :: PyExp
cond = String -> PyExp -> PyExp -> PyExp
BinOp String
"!=" PyExp
total_elements (Integer -> PyExp
Integer Integer
0)

  [PyStmt]
body <- CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
Py.collect (CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt])
-> CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall a b. (a -> b) -> a -> b
$ String
-> Safety
-> [PyExp]
-> [PyExp]
-> [KernelArg]
-> CompilerM OpenCL () ()
forall op s.
String
-> Safety -> [PyExp] -> [PyExp] -> [KernelArg] -> CompilerM op s ()
launchKernel String
name Safety
safety [PyExp]
kernel_size [PyExp]
workgroup_size' [KernelArg]
args
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If PyExp
cond [PyStmt]
body []

  Bool -> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Safety
safety Safety -> Safety -> Bool
forall a. Ord a => a -> a -> Bool
>= Safety
Imp.SafetyFull) (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"self.failure_is_an_option") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
    PrimValue -> PyExp
Py.compilePrimValue (IntValue -> PrimValue
Imp.IntValue (Int32 -> IntValue
Imp.Int32Value Int32
1))

  where mult_exp :: PyExp -> PyExp -> PyExp
mult_exp = String -> PyExp -> PyExp -> PyExp
BinOp String
"*"

launchKernel :: String -> Imp.Safety -> [PyExp] -> [PyExp] -> [Imp.KernelArg]
             -> Py.CompilerM op s ()
launchKernel :: String
-> Safety -> [PyExp] -> [PyExp] -> [KernelArg] -> CompilerM op s ()
launchKernel String
kernel_name Safety
safety [PyExp]
kernel_dims [PyExp]
workgroup_dims [KernelArg]
args = do
  let kernel_dims' :: PyExp
kernel_dims' = [PyExp] -> PyExp
Tuple [PyExp]
kernel_dims
      workgroup_dims' :: PyExp
workgroup_dims' = [PyExp] -> PyExp
Tuple [PyExp]
workgroup_dims
      kernel_name' :: String
kernel_name' = String
"self." String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
kernel_name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_var"
  [PyExp]
args' <- (KernelArg -> CompilerM op s PyExp)
-> [KernelArg] -> CompilerM op s [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelArg -> CompilerM op s PyExp
forall op s. KernelArg -> CompilerM op s PyExp
processKernelArg [KernelArg]
args
  let failure_args :: [PyExp]
failure_args = Int -> [PyExp] -> [PyExp]
forall a. Int -> [a] -> [a]
take (Safety -> Int
Imp.numFailureParams Safety
safety)
                     [String -> PyExp
Var String
"self.global_failure",
                      String -> PyExp
Var String
"self.failure_is_an_option",
                      String -> PyExp
Var String
"self.global_failure_args"]
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall (String
kernel_name' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
".set_args") ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$
    [PyExp]
failure_args [PyExp] -> [PyExp] -> [PyExp]
forall a. [a] -> [a] -> [a]
++ [PyExp]
args'
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"cl.enqueue_nd_range_kernel"
    [String -> PyExp
Var String
"self.queue", String -> PyExp
Var String
kernel_name', PyExp
kernel_dims', PyExp
workgroup_dims']
  CompilerM op s ()
forall op s. CompilerM op s ()
finishIfSynchronous
  where processKernelArg :: Imp.KernelArg -> Py.CompilerM op s PyExp
        processKernelArg :: KernelArg -> CompilerM op s PyExp
processKernelArg (Imp.ValueKArg Exp
e PrimType
bt) = do
          PyExp
e' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
e
          PyExp -> CompilerM op s PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall (PrimType -> String
Py.compilePrimToNp PrimType
bt) [PyExp
e']
        processKernelArg (Imp.MemKArg VName
v) = VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
        processKernelArg (Imp.SharedMemoryKArg (Imp.Count Exp
num_bytes)) = do
          PyExp
num_bytes' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
num_bytes
          PyExp -> CompilerM op s PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"cl.LocalMemory" [PyExp -> PyExp
asLong PyExp
num_bytes']

writeOpenCLScalar :: Py.WriteScalar Imp.OpenCL ()
writeOpenCLScalar :: WriteScalar OpenCL ()
writeOpenCLScalar PyExp
mem PyExp
i PrimType
bt String
"device" PyExp
val = do
  let nparr :: PyExp
nparr = PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"np.array")
              [PyExp -> PyArg
Arg PyExp
val, String -> PyExp -> PyArg
ArgKeyword String
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimType PrimType
bt]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"cl.enqueue_copy")
    [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue", PyExp -> PyArg
Arg PyExp
mem, PyExp -> PyArg
Arg PyExp
nparr,
     String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp -> PyExp -> PyExp
BinOp String
"*" (PyExp -> PyExp
asLong PyExp
i) (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt),
     String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"]

writeOpenCLScalar PyExp
_ PyExp
_ PrimType
_ String
space PyExp
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot write to '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

readOpenCLScalar :: Py.ReadScalar Imp.OpenCL ()
readOpenCLScalar :: ReadScalar OpenCL ()
readOpenCLScalar PyExp
mem PyExp
i PrimType
bt String
"device" = do
  VName
val <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"read_res"
  let val' :: PyExp
val' = String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
val
  let nparr :: PyExp
nparr = PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"np.empty")
              [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer Integer
1,
               String -> PyExp -> PyArg
ArgKeyword String
"dtype" (String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimType PrimType
bt)]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
val' PyExp
nparr
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"cl.enqueue_copy")
    [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue", PyExp -> PyArg
Arg PyExp
val', PyExp -> PyArg
Arg PyExp
mem,
     String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp -> PyExp -> PyExp
BinOp String
"*" (PyExp -> PyExp
asLong PyExp
i) (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt),
     String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"sync" [String -> PyExp
Var String
"self"]
  PyExp -> CompilerM OpenCL () PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM OpenCL () PyExp)
-> PyExp -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx -> PyExp
Index PyExp
val' (PyIdx -> PyExp) -> PyIdx -> PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer Integer
0

readOpenCLScalar PyExp
_ PyExp
_ PrimType
_ String
space =
  String -> CompilerM OpenCL () PyExp
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () PyExp)
-> String -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ String
"Cannot read from '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

allocateOpenCLBuffer :: Py.Allocate Imp.OpenCL ()
allocateOpenCLBuffer :: Allocate OpenCL ()
allocateOpenCLBuffer PyExp
mem PyExp
size String
"device" =
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
mem (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
  String -> [PyExp] -> PyExp
Py.simpleCall String
"opencl_alloc" [String -> PyExp
Var String
"self", PyExp
size, String -> PyExp
String (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> String
forall a. Pretty a => a -> String
pretty PyExp
mem]

allocateOpenCLBuffer PyExp
_ PyExp
_ String
space =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot allocate in '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' space"

copyOpenCLMemory :: Py.Copy Imp.OpenCL ()
copyOpenCLMemory :: Copy OpenCL ()
copyOpenCLMemory PyExp
destmem PyExp
destidx Space
Imp.DefaultSpace PyExp
srcmem PyExp
srcidx (Imp.Space String
"device") PyExp
nbytes PrimType
bt = do
  let divide :: PyExp
divide = String -> PyExp -> PyExp -> PyExp
BinOp String
"//" PyExp
nbytes (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt)
      end :: PyExp
end = String -> PyExp -> PyExp -> PyExp
BinOp String
"+" PyExp
destidx PyExp
divide
      dest :: PyExp
dest = PyExp -> PyIdx -> PyExp
Index PyExp
destmem (PyExp -> PyExp -> PyIdx
IdxRange PyExp
destidx PyExp
end)
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"cl.enqueue_copy")
    [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue", PyExp -> PyArg
Arg PyExp
dest, PyExp -> PyArg
Arg PyExp
srcmem,
     String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
srcidx,
     String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"]

copyOpenCLMemory PyExp
destmem PyExp
destidx (Imp.Space String
"device") PyExp
srcmem PyExp
srcidx Space
Imp.DefaultSpace PyExp
nbytes PrimType
bt = do
  let divide :: PyExp
divide = String -> PyExp -> PyExp -> PyExp
BinOp String
"//" PyExp
nbytes (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt)
      end :: PyExp
end = String -> PyExp -> PyExp -> PyExp
BinOp String
"+" PyExp
srcidx PyExp
divide
      src :: PyExp
src = PyExp -> PyIdx -> PyExp
Index PyExp
srcmem (PyExp -> PyExp -> PyIdx
IdxRange PyExp
srcidx PyExp
end)
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"cl.enqueue_copy")
    [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue", PyExp -> PyArg
Arg PyExp
destmem, PyExp -> PyArg
Arg PyExp
src,
     String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
destidx,
     String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"]

copyOpenCLMemory PyExp
destmem PyExp
destidx (Imp.Space String
"device") PyExp
srcmem PyExp
srcidx (Imp.Space String
"device") PyExp
nbytes PrimType
_ = do
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"cl.enqueue_copy")
    [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue", PyExp -> PyArg
Arg PyExp
destmem, PyExp -> PyArg
Arg PyExp
srcmem,
     String -> PyExp -> PyArg
ArgKeyword String
"dest_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
destidx,
     String -> PyExp -> PyArg
ArgKeyword String
"src_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
srcidx,
     String -> PyExp -> PyArg
ArgKeyword String
"byte_count" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
nbytes]
  CompilerM OpenCL () ()
forall op s. CompilerM op s ()
finishIfSynchronous

copyOpenCLMemory PyExp
destmem PyExp
destidx Space
Imp.DefaultSpace PyExp
srcmem PyExp
srcidx Space
Imp.DefaultSpace PyExp
nbytes PrimType
_ =
  PyExp -> PyExp -> PyExp -> PyExp -> PyExp -> CompilerM OpenCL () ()
forall op s.
PyExp -> PyExp -> PyExp -> PyExp -> PyExp -> CompilerM op s ()
Py.copyMemoryDefaultSpace PyExp
destmem PyExp
destidx PyExp
srcmem PyExp
srcidx PyExp
nbytes

copyOpenCLMemory PyExp
_ PyExp
_ Space
destspace PyExp
_ PyExp
_ Space
srcspace PyExp
_ PrimType
_=
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot copy to " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
forall a. Show a => a -> String
show Space
destspace String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" from " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
forall a. Show a => a -> String
show Space
srcspace

staticOpenCLArray :: Py.StaticArray Imp.OpenCL ()
staticOpenCLArray :: StaticArray OpenCL ()
staticOpenCLArray VName
name String
"device" PrimType
t ArrayContents
vs = do
  (PyStmt -> CompilerM OpenCL () ())
-> [PyStmt] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.atInit ([PyStmt] -> CompilerM OpenCL () ())
-> (CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt])
-> CompilerM OpenCL () ()
-> CompilerM OpenCL () ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
Py.collect (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ do
    -- Create host-side Numpy array with intended values.
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
name') (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ case ArrayContents
vs of
      Imp.ArrayValues [PrimValue]
vs' ->
        PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"np.array")
        [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
List ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ (PrimValue -> PyExp) -> [PrimValue] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> PyExp
Py.compilePrimValue [PrimValue]
vs',
         String -> PyExp -> PyArg
ArgKeyword String
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimToNp PrimType
t]
      Imp.ArrayZeros Int
n ->
        PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"np.zeros")
        [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n,
         String -> PyExp -> PyArg
ArgKeyword String
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimToNp PrimType
t]

    let num_elems :: Int
num_elems = case ArrayContents
vs of Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
                               Imp.ArrayZeros Int
n -> Int
n

    -- Create memory block on the device.
    VName
static_mem <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"static_mem"
    let size :: PyExp
size = Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
t
    Allocate OpenCL ()
allocateOpenCLBuffer (String -> PyExp
Var (VName -> String
Py.compileName VName
static_mem)) PyExp
size String
"device"

    -- Copy Numpy array to the device memory block.
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
size (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"cl.enqueue_copy")
      [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
       PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ VName -> String
Py.compileName VName
static_mem,
       PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"normaliseArray") [PyExp -> PyArg
Arg (String -> PyExp
Var String
name')],
       String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"]

    -- Store the memory block for later reference.
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign (PyExp -> String -> PyExp
Field (String -> PyExp
Var String
"self") String
name') (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ VName -> String
Py.compileName VName
static_mem

  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
name') (PyExp -> String -> PyExp
Field (String -> PyExp
Var String
"self") String
name')
  where name' :: String
name' = VName -> String
Py.compileName VName
name
staticOpenCLArray VName
_ String
space PrimType
_ ArrayContents
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"PyOpenCL backend cannot create static array in memory space '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'"

packArrayOutput :: Py.EntryOutput Imp.OpenCL ()
packArrayOutput :: EntryOutput OpenCL ()
packArrayOutput VName
mem String
"device" PrimType
bt Signedness
ept [DimSize]
dims = do
  PyExp
mem' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
mem
  PyExp -> CompilerM OpenCL () PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM OpenCL () PyExp)
-> PyExp -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"cl.array.Array")
    [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
     PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
Tuple ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ (DimSize -> PyExp) -> [DimSize] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> PyExp
Py.compileDim [DimSize]
dims,
     PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> String
Py.compilePrimTypeExt PrimType
bt Signedness
ept,
     String -> PyExp -> PyArg
ArgKeyword String
"data" PyExp
mem']
packArrayOutput VName
_ String
sid PrimType
_ Signedness
_ [DimSize]
_ =
  String -> CompilerM OpenCL () PyExp
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () PyExp)
-> String -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ String
"Cannot return array from " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" space."

unpackArrayInput :: Py.EntryInput Imp.OpenCL ()
unpackArrayInput :: EntryInput OpenCL ()
unpackArrayInput PyExp
mem String
"device" PrimType
t Signedness
s [DimSize]
dims PyExp
e = do
  let type_is_ok :: PyExp
type_is_ok =
        String -> PyExp -> PyExp -> PyExp
BinOp String
"and"
        (String -> PyExp -> PyExp -> PyExp
BinOp String
"in" (String -> [PyExp] -> PyExp
Py.simpleCall String
"type" [PyExp
e]) ([PyExp] -> PyExp
List [String -> PyExp
Var String
"np.ndarray", String -> PyExp
Var String
"cl.array.Array"]))
        (String -> PyExp -> PyExp -> PyExp
BinOp String
"==" (PyExp -> String -> PyExp
Field PyExp
e String
"dtype") (String -> PyExp
Var (PrimType -> Signedness -> String
Py.compilePrimToExtNp PrimType
t Signedness
s)))
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assert PyExp
type_is_ok (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
String String
"Parameter has unexpected type"

  (DimSize -> Int32 -> CompilerM OpenCL () ())
-> [DimSize] -> [Int32] -> CompilerM OpenCL () ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (PyExp -> DimSize -> Int32 -> CompilerM OpenCL () ()
forall op s. PyExp -> DimSize -> Int32 -> CompilerM op s ()
Py.unpackDim PyExp
e) [DimSize]
dims [Int32
0..]

  let memsize' :: PyExp
memsize' = String -> [PyExp] -> PyExp
Py.simpleCall String
"np.int64" [PyExp -> String -> PyExp
Field PyExp
e String
"nbytes"]
      pyOpenCLArrayCase :: [PyStmt]
pyOpenCLArrayCase =
        [PyExp -> PyExp -> PyStmt
Assign PyExp
mem (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ PyExp -> String -> PyExp
Field PyExp
e String
"data"]
  [PyStmt]
numpyArrayCase <- CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
Py.collect (CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt])
-> CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall a b. (a -> b) -> a -> b
$ do
    Allocate OpenCL ()
allocateOpenCLBuffer PyExp
mem PyExp
memsize' String
"device"
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
memsize' (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"cl.enqueue_copy")
      [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
       PyExp -> PyArg
Arg PyExp
mem,
       PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"normaliseArray") [PyExp -> PyArg
Arg PyExp
e],
       String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"]

  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If (String -> PyExp -> PyExp -> PyExp
BinOp String
"==" (String -> [PyExp] -> PyExp
Py.simpleCall String
"type" [PyExp
e]) (String -> PyExp
Var String
"cl.array.Array"))
    [PyStmt]
pyOpenCLArrayCase
    [PyStmt]
numpyArrayCase
unpackArrayInput PyExp
_ String
sid PrimType
_ Signedness
_ [DimSize]
_ PyExp
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot accept array from " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" space."

ifNotZeroSize :: PyExp -> PyStmt -> PyStmt
ifNotZeroSize :: PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
e PyStmt
s =
  PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If (String -> PyExp -> PyExp -> PyExp
BinOp String
"!=" PyExp
e (Integer -> PyExp
Integer Integer
0)) [PyStmt
s] []

finishIfSynchronous :: Py.CompilerM op s ()
finishIfSynchronous :: CompilerM op s ()
finishIfSynchronous =
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If (String -> PyExp
Var String
"synchronous") [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"sync" [String -> PyExp
Var String
"self"]] []