{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractKernels.StreamKernel
  ( segThreadCapped,
  )
where

import Control.Monad
import Data.List ()
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.GPU hiding
  ( BasicOp,
    Body,
    Exp,
    FParam,
    FunDef,
    LParam,
    Lambda,
    Pat,
    PatElem,
    Prog,
    RetType,
    Stm,
  )
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Prelude hiding (quot)

data KernelSize = KernelSize
  { -- | Int64
    KernelSize -> SubExp
kernelElementsPerThread :: SubExp,
    -- | Int32
    KernelSize -> SubExp
kernelNumThreads :: SubExp
  }
  deriving (KernelSize -> KernelSize -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelSize -> KernelSize -> Bool
$c/= :: KernelSize -> KernelSize -> Bool
== :: KernelSize -> KernelSize -> Bool
$c== :: KernelSize -> KernelSize -> Bool
Eq, Eq KernelSize
KernelSize -> KernelSize -> Bool
KernelSize -> KernelSize -> Ordering
KernelSize -> KernelSize -> KernelSize
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelSize -> KernelSize -> KernelSize
$cmin :: KernelSize -> KernelSize -> KernelSize
max :: KernelSize -> KernelSize -> KernelSize
$cmax :: KernelSize -> KernelSize -> KernelSize
>= :: KernelSize -> KernelSize -> Bool
$c>= :: KernelSize -> KernelSize -> Bool
> :: KernelSize -> KernelSize -> Bool
$c> :: KernelSize -> KernelSize -> Bool
<= :: KernelSize -> KernelSize -> Bool
$c<= :: KernelSize -> KernelSize -> Bool
< :: KernelSize -> KernelSize -> Bool
$c< :: KernelSize -> KernelSize -> Bool
compare :: KernelSize -> KernelSize -> Ordering
$ccompare :: KernelSize -> KernelSize -> Ordering
Ord, Int -> KernelSize -> ShowS
[KernelSize] -> ShowS
KernelSize -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelSize] -> ShowS
$cshowList :: [KernelSize] -> ShowS
show :: KernelSize -> String
$cshow :: KernelSize -> String
showsPrec :: Int -> KernelSize -> ShowS
$cshowsPrec :: Int -> KernelSize -> ShowS
Show)

numberOfGroups ::
  (MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
  String ->
  SubExp ->
  SubExp ->
  m (SubExp, SubExp)
numberOfGroups :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size = do
  Name
max_num_groups_key <- String -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> String
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String
desc forall a. [a] -> [a] -> [a]
++ String
"_num_groups")
  SubExp
num_groups <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_groups" forall a b. (a -> b) -> a -> b
$
      forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$
          SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w Name
max_num_groups_key SubExp
group_size
  SubExp
num_threads <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" forall a b. (a -> b) -> a -> b
$
      forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
num_groups SubExp
group_size
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
num_groups, SubExp
num_threads)

-- | Like 'segThread', but cap the thread count to the input size.
-- This is more efficient for small kernels, e.g. summing a small
-- array.
segThreadCapped :: MonadFreshNames m => MkSegLevel GPU m
segThreadCapped :: forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp]
ws String
desc ThreadRecommendation
r = do
  SubExp
w <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"nest_size"
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
ws
  SubExp
group_size <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SizeClass -> m SubExp
getSize (String
desc forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup

  case ThreadRecommendation
r of
    ThreadRecommendation
ManyThreads -> do
      SubExp
usable_groups <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"segmap_usable_groups"
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
            (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe)
            (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w)
            (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
group_size)
      let grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid (forall {k} (u :: k) e. e -> Count u e
Count SubExp
usable_groups) (forall {k} (u :: k) e. e -> Count u e
Count SubExp
group_size)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
SegNoVirt (forall a. a -> Maybe a
Just KernelGrid
grid)
    NoRecommendation SegVirt
v -> do
      (SubExp
num_groups, SubExp
_) <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size
      let grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid (forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_groups) (forall {k} (u :: k) e. e -> Count u e
Count SubExp
group_size)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
v (forall a. a -> Maybe a
Just KernelGrid
grid)