{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

-- | Extract limited nested parallelism for execution inside
-- individual kernel workgroups.
module Futhark.Pass.ExtractKernels.Intragroup (intraGroupParallelise) where

import Control.Monad.Identity
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.IR.GPU as Out
import Futhark.IR.GPU.Kernel hiding (HistOp)
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Util.Log
import Prelude hiding (log)

-- | Convert the statements inside a map nest to kernel statements,
-- attempting to parallelise any remaining (top-level) parallel
-- statements.  Anything that is not a map, scan or reduction will
-- simply be sequentialised.  This includes sequential loops that
-- contain maps, scans or reduction.  In the future, we could probably
-- do something more clever.  Make sure that the amount of parallelism
-- to be exploited does not exceed the group size.  Further, as a hack
-- we also consider the size of all intermediate arrays as
-- "parallelism to be exploited" to avoid exploding local memory.
--
-- We distinguish between "minimum group size" and "maximum
-- exploitable parallelism".
intraGroupParallelise ::
  (MonadFreshNames m, LocalScope Out.GPU m) =>
  KernelNest ->
  Lambda ->
  m
    ( Maybe
        ( (SubExp, SubExp),
          SubExp,
          Log,
          Out.Stms Out.GPU,
          Out.Stms Out.GPU
        )
    )
intraGroupParallelise :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope GPU m) =>
KernelNest
-> Lambda SOACS
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
intraGroupParallelise KernelNest
knest Lambda SOACS
lam = MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
 -> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)))
-> MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall a b. (a -> b) -> a -> b
$ do
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([(VName, SubExp)], [KernelInput])
 -> MaybeT m ([(VName, SubExp)], [KernelInput]))
-> m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall a b. (a -> b) -> a -> b
$ KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
knest

  (SubExp
num_groups, Stms GPU
w_stms) <-
    m (SubExp, Stms GPU) -> MaybeT m (SubExp, Stms GPU)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SubExp, Stms GPU) -> MaybeT m (SubExp, Stms GPU))
-> m (SubExp, Stms GPU) -> MaybeT m (SubExp, Stms GPU)
forall a b. (a -> b) -> a -> b
$
      Binder GPU SubExp -> m (SubExp, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (a, Stms rep)
runBinder (Binder GPU SubExp -> m (SubExp, Stms GPU))
-> Binder GPU SubExp -> m (SubExp, Stms GPU)
forall a b. (a -> b) -> a -> b
$
        String
-> Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"intra_num_groups"
          (ExpT GPU -> Binder GPU SubExp)
-> BinderT GPU (State VNameSource) (ExpT GPU) -> Binder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BinderT
     GPU
     (State VNameSource)
     (Exp (Rep (BinderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace)

  let body :: BodyT SOACS
body = Lambda SOACS -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda SOACS
lam

  VName
group_size <- String -> MaybeT m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"computed_group_size"
  let intra_lvl :: SegLevel
intra_lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt

  ([[SubExp]]
wss_min, [[SubExp]]
wss_avail, Log
log, KernelBody GPU
kbody) <-
    m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
 -> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody GPU))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall a b. (a -> b) -> a -> b
$
      Scope GPU
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type] -> Scope GPU) -> [Param Type] -> Scope GPU
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) (m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
 -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall a b. (a -> b) -> a -> b
$
        SegLevel
-> BodyT SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
SegLevel
-> BodyT SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody SegLevel
intra_lvl BodyT SOACS
body

  Scope GPU
outside_scope <- m (Scope GPU) -> MaybeT m (Scope GPU)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  -- outside_scope may also contain the inputs, even though those are
  -- not actually available outside the kernel.
  let available :: VName -> Bool
available VName
v =
        VName
v VName -> Scope GPU -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope GPU
outside_scope
          Bool -> Bool -> Bool
&& VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (KernelInput -> VName) -> [KernelInput] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> VName
kernelInputName [KernelInput]
inps
  Bool -> MaybeT m () -> MaybeT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
available ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> Names
forall a. FreeIn a => a -> Names
freeIn ([[SubExp]]
wss_min [[SubExp]] -> [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a] -> [a]
++ [[SubExp]]
wss_avail)) (MaybeT m () -> MaybeT m ()) -> MaybeT m () -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$
    String -> MaybeT m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Irregular parallelism"

  ((SubExp
intra_avail_par, SegSpace
kspace, Stms GPU
read_input_stms), Stms GPU
prelude_stms) <- m ((SubExp, SegSpace, Stms GPU), Stms GPU)
-> MaybeT m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ((SubExp, SegSpace, Stms GPU), Stms GPU)
 -> MaybeT m ((SubExp, SegSpace, Stms GPU), Stms GPU))
-> m ((SubExp, SegSpace, Stms GPU), Stms GPU)
-> MaybeT m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall a b. (a -> b) -> a -> b
$
    Binder GPU (SubExp, SegSpace, Stms GPU)
-> m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (a, Stms rep)
runBinder (Binder GPU (SubExp, SegSpace, Stms GPU)
 -> m ((SubExp, SegSpace, Stms GPU), Stms GPU))
-> Binder GPU (SubExp, SegSpace, Stms GPU)
-> m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
      let foldBinOp' :: BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' BinOp
_ [] = SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m))) -> SubExp -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
          foldBinOp' BinOp
bop (SubExp
x : [SubExp]
xs) = BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
bop SubExp
x [SubExp]
xs
      [SubExp]
ws_min <-
        ([SubExp] -> Binder GPU SubExp)
-> [[SubExp]] -> BinderT GPU (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"one_intra_par_min" (ExpT GPU -> Binder GPU SubExp)
-> ([SubExp] -> BinderT GPU (State VNameSource) (ExpT GPU))
-> [SubExp]
-> Binder GPU SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BinderT
     GPU
     (State VNameSource)
     (Exp (Rep (BinderT GPU (State VNameSource))))
forall {m :: * -> *}.
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) ([[SubExp]] -> BinderT GPU (State VNameSource) [SubExp])
-> [[SubExp]] -> BinderT GPU (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
          ([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_min
      [SubExp]
ws_avail <-
        ([SubExp] -> Binder GPU SubExp)
-> [[SubExp]] -> BinderT GPU (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"one_intra_par_avail" (ExpT GPU -> Binder GPU SubExp)
-> ([SubExp] -> BinderT GPU (State VNameSource) (ExpT GPU))
-> [SubExp]
-> Binder GPU SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BinderT
     GPU
     (State VNameSource)
     (Exp (Rep (BinderT GPU (State VNameSource))))
forall {m :: * -> *}.
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) ([[SubExp]] -> BinderT GPU (State VNameSource) [SubExp])
-> [[SubExp]] -> BinderT GPU (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
          ([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_avail

      -- The amount of parallelism available *in the worst case* is
      -- equal to the smallest parallel loop.
      SubExp
intra_avail_par <- String
-> Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"intra_avail_par" (ExpT GPU -> Binder GPU SubExp)
-> BinderT GPU (State VNameSource) (ExpT GPU) -> Binder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> [SubExp]
-> BinderT
     GPU
     (State VNameSource)
     (Exp (Rep (BinderT GPU (State VNameSource))))
forall {m :: * -> *}.
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> BinOp
SMin IntType
Int64) [SubExp]
ws_avail

      -- The group size is either the maximum of the minimum parallelism
      -- exploited, or the desired parallelism (bounded by the max group
      -- size) in case there is no minimum.
      [VName]
-> Exp (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
group_size]
        (ExpT GPU -> BinderT GPU (State VNameSource) ())
-> BinderT GPU (State VNameSource) (ExpT GPU)
-> BinderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< if [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
ws_min
          then
            BinOp
-> BinderT
     GPU
     (State VNameSource)
     (Exp (Rep (BinderT GPU (State VNameSource))))
-> BinderT
     GPU
     (State VNameSource)
     (Exp (Rep (BinderT GPU (State VNameSource))))
-> BinderT
     GPU
     (State VNameSource)
     (Exp (Rep (BinderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              (IntType -> BinOp
SMin IntType
Int64)
              (SubExp -> BinderT GPU (State VNameSource) (ExpT GPU)
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> BinderT GPU (State VNameSource) (ExpT GPU))
-> Binder GPU SubExp -> BinderT GPU (State VNameSource) (ExpT GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String
-> Exp (Rep (BinderT GPU (State VNameSource))) -> Binder GPU SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"max_group_size" (Op GPU -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (Op GPU -> ExpT GPU) -> Op GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPU (SOAC GPU)
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp GPU (SOAC GPU))
-> SizeOp -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
Out.GetSizeMax SizeClass
Out.SizeGroup))
              (SubExp
-> BinderT
     GPU
     (State VNameSource)
     (Exp (Rep (BinderT GPU (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
intra_avail_par)
          else BinOp
-> [SubExp]
-> BinderT
     GPU
     (State VNameSource)
     (Exp (Rep (BinderT GPU (State VNameSource))))
forall {m :: * -> *}.
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> BinOp
SMax IntType
Int64) [SubExp]
ws_min

      let inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT SOACS
body
          used_inps :: [KernelInput]
used_inps = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps

      Stms (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BinderT GPU (State VNameSource)))
Stms GPU
w_stms
      Stms GPU
read_input_stms <- Binder GPU [()] -> BinderT GPU (State VNameSource) (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Binder GPU [()] -> BinderT GPU (State VNameSource) (Stms GPU))
-> Binder GPU [()] -> BinderT GPU (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ (KernelInput -> BinderT GPU (State VNameSource) ())
-> [KernelInput] -> Binder GPU [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelInput -> BinderT GPU (State VNameSource) ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBinder m) =>
KernelInput -> m ()
readGroupKernelInput [KernelInput]
used_inps
      SegSpace
space <- [(VName, SubExp)] -> BinderT GPU (State VNameSource) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
      (SubExp, SegSpace, Stms GPU)
-> Binder GPU (SubExp, SegSpace, Stms GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
intra_avail_par, SegSpace
space, Stms GPU
read_input_stms)

  let kbody' :: KernelBody GPU
kbody' = KernelBody GPU
kbody {kernelBodyStms :: Stms GPU
kernelBodyStms = Stms GPU
read_input_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> KernelBody GPU -> Stms GPU
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPU
kbody}

  let nested_pat :: PatternT Type
nested_pat = LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
first_nest
      rts :: [Type]
rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
`stripArray`) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
nested_pat
      lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt
      kstm :: Stm GPU
kstm =
        Pattern GPU -> StmAux (ExpDec GPU) -> ExpT GPU -> Stm GPU
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let PatternT Type
Pattern GPU
nested_pat StmAux ()
StmAux (ExpDec GPU)
aux (ExpT GPU -> Stm GPU) -> ExpT GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$
          Op GPU -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (Op GPU -> ExpT GPU) -> Op GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
kspace [Type]
rts KernelBody GPU
kbody'

  let intra_min_par :: SubExp
intra_min_par = SubExp
intra_avail_par
  ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( (SubExp
intra_min_par, SubExp
intra_avail_par),
      VName -> SubExp
Var VName
group_size,
      Log
log,
      Stms GPU
prelude_stms,
      Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
kstm
    )
  where
    first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
knest
    aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
first_nest

readGroupKernelInput ::
  (DistRep (Rep m), MonadBinder m) =>
  KernelInput ->
  m ()
readGroupKernelInput :: forall (m :: * -> *).
(DistRep (Rep m), MonadBinder m) =>
KernelInput -> m ()
readGroupKernelInput KernelInput
inp
  | Array {} <- KernelInput -> Type
kernelInputType KernelInput
inp = do
    VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputName KernelInput
inp
    KernelInput -> m ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp {kernelInputName :: VName
kernelInputName = VName
v}
    [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [KernelInput -> VName
kernelInputName KernelInput
inp] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
  | Bool
otherwise =
    KernelInput -> m ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp

data IntraAcc = IntraAcc
  { IntraAcc -> Set [SubExp]
accMinPar :: S.Set [SubExp],
    IntraAcc -> Set [SubExp]
accAvailPar :: S.Set [SubExp],
    IntraAcc -> Log
accLog :: Log
  }

instance Semigroup IntraAcc where
  IntraAcc Set [SubExp]
min_x Set [SubExp]
avail_x Log
log_x <> :: IntraAcc -> IntraAcc -> IntraAcc
<> IntraAcc Set [SubExp]
min_y Set [SubExp]
avail_y Log
log_y =
    Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (Set [SubExp]
min_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
min_y) (Set [SubExp]
avail_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
avail_y) (Log
log_x Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log_y)

instance Monoid IntraAcc where
  mempty :: IntraAcc
mempty = Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc Set [SubExp]
forall a. Monoid a => a
mempty Set [SubExp]
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty

type IntraGroupM =
  BinderT Out.GPU (RWS () IntraAcc VNameSource)

instance MonadLogger IntraGroupM where
  addLog :: Log -> IntraGroupM ()
addLog Log
log = IntraAcc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell IntraAcc
forall a. Monoid a => a
mempty {accLog :: Log
accLog = Log
log}

runIntraGroupM ::
  (MonadFreshNames m, HasScope Out.GPU m) =>
  IntraGroupM () ->
  m (IntraAcc, Out.Stms Out.GPU)
runIntraGroupM :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntraGroupM () -> m (IntraAcc, Stms GPU)
runIntraGroupM IntraGroupM ()
m = do
  Scope GPU
scope <- Scope GPU -> Scope GPU
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope GPU -> Scope GPU) -> m (Scope GPU) -> m (Scope GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  (VNameSource -> ((IntraAcc, Stms GPU), VNameSource))
-> m (IntraAcc, Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((IntraAcc, Stms GPU), VNameSource))
 -> m (IntraAcc, Stms GPU))
-> (VNameSource -> ((IntraAcc, Stms GPU), VNameSource))
-> m (IntraAcc, Stms GPU)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let (((), Stms GPU
kstms), VNameSource
src', IntraAcc
acc) = RWS () IntraAcc VNameSource ((), Stms GPU)
-> () -> VNameSource -> (((), Stms GPU), VNameSource, IntraAcc)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (IntraGroupM ()
-> Scope GPU -> RWS () IntraAcc VNameSource ((), Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT IntraGroupM ()
m Scope GPU
scope) () VNameSource
src
     in ((IntraAcc
acc, Stms GPU
kstms), VNameSource
src')

parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
ws =
  IntraAcc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
    IntraAcc
forall a. Monoid a => a
mempty
      { accMinPar :: Set [SubExp]
accMinPar = [SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton [SubExp]
ws,
        accAvailPar :: Set [SubExp]
accAvailPar = [SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton [SubExp]
ws
      }

intraGroupBody :: SegLevel -> Body -> IntraGroupM (Out.Body Out.GPU)
intraGroupBody :: SegLevel -> BodyT SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl BodyT SOACS
body = do
  Stms GPU
stms <- IntraGroupM ()
-> BinderT
     GPU
     (RWST () IntraAcc VNameSource Identity)
     (Stms (Rep IntraGroupM))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Rep m))
collectStms_ (IntraGroupM ()
 -> BinderT
      GPU
      (RWST () IntraAcc VNameSource Identity)
      (Stms (Rep IntraGroupM)))
-> IntraGroupM ()
-> BinderT
     GPU
     (RWST () IntraAcc VNameSource Identity)
     (Stms (Rep IntraGroupM))
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms BodyT SOACS
body
  Body GPU -> IntraGroupM (Body GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body GPU -> IntraGroupM (Body GPU))
-> Body GPU -> IntraGroupM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> [SubExp] -> Body GPU
forall rep. Bindable rep => Stms rep -> [SubExp] -> Body rep
mkBody Stms GPU
stms ([SubExp] -> Body GPU) -> [SubExp] -> Body GPU
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult BodyT SOACS
body

intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl stm :: Stm
stm@(Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = do
  Scope GPU
scope <- BinderT GPU (RWST () IntraAcc VNameSource Identity) (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let lvl' :: SegLevel
lvl' = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegVirt
SegNoVirt

  case Exp SOACS
e of
    DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form BodyT SOACS
loopbody ->
      Scope GPU -> IntraGroupM () -> IntraGroupM ()
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm GPU
form') (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Scope GPU -> IntraGroupM () -> IntraGroupM ()
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams ([Param DeclType] -> Scope GPU) -> [Param DeclType] -> Scope GPU
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ([(Param DeclType, SubExp)] -> [Param DeclType])
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> a -> b
$ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
          Body GPU
loopbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl BodyT SOACS
loopbody
          Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
            Pattern (Rep IntraGroupM)
-> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern (Rep IntraGroupM)
Pattern SOACS
pat (Exp (Rep IntraGroupM) -> IntraGroupM ())
-> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ [(FParam GPU, SubExp)]
-> [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> ExpT GPU
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam SOACS, SubExp)]
[(FParam GPU, SubExp)]
ctx [(FParam SOACS, SubExp)]
[(FParam GPU, SubExp)]
val LoopForm GPU
form' Body GPU
loopbody'
      where
        form' :: LoopForm GPU
form' = case LoopForm SOACS
form of
          ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
inps -> VName -> IntType -> SubExp -> [(LParam GPU, VName)] -> LoopForm GPU
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam GPU, VName)]
inps
          WhileLoop VName
cond -> VName -> LoopForm GPU
forall rep. VName -> LoopForm rep
WhileLoop VName
cond
    If SubExp
cond BodyT SOACS
tbody BodyT SOACS
fbody IfDec (BranchType SOACS)
ifdec -> do
      Body GPU
tbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl BodyT SOACS
tbody
      Body GPU
fbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl BodyT SOACS
fbody
      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Rep IntraGroupM)
-> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern (Rep IntraGroupM)
Pattern SOACS
pat (Exp (Rep IntraGroupM) -> IntraGroupM ())
-> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body GPU -> Body GPU -> IfDec (BranchType GPU) -> ExpT GPU
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond Body GPU
tbody' Body GPU
fbody' IfDec (BranchType SOACS)
IfDec (BranchType GPU)
ifdec
    Op Op SOACS
soac
      | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux ->
        SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ())
-> (Stms SOACS -> Stms SOACS) -> Stms SOACS -> IntraGroupM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm -> Stm) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm -> Stm
forall rep. Certificates -> Stm rep -> Stm rep
certify (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
          (Stms SOACS -> IntraGroupM ())
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms SOACS)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Binder SOACS ()
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Pattern (Rep (BinderT SOACS (State VNameSource)))
-> SOAC (Rep (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pattern (Rep m) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pattern (Rep (BinderT SOACS (State VNameSource)))
Pattern SOACS
pat Op SOACS
SOAC (Rep (BinderT SOACS (State VNameSource)))
soac)
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form -> do
        let loopnest :: LoopNesting
loopnest = PatternT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
Pattern SOACS
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
            env :: DistEnv GPU IntraGroupM
env =
              DistEnv :: forall rep (m :: * -> *).
Nestings
-> Scope rep
-> (Stms SOACS -> DistNestT rep m (Stms rep))
-> (MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep))
-> (Stm -> Binder rep (Stms rep))
-> (Lambda SOACS -> Binder rep (Lambda rep))
-> MkSegLevel rep m
-> DistEnv rep m
DistEnv
                { distNest :: Nestings
distNest =
                    Nesting -> Nestings
singleNesting (Nesting -> Nestings) -> Nesting -> Nestings
forall a b. (a -> b) -> a -> b
$ Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest,
                  distScope :: Scope GPU
distScope =
                    PatternT Type -> Scope GPU
forall rep dec. (LetDec rep ~ dec) => PatternT dec -> Scope rep
scopeOfPattern PatternT Type
Pattern SOACS
pat
                      Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope GPU
scopeForGPU (Lambda SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam)
                      Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope GPU
scope,
                  distOnInnerMap :: MapLoop -> DistAcc GPU -> DistNestT GPU IntraGroupM (DistAcc GPU)
distOnInnerMap =
                    MapLoop -> DistAcc GPU -> DistNestT GPU IntraGroupM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap,
                  distOnTopLevelStms :: Stms SOACS -> DistNestT GPU IntraGroupM (Stms GPU)
distOnTopLevelStms =
                    BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
-> DistNestT GPU IntraGroupM (Stms GPU)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
 -> DistNestT GPU IntraGroupM (Stms GPU))
-> (Stms SOACS
    -> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU))
-> Stms SOACS
-> DistNestT GPU IntraGroupM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntraGroupM ()
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Rep m))
collectStms_ (IntraGroupM ()
 -> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU))
-> (Stms SOACS -> IntraGroupM ())
-> Stms SOACS
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl,
                  distSegLevel :: MkSegLevel GPU IntraGroupM
distSegLevel = \[SubExp]
minw String
_ ThreadRecommendation
_ -> do
                    IntraGroupM () -> BinderT GPU IntraGroupM ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IntraGroupM () -> BinderT GPU IntraGroupM ())
-> IntraGroupM () -> BinderT GPU IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
minw
                    SegLevel -> BinderT GPU IntraGroupM SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return SegLevel
lvl,
                  distOnSOACSStms :: Stm -> BinderT GPU (State VNameSource) (Stms GPU)
distOnSOACSStms =
                    Stms GPU -> BinderT GPU (State VNameSource) (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> BinderT GPU (State VNameSource) (Stms GPU))
-> (Stm -> Stms GPU)
-> Stm
-> BinderT GPU (State VNameSource) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU) -> (Stm -> Stm GPU) -> Stm -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm -> Stm GPU
soacsStmToGPU,
                  distOnSOACSLambda :: Lambda SOACS -> Binder GPU (Lambda GPU)
distOnSOACSLambda =
                    Lambda GPU -> Binder GPU (Lambda GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU -> Binder GPU (Lambda GPU))
-> (Lambda SOACS -> Lambda GPU)
-> Lambda SOACS
-> Binder GPU (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
                }
            acc :: DistAcc GPU
acc =
              DistAcc :: forall rep. Targets -> Stms rep -> DistAcc rep
DistAcc
                { distTargets :: Targets
distTargets = Target -> Targets
singleTarget (PatternT Type
Pattern SOACS
pat, BodyT SOACS -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (BodyT SOACS -> [SubExp]) -> BodyT SOACS -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda SOACS
lam),
                  distStms :: Stms GPU
distStms = Stms GPU
forall a. Monoid a => a
mempty
                }

        Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms
          (Stms GPU -> IntraGroupM ())
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistEnv GPU IntraGroupM
-> DistNestT GPU IntraGroupM (DistAcc GPU)
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv GPU IntraGroupM
env (DistAcc GPU
-> Stms SOACS -> DistNestT GPU IntraGroupM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc GPU
acc (BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda SOACS
lam))
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just ([Scan SOACS]
scans, Lambda SOACS
mapfun) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form,
        Scan Lambda SOACS
scanfun [SubExp]
nes <- [Scan SOACS] -> Scan SOACS
forall rep. Bindable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans -> do
        let scanfun' :: Lambda GPU
scanfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
scanfun
            mapfun' :: Lambda GPU
mapfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
mapfun
        Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
          Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntraGroupM ())
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pattern GPU
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pattern rep
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel GPU
SegLevel
lvl' Pattern SOACS
Pattern GPU
pat SubExp
w [Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda GPU
scanfun' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda GPU
mapfun' [VName]
arrs [] []
        [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
        Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
nes <- [Reduce SOACS] -> Reduce SOACS
forall rep. Bindable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce SOACS]
reds -> do
        let red_lam' :: Lambda GPU
red_lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
red_lam
            map_lam' :: Lambda GPU
map_lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
        Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
          Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntraGroupM ())
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pattern GPU
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pattern rep
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel GPU
SegLevel
lvl' Pattern SOACS
Pattern GPU
pat SubExp
w [Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda GPU
red_lam' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda GPU
map_lam' [VName]
arrs [] []
        [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
bucket_fun [VName]
arrs) -> do
      [HistOp GPU]
ops' <- [HistOp SOACS]
-> (HistOp SOACS
    -> BinderT
         GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU))
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) [HistOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS
  -> BinderT
       GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU))
 -> BinderT
      GPU (RWST () IntraAcc VNameSource Identity) [HistOp GPU])
-> (HistOp SOACS
    -> BinderT
         GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU))
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) [HistOp GPU]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) -> do
        (Lambda SOACS
op', [SubExp]
nes', Shape
shape) <- Lambda SOACS
-> [SubExp]
-> BinderT
     GPU
     (RWST () IntraAcc VNameSource Identity)
     (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
        let op'' :: Lambda GPU
op'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
op'
        HistOp GPU
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp GPU
 -> BinderT
      GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU))
-> HistOp GPU
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU)
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda GPU
-> HistOp GPU
forall rep.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
Out.HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda GPU
op''

      let bucket_fun' :: Lambda GPU
bucket_fun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
bucket_fun
      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntraGroupM ())
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pattern GPU
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp GPU]
-> Lambda GPU
-> [VName]
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep
-> Pattern rep
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
segHist SegOpLevel GPU
SegLevel
lvl' Pattern SOACS
Pattern GPU
pat SubExp
w [] [] [HistOp GPU]
ops' Lambda GPU
bucket_fun' [VName]
arrs
      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Stream SubExp
w [VName]
arrs StreamForm SOACS
Sequential [SubExp]
accs Lambda SOACS
lam)
      | LParam SOACS
chunk_size_param : [LParam SOACS]
_ <- Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam -> do
        Scope SOACS
types <- (Scope GPU -> Scope SOACS)
-> BinderT
     GPU (RWST () IntraAcc VNameSource Identity) (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope
        ((), Stms SOACS
stream_bnds) <-
          BinderT SOACS IntraGroupM ()
-> Scope SOACS
-> BinderT
     GPU (RWST () IntraAcc VNameSource Identity) ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT (Pattern (Rep (BinderT SOACS IntraGroupM))
-> SubExp
-> [SubExp]
-> LambdaT (Rep (BinderT SOACS IntraGroupM))
-> [VName]
-> BinderT SOACS IntraGroupM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Rep m)) =>
Pattern (Rep m)
-> SubExp -> [SubExp] -> LambdaT (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Rep (BinderT SOACS IntraGroupM))
Pattern SOACS
pat SubExp
w [SubExp]
accs LambdaT (Rep (BinderT SOACS IntraGroupM))
Lambda SOACS
lam [VName]
arrs) Scope SOACS
types
        let replace :: SubExp -> SubExp
replace (Var VName
v) | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
chunk_size_param = SubExp
w
            replace SubExp
se = SubExp
se
            replaceSets :: IntraAcc -> IntraAcc
replaceSets (IntraAcc Set [SubExp]
x Set [SubExp]
y Log
log) =
              Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
