{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

-- | Extract limited nested parallelism for execution inside
-- individual kernel workgroups.
module Futhark.Pass.ExtractKernels.Intragroup (intraGroupParallelise) where

import Control.Monad.Identity
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.IR.GPU as Out
import Futhark.IR.GPU.Op hiding (HistOp)
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Util.Log
import Prelude hiding (log)

-- | Convert the statements inside a map nest to kernel statements,
-- attempting to parallelise any remaining (top-level) parallel
-- statements.  Anything that is not a map, scan or reduction will
-- simply be sequentialised.  This includes sequential loops that
-- contain maps, scans or reduction.  In the future, we could probably
-- do something more clever.  Make sure that the amount of parallelism
-- to be exploited does not exceed the group size.  Further, as a hack
-- we also consider the size of all intermediate arrays as
-- "parallelism to be exploited" to avoid exploding local memory.
--
-- We distinguish between "minimum group size" and "maximum
-- exploitable parallelism".
intraGroupParallelise ::
  (MonadFreshNames m, LocalScope Out.GPU m) =>
  KernelNest ->
  Lambda ->
  m
    ( Maybe
        ( (SubExp, SubExp),
          SubExp,
          Log,
          Out.Stms Out.GPU,
          Out.Stms Out.GPU
        )
    )
intraGroupParallelise :: KernelNest
-> Lambda
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
intraGroupParallelise KernelNest
knest Lambda
lam = MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
 -> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)))
-> MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall a b. (a -> b) -> a -> b
$ do
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([(VName, SubExp)], [KernelInput])
 -> MaybeT m ([(VName, SubExp)], [KernelInput]))
-> m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall a b. (a -> b) -> a -> b
$ KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
knest

  (SubExp
num_groups, Stms GPU
w_stms) <-
    m (SubExp, Stms GPU) -> MaybeT m (SubExp, Stms GPU)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SubExp, Stms GPU) -> MaybeT m (SubExp, Stms GPU))
-> m (SubExp, Stms GPU) -> MaybeT m (SubExp, Stms GPU)
forall a b. (a -> b) -> a -> b
$
      Builder GPU SubExp -> m (SubExp, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU SubExp -> m (SubExp, Stms GPU))
-> Builder GPU SubExp -> m (SubExp, Stms GPU)
forall a b. (a -> b) -> a -> b
$
        String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"intra_num_groups"
          (ExpT GPU -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) (ExpT GPU)
-> Builder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace)

  let body :: BodyT SOACS
body = Lambda -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam

  VName
group_size <- String -> MaybeT m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"computed_group_size"
  let intra_lvl :: SegLevel
intra_lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt

  ([[SubExp]]
wss_min, [[SubExp]]
wss_avail, Log
log, KernelBody GPU
kbody) <-
    m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
 -> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody GPU))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall a b. (a -> b) -> a -> b
$
      Scope GPU
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type] -> Scope GPU) -> [Param Type] -> Scope GPU
forall a b. (a -> b) -> a -> b
$ Lambda -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda
lam) (m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
 -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall a b. (a -> b) -> a -> b
$
        SegLevel
-> BodyT SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
SegLevel
-> BodyT SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody SegLevel
intra_lvl BodyT SOACS
body

  Scope GPU
outside_scope <- m (Scope GPU) -> MaybeT m (Scope GPU)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  -- outside_scope may also contain the inputs, even though those are
  -- not actually available outside the kernel.
  let available :: VName -> Bool
available VName
v =
        VName
v VName -> Scope GPU -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope GPU
outside_scope
          Bool -> Bool -> Bool
&& VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (KernelInput -> VName) -> [KernelInput] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> VName
kernelInputName [KernelInput]
inps
  Bool -> MaybeT m () -> MaybeT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
available ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> Names
forall a. FreeIn a => a -> Names
freeIn ([[SubExp]]
wss_min [[SubExp]] -> [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a] -> [a]
++ [[SubExp]]
wss_avail)) (MaybeT m () -> MaybeT m ()) -> MaybeT m () -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$
    String -> MaybeT m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Irregular parallelism"

  ((SubExp
intra_avail_par, SegSpace
kspace, Stms GPU
read_input_stms), Stms GPU
prelude_stms) <- m ((SubExp, SegSpace, Stms GPU), Stms GPU)
-> MaybeT m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ((SubExp, SegSpace, Stms GPU), Stms GPU)
 -> MaybeT m ((SubExp, SegSpace, Stms GPU), Stms GPU))
