{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.Intragroup (intraGroupParallelise) where
import Control.Monad.Identity
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import Data.Map.Strict qualified as M
import Data.Set qualified as S
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.GPU hiding (HistOp)
import Futhark.IR.GPU.Op qualified as GPU
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Futhark.Transform.FirstOrderTransform qualified as FOT
import Futhark.Util.Log
import Prelude hiding (log)
intraGroupParallelise ::
(MonadFreshNames m, LocalScope GPU m) =>
KernelNest ->
Lambda SOACS ->
m
( Maybe
( (SubExp, SubExp),
SubExp,
Log,
Stms GPU,
Stms GPU
)
)
intraGroupParallelise :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope GPU m) =>
KernelNest
-> Lambda SOACS
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
intraGroupParallelise KernelNest
knest Lambda SOACS
lam = forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT forall a b. (a -> b) -> a -> b
$ do
([(VName, SubExp)]
ispace, [KernelInput]
inps) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
knest
(SubExp
num_groups, Stms GPU
w_stms) <-
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"intra_num_groups"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace)
let body :: Body SOACS
body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
VName
group_size <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"computed_group_size"
([[SubExp]]
wss_min, [[SubExp]]
wss_avail, Log
log, KernelBody GPU
kbody) <-
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Body SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody Body SOACS
body
Scope GPU
outside_scope <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
let available :: VName -> Bool
available VName
v =
VName
v forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope GPU
outside_scope
Bool -> Bool -> Bool
&& VName
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> VName
kernelInputName [KernelInput]
inps
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
available forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn ([[SubExp]]
wss_min forall a. [a] -> [a] -> [a]
++ [[SubExp]]
wss_avail)) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Irregular parallelism"
((SubExp
intra_avail_par, SegSpace
kspace, Stms GPU
read_input_stms), Stms GPU
prelude_stms) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
let foldBinOp' :: BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' BinOp
_ [] = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
foldBinOp' BinOp
bop (SubExp
x : [SubExp]
xs) = forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
bop SubExp
x [SubExp]
xs
[SubExp]
ws_min <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"one_intra_par_min" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_min
[SubExp]
ws_avail <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"one_intra_par_avail" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_avail
SubExp
intra_avail_par <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"intra_avail_par" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> BinOp
SMin IntType
Int64) [SubExp]
ws_avail
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
group_size]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
ws_min
then
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> BinOp
SMin IntType
Int64)
(forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"max_group_size" (forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
GetSizeMax SizeClass
SizeGroup))
(forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
intra_avail_par)
else forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> BinOp
SMax IntType
Int64) [SubExp]
ws_min
let inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn Body SOACS
body
used_inps :: [KernelInput]
used_inps = forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
w_stms
Stms GPU
read_input_stms <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readGroupKernelInput [KernelInput]
used_inps
SegSpace
space <- forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
intra_avail_par, SegSpace
space, Stms GPU
read_input_stms)
let kbody' :: KernelBody GPU
kbody' = KernelBody GPU
kbody {kernelBodyStms :: Stms GPU
kernelBodyStms = Stms GPU
read_input_stms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPU
kbody}
let nested_pat :: Pat Type
nested_pat = LoopNesting -> Pat Type
loopNestingPat LoopNesting
first_nest
rts :: [Type]
rts = forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace `stripArray`) forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
nested_pat
grid :: KernelGrid
grid = Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid (forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_groups) (forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size)
lvl :: SegLevel
lvl = SegVirt -> Maybe KernelGrid -> SegLevel
SegGroup SegVirt
SegNoVirt (forall a. a -> Maybe a
Just KernelGrid
grid)
kstm :: Stm GPU
kstm =
forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
nested_pat StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
kspace [Type]
rts KernelBody GPU
kbody'
let intra_min_par :: SubExp
intra_min_par = SubExp
intra_avail_par
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( (SubExp
intra_min_par, SubExp
intra_avail_par),
VName -> SubExp
Var VName
group_size,
Log
log,
Stms GPU
prelude_stms,
forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm GPU
kstm
)
where
first_nest :: LoopNesting
first_nest = forall a b. (a, b) -> a
fst KernelNest
knest
aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
first_nest
readGroupKernelInput ::
(DistRep (Rep m), MonadBuilder m) =>
KernelInput ->
m ()
readGroupKernelInput :: forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readGroupKernelInput KernelInput
inp
| Array {} <- KernelInput -> Type
kernelInputType KernelInput
inp = do
VName
v <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputName KernelInput
inp
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp {kernelInputName :: VName
kernelInputName = VName
v}
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [KernelInput -> VName
kernelInputName KernelInput
inp] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
| Bool
otherwise =
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp
data IntraAcc = IntraAcc
{ IntraAcc -> Set [SubExp]
accMinPar :: S.Set [SubExp],
IntraAcc -> Set [SubExp]
accAvailPar :: S.Set [SubExp],
IntraAcc -> Log
accLog :: Log
}
instance Semigroup IntraAcc where
IntraAcc Set [SubExp]
min_x Set [SubExp]
avail_x Log
log_x <> :: IntraAcc -> IntraAcc -> IntraAcc
<> IntraAcc Set [SubExp]
min_y Set [SubExp]
avail_y Log
log_y =
Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (Set [SubExp]
min_x forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
min_y) (Set [SubExp]
avail_x forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
avail_y) (Log
log_x forall a. Semigroup a => a -> a -> a
<> Log
log_y)
instance Monoid IntraAcc where
mempty :: IntraAcc
mempty = Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
type IntraGroupM =
BuilderT GPU (RWS () IntraAcc VNameSource)
instance MonadLogger IntraGroupM where
addLog :: Log -> IntraGroupM ()
addLog Log
log = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a. Monoid a => a
mempty {accLog :: Log
accLog = Log
log}
runIntraGroupM ::
(MonadFreshNames m, HasScope GPU m) =>
IntraGroupM () ->
m (IntraAcc, Stms GPU)
runIntraGroupM :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntraGroupM () -> m (IntraAcc, Stms GPU)
runIntraGroupM IntraGroupM ()
m = do
Scope GPU
scope <- forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
let (((), Stms GPU
kstms), VNameSource
src', IntraAcc
acc) = forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT IntraGroupM ()
m Scope GPU
scope) () VNameSource
src
in ((IntraAcc
acc, Stms GPU
kstms), VNameSource
src')
parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
ws =
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
forall a. Monoid a => a
mempty
{ accMinPar :: Set [SubExp]
accMinPar = forall a. a -> Set a
S.singleton [SubExp]
ws,
accAvailPar :: Set [SubExp]
accAvailPar = forall a. a -> Set a
S.singleton [SubExp]
ws
}
intraGroupBody :: Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody :: Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody Body SOACS
body = do
Stms GPU
stms <- forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ Stms SOACS -> IntraGroupM ()
intraGroupStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
stms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body
intraGroupStm :: Stm SOACS -> IntraGroupM ()
intraGroupStm :: Stm SOACS -> IntraGroupM ()
intraGroupStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = do
Scope GPU
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
let lvl :: SegLevel
lvl = SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
SegNoVirt forall a. Maybe a
Nothing
case Exp SOACS
e of
DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
loopbody ->
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm GPU
form') forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
merge) forall a b. (a -> b) -> a -> b
$ do
Body GPU
loopbody' <- Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody Body SOACS
loopbody
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec SOACS)
pat forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
merge LoopForm GPU
form' Body GPU
loopbody'
where
form' :: LoopForm GPU
form' = case LoopForm SOACS
form of
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
inps -> forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
inps
WhileLoop VName
cond -> forall {k} (rep :: k). VName -> LoopForm rep
WhileLoop VName
cond
Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ifdec -> do
[Case (Body GPU)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody) [Case (Body SOACS)]
cases
Body GPU
defbody' <- Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody Body SOACS
defbody
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec SOACS)
pat forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body GPU)]
cases' Body GPU
defbody' MatchDec (BranchType SOACS)
ifdec
Op Op SOACS
soac
| Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec SOACS)
aux ->
Stms SOACS -> IntraGroupM ()
intraGroupStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec SOACS)
pat Op SOACS
soac)
Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
| Just Lambda SOACS
lam <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form -> do
let loopnest :: LoopNesting
loopnest = Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux SubExp
w forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
env :: DistEnv GPU IntraGroupM
env =
DistEnv
{ distNest :: Nestings
distNest =
Nesting -> Nestings
singleNesting forall a b. (a -> b) -> a -> b
$ Names -> LoopNesting -> Nesting
Nesting forall a. Monoid a => a
mempty LoopNesting
loopnest,
distScope :: Scope GPU
distScope =
forall {k} (rep :: k) dec.
(LetDec rep ~ dec) =>
Pat dec -> Scope rep
scopeOfPat Pat (LetDec SOACS)
pat
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope GPU
scopeForGPU (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam)
forall a. Semigroup a => a -> a -> a
<> Scope GPU
scope,
distOnInnerMap :: MapLoop -> DistAcc GPU -> DistNestT GPU IntraGroupM (DistAcc GPU)
distOnInnerMap =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap,
distOnTopLevelStms :: Stms SOACS -> DistNestT GPU IntraGroupM (Stms GPU)
distOnTopLevelStms =
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> IntraGroupM ()
intraGroupStms,
distSegLevel :: MkSegLevel GPU IntraGroupM
distSegLevel = \[SubExp]
minw String
_ ThreadRecommendation
_ -> do
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
minw
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegLevel
lvl,
distOnSOACSStms :: Stm SOACS -> BuilderT GPU (State VNameSource) (Stms GPU)
distOnSOACSStms =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Stm GPU
soacsStmToGPU,
distOnSOACSLambda :: Lambda SOACS -> Builder GPU (Lambda GPU)
distOnSOACSLambda =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
}
acc :: DistAcc GPU
acc =
DistAcc
{ distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pat (LetDec SOACS)
pat, forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam),
distStms :: Stms GPU
distStms = forall a. Monoid a => a
mempty
}
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv GPU IntraGroupM
env (forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc GPU
acc (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam))
Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
| Just ([Scan SOACS]
scans, Lambda SOACS
mapfun) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form,
Scan Lambda SOACS
scanfun [SubExp]
nes <- forall {k} (rep :: k). Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans -> do
let scanfun' :: Lambda GPU
scanfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
scanfun
mapfun' :: Lambda GPU
mapfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
mapfun
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegLevel
lvl Pat (LetDec SOACS)
pat forall a. Monoid a => a
mempty SubExp
w [forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda GPU
scanfun' [SubExp]
nes forall a. Monoid a => a
mempty] Lambda GPU
mapfun' [VName]
arrs [] []
[SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
| Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
nes <- forall {k} (rep :: k). Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce SOACS]
reds -> do
let red_lam' :: Lambda GPU
red_lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
red_lam
map_lam' :: Lambda GPU
map_lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegLevel
lvl Pat (LetDec SOACS)
pat forall a. Monoid a => a
mempty SubExp
w [forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda GPU
red_lam' [SubExp]
nes forall a. Monoid a => a
mempty] Lambda GPU
map_lam' [VName]
arrs [] []
[SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
Op (Hist SubExp
w [VName]
arrs [HistOp SOACS]
ops Lambda SOACS
bucket_fun) -> do
[HistOp GPU]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) -> do
(Lambda SOACS
op', [SubExp]
nes', Shape
shape) <- forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
let op'' :: Lambda GPU
op'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
op'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
GPU.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda GPU
op''
let bucket_fun' :: Lambda GPU
bucket_fun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
bucket_fun
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep
-> Pat Type
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
segHist SegLevel
lvl Pat (LetDec SOACS)
pat SubExp
w [] [] [HistOp GPU]
ops' Lambda GPU
bucket_fun' [VName]
arrs
[SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
Op (Stream SubExp
w [VName]
arrs [SubExp]
accs Lambda SOACS
lam)
| LParam SOACS
chunk_size_param : [LParam SOACS]
_ <- forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam -> do
Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope
((), Stms SOACS
stream_stms) <-
forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (LetDec SOACS)
pat SubExp
w [SubExp]
accs Lambda SOACS
lam [VName]
arrs) Scope SOACS
types
let replace :: SubExp -> SubExp
replace (Var VName
v) | VName
v forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName LParam SOACS
chunk_size_param = SubExp
w
replace SubExp
se = SubExp
se
replaceSets :: IntraAcc -> IntraAcc
replaceSets (IntraAcc Set [SubExp]
x Set [SubExp]
y Log
log) =
Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
x) (forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
y) Log
log
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor IntraAcc -> IntraAcc
replaceSets forall a b. (a -> b) -> a -> b
$ Stms SOACS -> IntraGroupM ()
intraGroupStms Stms SOACS
stream_stms
Op (Scatter SubExp
w [VName]
ivs Lambda SOACS
lam [(Shape, Int, VName)]
dests) -> do
VName
write_i <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
SegSpace
space <- forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName
write_i, SubExp
w)]
let lam' :: Lambda GPU
lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
([Shape]
dests_ws, [Int]
_, [VName]
_) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
krets :: [KernelResult]
krets = do
(Shape
a_w, VName
a, [(Result, SubExpRes)]
is_vs) <-
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'
let cs :: Certs
cs =
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
is_vs' :: [(Slice SubExp, SubExp)]
is_vs' = [(forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v) | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
cs Shape
a_w VName
a [(Slice SubExp, SubExp)]
is_vs'
inputs :: [KernelInput]
inputs = do
(Param Type
p, VName
p_a) <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam') [VName]
ivs
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (forall dec. Param dec -> VName
paramName Param Type
p) (forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]
Stms GPU
kstms <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall a b. (a -> b) -> a -> b
$ do
let ts :: [Type]
ts = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Int
length) [Shape]
dests_ws forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec SOACS)
pat
body :: KernelBody GPU
body = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms [KernelResult]
krets
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec SOACS)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody GPU
body
[SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
Exp SOACS
_ ->
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stm GPU
soacsStmToGPU Stm SOACS
stm
intraGroupStms :: Stms SOACS -> IntraGroupM ()
intraGroupStms :: Stms SOACS -> IntraGroupM ()
intraGroupStms = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> IntraGroupM ()
intraGroupStm
intraGroupParalleliseBody ::
(MonadFreshNames m, HasScope GPU m) =>
Body SOACS ->
m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Body SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody Body SOACS
body = do
(IntraAcc Set [SubExp]
min_ws Set [SubExp]
avail_ws Log
log, Stms GPU
kstms) <-
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntraGroupM () -> m (IntraAcc, Stms GPU)
runIntraGroupM forall a b. (a -> b) -> a -> b
$ Stms SOACS -> IntraGroupM ()
intraGroupStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall a. Set a -> [a]
S.toList Set [SubExp]
min_ws,
forall a. Set a -> [a]
S.toList Set [SubExp]
avail_ws,
Log
log,
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body
)
where
ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se