{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Pass.ExtractKernels.StreamKernel
( segThreadCapped
, streamRed
, streamMap
)
where
import Control.Monad
import Control.Monad.Writer
import Data.List ()
import Prelude hiding (quot)
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.Kernels
hiding (Prog, Body, Stm, Pattern, PatElem,
BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType)
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.ToKernels
import Futhark.MonadFreshNames
import Futhark.Tools
data KernelSize = KernelSize { KernelSize -> SubExp
kernelElementsPerThread :: SubExp
, KernelSize -> SubExp
kernelNumThreads :: SubExp
}
deriving (KernelSize -> KernelSize -> Bool
(KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool) -> Eq KernelSize
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernelSize -> KernelSize -> Bool
$c/= :: KernelSize -> KernelSize -> Bool
== :: KernelSize -> KernelSize -> Bool
$c== :: KernelSize -> KernelSize -> Bool
Eq, Eq KernelSize
Eq KernelSize
-> (KernelSize -> KernelSize -> Ordering)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> Bool)
-> (KernelSize -> KernelSize -> KernelSize)
-> (KernelSize -> KernelSize -> KernelSize)
-> Ord KernelSize
KernelSize -> KernelSize -> Bool
KernelSize -> KernelSize -> Ordering
KernelSize -> KernelSize -> KernelSize
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernelSize -> KernelSize -> KernelSize
$cmin :: KernelSize -> KernelSize -> KernelSize
max :: KernelSize -> KernelSize -> KernelSize
$cmax :: KernelSize -> KernelSize -> KernelSize
>= :: KernelSize -> KernelSize -> Bool
$c>= :: KernelSize -> KernelSize -> Bool
> :: KernelSize -> KernelSize -> Bool
$c> :: KernelSize -> KernelSize -> Bool
<= :: KernelSize -> KernelSize -> Bool
$c<= :: KernelSize -> KernelSize -> Bool
< :: KernelSize -> KernelSize -> Bool
$c< :: KernelSize -> KernelSize -> Bool
compare :: KernelSize -> KernelSize -> Ordering
$ccompare :: KernelSize -> KernelSize -> Ordering
$cp1Ord :: Eq KernelSize
Ord, Int -> KernelSize -> ShowS
[KernelSize] -> ShowS
KernelSize -> String
(Int -> KernelSize -> ShowS)
-> (KernelSize -> String)
-> ([KernelSize] -> ShowS)
-> Show KernelSize
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelSize] -> ShowS
$cshowList :: [KernelSize] -> ShowS
show :: KernelSize -> String
$cshow :: KernelSize -> String
showsPrec :: Int -> KernelSize -> ShowS
$cshowsPrec :: Int -> KernelSize -> ShowS
Show)
numberOfGroups :: (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups :: String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w64 SubExp
group_size = do
Name
max_num_groups_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_num_groups")
SubExp
num_groups <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp (Lore m) inner
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Lore m) inner)
-> SizeOp -> HostOp (Lore m) inner
forall a b. (a -> b) -> a -> b
$ SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w64 Name
max_num_groups_key SubExp
group_size
SubExp
num_threads <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef) SubExp
num_groups SubExp
group_size
(SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_groups, SubExp
num_threads)
blockedKernelSize :: (MonadBinder m, Lore m ~ Kernels) =>
String -> SubExp -> m KernelSize
blockedKernelSize :: String -> SubExp -> m KernelSize
blockedKernelSize String
desc SubExp
w = do
SubExp
group_size <- String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup
SubExp
w64 <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"w64" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64) SubExp
w
(SubExp
_, SubExp
num_threads) <- String -> SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w64 SubExp
group_size
SubExp
per_thread_elements <-
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"per_thread_elements" (ExpT Kernels -> m SubExp) -> m (ExpT Kernels) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w64) (SubExp -> m (ExpT Kernels)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (SubExp -> m (ExpT Kernels)) -> m SubExp -> m (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
num_threads)
KernelSize -> m KernelSize
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelSize -> m KernelSize) -> KernelSize -> m KernelSize
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> KernelSize
KernelSize SubExp
per_thread_elements SubExp
num_threads
splitArrays :: (MonadBinder m, Lore m ~ Kernels) =>
VName -> [VName]
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> [VName]
-> m ()
splitArrays :: VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays VName
chunk_size [VName]
split_bound SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i [VName]
arrs = do
[VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
chunk_size] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace SplitOrdering
ordering SubExp
w SubExp
i SubExp
elems_per_i
case SplitOrdering
ordering of
SplitOrdering
SplitContiguous -> do
SubExp
offset <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef) SubExp
i SubExp
elems_per_i
(VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset) [VName]
split_bound [VName]
arrs
SplitStrided SubExp
stride -> (VName -> VName -> m ()) -> [VName] -> [VName] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride) [VName]
split_bound [VName]
arrs
where contiguousSlice :: SubExp -> VName -> VName -> m ()
contiguousSlice SubExp
offset VName
slice_name VName
arr = do
Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
let slice :: Slice SubExp
slice = Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
offset (VName -> SubExp
Var VName
chunk_size) (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1::Int32))]
[VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
slice_name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice
stridedSlice :: SubExp -> VName -> VName -> m ()
stridedSlice SubExp
stride VName
slice_name VName
arr = do
Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
let slice :: Slice SubExp
slice = Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (VName -> SubExp
Var VName
chunk_size) SubExp
stride]
[VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
slice_name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice
partitionChunkedKernelFoldParameters :: Int -> [Param dec]
-> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters :: Int -> [Param dec] -> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters Int
num_accs (Param dec
i_param : Param dec
chunk_param : [Param dec]
params) =
let ([Param dec]
acc_params, [Param dec]
arr_params) = Int -> [Param dec] -> ([Param dec], [Param dec])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_accs [Param dec]
params
in (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
i_param, Param dec
chunk_param, [Param dec]
acc_params, [Param dec]
arr_params)
partitionChunkedKernelFoldParameters Int
_ [Param dec]
_ =
String -> (VName, Param dec, [Param dec], [Param dec])
forall a. HasCallStack => String -> a
error String
"partitionChunkedKernelFoldParameters: lambda takes too few parameters"
blockedPerThread :: (MonadBinder m, Lore m ~ Kernels) =>
VName -> SubExp -> KernelSize -> StreamOrd -> Lambda (Lore m)
-> Int -> [VName]
-> m ([PatElemT Type], [PatElemT Type])
blockedPerThread :: VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore m)
-> Int
-> [VName]
-> m ([PatElemT Type], [PatElemT Type])
blockedPerThread VName
thread_gtid SubExp
w KernelSize
kernel_size StreamOrd
ordering Lambda (Lore m)
lam Int
num_nonconcat [VName]
arrs = do
let (VName
_, Param Type
chunk_size, [], [Param Type]
arr_params) =
Int
-> [Param Type] -> (VName, Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (VName, Param dec, [Param dec], [Param dec])
partitionChunkedKernelFoldParameters Int
0 ([Param Type] -> (VName, Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (VName, Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
LambdaT Kernels
lam
ordering' :: SplitOrdering
ordering' =
case StreamOrd
ordering of StreamOrd
InOrder -> SplitOrdering
SplitContiguous
StreamOrd
Disorder -> SubExp -> SplitOrdering
SplitStrided (SubExp -> SplitOrdering) -> SubExp -> SplitOrdering
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelNumThreads KernelSize
kernel_size
red_ts :: [Type]
red_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
num_nonconcat ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
LambdaT Kernels
lam
map_ts :: [Type]
map_ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_nonconcat ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Lore m)
LambdaT Kernels
lam
SubExp
per_thread <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int32 (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ KernelSize -> SubExp
kernelElementsPerThread KernelSize
kernel_size
VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
VName
-> [VName]
-> SplitOrdering
-> SubExp
-> SubExp
-> SubExp
-> [VName]
-> m ()
splitArrays (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
arr_params) SplitOrdering
ordering' SubExp
w
(VName -> SubExp
Var VName
thread_gtid) SubExp
per_thread [VName]
arrs
[PatElemT Type]
chunk_red_pes <- [Type] -> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
red_ts ((Type -> m (PatElemT Type)) -> m [PatElemT Type])
-> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \Type
red_t -> do
VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_red"
PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT Type -> m (PatElemT Type))
-> PatElemT Type -> m (PatElemT Type)
forall a b. (a -> b) -> a -> b
$ VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
pe_name Type
red_t
[PatElemT Type]
chunk_map_pes <- [Type] -> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
map_ts ((Type -> m (PatElemT Type)) -> m [PatElemT Type])
-> (Type -> m (PatElemT Type)) -> m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \Type
map_t -> do
VName
pe_name <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_fold_map"
PatElemT Type -> m (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT Type -> m (PatElemT Type))
-> PatElemT Type -> m (PatElemT Type)
forall a b. (a -> b) -> a -> b
$ VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
pe_name (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ Type
map_t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk_size)
let ([SubExp]
chunk_red_ses, [SubExp]
chunk_map_ses) =
Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_nonconcat ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT Kernels -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT Kernels -> [SubExp]) -> BodyT Kernels -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
LambdaT Kernels
lam
Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore m) -> m ()) -> Stms (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
LambdaT Kernels
lam) Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>
[Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
[ Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT Type
pe,SubExp
se) <- [PatElemT Type] -> [SubExp] -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT Type]
chunk_red_pes [SubExp]
chunk_red_ses ] Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>
[Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
[ Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT Type
pe,SubExp
se) <- [PatElemT Type] -> [SubExp] -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT Type]
chunk_map_pes [SubExp]
chunk_map_ses ]
([PatElemT Type], [PatElemT Type])
-> m ([PatElemT Type], [PatElemT Type])
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT Type]
chunk_red_pes, [PatElemT Type]
chunk_map_pes)
kerneliseLambda :: MonadFreshNames m =>
[SubExp] -> Lambda Kernels -> m (Lambda Kernels)
kerneliseLambda :: [SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
kerneliseLambda [SubExp]
nes LambdaT Kernels
lam = do
VName
thread_index <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"thread_index"
let thread_index_param :: Param Type
thread_index_param = VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
thread_index (Type -> Param Type) -> Type -> Param Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32
(Param Type
fold_chunk_param, [Param Type]
fold_acc_params, [Param Type]
fold_inp_params) =
Int -> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Param Type] -> (Param Type, [Param Type], [Param Type]))
-> [Param Type] -> (Param Type, [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT Kernels
lam
mkAccInit :: Param dec -> SubExp -> Stm lore
mkAccInit Param dec
p (Var VName
v)
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param dec -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param dec
p =
[Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
p] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
mkAccInit Param dec
p SubExp
x = [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
p] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
x
acc_init_bnds :: Stms Kernels
acc_init_bnds = [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels) -> [Stm Kernels] -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp -> Stm Kernels)
-> [Param Type] -> [SubExp] -> [Stm Kernels]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param Type -> SubExp -> Stm Kernels
forall lore dec.
(Bindable lore, Typed dec) =>
Param dec -> SubExp -> Stm lore
mkAccInit [Param Type]
fold_acc_params [SubExp]
nes
LambdaT Kernels -> m (LambdaT Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return LambdaT Kernels
lam { lambdaBody :: BodyT Kernels
lambdaBody = Stms Kernels -> BodyT Kernels -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
insertStms Stms Kernels
acc_init_bnds (BodyT Kernels -> BodyT Kernels) -> BodyT Kernels -> BodyT Kernels
forall a b. (a -> b) -> a -> b
$
LambdaT Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT Kernels
lam
, lambdaParams :: [LParam Kernels]
lambdaParams = Param Type
thread_index_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
:
Param Type
fold_chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
:
[Param Type]
fold_inp_params
}
prepareStream :: (MonadBinder m, Lore m ~ Kernels) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> Lambda Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream :: KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = do
let (KernelSize SubExp
elems_per_thread SubExp
num_threads) = KernelSize
size
let (StreamOrd
ordering, SplitOrdering
split_ordering) =
case Commutativity
comm of Commutativity
Commutative -> (StreamOrd
Disorder, SubExp -> SplitOrdering
SplitStrided SubExp
num_threads)
Commutativity
Noncommutative -> (StreamOrd
InOrder, SplitOrdering
SplitContiguous)
LambdaT Kernels
fold_lam' <- [SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
forall (m :: * -> *).
MonadFreshNames m =>
[SubExp] -> LambdaT Kernels -> m (LambdaT Kernels)
kerneliseLambda [SubExp]
nes LambdaT Kernels
fold_lam
SubExp
elems_per_thread_32 <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int32 SubExp
elems_per_thread
VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> m SegSpace)
-> [(VName, SubExp)] -> m SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
num_threads)]
KernelBody Kernels
kbody <- (([KernelResult], Stms Kernels) -> KernelBody Kernels)
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms Kernels -> KernelBody Kernels)
-> ([KernelResult], Stms Kernels) -> KernelBody Kernels
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> Stms Kernels -> KernelBody Kernels
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ()))) (m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels))
-> m ([KernelResult], Stms Kernels) -> m (KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$ Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult] -> m ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> m ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$
Scope Kernels
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (Binder Kernels [KernelResult] -> Binder Kernels [KernelResult])
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
([PatElemT Type]
chunk_red_pes, [PatElemT Type]
chunk_map_pes) <-
VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore (BinderT Kernels (State VNameSource)))
-> Int
-> [VName]
-> BinderT
Kernels (State VNameSource) ([PatElemT Type], [PatElemT Type])
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
VName
-> SubExp
-> KernelSize
-> StreamOrd
-> Lambda (Lore m)
-> Int
-> [VName]
-> m ([PatElemT Type], [PatElemT Type])
blockedPerThread VName
gtid SubExp
w KernelSize
size StreamOrd
ordering Lambda (Lore (BinderT Kernels (State VNameSource)))
LambdaT Kernels
fold_lam' ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
arrs
let concatReturns :: PatElemT Type -> KernelResult
concatReturns PatElemT Type
pe =
SplitOrdering -> SubExp -> SubExp -> VName -> KernelResult
ConcatReturns SplitOrdering
split_ordering SubExp
w SubExp
elems_per_thread_32 (VName -> KernelResult) -> VName -> KernelResult
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
pe
[KernelResult] -> Binder Kernels [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT Type -> KernelResult)
-> [PatElemT Type] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify (SubExp -> KernelResult)
-> (PatElemT Type -> SubExp) -> PatElemT Type -> KernelResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElemT Type -> VName) -> PatElemT Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
chunk_red_pes [KernelResult] -> [KernelResult] -> [KernelResult]
forall a. [a] -> [a] -> [a]
++
(PatElemT Type -> KernelResult)
-> [PatElemT Type] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> KernelResult
concatReturns [PatElemT Type]
chunk_map_pes)
let ([Type]
redout_ts, [Type]
mapout_ts) = Int -> [Type] -> ([Type], [Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([Type] -> ([Type], [Type])) -> [Type] -> ([Type], [Type])
forall a b. (a -> b) -> a -> b
$ LambdaT Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT Kernels
fold_lam
ts :: [Type]
ts = [Type]
redout_ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [Type]
mapout_ts
(SubExp, SegSpace, [Type], KernelBody Kernels)
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_threads, SegSpace
space, [Type]
ts, KernelBody Kernels
kbody)
streamRed :: (MonadFreshNames m, HasScope Kernels m) =>
MkSegLevel Kernels m
-> Pattern Kernels
-> SubExp
-> Commutativity
-> Lambda Kernels -> Lambda Kernels
-> [SubExp] -> [VName]
-> m (Stms Kernels)
streamRed :: MkSegLevel Kernels m
-> Pattern Kernels
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (Stms Kernels)
streamRed MkSegLevel Kernels m
mk_lvl Pattern Kernels
pat SubExp
w Commutativity
comm LambdaT Kernels
red_lam LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = BinderT Kernels m () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
BinderT lore m a -> m (Stms lore)
runBinderT'_ (BinderT Kernels m () -> m (Stms Kernels))
-> BinderT Kernels m () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
KernelSize
size <- String -> SubExp -> BinderT Kernels m KernelSize
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_red" SubExp
w
let ([PatElemT Type]
redout_pes, [PatElemT Type]
mapout_pes) = Int -> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([PatElemT Type] -> ([PatElemT Type], [PatElemT Type]))
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
Pattern Kernels
pat
(PatternT Type
redout_pat, [(VName, SubExp)]
ispace, BinderT Kernels m ()
read_dummy) <- Pattern (Lore (BinderT Kernels m))
-> BinderT
Kernels
m
(Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
BinderT Kernels m ())
forall (m :: * -> *).
(MonadFreshNames m, MonadBinder m, DistLore (Lore m)) =>
Pattern (Lore m) -> m (Pattern (Lore m), [(VName, SubExp)], m ())
dummyDim (Pattern (Lore (BinderT Kernels m))
-> BinderT
Kernels
m
(Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
BinderT Kernels m ()))
-> Pattern (Lore (BinderT Kernels m))
-> BinderT
Kernels
m
(Pattern (Lore (BinderT Kernels m)), [(VName, SubExp)],
BinderT Kernels m ())
forall a b. (a -> b) -> a -> b
$ [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type]
redout_pes
let pat' :: PatternT Type
pat' = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
redout_pat [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. [a] -> [a] -> [a]
++ [PatElemT Type]
mapout_pes
(SubExp
_, SegSpace
kspace, [Type]
ts, KernelBody Kernels
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> BinderT Kernels m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [(VName, SubExp)]
ispace SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs
SegLevel
lvl <- MkSegLevel Kernels m
mk_lvl [SubExp
w] String
"stream_red" (ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels))
-> ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Pattern (Lore (BinderT Kernels m))
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT Kernels m))
pat' (Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ())
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp Kernels]
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed SegLevel
lvl SegSpace
kspace
[Commutativity
-> LambdaT Kernels -> [SubExp] -> Shape -> SegBinOp Kernels
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm LambdaT Kernels
red_lam [SubExp]
nes Shape
forall a. Monoid a => a
mempty] [Type]
ts KernelBody Kernels
kbody
BinderT Kernels m ()
read_dummy
streamMap :: (MonadFreshNames m, HasScope Kernels m) =>
MkSegLevel Kernels m
-> [String] -> [PatElem Kernels] -> SubExp
-> Commutativity -> Lambda Kernels -> [SubExp] -> [VName]
-> m ((SubExp, [VName]), Stms Kernels)
streamMap :: MkSegLevel Kernels m
-> [String]
-> [PatElem Kernels]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m ((SubExp, [VName]), Stms Kernels)
streamMap MkSegLevel Kernels m
mk_lvl [String]
out_desc [PatElem Kernels]
mapout_pes SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs = BinderT Kernels m (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' (BinderT Kernels m (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels))
-> BinderT Kernels m (SubExp, [VName])
-> m ((SubExp, [VName]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
KernelSize
size <- String -> SubExp -> BinderT Kernels m KernelSize
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
String -> SubExp -> m KernelSize
blockedKernelSize String
"stream_map" SubExp
w
(SubExp
threads, SegSpace
kspace, [Type]
ts, KernelBody Kernels
kbody) <- KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> BinderT Kernels m (SubExp, SegSpace, [Type], KernelBody Kernels)
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelSize
-> [(VName, SubExp)]
-> SubExp
-> Commutativity
-> LambdaT Kernels
-> [SubExp]
-> [VName]
-> m (SubExp, SegSpace, [Type], KernelBody Kernels)
prepareStream KernelSize
size [] SubExp
w Commutativity
comm LambdaT Kernels
fold_lam [SubExp]
nes [VName]
arrs
let redout_ts :: [Type]
redout_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Type]
ts
[PatElemT Type]
redout_pes <- [(String, Type)]
-> ((String, Type) -> BinderT Kernels m (PatElemT Type))
-> BinderT Kernels m [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([String] -> [Type] -> [(String, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [String]
out_desc [Type]
redout_ts) (((String, Type) -> BinderT Kernels m (PatElemT Type))
-> BinderT Kernels m [PatElemT Type])
-> ((String, Type) -> BinderT Kernels m (PatElemT Type))
-> BinderT Kernels m [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \(String
desc, Type
t) ->
VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem (VName -> Type -> PatElemT Type)
-> BinderT Kernels m VName
-> BinderT Kernels m (Type -> PatElemT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BinderT Kernels m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc BinderT Kernels m (Type -> PatElemT Type)
-> BinderT Kernels m Type -> BinderT Kernels m (PatElemT Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> BinderT Kernels m Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
threads)
let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [PatElemT Type]
redout_pes [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. [a] -> [a] -> [a]
++ [PatElemT Type]
[PatElem Kernels]
mapout_pes
SegLevel
lvl <- MkSegLevel Kernels m
mk_lvl [SubExp
w] String
"stream_map" (ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels))
-> ThreadRecommendation -> BinderT Kernels m (SegOpLevel Kernels)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Pattern (Lore (BinderT Kernels m))
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind PatternT Type
Pattern (Lore (BinderT Kernels m))
pat (Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ())
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
kspace [Type]
ts KernelBody Kernels
kbody
(SubExp, [VName]) -> BinderT Kernels m (SubExp, [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
threads, (PatElemT Type -> VName) -> [PatElemT Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT Type]
redout_pes)
segThreadCapped :: MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped :: MkSegLevel Kernels m
segThreadCapped [SubExp]
ws String
desc ThreadRecommendation
r = do
SubExp
w64 <- String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"nest_size" (ExpT Kernels -> BinderT Kernels m SubExp)
-> BinderT Kernels m (ExpT Kernels) -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
BinOp
-> SubExp
-> [SubExp]
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) ([SubExp] -> BinderT Kernels m (ExpT Kernels))
-> BinderT Kernels m [SubExp] -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
(SubExp -> BinderT Kernels m SubExp)
-> [SubExp] -> BinderT Kernels m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (IntType -> SubExp -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64) [SubExp]
ws
SubExp
group_size <- String -> SizeClass -> BinderT Kernels m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup
case ThreadRecommendation
r of
ThreadRecommendation
ManyThreads -> do
SubExp
usable_groups <- String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"segmap_usable_groups" (ExpT Kernels -> BinderT Kernels m SubExp)
-> (SubExp -> ExpT Kernels) -> SubExp -> BinderT Kernels m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels)
-> (SubExp -> BasicOp) -> SubExp -> ExpT Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int64 IntType
Int32) (SubExp -> BinderT Kernels m SubExp)
-> BinderT Kernels m SubExp -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
String
-> Exp (Lore (BinderT Kernels m)) -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"segmap_usable_groups_64" (ExpT Kernels -> BinderT Kernels m SubExp)
-> BinderT Kernels m (ExpT Kernels) -> BinderT Kernels m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
BinOp
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
-> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (IntType -> Safety -> BinOp
SDivUp IntType
Int64 Safety
Unsafe) (SubExp -> BinderT Kernels m (Exp (Lore (BinderT Kernels m)))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w64)
(SubExp -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> BinderT Kernels m (ExpT Kernels))
-> BinderT Kernels m SubExp -> BinderT Kernels m (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> SubExp -> BinderT Kernels m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
group_size)
SegLevel -> BinderT Kernels m SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return (SegLevel -> BinderT Kernels m SegLevel)
-> SegLevel -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
usable_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
SegNoVirt
NoRecommendation SegVirt
v -> do
(SubExp
num_groups, SubExp
_) <- String -> SubExp -> SubExp -> BinderT Kernels m (SubExp, SubExp)
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SubExp -> SubExp -> m (SubExp, SubExp)
numberOfGroups String
desc SubExp
w64 SubExp
group_size
SegLevel -> BinderT Kernels m SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return (SegLevel -> BinderT Kernels m SegLevel)
-> SegLevel -> BinderT Kernels m SegLevel
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count SubExp
group_size) SegVirt
v