-> m ((SubExp, SegSpace, Stms GPU), Stms GPU)
-> MaybeT m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall a b. (a -> b) -> a -> b
$
    Builder GPU (SubExp, SegSpace, Stms GPU)
-> m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU (SubExp, SegSpace, Stms GPU)
 -> m ((SubExp, SegSpace, Stms GPU), Stms GPU))
-> Builder GPU (SubExp, SegSpace, Stms GPU)
-> m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
      let foldBinOp' :: BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' BinOp
_ [] = SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m))) -> SubExp -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
          foldBinOp' BinOp
bop (SubExp
x : [SubExp]
xs) = BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
bop SubExp
x [SubExp]
xs
      [SubExp]
ws_min <-
        ([SubExp] -> Builder GPU SubExp)
-> [[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"one_intra_par_min" (ExpT GPU -> Builder GPU SubExp)
-> ([SubExp] -> BuilderT GPU (State VNameSource) (ExpT GPU))
-> [SubExp]
-> Builder GPU SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) ([[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp])
-> [[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
          ([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_min
      [SubExp]
ws_avail <-
        ([SubExp] -> Builder GPU SubExp)
-> [[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"one_intra_par_avail" (ExpT GPU -> Builder GPU SubExp)
-> ([SubExp] -> BuilderT GPU (State VNameSource) (ExpT GPU))
-> [SubExp]
-> Builder GPU SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) ([[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp])
-> [[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
          ([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_avail

      -- The amount of parallelism available *in the worst case* is
      -- equal to the smallest parallel loop.
      SubExp
intra_avail_par <- String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"intra_avail_par" (ExpT GPU -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) (ExpT GPU)
-> Builder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> [SubExp]
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> BinOp
SMin IntType
Int64) [SubExp]
ws_avail

      -- The group size is either the maximum of the minimum parallelism
      -- exploited, or the desired parallelism (bounded by the max group
      -- size) in case there is no minimum.
      [VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
group_size]
        (ExpT GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (ExpT GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< if [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
ws_min
          then
            BinOp
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              (IntType -> BinOp
SMin IntType
Int64)
              (SubExp -> BuilderT GPU (State VNameSource) (ExpT GPU)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> BuilderT GPU (State VNameSource) (ExpT GPU))
-> Builder GPU SubExp
-> BuilderT GPU (State VNameSource) (ExpT GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"max_group_size" (Op GPU -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (Op GPU -> ExpT GPU) -> Op GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPU (SOAC GPU)
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp GPU (SOAC GPU))
-> SizeOp -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
Out.GetSizeMax SizeClass
Out.SizeGroup))
              (SubExp
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
intra_avail_par)
          else BinOp
-> [SubExp]
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> BinOp
SMax IntType
Int64) [SubExp]
ws_min

      let inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT SOACS
body
          used_inps :: [KernelInput]
used_inps = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps

      Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
Stms GPU
w_stms
      Stms GPU
read_input_stms <- Builder GPU [()] -> BuilderT GPU (State VNameSource) (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU [()] -> BuilderT GPU (State VNameSource) (Stms GPU))
-> Builder GPU [()] -> BuilderT GPU (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ (KernelInput -> BuilderT GPU (State VNameSource) ())
-> [KernelInput] -> Builder GPU [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelInput -> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readGroupKernelInput [KernelInput]
used_inps
      SegSpace
space <- [(VName, SubExp)] -> BuilderT GPU (State VNameSource) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
      (SubExp, SegSpace, Stms GPU)
-> Builder GPU (SubExp, SegSpace, Stms GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
intra_avail_par, SegSpace
space, Stms GPU
read_input_stms)

  let kbody' :: KernelBody GPU
kbody' = KernelBody GPU
kbody {kernelBodyStms :: Stms GPU
kernelBodyStms = Stms GPU
read_input_stms Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> KernelBody GPU -> Stms GPU
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPU
kbody}

  let nested_pat :: PatT Type
nested_pat = LoopNesting -> PatT Type
loopNestingPat LoopNesting
first_nest
      rts :: [Type]
rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
`stripArray`) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatT Type -> [Type]
forall dec. Typed dec => PatT dec -> [Type]
patTypes PatT Type
nested_pat
      lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt
      kstm :: Stm GPU
kstm =
        Pat GPU -> StmAux (ExpDec GPU) -> ExpT GPU -> Stm GPU
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let PatT Type
Pat GPU
nested_pat StmAux ()
StmAux (ExpDec GPU)
aux (ExpT GPU -> Stm GPU) -> ExpT GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$
          Op GPU -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (Op GPU -> ExpT GPU) -> Op GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
kspace [Type]
rts KernelBody GPU
kbody'

  let intra_min_par :: SubExp
intra_min_par = SubExp
intra_avail_par
  ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( (SubExp
intra_min_par, SubExp
intra_avail_par),
      VName -> SubExp
Var VName
group_size,
      Log
log,
      Stms GPU
prelude_stms,
      Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
kstm
    )
  where
    first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
knest
    aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
first_nest

readGroupKernelInput ::
  (DistRep (Rep m), MonadBuilder m) =>
  KernelInput ->
  m ()
readGroupKernelInput :: KernelInput -> m ()
readGroupKernelInput KernelInput
inp
  | Array {} <- KernelInput -> Type
kernelInputType KernelInput
inp = do
    VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputName KernelInput
inp
    KernelInput -> m ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp {kernelInputName :: VName
kernelInputName = VName
v}
    [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [KernelInput -> VName
kernelInputName KernelInput
inp] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
  | Bool
otherwise =
    KernelInput -> m ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp

data IntraAcc = IntraAcc
  { IntraAcc -> Set [SubExp]
accMinPar :: S.Set [SubExp],
    IntraAcc -> Set [SubExp]
accAvailPar :: S.Set [SubExp],
    IntraAcc -> Log
accLog :: Log
  }

instance Semigroup IntraAcc where
  IntraAcc Set [SubExp]
min_x Set [SubExp]
avail_x Log
log_x <> :: IntraAcc -> IntraAcc -> IntraAcc
<> IntraAcc Set [SubExp]
min_y Set [SubExp]
avail_y Log
log_y =
    Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (Set [SubExp]
min_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
min_y) (Set [SubExp]
avail_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
avail_y) (Log
log_x Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log_y)

instance Monoid IntraAcc where
  mempty :: IntraAcc
mempty = Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc Set [SubExp]
forall a. Monoid a => a
mempty Set [SubExp]
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty

type IntraGroupM =
  BuilderT Out.GPU (RWS () IntraAcc VNameSource)

instance MonadLogger IntraGroupM where
  addLog :: Log -> IntraGroupM ()
addLog Log
log = IntraAcc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell IntraAcc
forall a. Monoid a => a
mempty {accLog :: Log
accLog = Log
log}

runIntraGroupM ::
  (MonadFreshNames m, HasScope Out.GPU m) =>
  IntraGroupM () ->
  m (IntraAcc, Out.Stms Out.GPU)
runIntraGroupM :: IntraGroupM () -> m (IntraAcc, Stms GPU)
runIntraGroupM IntraGroupM ()
m = do
  Scope GPU
scope <- Scope GPU -> Scope GPU
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope GPU -> Scope GPU) -> m (Scope GPU) -> m (Scope GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  (VNameSource -> ((IntraAcc, Stms GPU), VNameSource))
-> m (IntraAcc, Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((IntraAcc, Stms GPU), VNameSource))
 -> m (IntraAcc, Stms GPU))
-> (VNameSource -> ((IntraAcc, Stms GPU), VNameSource))
-> m (IntraAcc, Stms GPU)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let (((), Stms GPU
kstms), VNameSource
src', IntraAcc
acc) = RWS () IntraAcc VNameSource ((), Stms GPU)
-> () -> VNameSource -> (((), Stms GPU), VNameSource, IntraAcc)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (IntraGroupM ()
-> Scope GPU -> RWS () IntraAcc VNameSource ((), Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT IntraGroupM ()
m Scope GPU
scope) () VNameSource
src
     in ((IntraAcc
acc, Stms GPU
kstms), VNameSource
src')

parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
ws =
  IntraAcc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
    IntraAcc
forall a. Monoid a => a
mempty
      { accMinPar :: Set [SubExp]
accMinPar = [SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton [SubExp]
ws,
        accAvailPar :: Set [SubExp]
accAvailPar = [SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton [SubExp]
ws
      }

intraGroupBody :: SegLevel -> Body -> IntraGroupM (Out.Body Out.GPU)
intraGroupBody :: SegLevel -> BodyT SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl BodyT SOACS
body = do
  Stms GPU
stms <- IntraGroupM ()
-> BuilderT
     GPU
     (RWST () IntraAcc VNameSource Identity)
     (Stms (Rep IntraGroupM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (IntraGroupM ()
 -> BuilderT
      GPU
      (RWST () IntraAcc VNameSource Identity)
      (Stms (Rep IntraGroupM)))
-> IntraGroupM ()
-> BuilderT
     GPU
     (RWST () IntraAcc VNameSource Identity)
     (Stms (Rep IntraGroupM))
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms BodyT SOACS
body
  Body GPU -> IntraGroupM (Body GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body GPU -> IntraGroupM (Body GPU))
-> Body GPU -> IntraGroupM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
stms (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT SOACS
body

intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl stm :: Stm
stm@(Let Pat SOACS
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = do
  Scope GPU
scope <- BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let lvl' :: SegLevel
lvl' = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegVirt
SegNoVirt

  case Exp SOACS
e of
    DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form BodyT SOACS
loopbody ->
      Scope GPU -> IntraGroupM () -> IntraGroupM ()
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm GPU
form') (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Scope GPU -> IntraGroupM () -> IntraGroupM ()
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams ([Param DeclType] -> Scope GPU) -> [Param DeclType] -> Scope GPU
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
          Body GPU
loopbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl BodyT SOACS
loopbody
          Certs -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
            Pat (Rep IntraGroupM) -> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep IntraGroupM)
Pat SOACS
pat (Exp (Rep IntraGroupM) -> IntraGroupM ())
-> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> ExpT GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam SOACS, SubExp)]
[(FParam GPU, SubExp)]
merge LoopForm GPU
form' Body GPU
loopbody'
      where
        form' :: LoopForm GPU
form' = case LoopForm SOACS
form of
          ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
inps -> VName -> IntType -> SubExp -> [(LParam GPU, VName)] -> LoopForm GPU
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam GPU, VName)]
inps
          WhileLoop VName
cond -> VName -> LoopForm GPU
forall rep. VName -> LoopForm rep
WhileLoop VName
cond
    If SubExp
cond BodyT SOACS
tbody BodyT SOACS
fbody IfDec (BranchType SOACS)
ifdec -> do
      Body GPU
tbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl BodyT SOACS
tbody
      Body GPU
fbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body GPU)
intraGroupBody SegLevel
lvl BodyT SOACS
fbody
      Certs -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Pat (Rep IntraGroupM) -> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep IntraGroupM)
Pat SOACS
pat (Exp (Rep IntraGroupM) -> IntraGroupM ())
-> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body GPU -> Body GPU -> IfDec (BranchType GPU) -> ExpT GPU
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond Body GPU
tbody' Body GPU
fbody' IfDec (BranchType SOACS)
IfDec (BranchType GPU)
ifdec
    Op Op SOACS
soac
      | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux ->
        SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ())
-> (Stms SOACS -> Stms SOACS) -> Stms SOACS -> IntraGroupM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm -> Stm) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm -> Stm
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
          (Stms SOACS -> IntraGroupM ())
-> BuilderT
     GPU (RWST () IntraAcc VNameSource Identity) (Stms SOACS)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Builder SOACS ()
-> BuilderT
     GPU (RWST () IntraAcc VNameSource Identity) (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Pat (Rep (BuilderT SOACS (State VNameSource)))
-> SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pat (Rep m) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (Rep (BuilderT SOACS (State VNameSource)))
Pat SOACS
pat Op SOACS
SOAC (Rep (BuilderT SOACS (State VNameSource)))
soac)
    Op (Screma w arrs form)
      | Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form -> do
        let loopnest :: LoopNesting
loopnest = PatT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatT Type
Pat SOACS
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda
lam) [VName]
arrs
            env :: DistEnv GPU IntraGroupM
env =
              DistEnv :: forall rep (m :: * -> *).
Nestings
-> Scope rep
-> (Stms SOACS -> DistNestT rep m (Stms rep))
-> (MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep))
-> (Stm -> Builder rep (Stms rep))
-> (Lambda -> Builder rep (Lambda rep))
-> MkSegLevel rep m
-> DistEnv rep m
DistEnv
                { distNest :: Nestings
distNest =
                    Nesting -> Nestings
singleNesting (Nesting -> Nestings) -> Nesting -> Nestings
forall a b. (a -> b) -> a -> b
$ Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest,
                  distScope :: Scope GPU
distScope =
                    PatT Type -> Scope GPU
forall rep dec. (LetDec rep ~ dec) => PatT dec -> Scope rep
scopeOfPat PatT Type
Pat SOACS
pat
                      Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope GPU
scopeForGPU (Lambda -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda
lam)
                      Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope GPU
scope,
                  distOnInnerMap :: MapLoop -> DistAcc GPU -> DistNestT GPU IntraGroupM (DistAcc GPU)
distOnInnerMap =
                    MapLoop -> DistAcc GPU -> DistNestT GPU IntraGroupM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap,
                  distOnTopLevelStms :: Stms SOACS -> DistNestT GPU IntraGroupM (Stms GPU)
distOnTopLevelStms =
                    BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
-> DistNestT GPU IntraGroupM (Stms GPU)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
 -> DistNestT GPU IntraGroupM (Stms GPU))
-> (Stms SOACS
    -> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU))
-> Stms SOACS
-> DistNestT GPU IntraGroupM (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntraGroupM ()
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (IntraGroupM ()
 -> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU))
-> (Stms SOACS -> IntraGroupM ())
-> Stms SOACS
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl,
                  distSegLevel :: MkSegLevel GPU IntraGroupM
distSegLevel = \[SubExp]
minw String
_ ThreadRecommendation
_ -> do
                    IntraGroupM () -> BuilderT GPU IntraGroupM ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IntraGroupM () -> BuilderT GPU IntraGroupM ())
-> IntraGroupM () -> BuilderT GPU IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
minw
                    SegLevel -> BuilderT GPU IntraGroupM SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return SegLevel
lvl,
                  distOnSOACSStms :: Stm -> BuilderT GPU (State VNameSource) (Stms GPU)
distOnSOACSStms =
                    Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU))
-> (Stm -> Stms GPU)
-> Stm
-> BuilderT GPU (State VNameSource) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU) -> (Stm -> Stm GPU) -> Stm -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm -> Stm GPU
soacsStmToGPU,
                  distOnSOACSLambda :: Lambda -> Builder GPU (Lambda GPU)
distOnSOACSLambda =
                    Lambda GPU -> Builder GPU (Lambda GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU -> Builder GPU (Lambda GPU))
-> (Lambda -> Lambda GPU) -> Lambda -> Builder GPU (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Lambda GPU
soacsLambdaToGPU
                }
            acc :: DistAcc GPU
acc =
              DistAcc :: forall rep. Targets -> Stms rep -> DistAcc rep
DistAcc
                { distTargets :: Targets
distTargets = Target -> Targets
singleTarget (PatT Type
Pat SOACS
pat, BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT SOACS -> Result) -> BodyT SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam),
                  distStms :: Stms GPU
distStms = Stms GPU
forall a. Monoid a => a
mempty
                }

        Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
          (Stms GPU -> IntraGroupM ())
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistEnv GPU IntraGroupM
-> DistNestT GPU IntraGroupM (DistAcc GPU)
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv GPU IntraGroupM
env (DistAcc GPU
-> Stms SOACS -> DistNestT GPU IntraGroupM (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc GPU
acc (BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda
lam))
    Op (Screma w arrs form)
      | Just ([Scan SOACS]
scans, Lambda
mapfun) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form,
        Scan Lambda
scanfun [SubExp]
nes <- [Scan SOACS] -> Scan SOACS
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans -> do
        let scanfun' :: Lambda GPU
scanfun' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
scanfun
            mapfun' :: Lambda GPU
mapfun' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
mapfun
        Certs -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
          Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntraGroupM ())
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat GPU
-> Certs
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat rep
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel GPU
SegLevel
lvl' Pat SOACS
Pat GPU
pat Certs
forall a. Monoid a => a
mempty SubExp
w [Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda GPU
scanfun' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda GPU
mapfun' [VName]
arrs [] []
        [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Screma w arrs form)
      | Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
        Reduce Commutativity
comm Lambda
red_lam [SubExp]
nes <- [Reduce SOACS] -> Reduce SOACS
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce SOACS]
reds -> do
        let red_lam' :: Lambda GPU
red_lam' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
red_lam
            map_lam' :: Lambda GPU
map_lam' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
map_lam
        Certs -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
          Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntraGroupM ())
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat GPU
-> Certs
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat rep
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel GPU
SegLevel
lvl' Pat SOACS
Pat GPU
pat Certs
forall a. Monoid a => a
mempty SubExp
w [Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm Lambda GPU
red_lam' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda GPU
map_lam' [VName]
arrs [] []
        [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Hist w arrs ops bucket_fun) -> do
      [HistOp GPU]
ops' <- [HistOp SOACS]
-> (HistOp SOACS
    -> BuilderT
         GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU))
-> BuilderT
     GPU (RWST () IntraAcc VNameSource Identity) [HistOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS
  -> BuilderT
       GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU))
 -> BuilderT
      GPU (RWST () IntraAcc VNameSource Identity) [HistOp GPU])
-> (HistOp SOACS
    -> BuilderT
         GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU))
