{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.StreamKernel
( segThreadCapped,
)
where
import Control.Monad
import Data.List ()
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.GPU hiding
( BasicOp,
Body,
Exp,
FParam,
FunDef,
LParam,
Lambda,
Pat,
PatElem,
Prog,
RetType,
Stm,
)
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Prelude hiding (quot)
data KernelSize = KernelSize
{
KernelSize -> SubExp
kernelElementsPerThread :: SubExp,
KernelSize -> SubExp
kernelNumThreads :: SubExp
}
deriving (KernelSize -> KernelSize -> Bool
(KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool) -> Eq KernelSize
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KernelSize -> KernelSize -> Bool
== :: KernelSize -> KernelSize -> Bool
$c/= :: KernelSize -> KernelSize -> Bool
/= :: KernelSize -> KernelSize -> Bool
Eq, Eq KernelSize
Eq KernelSize
-> (KernelSize -> KernelSize -> Ordering)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> KernelSize)
-> (KernelSize -> KernelSize -> KernelSize)
-> Ord KernelSize
KernelSize -> KernelSize -> Bool
KernelSize -> KernelSize -> Ordering
KernelSize -> KernelSize -> KernelSize
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: KernelSize -> KernelSize -> Ordering
compare :: KernelSize -> KernelSize -> Ordering
$c< :: KernelSize -> KernelSize -> Bool
< :: KernelSize -> KernelSize -> Bool
$c<= :: KernelSize -> KernelSize -> Bool
<= :: KernelSize -> KernelSize -> Bool
$c> :: KernelSize -> KernelSize -> Bool
> :: KernelSize -> KernelSize -> Bool
$c>= :: KernelSize -> KernelSize -> Bool
>= :: KernelSize -> KernelSize -> Bool
$cmax :: KernelSize -> KernelSize -> KernelSize
max :: KernelSize -> KernelSize -> KernelSize
$cmin :: KernelSize -> KernelSize -> KernelSize
min :: KernelSize -> KernelSize -> KernelSize
Ord, Int -> KernelSize -> ShowS
[KernelSize] -> ShowS
KernelSize -> String
(Int -> KernelSize -> ShowS)
-> (KernelSize -> String)
-> ([KernelSize] -> ShowS)
-> Show KernelSize
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KernelSize -> ShowS
showsPrec :: Int -> KernelSize -> ShowS
$cshow :: KernelSize -> String
show :: KernelSize -> String
$cshowList :: [KernelSize] -> ShowS
showList :: [KernelSize] -> ShowS
Show)
numberOfGroups ::
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String ->
SubExp ->
SubExp ->
m (SubExp, SubExp)
numberOfGroups :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size = do
Name
max_num_groups_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
prettyString (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_num_groups")
SubExp
num_groups <-
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_groups" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
SizeOp -> HostOp inner (Rep m)
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp inner (Rep m)) -> SizeOp -> HostOp inner (Rep m)
forall a b. (a -> b) -> a -> b
$
SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w Name
max_num_groups_key SubExp
group_size
SubExp
num_threads <-
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
num_groups SubExp
group_size
(SubExp, SubExp) -> m (SubExp, SubExp)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
num_groups, SubExp
num_threads)
segThreadCapped :: (MonadFreshNames m) => MkSegLevel GPU m
segThreadCapped :: forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp]
ws String
desc ThreadRecommendation
r = do
SubExp
w <-
String -> Exp (Rep (BuilderT GPU m)) -> BuilderT GPU m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"nest_size"
(Exp GPU -> BuilderT GPU m SubExp)
-> BuilderT GPU m (Exp GPU) -> BuilderT GPU m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
ws
SubExp
group_size <- String -> SizeClass -> BuilderT GPU m SubExp
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup
case ThreadRecommendation
r of
ThreadRecommendation
ManyThreads -> do
SubExp
usable_groups <-
String -> Exp (Rep (BuilderT GPU m)) -> BuilderT GPU m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"segmap_usable_groups"
(Exp GPU -> BuilderT GPU m SubExp)
-> BuilderT GPU m (Exp GPU) -> BuilderT GPU m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
-> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe)
(SubExp -> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w)
(SubExp -> BuilderT GPU m (Exp (Rep (BuilderT GPU m)))
SubExp -> BuilderT GPU m (Exp GPU)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> BuilderT GPU m (Exp GPU))
-> BuilderT GPU m SubExp -> BuilderT GPU m (Exp GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> SubExp -> BuilderT GPU m SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
group_size)
let grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumGroups SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
usable_groups) (SubExp -> Count GroupSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
group_size)
SegLevel -> BuilderT GPU m SegLevel
forall a. a -> BuilderT GPU m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegLevel -> BuilderT GPU m SegLevel)
-> SegLevel -> BuilderT GPU m SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
SegNoVirt (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid)
NoRecommendation SegVirt
v -> do
(SubExp
num_groups, SubExp
_) <- String -> SubExp -> SubExp -> BuilderT GPU m (SubExp, SubExp)
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w SubExp
group_size
let grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumGroups SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
group_size)
SegLevel -> BuilderT GPU m SegLevel
forall a. a -> BuilderT GPU m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegLevel -> BuilderT GPU m SegLevel)
-> SegLevel -> BuilderT GPU m SegLevel
forall a b. (a -> b) -> a -> b
$ SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
v (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid)