{-# LANGUAGE TypeFamilies #-}

-- | Extract limited nested parallelism for execution inside
-- individual kernel workgroups.
module Futhark.Pass.ExtractKernels.Intragroup (intraGroupParallelise) where

import Control.Monad.Identity
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import Data.Map.Strict qualified as M
import Data.Set qualified as S
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.GPU hiding (HistOp)
import Futhark.IR.GPU.Op qualified as GPU
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Futhark.Transform.FirstOrderTransform qualified as FOT
import Futhark.Util.Log
import Prelude hiding (log)

-- | Convert the statements inside a map nest to kernel statements,
-- attempting to parallelise any remaining (top-level) parallel
-- statements.  Anything that is not a map, scan or reduction will
-- simply be sequentialised.  This includes sequential loops that
-- contain maps, scans or reduction.  In the future, we could probably
-- do something more clever.  Make sure that the amount of parallelism
-- to be exploited does not exceed the group size.  Further, as a hack
-- we also consider the size of all intermediate arrays as
-- "parallelism to be exploited" to avoid exploding local memory.
--
-- We distinguish between "minimum group size" and "maximum
-- exploitable parallelism".
intraGroupParallelise ::
  (MonadFreshNames m, LocalScope GPU m) =>
  KernelNest ->
  Lambda SOACS ->
  m
    ( Maybe
        ( (SubExp, SubExp),
          SubExp,
          Log,
          Stms GPU,
          Stms GPU
        )
    )
intraGroupParallelise :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope GPU m) =>
KernelNest
-> Lambda SOACS
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
intraGroupParallelise KernelNest
knest Lambda SOACS
lam = forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT forall a b. (a -> b) -> a -> b
$ do
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
knest

  (SubExp
num_groups, Stms GPU
w_stms) <-
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
      forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"intra_num_groups"
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace)

  let body :: Body SOACS
body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam

  VName
group_size <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"computed_group_size"
  let intra_lvl :: SegLevel
intra_lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_groups) (forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt

  ([[SubExp]]
wss_min, [[SubExp]]
wss_avail, Log
log, KernelBody GPU
kbody) <-
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
SegLevel
-> Body SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody SegLevel
intra_lvl Body SOACS
body

  Scope GPU
outside_scope <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  -- outside_scope may also contain the inputs, even though those are
  -- not actually available outside the kernel.
  let available :: VName -> Bool
available VName
v =
        VName
v forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope GPU
outside_scope
          Bool -> Bool -> Bool
&& VName
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> VName
kernelInputName [KernelInput]
inps
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
available forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn ([[SubExp]]
wss_min forall a. [a] -> [a] -> [a]
++ [[SubExp]]
wss_avail)) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Irregular parallelism"

  ((SubExp
intra_avail_par, SegSpace
kspace, Stms GPU
read_input_stms), Stms GPU
prelude_stms) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$
    forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
      let foldBinOp' :: BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' BinOp
_ [] = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
          foldBinOp' BinOp
bop (SubExp
x : [SubExp]
xs) = forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
bop SubExp
x [SubExp]
xs
      [SubExp]
ws_min <-
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"one_intra_par_min" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) forall a b. (a -> b) -> a -> b
$
          forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_min
      [SubExp]
ws_avail <-
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"one_intra_par_avail" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) forall a b. (a -> b) -> a -> b
$
          forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_avail

      -- The amount of parallelism available *in the worst case* is
      -- equal to the smallest parallel loop, or *at least* 1.
      SubExp
intra_avail_par <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"intra_avail_par" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> BinOp
SMin IntType
Int64) [SubExp]
ws_avail

      -- The group size is either the maximum of the minimum parallelism
      -- exploited, or the desired parallelism (bounded by the max group
      -- size) in case there is no minimum.
      forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
group_size]
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
ws_min
          then
            forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              (IntType -> BinOp
SMin IntType
Int64)
              (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"max_group_size" (forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
GetSizeMax SizeClass
SizeGroup))
              (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
intra_avail_par)
          else forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> BinOp
SMax IntType
Int64) [SubExp]
ws_min

      let inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn Body SOACS
body
          used_inps :: [KernelInput]
used_inps = forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps

      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