-> BuilderT
     GPU (RWST () IntraAcc VNameSource Identity) [HistOp GPU]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda
op) -> do
        (Lambda
op', [SubExp]
nes', Shape
shape) <- Lambda
-> [SubExp]
-> BuilderT
     GPU
     (RWST () IntraAcc VNameSource Identity)
     (Lambda, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda -> [SubExp] -> m (Lambda, [SubExp], Shape)
determineReduceOp Lambda
op [SubExp]
nes
        let op'' :: Lambda GPU
op'' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
op'
        HistOp GPU
-> BuilderT
     GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp GPU
 -> BuilderT
      GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU))
-> HistOp GPU
-> BuilderT
     GPU (RWST () IntraAcc VNameSource Identity) (HistOp GPU)
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda GPU
-> HistOp GPU
forall rep.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
Out.HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda GPU
op''

      let bucket_fun' :: Lambda GPU
bucket_fun' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
bucket_fun
      Certs -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntraGroupM ())
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat GPU
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp GPU]
-> Lambda GPU
-> [VName]
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep
-> Pat rep
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
segHist SegOpLevel GPU
SegLevel
lvl' Pat SOACS
Pat GPU
pat SubExp
w [] [] [HistOp GPU]
ops' Lambda GPU
bucket_fun' [VName]
arrs
      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Stream w arrs Sequential accs lam)
      | LParam SOACS
