{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

-- | Extract limited nested parallelism for execution inside
-- individual kernel workgroups.
module Futhark.Pass.ExtractKernels.Intragroup (intraGroupParallelise) where

import Control.Monad.Identity
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.IR.Kernels as Out
import Futhark.IR.Kernels.Kernel hiding (HistOp)
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ToKernels
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Util (chunks)
import Futhark.Util.Log
import Prelude hiding (log)

-- | Convert the statements inside a map nest to kernel statements,
-- attempting to parallelise any remaining (top-level) parallel
-- statements.  Anything that is not a map, scan or reduction will
-- simply be sequentialised.  This includes sequential loops that
-- contain maps, scans or reduction.  In the future, we could probably
-- do something more clever.  Make sure that the amount of parallelism
-- to be exploited does not exceed the group size.  Further, as a hack
-- we also consider the size of all intermediate arrays as
-- "parallelism to be exploited" to avoid exploding local memory.
--
-- We distinguish between "minimum group size" and "maximum
-- exploitable parallelism".
intraGroupParallelise ::
  (MonadFreshNames m, LocalScope Out.Kernels m) =>
  KernelNest ->
  Lambda ->
  m
    ( Maybe
        ( (SubExp, SubExp),
          SubExp,
          Log,
          Out.Stms Out.Kernels,
          Out.Stms Out.Kernels
        )
    )
intraGroupParallelise :: KernelNest
-> Lambda
-> m (Maybe
        ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
intraGroupParallelise KernelNest
knest Lambda
lam = MaybeT
  m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> m (Maybe
        ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT
   m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
 -> m (Maybe
         ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)))
-> MaybeT
     m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> m (Maybe
        ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
forall a b. (a -> b) -> a -> b
$ do
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([(VName, SubExp)], [KernelInput])
 -> MaybeT m ([(VName, SubExp)], [KernelInput]))
-> m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall a b. (a -> b) -> a -> b
$ KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
knest

  (SubExp
num_groups, Stms Kernels
w_stms) <-
    m (SubExp, Stms Kernels) -> MaybeT m (SubExp, Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SubExp, Stms Kernels) -> MaybeT m (SubExp, Stms Kernels))
-> m (SubExp, Stms Kernels) -> MaybeT m (SubExp, Stms Kernels)
forall a b. (a -> b) -> a -> b
$
      Binder Kernels SubExp -> m (SubExp, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels SubExp -> m (SubExp, Stms Kernels))
-> Binder Kernels SubExp -> m (SubExp, Stms Kernels)
forall a b. (a -> b) -> a -> b
$
        String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"intra_num_groups"
          (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace)

  let body :: BodyT SOACS
body = Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam

  VName
group_size <- String -> MaybeT m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"computed_group_size"
  let intra_lvl :: SegLevel
intra_lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt

  ([[SubExp]]
wss_min, [[SubExp]]
wss_avail, Log
log, KernelBody Kernels
kbody) <-
    m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
 -> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$
      Scope Kernels
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param Type] -> Scope Kernels) -> [Param Type] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) (m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
 -> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$
        SegLevel
-> BodyT SOACS
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> BodyT SOACS
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
intraGroupParalleliseBody SegLevel
intra_lvl BodyT SOACS
body

  Scope Kernels
outside_scope <- m (Scope Kernels) -> MaybeT m (Scope Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  -- outside_scope may also contain the inputs, even though those are
  -- not actually available outside the kernel.
  let available :: VName -> Bool
available VName
v =
        VName
v VName -> Scope Kernels -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope Kernels
outside_scope
          Bool -> Bool -> Bool
&& VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (KernelInput -> VName) -> [KernelInput] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> VName
kernelInputName [KernelInput]
inps
  Bool -> MaybeT m () -> MaybeT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
available ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> Names
forall a. FreeIn a => a -> Names
freeIn ([[SubExp]]
wss_min [[SubExp]] -> [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a] -> [a]
++ [[SubExp]]
wss_avail)) (MaybeT m () -> MaybeT m ()) -> MaybeT m () -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$
    String -> MaybeT m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Irregular parallelism"

  ((SubExp
intra_avail_par, SegSpace
kspace, Stms Kernels
read_input_stms), Stms Kernels
prelude_stms) <- m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
-> MaybeT m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
 -> MaybeT m ((SubExp, SegSpace, Stms Kernels), Stms Kernels))
-> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
-> MaybeT m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall a b. (a -> b) -> a -> b
$
    Binder Kernels (SubExp, SegSpace, Stms Kernels)
-> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (SubExp, SegSpace, Stms Kernels)
 -> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels))
-> Binder Kernels (SubExp, SegSpace, Stms Kernels)
-> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
      let foldBinOp' :: BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' BinOp
_ [] = SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> m (Exp (Lore m))) -> SubExp -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
          foldBinOp' BinOp
bop (SubExp
x : [SubExp]
xs) = BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp BinOp
bop SubExp
x [SubExp]
xs
      [SubExp]
ws_min <-
        ([SubExp] -> Binder Kernels SubExp)
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"one_intra_par_min" (ExpT Kernels -> Binder Kernels SubExp)
-> ([SubExp] -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [SubExp]
-> Binder Kernels SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) ([[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp])
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
          ([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_min
      [SubExp]
ws_avail <-
        ([SubExp] -> Binder Kernels SubExp)
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"one_intra_par_avail" (ExpT Kernels -> Binder Kernels SubExp)
-> ([SubExp] -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [SubExp]
-> Binder Kernels SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) ([[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp])
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
          ([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_avail

      -- The amount of parallelism available *in the worst case* is
      -- equal to the smallest parallel loop.
      SubExp
intra_avail_par <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"intra_avail_par" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> BinOp
SMin IntType
Int64) [SubExp]
ws_avail

      -- The group size is either the maximum of the minimum parallelism
      -- exploited, or the desired parallelism (bounded by the max group
      -- size) in case there is no minimum.
      [VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
group_size]
        (ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< if [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
ws_min
          then
            BinOp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
              (IntType -> BinOp
SMin IntType
Int64)
              (SubExp -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> Binder Kernels SubExp
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"max_group_size" (Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
Out.GetSizeMax SizeClass
Out.SizeGroup))
              (SubExp
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
intra_avail_par)
          else BinOp
-> [SubExp]
-> BinderT
     Kernels
     (State VNameSource)
     (Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> BinOp
SMax IntType
Int64) [SubExp]
ws_min

      let inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT SOACS
body
          used_inps :: [KernelInput]
used_inps = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps

      Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
w_stms
      Stms Kernels
read_input_stms <- Binder Kernels [()]
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels [()]
 -> BinderT Kernels (State VNameSource) (Stms Kernels))
-> Binder Kernels [()]
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (KernelInput -> BinderT Kernels (State VNameSource) ())
-> [KernelInput] -> Binder Kernels [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelInput -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
(DistLore (Lore m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
used_inps
      SegSpace
space <- [(VName, SubExp)] -> BinderT Kernels (State VNameSource) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
      (SubExp, SegSpace, Stms Kernels)
-> Binder Kernels (SubExp, SegSpace, Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
intra_avail_par, SegSpace
space, Stms Kernels
read_input_stms)

  let kbody' :: KernelBody Kernels
kbody' = KernelBody Kernels
kbody {kernelBodyStms :: Stms Kernels
kernelBodyStms = Stms Kernels
read_input_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> KernelBody Kernels -> Stms Kernels
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody Kernels
kbody}

  let nested_pat :: PatternT Type
nested_pat = LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
first_nest
      rts :: [Type]
rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
`stripArray`) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
nested_pat
      lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt
      kstm :: Stm Kernels
kstm =
        Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern Kernels
nested_pat StmAux ()
StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$
          Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
kspace [Type]
rts KernelBody Kernels
kbody'

  let intra_min_par :: SubExp
intra_min_par = SubExp
intra_avail_par
  ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> MaybeT
     m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( (SubExp
intra_min_par, SubExp
intra_avail_par),
      VName -> SubExp
Var VName
group_size,
      Log
log,
      Stms Kernels
prelude_stms,
      Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
kstm
    )
  where
    first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
knest
    aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
first_nest

data Acc = Acc
  { Acc -> Set [SubExp]
accMinPar :: S.Set [SubExp],
    Acc -> Set [SubExp]
accAvailPar :: S.Set [SubExp],
    Acc -> Log
accLog :: Log
  }

instance Semigroup Acc where
  Acc Set [SubExp]
min_x Set [SubExp]
avail_x Log
log_x <> :: Acc -> Acc -> Acc
<> Acc Set [SubExp]
min_y Set [SubExp]
avail_y Log
log_y =
    Set [SubExp] -> Set [SubExp] -> Log -> Acc
Acc (Set [SubExp]
min_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
min_y) (Set [SubExp]
avail_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
avail_y) (Log
log_x Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log_y)

instance Monoid Acc where
  mempty :: Acc
mempty = Set [SubExp] -> Set [SubExp] -> Log -> Acc
Acc Set [SubExp]
forall a. Monoid a => a
mempty Set [SubExp]
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty

type IntraGroupM =
  BinderT Out.Kernels (RWS () Acc VNameSource)

instance MonadLogger IntraGroupM where
  addLog :: Log -> IntraGroupM ()
addLog Log
log = Acc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Acc
forall a. Monoid a => a
mempty {accLog :: Log
accLog = Log
log}

runIntraGroupM ::
  (MonadFreshNames m, HasScope Out.Kernels m) =>
  IntraGroupM () ->
  m (Acc, Out.Stms Out.Kernels)
runIntraGroupM :: IntraGroupM () -> m (Acc, Stms Kernels)
runIntraGroupM IntraGroupM ()
m = do
  Scope Kernels
scope <- Scope Kernels -> Scope Kernels
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope Kernels -> Scope Kernels)
-> m (Scope Kernels) -> m (Scope Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  (VNameSource -> ((Acc, Stms Kernels), VNameSource))
-> m (Acc, Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Acc, Stms Kernels), VNameSource))
 -> m (Acc, Stms Kernels))
-> (VNameSource -> ((Acc, Stms Kernels), VNameSource))
-> m (Acc, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let (((), Stms Kernels
kstms), VNameSource
src', Acc
acc) = RWS () Acc VNameSource ((), Stms Kernels)
-> () -> VNameSource -> (((), Stms Kernels), VNameSource, Acc)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (IntraGroupM ()
-> Scope Kernels -> RWS () Acc VNameSource ((), Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT IntraGroupM ()
m Scope Kernels
scope) () VNameSource
src
     in ((Acc
acc, Stms Kernels
kstms), VNameSource
src')

parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
ws =
  Acc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
    Acc
forall a. Monoid a => a
mempty
      { accMinPar :: Set [SubExp]
accMinPar = [SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton [SubExp]
ws,
        accAvailPar :: Set [SubExp]
accAvailPar = [SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton [SubExp]
ws
      }

intraGroupBody :: SegLevel -> Body -> IntraGroupM (Out.Body Out.Kernels)
intraGroupBody :: SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
body = do
  Stms Kernels
stms <- IntraGroupM ()
-> BinderT
     Kernels
     (RWST () Acc VNameSource Identity)
     (Stms (Lore IntraGroupM))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (IntraGroupM ()
 -> BinderT
      Kernels
      (RWST () Acc VNameSource Identity)
      (Stms (Lore IntraGroupM)))
-> IntraGroupM ()
-> BinderT
     Kernels
     (RWST () Acc VNameSource Identity)
     (Stms (Lore IntraGroupM))
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body
  Body Kernels -> IntraGroupM (Body Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body Kernels -> IntraGroupM (Body Kernels))
-> Body Kernels -> IntraGroupM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [SubExp] -> Body Kernels
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms Kernels
stms ([SubExp] -> Body Kernels) -> [SubExp] -> Body Kernels
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body

intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl stm :: Stm
stm@(Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = do
  Scope Kernels
scope <- BinderT Kernels (RWST () Acc VNameSource Identity) (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let lvl' :: SegLevel
lvl' = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegVirt
SegNoVirt

  case Exp SOACS
e of
    DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form BodyT SOACS
loopbody ->
      Scope Kernels -> IntraGroupM () -> IntraGroupM ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm Kernels
form') (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Scope Kernels -> IntraGroupM () -> IntraGroupM ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope Kernels)
-> [Param DeclType] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ([(Param DeclType, SubExp)] -> [Param DeclType])
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> a -> b
$ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
          Body Kernels
loopbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
loopbody
          Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
            Pattern (Lore IntraGroupM)
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore IntraGroupM)
Pattern SOACS
pat (Exp (Lore IntraGroupM) -> IntraGroupM ())
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> Body Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam SOACS, SubExp)]
[(FParam Kernels, SubExp)]
ctx [(FParam SOACS, SubExp)]
[(FParam Kernels, SubExp)]
val LoopForm Kernels
form' Body Kernels
loopbody'
      where
        form' :: LoopForm Kernels
form' = case LoopForm SOACS
form of
          ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
inps -> VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam Kernels, VName)]
inps
          WhileLoop VName
cond -> VName -> LoopForm Kernels
forall lore. VName -> LoopForm lore
WhileLoop VName
cond
    If SubExp
cond BodyT SOACS
tbody BodyT SOACS
fbody IfDec (BranchType SOACS)
ifdec -> do
      Body Kernels
tbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
tbody
      Body Kernels
fbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
fbody
      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore IntraGroupM)
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore IntraGroupM)
Pattern SOACS
pat (Exp (Lore IntraGroupM) -> IntraGroupM ())
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body Kernels
-> Body Kernels
-> IfDec (BranchType Kernels)
-> ExpT Kernels
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond Body Kernels
tbody' Body Kernels
fbody' IfDec (BranchType SOACS)
IfDec (BranchType Kernels)
ifdec
    Op Op SOACS
soac
      | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux ->
        SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ())
-> (Stms SOACS -> Stms SOACS) -> Stms SOACS -> IntraGroupM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm -> Stm) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm -> Stm
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
          (Stms SOACS -> IntraGroupM ())
-> BinderT Kernels (RWST () Acc VNameSource Identity) (Stms SOACS)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Binder SOACS ()
-> BinderT Kernels (RWST () Acc VNameSource Identity) (Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Pattern (Lore (BinderT SOACS (State VNameSource)))
-> SOAC (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT SOACS (State VNameSource)))
Pattern SOACS
pat Op SOACS
SOAC (Lore (BinderT SOACS (State VNameSource)))
soac)
    Op (Screma w form arrs)
      | Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form -> do
        let loopnest :: LoopNesting
loopnest = PatternT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
Pattern SOACS
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs
            env :: DistEnv Kernels IntraGroupM
env =
              DistEnv :: forall lore (m :: * -> *).
Nestings
-> Scope lore
-> (Stms SOACS -> DistNestT lore m (Stms lore))
-> (MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore))
-> (Stm -> Binder lore (Stms lore))
-> (Lambda -> Binder lore (Lambda lore))
-> MkSegLevel lore m
-> DistEnv lore m
DistEnv
                { distNest :: Nestings
distNest =
                    Nesting -> Nestings
singleNesting (Nesting -> Nestings) -> Nesting -> Nestings
forall a b. (a -> b) -> a -> b
$ Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest,
                  distScope :: Scope Kernels
distScope =
                    PatternT Type -> Scope Kernels
forall lore dec. (LetDec lore ~ dec) => PatternT dec -> Scope lore
scopeOfPattern PatternT Type
Pattern SOACS
pat
                      Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope Kernels
scopeForKernels (Lambda -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda
lam)
                      Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Scope Kernels
scope,
                  distOnInnerMap :: MapLoop
-> DistAcc Kernels
-> DistNestT Kernels IntraGroupM (DistAcc Kernels)
distOnInnerMap =
                    MapLoop
-> DistAcc Kernels
-> DistNestT Kernels IntraGroupM (DistAcc Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distributeMap,
                  distOnTopLevelStms :: Stms SOACS -> DistNestT Kernels IntraGroupM (Stms Kernels)
distOnTopLevelStms =
                    BinderT Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> DistNestT Kernels IntraGroupM (Stms Kernels)
forall lore (m :: * -> *) a.
(LocalScope lore m, DistLore lore) =>
m a -> DistNestT lore m a
liftInner (BinderT Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
 -> DistNestT Kernels IntraGroupM (Stms Kernels))
-> (Stms SOACS
    -> BinderT
         Kernels (RWST () Acc VNameSource Identity) (Stms Kernels))
-> Stms SOACS
-> DistNestT Kernels IntraGroupM (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntraGroupM ()
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (IntraGroupM ()
 -> BinderT
      Kernels (RWST () Acc VNameSource Identity) (Stms Kernels))
-> (Stms SOACS -> IntraGroupM ())
-> Stms SOACS
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl,
                  distSegLevel :: MkSegLevel Kernels IntraGroupM
distSegLevel = \[SubExp]
minw String
_ ThreadRecommendation
_ -> do
                    IntraGroupM () -> BinderT Kernels IntraGroupM ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IntraGroupM () -> BinderT Kernels IntraGroupM ())
-> IntraGroupM () -> BinderT Kernels IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
minw
                    SegLevel -> BinderT Kernels IntraGroupM SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return SegLevel
lvl,
                  distOnSOACSStms :: Stm -> BinderT Kernels (State VNameSource) (Stms Kernels)
distOnSOACSStms =
                    Stms Kernels -> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms Kernels
 -> BinderT Kernels (State VNameSource) (Stms Kernels))
-> (Stm -> Stms Kernels)
-> Stm
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> Stms Kernels)
-> (Stm -> Stm Kernels) -> Stm -> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm -> Stm Kernels
soacsStmToKernels,
                  distOnSOACSLambda :: Lambda -> Binder Kernels (Lambda Kernels)
distOnSOACSLambda =
                    Lambda Kernels -> Binder Kernels (Lambda Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda Kernels -> Binder Kernels (Lambda Kernels))
-> (Lambda -> Lambda Kernels)
-> Lambda
-> Binder Kernels (Lambda Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Lambda Kernels
soacsLambdaToKernels
                }
            acc :: DistAcc Kernels
acc =
              DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc
                { distTargets :: Targets
distTargets = Target -> Targets
singleTarget (PatternT Type
Pattern SOACS
pat, BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT SOACS -> [SubExp]) -> BodyT SOACS -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam),
                  distStms :: Stms Kernels
distStms = Stms Kernels
forall a. Monoid a => a
mempty
                }

        Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms
          (Stms Kernels -> IntraGroupM ())
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistEnv Kernels IntraGroupM
-> DistNestT Kernels IntraGroupM (DistAcc Kernels)
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) lore.
(MonadLogger m, DistLore lore) =>
DistEnv lore m -> DistNestT lore m (DistAcc lore) -> m (Stms lore)
runDistNestT DistEnv Kernels IntraGroupM
env (DistAcc Kernels
-> Stms SOACS -> DistNestT Kernels IntraGroupM (DistAcc Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc Kernels
acc (BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam))
    Op (Screma w form arrs)
      | Just ([Scan SOACS]
scans, Lambda
mapfun) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm SOACS
form,
        Scan Lambda
scanfun [SubExp]
nes <- [Scan SOACS] -> Scan SOACS
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan SOACS]
scans -> do
        let scanfun' :: Lambda Kernels