x) (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
y) Log
log
        (IntraAcc -> IntraAcc) -> IntraGroupM () -> IntraGroupM ()
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor IntraAcc -> IntraAcc
replaceSets (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl Stms SOACS
stream_bnds
    Op (Scatter SubExp
w Lambda SOACS
lam [VName]
ivs [(Shape, Int, VName)]
dests) -> do
      VName
write_i <- String -> BinderT GPU (RWST () IntraAcc VNameSource Identity) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
      SegSpace
space <- [(VName, SubExp)]
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName
write_i, SubExp
w)]

      let lam' :: Lambda GPU
lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
          ([Shape]
dests_ws, [Int]
_, [VName]
_) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
          krets :: [KernelResult]
krets = do
            (Shape
a_w, VName
a, [([SubExp], SubExp)]
is_vs) <-
              [(Shape, Int, VName)]
-> [SubExp] -> [(Shape, VName, [([SubExp], SubExp)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests ([SubExp] -> [(Shape, VName, [([SubExp], SubExp)])])
-> [SubExp] -> [(Shape, VName, [([SubExp], SubExp)])]
forall a b. (a -> b) -> a -> b
$ Body GPU -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (Body GPU -> [SubExp]) -> Body GPU -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPU
lam'
            KernelResult -> [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Shape
a_w VName
a [((SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
is, SubExp
v) | ([SubExp]
is, SubExp
v) <- [([SubExp], SubExp)]
is_vs]
          inputs :: [KernelInput]
inputs = do
            (Param Type
p, VName
p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [LParam GPU]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPU
lam') [VName]
ivs
            KernelInput -> [KernelInput]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelInput -> [KernelInput]) -> KernelInput -> [KernelInput]
forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]

      Stms GPU
kstms <- BinderT GPU (State VNameSource) ()
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (BinderT GPU (State VNameSource) ()
 -> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU))
-> BinderT GPU (State VNameSource) ()
-> BinderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall a b. (a -> b) -> a -> b
$
        Scope GPU
-> BinderT GPU (State VNameSource) ()
-> BinderT GPU (State VNameSource) ()
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (BinderT GPU (State VNameSource) ()
 -> BinderT GPU (State VNameSource) ())
-> BinderT GPU (State VNameSource) ()
-> BinderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ do
          (KernelInput -> BinderT GPU (State VNameSource) ())
-> [KernelInput] -> BinderT GPU (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BinderT GPU (State VNameSource) ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs
          Stms (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (BinderT GPU (State VNameSource)))
 -> BinderT GPU (State VNameSource) ())
-> Stms (Rep (BinderT GPU (State VNameSource)))
-> BinderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body GPU -> Stms GPU
forall rep. BodyT rep -> Stms rep
bodyStms (Body GPU -> Stms GPU) -> Body GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPU
lam'

      Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
        let ts :: [Type]
ts = (Shape -> Type -> Type) -> [Shape] -> [Type] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> (Shape -> Int) -> Shape -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [Shape]
dests_ws ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
Pattern SOACS
pat
            body :: KernelBody GPU
body = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms [KernelResult]
krets
        Pattern (Rep IntraGroupM)
-> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern (Rep IntraGroupM)
Pattern SOACS
pat (Exp (Rep IntraGroupM) -> IntraGroupM ())
-> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Op GPU -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (Op GPU -> ExpT GPU) -> Op GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl' SegSpace
space [Type]
ts KernelBody GPU
body

      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Exp SOACS
_ ->
      Stm (Rep IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stm (Rep m) -> m ()
addStm (Stm (Rep IntraGroupM) -> IntraGroupM ())
-> Stm (Rep IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Stm -> Stm GPU
soacsStmToGPU Stm
stm

intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl = (Stm -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl)

intraGroupParalleliseBody ::
  (MonadFreshNames m, HasScope Out.GPU m) =>
  SegLevel ->
  Body ->
  m ([[SubExp]], [[SubExp]], Log, Out.KernelBody Out.GPU)
intraGroupParalleliseBody :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
SegLevel
-> BodyT SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody SegLevel
lvl BodyT SOACS
body = do
  (IntraAcc Set [SubExp]
min_ws Set [SubExp]
avail_ws Log
log, Stms GPU
kstms) <-
    IntraGroupM () -> m (IntraAcc, Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntraGroupM () -> m (IntraAcc, Stms GPU)
runIntraGroupM (IntraGroupM () -> m (IntraAcc, Stms GPU))
-> IntraGroupM () -> m (IntraAcc, Stms GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms BodyT SOACS
body
  ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
min_ws,
      Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
avail_ws,
      Log
log,
      BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms ([KernelResult] -> KernelBody GPU)
-> [KernelResult] -> KernelBody GPU
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) ([SubExp] -> [KernelResult]) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult BodyT SOACS
body
    )