chunk_size_param : [LParam SOACS]
_ <- Lambda -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda
lam -> do
        Scope SOACS
types <- (Scope GPU -> Scope SOACS)
-> BuilderT
     GPU (RWST () IntraAcc VNameSource Identity) (Scope SOACS)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope
        ((), Stms SOACS
stream_stms) <-
          BuilderT SOACS IntraGroupM ()
-> Scope SOACS
-> BuilderT
     GPU (RWST () IntraAcc VNameSource Identity) ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Pat (Rep (BuilderT SOACS IntraGroupM))
-> SubExp
-> [SubExp]
-> LambdaT (Rep (BuilderT SOACS IntraGroupM))
-> [VName]
-> BuilderT SOACS IntraGroupM ()
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (Rep m)
-> SubExp -> [SubExp] -> LambdaT (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (Rep (BuilderT SOACS IntraGroupM))
Pat SOACS
pat SubExp
w [SubExp]
accs LambdaT (Rep (BuilderT SOACS IntraGroupM))
Lambda
lam [VName]
arrs) Scope SOACS
types
        let replace :: SubExp -> SubExp
replace (Var VName
v) | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
chunk_size_param = SubExp
w
            replace SubExp
se = SubExp
se
            replaceSets :: IntraAcc -> IntraAcc