scanfun' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
scanfun
            mapfun' :: Lambda Kernels
mapfun' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
mapfun
        Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
          Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> IntraGroupM ())
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel Kernels
-> Pattern Kernels
-> SubExp
-> [SegBinOp Kernels]
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segScan SegOpLevel Kernels
SegLevel
lvl' Pattern SOACS
Pattern Kernels
pat SubExp
w [Commutativity
-> Lambda Kernels -> [SubExp] -> Shape -> SegBinOp Kernels
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
Noncommutative Lambda Kernels
scanfun' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda Kernels
mapfun' [VName]
arrs [] []
        [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Screma w form arrs)
      | Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form,
        Reduce Commutativity
comm Lambda
red_lam [SubExp]
nes <- [Reduce SOACS] -> Reduce SOACS
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce SOACS]
reds -> do
        let red_lam' :: Lambda Kernels
red_lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
red_lam
            map_lam' :: Lambda Kernels
map_lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
map_lam
        Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
          Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> IntraGroupM ())
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel Kernels
-> Pattern Kernels
-> SubExp
-> [SegBinOp Kernels]
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segRed SegOpLevel Kernels
SegLevel
lvl' Pattern SOACS
Pattern Kernels
pat SubExp
w [Commutativity
-> Lambda Kernels -> [SubExp] -> Shape -> SegBinOp Kernels
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm Lambda Kernels
red_lam' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda Kernels
map_lam' [VName]
arrs [] []
        [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Hist w ops bucket_fun arrs) -> do
      [HistOp Kernels]
ops' <- [HistOp SOACS]
-> (HistOp SOACS
    -> BinderT
         Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) [HistOp Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS
  -> BinderT
       Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
 -> BinderT
      Kernels (RWST () Acc VNameSource Identity) [HistOp Kernels])
-> (HistOp SOACS
    -> BinderT
         Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) [HistOp Kernels]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda
op) -> do
        (Lambda
op', [SubExp]
nes', Shape
shape) <- Lambda
-> [SubExp]
-> BinderT
     Kernels
     (RWST () Acc VNameSource Identity)
     (Lambda, [SubExp], Shape)