w_stms
      Stms GPU
read_input_stms <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readGroupKernelInput [KernelInput]
used_inps
      SegSpace
space <- forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
intra_avail_par, SegSpace
space, Stms GPU
read_input_stms)

  let kbody' :: KernelBody GPU
kbody' = KernelBody GPU
kbody {kernelBodyStms :: Stms GPU
kernelBodyStms = Stms GPU
read_input_stms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPU
kbody}

  let nested_pat :: Pat Type
nested_pat = LoopNesting -> Pat Type
loopNestingPat LoopNesting
first_nest
      rts :: [Type]
rts = forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace `stripArray`) forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
nested_pat
      lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_groups) (forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt
      kstm :: Stm GPU
kstm =
        forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
nested_pat StmAux ()
aux forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$
              forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
kspace [Type]
rts KernelBody GPU
kbody'

  let intra_min_par :: SubExp
intra_min_par = SubExp
intra_avail_par
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( (SubExp
intra_min_par, SubExp
intra_avail_par),
      VName -> SubExp
Var VName
group_size,
      Log
log,
      Stms GPU
prelude_stms,
      forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm GPU
kstm
    )
  where
    first_nest :: LoopNesting
first_nest = forall a b. (a, b) -> a
fst KernelNest
knest
    aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
first_nest

readGroupKernelInput ::
  (DistRep (Rep m), MonadBuilder m) =>
  KernelInput ->
  m ()
readGroupKernelInput :: forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readGroupKernelInput KernelInput
inp
  | Array {} <- KernelInput -> Type
kernelInputType KernelInput
inp = do
      VName
v <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputName KernelInput
inp
      forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp {kernelInputName :: VName
kernelInputName = VName
v}
      forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [KernelInput -> VName
kernelInputName KernelInput
inp] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
  | Bool
otherwise =
      forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp

data IntraAcc = IntraAcc
  { IntraAcc -> Set [SubExp]
accMinPar :: S.Set [SubExp],
    IntraAcc -> Set [SubExp]
accAvailPar :: S.Set [SubExp],
    IntraAcc -> Log
accLog :: Log
  }

instance Semigroup IntraAcc where
  IntraAcc Set [SubExp]
min_x Set [SubExp]
avail_x Log
log_x <> :: IntraAcc -> IntraAcc -> IntraAcc
<> IntraAcc Set [SubExp]
min_y Set [SubExp]
avail_y Log
log_y =
    Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (Set [SubExp]
min_x forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
min_y) (Set [SubExp]
avail_x forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
avail_y) (Log
log_x forall a. Semigroup a => a -> a -> a
<> Log
log_y)

instance Monoid IntraAcc where
  mempty :: IntraAcc
mempty = Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty

type IntraGroupM =
  BuilderT GPU (RWS () IntraAcc VNameSource)

instance MonadLogger IntraGroupM where
  addLog :: Log -> IntraGroupM ()
addLog Log
log = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a. Monoid a => a
mempty {accLog :: Log
accLog = Log
log}

runIntraGroupM ::
  (MonadFreshNames m, HasScope GPU m) =>
  IntraGroupM () ->
  m (IntraAcc, Stms GPU)
runIntraGroupM :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntraGroupM () -> m (IntraAcc, Stms GPU)
runIntraGroupM IntraGroupM ()
m = do
  Scope GPU
scope <- forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let (((), Stms GPU
kstms), VNameSource
src', IntraAcc
acc) = forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT IntraGroupM ()
m Scope GPU
scope) () VNameSource
src
     in ((IntraAcc
acc, Stms GPU
kstms), VNameSource
src')

parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
ws =
  forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
    forall a. Monoid a => a
mempty
      { accMinPar :: Set [SubExp]
accMinPar = forall a. a -> Set a
S.singleton [SubExp]
ws,
        accAvailPar :: Set [SubExp]
accAvailPar = forall a. a -> Set a
S.singleton [SubExp]
ws
      }

intraGroupBody :: SegLevel -> Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody :: SegLevel -> Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl Body SOACS
body = do
  Stms GPU
stms <- forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
stms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body

intraGroupStm :: SegLevel -> Stm SOACS -> IntraGroupM ()
intraGroupStm :: SegLevel -> Stm SOACS -> IntraGroupM ()
intraGroupStm SegLevel
lvl stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = do
  Scope GPU
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  let lvl' :: SegLevel
lvl' = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegVirt
SegNoVirt

  case Exp SOACS
e of
    DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
loopbody ->
      forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm GPU
form') forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
merge) forall a b. (a -> b) -> a -> b
$ do
          Body GPU
loopbody' <- SegLevel -> Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl Body SOACS
loopbody
          forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec SOACS)
pat forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
merge LoopForm GPU
form' Body GPU
loopbody'
      where
        form' :: LoopForm GPU
form' = case LoopForm SOACS
form of
          ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
inps -> forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
inps
          WhileLoop VName
cond -> forall {k} (rep :: k). VName -> LoopForm rep
WhileLoop VName
cond
    Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ifdec -> do
      [Case (Body GPU)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ SegLevel -> Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl) [Case (Body SOACS)]
cases
      Body GPU
defbody' <- SegLevel -> Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl Body SOACS
defbody
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec SOACS)
pat forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body GPU)]
cases' Body GPU
defbody' MatchDec (BranchType SOACS)
ifdec
    Op Op SOACS
soac
      | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec SOACS)
aux ->
          SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux))
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec SOACS)
pat Op SOACS
soac)
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just Lambda SOACS
lam <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form -> do
          let loopnest :: LoopNesting
loopnest = Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux SubExp
w forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
              env :: DistEnv GPU IntraGroupM
env =
                DistEnv
                  { distNest :: Nestings
distNest =
                      Nesting -> Nestings
singleNesting forall a b. (a -> b) -> a -> b
$ Names -> LoopNesting -> Nesting
Nesting forall a. Monoid a => a
mempty LoopNesting
loopnest,
                    distScope :: Scope GPU
distScope =
                      forall {k} (rep :: k) dec.
(LetDec rep ~ dec) =>
Pat dec -> Scope rep
scopeOfPat Pat (LetDec SOACS)
pat
                        forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope GPU
scopeForGPU (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam)
                        forall a. Semigroup a => a -> a -> a
<> Scope GPU
scope,
                    distOnInnerMap :: MapLoop -> DistAcc GPU -> DistNestT GPU IntraGroupM (DistAcc GPU)
distOnInnerMap =
                      forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap,
                    distOnTopLevelStms :: Stms SOACS -> DistNestT GPU IntraGroupM (Stms GPU)
distOnTopLevelStms =
                      forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl,
                    distSegLevel :: MkSegLevel GPU IntraGroupM
distSegLevel = \[SubExp]
minw String
_ ThreadRecommendation
_ -> do
                      forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
minw
                      forall (f :: * -> *) a. Applicative f => a -> f a
pure SegLevel
lvl,
                    distOnSOACSStms :: Stm SOACS -> BuilderT GPU (State VNameSource) (Stms GPU)
distOnSOACSStms =
                      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Stm GPU
soacsStmToGPU,
                    distOnSOACSLambda :: Lambda SOACS -> Builder GPU (Lambda GPU)
distOnSOACSLambda =
                      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
                  }
              acc :: DistAcc GPU
acc =
                DistAcc
                  { distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pat (LetDec SOACS)
pat, forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam),
                    distStms :: Stms GPU
distStms = forall a. Monoid a => a
mempty
                  }

          forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv GPU IntraGroupM
env (forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc GPU
acc (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam))
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just ([Scan SOACS]
scans, Lambda SOACS
mapfun) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form,
        Scan Lambda SOACS
scanfun [SubExp]
nes <- forall {k} (rep :: k). Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans -> do
          let scanfun' :: Lambda GPU
scanfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
scanfun
              mapfun' :: Lambda GPU
mapfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
mapfun
          forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegLevel
lvl' Pat (LetDec SOACS)
pat forall a. Monoid a => a
mempty SubExp
w [forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda GPU
scanfun' [SubExp]
nes forall a. Monoid a => a
mempty] Lambda GPU
mapfun' [VName]
arrs [] []
          [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
        Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
nes <- forall {k} (rep :: k). Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce SOACS]
reds -> do
          let red_lam' :: Lambda GPU
red_lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
red_lam
              map_lam' :: Lambda GPU
map_lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
          forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegLevel
lvl' Pat (LetDec SOACS)
pat forall a. Monoid a => a
mempty SubExp
w [forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda GPU
red_lam' [SubExp]
nes forall a. Monoid a => a
mempty] Lambda GPU
map_lam' [VName]
arrs [] []
          [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Hist SubExp
w [VName]
arrs [HistOp SOACS]
ops Lambda SOACS
bucket_fun) -> do
      [HistOp GPU]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) -> do
        (Lambda SOACS
op', [SubExp]
nes', Shape
shape) <- forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
        let op'' :: Lambda GPU
op'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
op'
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
GPU.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda GPU
op''

      let bucket_fun' :: Lambda GPU
bucket_fun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
bucket_fun
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep
-> Pat Type
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
segHist SegLevel
lvl' Pat (LetDec SOACS)
pat SubExp
w [] [] [HistOp GPU]
ops' Lambda GPU
bucket_fun' [VName]
arrs
      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Stream SubExp
w [VName]
arrs [SubExp]
accs Lambda SOACS
lam)
      | LParam SOACS
chunk_size_param : [LParam SOACS]
_ <- forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam -> do
          Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope
          ((), Stms SOACS
stream_stms) <-
            forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (LetDec SOACS)
pat SubExp
w [SubExp]
accs Lambda SOACS
lam [VName]
arrs) Scope SOACS
types
          let replace :: SubExp -> SubExp
replace (Var VName
v) | VName
v forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName LParam SOACS
chunk_size_param = SubExp
w
              replace SubExp
se = SubExp
se
              replaceSets :: IntraAcc -> IntraAcc
replaceSets (IntraAcc Set [SubExp]
x Set [SubExp]
y Log
log) =
                Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
x) (forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
y) Log
log
          forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor IntraAcc -> IntraAcc
replaceSets forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl Stms SOACS
stream_stms
    Op (Scatter SubExp
w [VName]
ivs Lambda SOACS
lam [(Shape, Int, VName)]
dests) -> do
      VName
write_i <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
      SegSpace
space <- forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName
write_i, SubExp
w)]

      let lam' :: Lambda GPU
lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
          ([Shape]
dests_ws, [Int]
_, [VName]
_) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
          krets :: [KernelResult]
krets = do
            (Shape
a_w, VName
a, [(Result, SubExpRes)]
is_vs) <-
              forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'
            let cs :: Certs
cs =
                  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
                    forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
                is_vs' :: [(Slice SubExp, SubExp)]
is_vs' = [(forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v) | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs]
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
cs Shape
a_w VName
a [(Slice SubExp, SubExp)]
is_vs'
          inputs :: [KernelInput]
inputs = do
            (Param Type
p, VName
p_a) <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam') [VName]
ivs
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (forall dec. Param dec -> VName
paramName Param Type
p) (forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]

      Stms GPU
kstms <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs
          forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'

      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux) forall a b. (a -> b) -> a -> b
$ do
        let ts :: [Type]
ts = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Int
length) [Shape]
dests_ws forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec SOACS)
pat
            body :: KernelBody GPU
body = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms [KernelResult]
krets
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec SOACS)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl' SegSpace
space [Type]
ts KernelBody GPU
body

      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Exp SOACS
_ ->
      forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stm GPU
soacsStmToGPU Stm SOACS
stm

intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SegLevel -> Stm SOACS -> IntraGroupM ()
intraGroupStm SegLevel
lvl)

intraGroupParalleliseBody ::
  (MonadFreshNames m, HasScope GPU m) =>
  SegLevel ->
  Body SOACS ->
  m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
SegLevel
-> Body SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody SegLevel
lvl Body SOACS
body = do
  (IntraAcc Set [SubExp]
min_ws Set [SubExp]
avail_ws Log
log, Stms GPU
kstms) <-
    forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntraGroupM () -> m (IntraAcc, Stms GPU)
runIntraGroupM forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall a. Set a -> [a]
S.toList Set [SubExp]
min_ws,
      forall a. Set a -> [a]
S.toList Set [SubExp]
avail_ws,
      Log
log,
      forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body
    )
  where
    ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se