replaceSets (IntraAcc Set [SubExp]
x Set [SubExp]
y Log
log) =
              Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
x) (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
y) Log
log
        (IntraAcc -> IntraAcc) -> IntraGroupM () -> IntraGroupM ()
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor IntraAcc -> IntraAcc
replaceSets (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl Stms SOACS
stream_stms
    Op (Scatter w ivs lam dests) -> do
      VName
write_i <- String
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
      SegSpace
space <- [(VName, SubExp)]
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName
write_i, SubExp
w)]

      let lam' :: Lambda GPU
lam' = Lambda -> Lambda GPU
soacsLambdaToGPU Lambda
lam
          ([Shape]
dests_ws, [Int]
_, [VName]
_) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
          krets :: [KernelResult]
krets = do
            (Shape
a_w, VName
a, [(Result, SubExpRes)]
is_vs) <-
              [(Shape, Int, VName)]
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests (Result -> [(Shape, VName, [(Result, SubExpRes)])])
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall a b. (a -> b) -> a -> b
$ Body GPU -> Result
forall rep. BodyT rep -> Result
bodyResult (Body GPU -> Result) -> Body GPU -> Result
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPU
lam'
            let cs :: Certs
cs =
                  ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((SubExpRes -> Certs) -> Result -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts (Result -> Certs)
-> ((Result, SubExpRes) -> Result) -> (Result, SubExpRes) -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> Result
forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
                    Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts (SubExpRes -> Certs)
-> ((Result, SubExpRes) -> SubExpRes)
-> (Result, SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
                is_vs' :: [(Slice SubExp, SubExp)]
is_vs' = [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> DimIndex SubExp) -> Result -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (SubExpRes -> SubExp) -> SubExpRes -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v) | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs]
            KernelResult -> [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
cs Shape
a_w VName
a [(Slice SubExp, SubExp)]
is_vs'
          inputs :: [KernelInput]
inputs = do
            (Param Type
p, VName
p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [LParam GPU]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPU
lam') [VName]
ivs
            KernelInput -> [KernelInput]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelInput -> [KernelInput]) -> KernelInput -> [KernelInput]
forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]

      Stms GPU
kstms <- BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (BuilderT GPU (State VNameSource) ()
 -> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU))
-> BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (RWST () IntraAcc VNameSource Identity) (Stms GPU)
forall a b. (a -> b) -> a -> b
$
        Scope GPU
-> BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (State VNameSource) ()
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (BuilderT GPU (State VNameSource) ()
 -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ do
          (KernelInput -> BuilderT GPU (State VNameSource) ())
-> [KernelInput] -> BuilderT GPU (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs
          Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (BuilderT GPU (State VNameSource)))
 -> BuilderT GPU (State VNameSource) ())
-> Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body GPU -> Stms GPU
forall rep. BodyT rep -> Stms rep
bodyStms (Body GPU -> Stms GPU) -> Body GPU -> Stms GPU
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPU
lam'

      Certs -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
        let ts :: [Type]
ts = (Shape -> Type -> Type) -> [Shape] -> [Type] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> (Shape -> Int) -> Shape -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [Shape]
dests_ws ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatT Type -> [Type]
forall dec. Typed dec => PatT dec -> [Type]
patTypes PatT Type
Pat SOACS
pat
            body :: KernelBody GPU
body = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms [KernelResult]
krets
        Pat (Rep IntraGroupM) -> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep IntraGroupM)