forall (m :: * -> *) lore.
(MonadBinder m, Lore m ~ lore) =>
Lambda -> [SubExp] -> m (Lambda, [SubExp], Shape)
determineReduceOp Lambda
op [SubExp]
nes
        let op'' :: Lambda Kernels
op'' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
op'
        HistOp Kernels
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp Kernels
 -> BinderT
      Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
-> HistOp Kernels
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda Kernels
-> HistOp Kernels
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda lore
-> HistOp lore
Out.HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda Kernels
op''

      let bucket_fun' :: Lambda Kernels
bucket_fun' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
bucket_fun
      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> IntraGroupM ())
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel Kernels
-> Pattern Kernels
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp Kernels]
-> Lambda Kernels
-> [VName]
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp lore]
-> Lambda lore
-> [VName]
-> m (Stms lore)
segHist SegOpLevel Kernels
SegLevel
lvl' Pattern SOACS
Pattern Kernels
pat SubExp
w [] [] [HistOp Kernels]
ops' Lambda Kernels
bucket_fun' [VName]
arrs
      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Stream w (Sequential accs) lam arrs)
      | LParam SOACS
chunk_size_param : [LParam SOACS]
_ <- Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam -> do
        Scope SOACS
types <- (Scope Kernels -> Scope SOACS)
-> BinderT Kernels (RWST () Acc VNameSource Identity) (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope
        ((), Stms SOACS
stream_bnds) <-
          BinderT SOACS IntraGroupM ()
-> Scope SOACS
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS IntraGroupM))
-> SubExp
-> [SubExp]
-> LambdaT (Lore (BinderT SOACS IntraGroupM))
-> [VName]
-> BinderT SOACS IntraGroupM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> [SubExp] -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS IntraGroupM))
Pattern SOACS
pat SubExp
w [SubExp]
accs LambdaT (Lore (BinderT SOACS IntraGroupM))
Lambda
lam [VName]
arrs) Scope SOACS
types
        let replace :: SubExp -> SubExp
