{-# LANGUAGE QuasiQuotes     #-}
{-# LANGUAGE TemplateHaskell #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Analysis.Launch
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Analysis.Launch (

  DeviceProperties, Occupancy, LaunchConfig,
  simpleLaunchConfig, launchConfig,
  multipleOf, multipleOfQ,

) where

import Foreign.CUDA.Analysis                            as CUDA
import Language.Haskell.TH


-- | Given information about the resource usage of the compiled kernel,
-- determine the optimum launch parameters.
--
type LaunchConfig
  =  Int                            -- maximum #threads per block
  -> Int                            -- #registers per thread
  -> Int                            -- #bytes of static shared memory
  -> ( Occupancy
     , Int                          -- thread block size
     , Int -> Int                   -- grid size required to process the given input size
     , Int                          -- #bytes dynamic shared memory
     , Q (TExp (Int -> Int))
     )

-- | Analytics for a simple kernel which requires no additional shared memory or
-- have other constraints on launch configuration. The smallest thread block
-- size, in increments of a single warp, with the highest occupancy is used.
--
simpleLaunchConfig :: DeviceProperties -> LaunchConfig
simpleLaunchConfig :: DeviceProperties -> LaunchConfig
simpleLaunchConfig DeviceProperties
dev = DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev (DeviceProperties -> [Int]
decWarp DeviceProperties
dev) (Int -> Int -> Int
forall a b. a -> b -> a
const Int
0) Int -> Int -> Int
multipleOf Q (TExp (Int -> Int -> Int))
multipleOfQ


-- | Determine the optimal kernel launch configuration for a kernel.
--
launchConfig
    :: DeviceProperties             -- ^ Device architecture to optimise for
    -> [Int]                        -- ^ Thread block sizes to consider
    -> (Int -> Int)                 -- ^ Shared memory (#bytes) as a function of thread block size
    -> (Int -> Int -> Int)          -- ^ Determine grid size for input size 'n' (first arg) over thread blocks of size 'm' (second arg)
    -> Q (TExp (Int -> Int -> Int))
    -> LaunchConfig
launchConfig :: DeviceProperties
-> [Int]
-> (Int -> Int)
-> (Int -> Int -> Int)
-> Q (TExp (Int -> Int -> Int))
-> LaunchConfig
launchConfig DeviceProperties
dev [Int]
candidates Int -> Int
dynamic_smem Int -> Int -> Int
grid_size Q (TExp (Int -> Int -> Int))
grid_sizeQ Int
maxThreads Int
registers Int
static_smem =
  let
      (Int
cta, Occupancy
occ)  = DeviceProperties
-> [Int] -> (Int -> Int) -> (Int -> Int) -> (Int, Occupancy)
optimalBlockSizeOf DeviceProperties
dev ((Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxThreads) [Int]
candidates) (Int -> Int -> Int
forall a b. a -> b -> a
const Int
registers) Int -> Int
smem
      maxGrid :: Int
maxGrid     = DeviceProperties -> Int
multiProcessorCount DeviceProperties
dev Int -> Int -> Int
forall a. Num a => a -> a -> a
* Occupancy -> Int
activeThreadBlocks Occupancy
occ
      grid :: Int -> Int
grid Int
n      = Int
maxGrid Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int -> Int -> Int
grid_size Int
n Int
cta
      smem :: Int -> Int
smem Int
n      = Int
static_smem Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int
dynamic_smem Int
n
      gridQ :: Q (TExp (Int -> Int))
gridQ       = [|| \n -> (maxGrid::Int) `min` $$grid_sizeQ (n::Int) (cta::Int) ||]
  in
  ( Occupancy
occ, Int
cta, Int -> Int
grid, Int -> Int
dynamic_smem Int
cta, Q (TExp (Int -> Int))
gridQ )


-- | The next highest multiple of 'y' from 'x'.
--
multipleOf :: Int -> Int -> Int
multipleOf :: Int -> Int -> Int
multipleOf Int
x Int
y = ((Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
y)

multipleOfQ :: Q (TExp (Int -> Int -> Int))
multipleOfQ :: Q (TExp (Int -> Int -> Int))
multipleOfQ = [|| multipleOf ||]