Pat SOACS
pat (Exp (Rep IntraGroupM) -> IntraGroupM ())
-> Exp (Rep IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Op GPU -> ExpT GPU
forall rep. Op rep -> ExpT rep
Op (Op GPU -> ExpT GPU) -> Op GPU -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl' SegSpace
space [Type]
ts KernelBody GPU
body

      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Exp SOACS
_ ->
      Stm (Rep IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep IntraGroupM) -> IntraGroupM ())
-> Stm (Rep IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Stm -> Stm GPU
soacsStmToGPU Stm
stm

intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl = (Stm -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl)

intraGroupParalleliseBody ::
  (MonadFreshNames m, HasScope Out.GPU m) =>
  SegLevel ->
  Body ->
  m ([[SubExp]], [[SubExp]], Log, Out.KernelBody Out.GPU)
intraGroupParalleliseBody :: SegLevel
-> BodyT SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody SegLevel
lvl BodyT SOACS
body = do
  (IntraAcc Set [SubExp]
min_ws Set [SubExp]
avail_ws Log
log, Stms GPU
kstms) <-
    IntraGroupM () -> m (IntraAcc, Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntraGroupM () -> m (IntraAcc, Stms GPU)
runIntraGroupM (IntraGroupM () -> m (IntraAcc, Stms GPU))
-> IntraGroupM () -> m (IntraAcc, Stms GPU)
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms BodyT SOACS
body
  ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
min_ws,
      Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
avail_ws,
      Log
log,
      BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms ([KernelResult] -> KernelBody GPU)
-> [KernelResult] -> KernelBody GPU
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret (Result -> [KernelResult]) -> Result -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT SOACS
body
    )
  where
    ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se