replace (Var VName
v) | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
chunk_size_param = SubExp
w
            replace SubExp
se = SubExp
se
            replaceSets :: Acc -> Acc
replaceSets (Acc Set [SubExp]
x Set [SubExp]
y Log
log) =
              Set [SubExp] -> Set [SubExp] -> Log -> Acc
Acc (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
x) (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
y) Log
log
        (Acc -> Acc) -> IntraGroupM () -> IntraGroupM ()
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor Acc -> Acc
replaceSets (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl Stms SOACS
stream_bnds
    Op (Scatter w lam ivs dests) -> do
      VName
write_i <- String -> BinderT Kernels (RWST () Acc VNameSource Identity) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
      SegSpace
space <- [(VName, SubExp)]
-> BinderT Kernels (RWST () Acc VNameSource Identity) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName
write_i, SubExp
w)]

      let lam' :: Lambda Kernels
lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
lam
          ([SubExp]
dests_ws, [Int]
dests_ns, [VName]
dests_vs) = [(SubExp, Int, VName)] -> ([SubExp], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Int, VName)]
dests
          ([SubExp]
i_res, [SubExp]
v_res) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
dests_ns) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ Body Kernels -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body Kernels -> [SubExp]) -> Body Kernels -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam'
          krets :: [KernelResult]
