-- | Some OpenCL platforms have a SIMD/warp/wavefront-based execution
-- model that execute groups of threads in lockstep, permitting us to
-- perform cross-thread synchronisation within each such group without
-- the use of barriers.  Unfortunately, there seems to be no reliable
-- way to query these sizes at runtime.  Instead, we use builtin
-- tables to figure out which size we should use for a specific
-- platform and device.  If nothing matches here, the wave size should
-- be set to one.
--
-- We also use this to select reasonable default group sizes and group
-- counts.
module Futhark.CodeGen.OpenCL.Heuristics
       ( SizeHeuristic (..)
       , DeviceType (..)
       , WhichSize (..)
       , DeviceInfo (..)
       , sizeHeuristicsTable
       )
       where

import Futhark.Analysis.PrimExp
import Futhark.Util.Pretty

-- | The type of OpenCL device that this heuristic applies to.
data DeviceType = DeviceCPU | DeviceGPU

-- | The value supplies by a heuristic can depend on some device
-- information.  This will be translated into a call to
-- @clGetDeviceInfo()@. Make sure to only request info that can be
-- casted to a scalar type.
newtype DeviceInfo = DeviceInfo String

instance Pretty DeviceInfo where
  ppr :: DeviceInfo -> Doc
ppr (DeviceInfo String
s) = String -> Doc
text String
"device_info" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (String -> Doc
forall a. Pretty a => a -> Doc
ppr String
s)

-- | A size that can be assigned a default.
data WhichSize = LockstepWidth | NumGroups | GroupSize | TileSize | Threshold

-- | A heuristic for setting the default value for something.
data SizeHeuristic =
    SizeHeuristic { SizeHeuristic -> String
platformName :: String
                  , SizeHeuristic -> DeviceType
deviceType :: DeviceType
                  , SizeHeuristic -> WhichSize
heuristicSize :: WhichSize
                  , SizeHeuristic -> PrimExp DeviceInfo
heuristicValue :: PrimExp DeviceInfo
                  }

-- | All of our heuristics.
sizeHeuristicsTable :: [SizeHeuristic]
sizeHeuristicsTable :: [SizeHeuristic]
sizeHeuristicsTable =
  [ String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"NVIDIA CUDA" DeviceType
DeviceGPU WhichSize
LockstepWidth (PrimExp DeviceInfo -> SizeHeuristic)
-> PrimExp DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ Int32 -> PrimExp DeviceInfo
forall v. Int32 -> PrimExp v
constant Int32
32
  , String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"AMD Accelerated Parallel Processing" DeviceType
DeviceGPU WhichSize
LockstepWidth (PrimExp DeviceInfo -> SizeHeuristic)
-> PrimExp DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ Int32 -> PrimExp DeviceInfo
forall v. Int32 -> PrimExp v
constant Int32
32
  , String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceGPU WhichSize
LockstepWidth (PrimExp DeviceInfo -> SizeHeuristic)
-> PrimExp DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ Int32 -> PrimExp DeviceInfo
forall v. Int32 -> PrimExp v
constant Int32
1
  -- We calculate the number of groups to aim for 1024 threads per
  -- compute unit if we also use the default group size.  This seems
  -- to perform well in practice.
  , String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceGPU WhichSize
NumGroups (PrimExp DeviceInfo -> SizeHeuristic)
-> PrimExp DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ PrimExp DeviceInfo
4 PrimExp DeviceInfo -> PrimExp DeviceInfo -> PrimExp DeviceInfo
forall a. Num a => a -> a -> a
* PrimExp DeviceInfo
max_compute_units
  , String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceGPU WhichSize
GroupSize (PrimExp DeviceInfo -> SizeHeuristic)
-> PrimExp DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ Int32 -> PrimExp DeviceInfo
forall v. Int32 -> PrimExp v
constant Int32
256
  , String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceGPU WhichSize
TileSize (PrimExp DeviceInfo -> SizeHeuristic)
-> PrimExp DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ Int32 -> PrimExp DeviceInfo
forall v. Int32 -> PrimExp v
constant Int32
32
  , String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceGPU WhichSize
Threshold (PrimExp DeviceInfo -> SizeHeuristic)
-> PrimExp DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ Int32 -> PrimExp DeviceInfo
forall v. Int32 -> PrimExp v
constant (Int32 -> PrimExp DeviceInfo) -> Int32 -> PrimExp DeviceInfo
forall a b. (a -> b) -> a -> b
$ Int32
32Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
*Int32
1024

  , String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceCPU WhichSize
LockstepWidth (PrimExp DeviceInfo -> SizeHeuristic)
-> PrimExp DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ Int32 -> PrimExp DeviceInfo
forall v. Int32 -> PrimExp v
constant Int32
1
  , String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceCPU WhichSize
NumGroups PrimExp DeviceInfo
max_compute_units
  , String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceCPU WhichSize
GroupSize (PrimExp DeviceInfo -> SizeHeuristic)
-> PrimExp DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ Int32 -> PrimExp DeviceInfo
forall v. Int32 -> PrimExp v
constant Int32
32
  , String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceCPU WhichSize
TileSize (PrimExp DeviceInfo -> SizeHeuristic)
-> PrimExp DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ Int32 -> PrimExp DeviceInfo
forall v. Int32 -> PrimExp v
constant Int32
4
  , String
-> DeviceType -> WhichSize -> PrimExp DeviceInfo -> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceCPU WhichSize
Threshold PrimExp DeviceInfo
max_compute_units
  ]
  where constant :: Int32 -> PrimExp v
constant = PrimValue -> PrimExp v
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> PrimExp v)
-> (Int32 -> PrimValue) -> Int32 -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntValue -> PrimValue
IntValue (IntValue -> PrimValue)
-> (Int32 -> IntValue) -> Int32 -> PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> IntValue
Int32Value
        max_compute_units :: PrimExp DeviceInfo
max_compute_units =
          DeviceInfo -> PrimType -> PrimExp DeviceInfo
forall v. v -> PrimType -> PrimExp v
LeafExp (String -> DeviceInfo
DeviceInfo String
"MAX_COMPUTE_UNITS") (PrimType -> PrimExp DeviceInfo) -> PrimType -> PrimExp DeviceInfo
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32