-- | Some OpenCL platforms have a SIMD/warp/wavefront-based execution
-- model that execute groups of threads in lockstep, permitting us to
-- perform cross-thread synchronisation within each such group without
-- the use of barriers.  Unfortunately, there seems to be no reliable
-- way to query these sizes at runtime.  Instead, we use builtin
-- tables to figure out which size we should use for a specific
-- platform and device.  If nothing matches here, the wave size should
-- be set to one.
--
-- We also use this to select reasonable default group sizes and group
-- counts.
module Futhark.CodeGen.OpenCL.Heuristics
  ( SizeHeuristic (..),
    DeviceType (..),
    WhichSize (..),
    DeviceInfo (..),
    sizeHeuristicsTable,
  )
where

import Futhark.Analysis.PrimExp
import Futhark.Util.Pretty

-- | The type of OpenCL device that this heuristic applies to.
data DeviceType = DeviceCPU | DeviceGPU

-- | The value supplies by a heuristic can depend on some device
-- information.  This will be translated into a call to
-- @clGetDeviceInfo()@. Make sure to only request info that can be
-- casted to a scalar type.
newtype DeviceInfo = DeviceInfo String

instance Pretty DeviceInfo where
  ppr :: DeviceInfo -> Doc
ppr (DeviceInfo String
s) = String -> Doc
text String
"device_info" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (String -> Doc
forall a. Pretty a => a -> Doc
ppr String
s)

-- | A size that can be assigned a default.
data WhichSize = LockstepWidth | NumGroups | GroupSize | TileSize | RegTileSize | Threshold

-- | A heuristic for setting the default value for something.
data SizeHeuristic = SizeHeuristic
  { SizeHeuristic -> String
platformName :: String,
    SizeHeuristic -> DeviceType
deviceType :: DeviceType,
    SizeHeuristic -> WhichSize
heuristicSize :: WhichSize,
    SizeHeuristic -> TPrimExp Int32 DeviceInfo
heuristicValue :: TPrimExp Int32 DeviceInfo
  }

-- | All of our heuristics.
sizeHeuristicsTable :: [SizeHeuristic]
sizeHeuristicsTable :: [SizeHeuristic]
sizeHeuristicsTable =
  [ String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"NVIDIA CUDA" DeviceType
DeviceGPU WhichSize
LockstepWidth TPrimExp Int32 DeviceInfo
32,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"AMD Accelerated Parallel Processing" DeviceType
DeviceGPU WhichSize
LockstepWidth TPrimExp Int32 DeviceInfo
32,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceGPU WhichSize
LockstepWidth TPrimExp Int32 DeviceInfo
1,
    -- We calculate the number of groups to aim for 1024 threads per
    -- compute unit if we also use the default group size.  This seems
    -- to perform well in practice.
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceGPU WhichSize
NumGroups (TPrimExp Int32 DeviceInfo -> SizeHeuristic)
-> TPrimExp Int32 DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 DeviceInfo
4 TPrimExp Int32 DeviceInfo
-> TPrimExp Int32 DeviceInfo -> TPrimExp Int32 DeviceInfo
forall a. Num a => a -> a -> a
* TPrimExp Int32 DeviceInfo
forall t. TPrimExp t DeviceInfo
max_compute_units,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceGPU WhichSize
GroupSize TPrimExp Int32 DeviceInfo
256,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceGPU WhichSize
TileSize TPrimExp Int32 DeviceInfo
32,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceGPU WhichSize
RegTileSize TPrimExp Int32 DeviceInfo
2,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceGPU WhichSize
Threshold (TPrimExp Int32 DeviceInfo -> SizeHeuristic)
-> TPrimExp Int32 DeviceInfo -> SizeHeuristic
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 DeviceInfo
32 TPrimExp Int32 DeviceInfo
-> TPrimExp Int32 DeviceInfo -> TPrimExp Int32 DeviceInfo
forall a. Num a => a -> a -> a
* TPrimExp Int32 DeviceInfo
1024,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceCPU WhichSize
LockstepWidth TPrimExp Int32 DeviceInfo
1,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceCPU WhichSize
NumGroups TPrimExp Int32 DeviceInfo
forall t. TPrimExp t DeviceInfo
max_compute_units,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceCPU WhichSize
GroupSize TPrimExp Int32 DeviceInfo
32,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceCPU WhichSize
TileSize TPrimExp Int32 DeviceInfo
4,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceCPU WhichSize
RegTileSize TPrimExp Int32 DeviceInfo
1,
    String
-> DeviceType
-> WhichSize
-> TPrimExp Int32 DeviceInfo
-> SizeHeuristic
SizeHeuristic String
"" DeviceType
DeviceCPU WhichSize
Threshold TPrimExp Int32 DeviceInfo
forall t. TPrimExp t DeviceInfo
max_compute_units
  ]
  where
    max_compute_units :: TPrimExp t DeviceInfo
max_compute_units =
      PrimExp DeviceInfo -> TPrimExp t DeviceInfo
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp DeviceInfo -> TPrimExp t DeviceInfo)
-> PrimExp DeviceInfo -> TPrimExp t DeviceInfo
forall a b. (a -> b) -> a -> b
$ DeviceInfo -> PrimType -> PrimExp DeviceInfo
forall v. v -> PrimType -> PrimExp v
LeafExp (String -> DeviceInfo
DeviceInfo String
"MAX_COMPUTE_UNITS") (PrimType -> PrimExp DeviceInfo) -> PrimType -> PrimExp DeviceInfo
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32