krets = do
            (SubExp
a_w, VName
a, [(SubExp, SubExp)]
is_vs) <- [SubExp]
-> [VName]
-> [[(SubExp, SubExp)]]
-> [(SubExp, VName, [(SubExp, SubExp)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [SubExp]
dests_ws [VName]
dests_vs ([[(SubExp, SubExp)]] -> [(SubExp, VName, [(SubExp, SubExp)])])
-> [[(SubExp, SubExp)]] -> [(SubExp, VName, [(SubExp, SubExp)])]
forall a b. (a -> b) -> a -> b
$ [Int] -> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
dests_ns ([(SubExp, SubExp)] -> [[(SubExp, SubExp)]])
-> [(SubExp, SubExp)] -> [[(SubExp, SubExp)]]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
i_res [SubExp]
v_res
            KernelResult -> [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns [SubExp
a_w] VName
a [([SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i], SubExp
v) | (SubExp
i, SubExp
v) <- [(SubExp, SubExp)]
is_vs]
          inputs :: [KernelInput]
inputs = do
            (Param Type
p, VName
p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam') [VName]
ivs
            KernelInput -> [KernelInput]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelInput -> [KernelInput]) -> KernelInput -> [KernelInput]
forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]

      Stms Kernels
kstms <- BinderT Kernels (State VNameSource) ()
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) ()
 -> BinderT
      Kernels (RWST () Acc VNameSource Identity) (Stms Kernels))
-> BinderT Kernels (State VNameSource) ()
-> BinderT
     Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
        Scope Kernels
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (BinderT Kernels (State VNameSource) ()
 -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ do
          (KernelInput -> BinderT Kernels (State VNameSource) ())
-> [KernelInput] -> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
(DistLore (Lore m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs
          Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) ())
-> Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (Body Kernels -> Stms Kernels) -> Body Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam'

      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
        let ts :: [Type]
ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
Pattern SOACS
pat
            body :: KernelBody Kernels
body = BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
kstms [KernelResult]
krets
        Pattern (Lore IntraGroupM)
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore IntraGroupM)
Pattern SOACS
pat (Exp (Lore IntraGroupM) -> IntraGroupM ())
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl' SegSpace
space [Type]
ts KernelBody Kernels
body

      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Exp SOACS
_ ->
      Stm (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore IntraGroupM) -> IntraGroupM ())
-> Stm (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Stm -> Stm Kernels
soacsStmToKernels Stm
stm

intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl = (Stm -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl)

intraGroupParalleliseBody ::
  (MonadFreshNames m, HasScope Out.Kernels m) =>
  SegLevel ->
  Body ->
  m ([[SubExp]], [[SubExp]], Log, Out.KernelBody Out.Kernels)
intraGroupParalleliseBody :: SegLevel
-> BodyT SOACS
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
intraGroupParalleliseBody SegLevel
lvl BodyT SOACS
body = do
  (Acc Set [SubExp]
min_ws Set [SubExp]
avail_ws Log
log, Stms Kernels
kstms) <-
    IntraGroupM () -> m (Acc, Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
IntraGroupM () -> m (Acc, Stms Kernels)
runIntraGroupM (IntraGroupM () -> m (Acc, Stms Kernels))
-> IntraGroupM () -> m (Acc, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body
  ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
min_ws,
      Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
avail_ws,
      Log
log,
      BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
kstms ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) ([SubExp] -> [KernelResult]) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body
    )