{-# LANGUAGE TypeFamilies #-}

-- | It is well known that fully parallel loops can always be
-- interchanged inwards with a sequential loop.  This module
-- implements that transformation.
--
-- This is also where we implement loop-switching (for branches),
-- which is semantically similar to interchange.
module Futhark.Pass.ExtractKernels.Interchange
  ( SeqLoop (..),
    interchangeLoops,
    Branch (..),
    interchangeBranch,
    WithAccStm (..),
    interchangeWithAcc,
  )
where

import Control.Monad
import Data.List (find)
import Data.Maybe
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.Distribution
  ( KernelNest,
    LoopNesting (..),
    kernelNestLoops,
    scopeOfKernelNest,
  )
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (splitFromEnd)

-- | An encoding of a sequential do-loop with no existential context,
-- alongside its result pattern.
data SeqLoop
  = SeqLoop [Int] (Pat Type) [(FParam SOACS, SubExp)] (LoopForm SOACS) (Body SOACS)

loopPerm :: SeqLoop -> [Int]
loopPerm :: SeqLoop -> [Int]
loopPerm (SeqLoop [Int]
perm Pat Type
_ [(FParam SOACS, SubExp)]
_ LoopForm SOACS
_ Body SOACS
_) = [Int]
perm

seqLoopStm :: SeqLoop -> Stm SOACS
seqLoopStm :: SeqLoop -> Stm SOACS
seqLoopStm (SeqLoop [Int]
_ Pat Type
pat [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body) =
  forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body

interchangeLoop ::
  (MonadBuilder m, Rep m ~ SOACS) =>
  (VName -> Maybe VName) ->
  SeqLoop ->
  LoopNesting ->
  m SeqLoop
interchangeLoop :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
(VName -> Maybe VName) -> SeqLoop -> LoopNesting -> m SeqLoop
interchangeLoop
  VName -> Maybe VName
isMapParameter
  (SeqLoop [Int]
perm Pat Type
loop_pat [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body)
  (MapNesting Pat Type
pat StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) = do
    [(Param DeclType, SubExp)]
merge_expanded <-
      forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param Type, VName)]
params_and_arrs) forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param DeclType, SubExp) -> m (Param DeclType, SubExp)
expand [(FParam SOACS, SubExp)]
merge

    let loop_pat_expanded :: Pat Type
loop_pat_expanded =
          forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map PatElem Type -> PatElem Type
expandPatElem forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
loop_pat
        new_params :: [Param Type]
new_params =
          [forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pname forall a b. (a -> b) -> a -> b
$ forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
ptype | (Param Attrs
attrs VName
pname DeclType
ptype, SubExp
_) <- [(FParam SOACS, SubExp)]
merge]
        new_arrs :: [VName]
new_arrs = forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge_expanded
        rettype :: [Type]
rettype = forall a b. (a -> b) -> [a] -> [b]
map forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
loop_pat_expanded

    -- If the map consumes something that is bound outside the loop
    -- (i.e. is not a merge parameter), we have to copy() it.  As a
    -- small simplification, we just remove the parameter outright if
    -- it is not used anymore.  This might happen if the parameter was
    -- used just as the inital value of a merge parameter.
    (([Param Type]
params', [VName]
arrs'), Stms SOACS
pre_copy_stms) <-
      forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$
        forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
new_params) forall a b. (a -> b) -> a -> b
$
          forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param Type, VName)
-> BuilderT SOACS (State VNameSource) (Maybe (Param Type, VName))
copyOrRemoveParam [(Param Type, VName)]
params_and_arrs

    let lam :: Lambda SOACS
lam = forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda ([Param Type]
params' forall a. Semigroup a => a -> a -> a
<> [Param Type]
new_params) Body SOACS
body [Type]
rettype
        map_stm :: Stm SOACS
map_stm =
          forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
loop_pat_expanded StmAux ()
aux forall a b. (a -> b) -> a -> b
$
            forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
              forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
arrs' forall a. Semigroup a => a -> a -> a
<> [VName]
new_arrs) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
        res :: Result
res = [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
loop_pat_expanded
        pat' :: Pat Type
pat' = forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat

    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
      [Int]
-> Pat Type
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> Body SOACS
-> SeqLoop
SeqLoop [Int]
perm Pat Type
pat' [(Param DeclType, SubExp)]
merge_expanded LoopForm SOACS
form forall a b. (a -> b) -> a -> b
$
        forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms SOACS
pre_copy_stms forall a. Semigroup a => a -> a -> a
<> forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
map_stm) Result
res
    where
      free_in_body :: Names
free_in_body = forall a. FreeIn a => a -> Names
freeIn Body SOACS
body

      copyOrRemoveParam :: (Param Type, VName)
-> BuilderT SOACS (State VNameSource) (Maybe (Param Type, VName))
copyOrRemoveParam (Param Type
param, VName
arr)
        | forall dec. Param dec -> VName
paramName Param Type
param VName -> Names -> Bool
`notNameIn` Names
free_in_body =
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
        | Bool
otherwise =
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Param Type
param, VName
arr)

      expandedInit :: String -> SubExp -> m SubExp
expandedInit String
_ (Var VName
v)
        | Just VName
arr <- VName -> Maybe VName
isMapParameter VName
v =
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
      expandedInit String
param_name SubExp
se =
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (String
param_name forall a. Semigroup a => a -> a -> a
<> String
"_expanded_init") forall a b. (a -> b) -> a -> b
$
          forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            ShapeBase SubExp -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
se

      expand :: (Param DeclType, SubExp) -> m (Param DeclType, SubExp)
expand (Param DeclType
merge_param, SubExp
merge_init) = do
        Param DeclType
expanded_param <-
          forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (String
param_name forall a. Semigroup a => a -> a -> a
<> String
"_expanded") forall a b. (a -> b) -> a -> b
$
            -- FIXME: Unique here is a hack to make sure the copy from
            -- makeCopyInitial is not prematurely simplified away.
            -- It'd be better to fix this somewhere else...
            forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
merge_param) (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) Uniqueness
Unique
        SubExp
expanded_init <- String -> SubExp -> m SubExp
expandedInit String
param_name SubExp
merge_init
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param DeclType
expanded_param, SubExp
expanded_init)
        where
          param_name :: String
param_name = VName -> String
baseString forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param DeclType
merge_param

      expandPatElem :: PatElem Type -> PatElem Type
expandPatElem (PatElem VName
name Type
t) =
        forall dec. VName -> dec -> PatElem dec
PatElem VName
name forall a b. (a -> b) -> a -> b
$ forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
t SubExp
w

-- We need to copy some initial arguments because otherwise the result
-- of the loop might alias the input (if the number of iterations is
-- 0), which is a problem if the result is consumed.
maybeCopyInitial ::
  (MonadBuilder m) =>
  (VName -> Bool) ->
  SeqLoop ->
  m SeqLoop
maybeCopyInitial :: forall (m :: * -> *).
MonadBuilder m =>
(VName -> Bool) -> SeqLoop -> m SeqLoop
maybeCopyInitial VName -> Bool
isMapInput (SeqLoop [Int]
perm Pat Type
loop_pat [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body) =
  [Int]
-> Pat Type
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> Body SOACS
-> SeqLoop
SeqLoop [Int]
perm Pat Type
loop_pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param DeclType, SubExp) -> m (Param DeclType, SubExp)
f [(FParam SOACS, SubExp)]
merge forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure LoopForm SOACS
form forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
body
  where
    f :: (Param DeclType, SubExp) -> m (Param DeclType, SubExp)
f (Param DeclType
p, Var VName
arg)
      | VName -> Bool
isMapInput VName
arg,
        Array {} <- forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p =
          (Param DeclType
p,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString (forall dec. Param dec -> VName
paramName Param DeclType
p) forall a. Semigroup a => a -> a -> a
<> String
"_inter_copy") (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
arg)
    f (Param DeclType
p, SubExp
arg) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param DeclType
p, SubExp
arg)

manifestMaps :: [LoopNesting] -> [VName] -> Stms SOACS -> ([VName], Stms SOACS)
manifestMaps :: [LoopNesting] -> [VName] -> Stms SOACS -> ([VName], Stms SOACS)
manifestMaps [] [VName]
res Stms SOACS
stms = ([VName]
res, Stms SOACS
stms)
manifestMaps (LoopNesting
n : [LoopNesting]
ns) [VName]
res Stms SOACS
stms =
  let ([VName]
res', Stms SOACS
stms') = [LoopNesting] -> [VName] -> Stms SOACS -> ([VName], Stms SOACS)
manifestMaps [LoopNesting]
ns [VName]
res Stms SOACS
stms
      ([Param Type]
params, [VName]
arrs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
n
      lam :: Lambda SOACS
lam =
        forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda
          [Param Type]
params
          (forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms SOACS
stms' forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
res')
          (forall a b. (a -> b) -> [a] -> [b]
map forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes (LoopNesting -> Pat Type
loopNestingPat LoopNesting
n))
   in ( forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pat Type
loopNestingPat LoopNesting
n,
        forall rep. Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$
          forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (LoopNesting -> Pat Type
loopNestingPat LoopNesting
n) (LoopNesting -> StmAux ()
loopNestingAux LoopNesting
n) forall a b. (a -> b) -> a -> b
$
            forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
              forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (LoopNesting -> SubExp
loopNestingWidth LoopNesting
n) [VName]
arrs (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
      )

-- | Given a (parallel) map nesting and an inner sequential loop, move
-- the maps inside the sequential loop.  The result is several
-- statements - one of these will be the loop, which will then contain
-- statements with @map@ expressions.
interchangeLoops ::
  (MonadFreshNames m, HasScope SOACS m) =>
  KernelNest ->
  SeqLoop ->
  m (Stms SOACS)
interchangeLoops :: forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
full_nest = [LoopNesting] -> SeqLoop -> m (Stms SOACS)
recurse (KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
full_nest)
  where
    recurse :: [LoopNesting] -> SeqLoop -> m (Stms SOACS)
recurse [LoopNesting]
nest SeqLoop
loop
      | ([LoopNesting]
ns, [LoopNesting
n]) <- forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
1 [LoopNesting]
nest = do
          let isMapParameter :: VName -> Maybe VName
isMapParameter VName
v =
                forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
n)
              isMapInput :: VName -> Bool
isMapInput VName
v =
                VName
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd (LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs LoopNesting
n)
          (SeqLoop
loop', Stms SOACS
stms) <-
            forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. (LParamInfo rep ~ Type) => KernelNest -> Scope rep
scopeOfKernelNest KernelNest
full_nest) forall a b. (a -> b) -> a -> b
$
              forall (m :: * -> *).
MonadBuilder m =>
(VName -> Bool) -> SeqLoop -> m SeqLoop
maybeCopyInitial VName -> Bool
isMapInput
                forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
(VName -> Maybe VName) -> SeqLoop -> LoopNesting -> m SeqLoop
interchangeLoop VName -> Maybe VName
isMapParameter SeqLoop
loop LoopNesting
n

          -- Only safe to continue interchanging if we didn't need to add
          -- any new statements; otherwise we manifest the remaining nests
          -- as Maps and hand them back to the flattener.
          if forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms SOACS
stms
            then [LoopNesting] -> SeqLoop -> m (Stms SOACS)
recurse [LoopNesting]
ns SeqLoop
loop'
            else
              let loop_stm :: Stm SOACS
loop_stm = SeqLoop -> Stm SOACS
seqLoopStm SeqLoop
loop'
                  names :: [VName]
names = forall a. [Int] -> [a] -> [a]
rearrangeShape (SeqLoop -> [Int]
loopPerm SeqLoop
loop') (forall dec. Pat dec -> [VName]
patNames (forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm SOACS
loop_stm))
               in forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ [LoopNesting] -> [VName] -> Stms SOACS -> ([VName], Stms SOACS)
manifestMaps [LoopNesting]
ns [VName]
names forall a b. (a -> b) -> a -> b
$ Stms SOACS
stms forall a. Semigroup a => a -> a -> a
<> forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
loop_stm
      | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ SeqLoop -> Stm SOACS
seqLoopStm SeqLoop
loop

-- | An encoding of a branch with alongside its result pattern.
data Branch
  = Branch [Int] (Pat Type) [SubExp] [Case (Body SOACS)] (Body SOACS) (MatchDec (BranchType SOACS))

branchStm :: Branch -> Stm SOACS
branchStm :: Branch -> Stm SOACS
branchStm (Branch [Int]
_ Pat Type
pat [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret) =
  forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret

interchangeBranch1 ::
  (MonadFreshNames m, HasScope SOACS m) =>
  Branch ->
  LoopNesting ->
  m Branch
interchangeBranch1 :: forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
Branch -> LoopNesting -> m Branch
interchangeBranch1
  (Branch [Int]
perm Pat Type
branch_pat [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody (MatchDec [BranchType SOACS]
ret MatchSort
if_sort))
  (MapNesting Pat Type
pat StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) = do
    let ret' :: [ExtType]
ret' = forall a b. (a -> b) -> [a] -> [b]
map (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` forall a. a -> Ext a
Free SubExp
w) [BranchType SOACS]
ret
        pat' :: Pat Type
pat' = forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat

        ([Param Type]
params, [VName]
arrs) = forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
        lam_ret :: [Type]
lam_ret = forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
pat

        branch_pat' :: Pat Type
branch_pat' =
          forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w)) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
branch_pat

        mkBranch :: Body SOACS -> BuilderT SOACS (State VNameSource) (Body SOACS)
mkBranch Body SOACS
branch = (forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody =<<) forall a b. (a -> b) -> a -> b
$ do
          let lam :: Lambda SOACS
lam = forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
params Body SOACS
branch [Type]
lam_ret
              res :: Result
res = [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
branch_pat'
              map_stm :: Stm SOACS
map_stm = forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
branch_pat' StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
map_stm) Result
res

    [Case (Body SOACS)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body SOACS -> BuilderT SOACS (State VNameSource) (Body SOACS)
mkBranch) [Case (Body SOACS)]
cases
    Body SOACS
defbody' <- forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$ Body SOACS -> BuilderT SOACS (State VNameSource) (Body SOACS)
mkBranch Body SOACS
defbody
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int]
-> Pat Type
-> [SubExp]
-> [Case (Body SOACS)]
-> Body SOACS
-> MatchDec (BranchType SOACS)
-> Branch
Branch [Int
0 .. forall dec. Pat dec -> Int
patSize Pat Type
pat forall a. Num a => a -> a -> a
- Int
1] Pat Type
pat' [SubExp]
cond [Case (Body SOACS)]
cases' Body SOACS
defbody' forall a b. (a -> b) -> a -> b
$
      forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [ExtType]
ret' MatchSort
if_sort

-- | Given a (parallel) map nesting and an inner branch, move the maps
-- inside the branch.  The result is the resulting branch expression,
-- which will then contain statements with @map@ expressions.
interchangeBranch ::
  (MonadFreshNames m, HasScope SOACS m) =>
  KernelNest ->
  Branch ->
  m (Stm SOACS)
interchangeBranch :: forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> Branch -> m (Stm SOACS)
interchangeBranch KernelNest
nest Branch
loop =
  Branch -> Stm SOACS
branchStm forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
Branch -> LoopNesting -> m Branch
interchangeBranch1 Branch
loop (forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
nest)

-- | An encoding of a WithAcc with alongside its result pattern.
data WithAccStm
  = WithAccStm [Int] (Pat Type) [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] (Lambda SOACS)

withAccStm :: WithAccStm -> Stm SOACS
withAccStm :: WithAccStm -> Stm SOACS
withAccStm (WithAccStm [Int]
_ Pat Type
pat [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
lam) =
  forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
lam

interchangeWithAcc1 ::
  (MonadFreshNames m, LocalScope SOACS m) =>
  WithAccStm ->
  LoopNesting ->
  m WithAccStm
interchangeWithAcc1 :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
WithAccStm -> LoopNesting -> m WithAccStm
interchangeWithAcc1
  (WithAccStm [Int]
perm Pat Type
_withacc_pat [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
acc_lam)
  (MapNesting Pat Type
map_pat StmAux ()
map_aux SubExp
w [(Param Type, VName)]
params_and_arrs) = do
    [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))
-> m (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))
onInput [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs
    [Param Type]
lam_params' <- forall {m :: * -> *} {a}.
MonadFreshNames m =>
[Param a] -> m [Param a]
newAccLamParams forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
acc_lam
    Param Type
iota_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"iota_p" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
    Lambda SOACS
acc_lam' <- SubExp -> Lambda SOACS -> m (Lambda SOACS)
trLam (VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param Type
iota_p)) forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
[LParam rep] -> Builder rep Result -> m (Lambda rep)
runLambdaBuilder [Param Type]
lam_params' forall a b. (a -> b) -> a -> b
$ do
      let acc_params :: [Param Type]
acc_params = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs) [Param Type]
lam_params'
          orig_acc_params :: [Param Type]
orig_acc_params = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
acc_lam
      VName
iota_w <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"acc_inter_iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
      let ([Param Type]
params, [VName]
arrs) = forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
          maplam_ret :: [Type]
maplam_ret = forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
acc_lam
          maplam :: Lambda SOACS
maplam = forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda (Param Type
iota_p forall a. a -> [a] -> [a]
: [Param Type]
orig_acc_params forall a. [a] -> [a] -> [a]
++ [Param Type]
params) (forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
acc_lam) [Type]
maplam_ret
      forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
map_aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> Result
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"withacc_inter" forall a b. (a -> b) -> a -> b
$
        forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
          forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota_w forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
acc_params forall a. [a] -> [a] -> [a]
++ [VName]
arrs) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
maplam)
    let pat :: Pat Type
pat = forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
map_pat
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [Int]
-> Pat Type
-> [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> Lambda SOACS
-> WithAccStm
WithAccStm [Int]
perm Pat Type
pat [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs' Lambda SOACS
acc_lam'
    where
      newAccLamParams :: [Param a] -> m [Param a]
newAccLamParams [Param a]
ps = do
        let ([Param a]
cert_ps, [Param a]
acc_ps) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param a]
ps forall a. Integral a => a -> a -> a
`div` Int
2) [Param a]
ps
        -- Should not rename the certificates.
        [Param a]
acc_ps' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param a]
acc_ps forall a b. (a -> b) -> a -> b
$ \(Param Attrs
attrs VName
v a
t) ->
          forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
v) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
t
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [Param a]
cert_ps forall a. Semigroup a => a -> a -> a
<> [Param a]
acc_ps'

      num_accs :: Int
num_accs = forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs
      acc_certs :: [VName]
acc_certs = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
num_accs forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
acc_lam
      onArr :: VName -> m VName
onArr VName
v =
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. b -> (a -> b) -> Maybe a -> b
maybe VName
v forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
      onInput :: (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))
-> m (ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))
onInput (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
op) =
        (forall d. [d] -> ShapeBase d
Shape [SubExp
w] forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape,,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m VName
onArr [VName]
arrs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {rep} {shape} {u} {m :: * -> *} {b}.
(LParamInfo rep ~ TypeBase shape u, MonadFreshNames m) =>
(Lambda rep, b) -> m (Lambda rep, b)
onOp Maybe (Lambda SOACS, [SubExp])
op

      onOp :: (Lambda rep, b) -> m (Lambda rep, b)
onOp (Lambda rep
op_lam, b
nes) = do
        -- We need to add an additional index parameter because we are
        -- extending the index space of the accumulator.
        Param (TypeBase shape u)
idx_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"idx" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
op_lam {lambdaParams :: [LParam rep]
lambdaParams = Param (TypeBase shape u)
idx_p forall a. a -> [a] -> [a]
: forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
op_lam}, b
nes)

      trType :: TypeBase shape u -> TypeBase shape u
      trType :: forall shape u. TypeBase shape u -> TypeBase shape u
trType (Acc VName
acc ShapeBase SubExp
ispace [Type]
ts u
u)
        | VName
acc forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
acc_certs =
            forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc VName
acc (forall d. [d] -> ShapeBase d
Shape [SubExp
w] forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
ispace) [Type]
ts u
u
      trType TypeBase shape u
t = TypeBase shape u
t

      trParam :: Param (TypeBase shape u) -> Param (TypeBase shape u)
      trParam :: forall shape u.
Param (TypeBase shape u) -> Param (TypeBase shape u)
trParam = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall shape u. TypeBase shape u -> TypeBase shape u
trType

      trLam :: SubExp -> Lambda SOACS -> m (Lambda SOACS)
trLam SubExp
i (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ret) =
        forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) forall a b. (a -> b) -> a -> b
$
          forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda (forall a b. (a -> b) -> [a] -> [b]
map forall shape u.
Param (TypeBase shape u) -> Param (TypeBase shape u)
trParam [LParam SOACS]
params) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Body SOACS -> m (Body SOACS)
trBody SubExp
i Body SOACS
body forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. (a -> b) -> [a] -> [b]
map forall shape u. TypeBase shape u -> TypeBase shape u
trType [Type]
ret)

      trBody :: SubExp -> Body SOACS -> m (Body SOACS)
trBody SubExp
i (Body BodyDec SOACS
dec Stms SOACS
stms Result
res) =
        forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms SOACS
stms forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec SOACS
dec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SubExp -> Stm SOACS -> m (Stm SOACS)
trStm SubExp
i) Stms SOACS
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

      trStm :: SubExp -> Stm SOACS -> m (Stm SOACS)
trStm SubExp
i (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) =
        forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall shape u. TypeBase shape u -> TypeBase shape u
trType Pat (LetDec SOACS)
pat) StmAux (ExpDec SOACS)
aux forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Exp SOACS -> m (Exp SOACS)
trExp SubExp
i Exp SOACS
e

      trSOAC :: SubExp -> SOAC SOACS -> m (SOAC SOACS)
trSOAC SubExp
i = forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper SOACS SOACS m
mapper
        where
          mapper :: SOACMapper SOACS SOACS m
mapper =
            forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> m (Lambda SOACS)
mapOnSOACLambda = SubExp -> Lambda SOACS -> m (Lambda SOACS)
trLam SubExp
i}

      trExp :: SubExp -> Exp SOACS -> m (Exp SOACS)
trExp SubExp
i (WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
acc_inputs Lambda SOACS
lam) =
        forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda SOACS, [SubExp]))]
acc_inputs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Lambda SOACS -> m (Lambda SOACS)
trLam SubExp
i Lambda SOACS
lam
      trExp SubExp
i (BasicOp (UpdateAcc VName
acc [SubExp]
is [SubExp]
ses)) = do
        Type
acc_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
acc
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case Type
acc_t of
          Acc VName
cert ShapeBase SubExp
_ [Type]
_ NoUniqueness
_
            | VName
cert forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
acc_certs ->
                forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
acc (SubExp
i forall a. a -> [a] -> [a]
: [SubExp]
is) [SubExp]
ses
          Type
_ ->
            forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
acc [SubExp]
is [SubExp]
ses
      trExp SubExp
i Exp SOACS
e = forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper SOACS SOACS m
mapper Exp SOACS
e
        where
          mapper :: Mapper SOACS SOACS m
mapper =
            forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
              { mapOnBody :: Scope SOACS -> Body SOACS -> m (Body SOACS)
mapOnBody = \Scope SOACS
scope -> forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope SOACS
scope forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Body SOACS -> m (Body SOACS)
trBody SubExp
i,
                mapOnRetType :: RetType SOACS -> m (RetType SOACS)
mapOnRetType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. TypeBase shape u -> TypeBase shape u
trType,
                mapOnBranchType :: BranchType SOACS -> m (BranchType SOACS)
mapOnBranchType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. TypeBase shape u -> TypeBase shape u
trType,
                mapOnFParam :: FParam SOACS -> m (FParam SOACS)
mapOnFParam = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u.
Param (TypeBase shape u) -> Param (TypeBase shape u)
trParam,
                mapOnLParam :: LParam SOACS -> m (LParam SOACS)
mapOnLParam = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u.
Param (TypeBase shape u) -> Param (TypeBase shape u)
trParam,
                mapOnOp :: Op SOACS -> m (Op SOACS)
mapOnOp = SubExp -> SOAC SOACS -> m (SOAC SOACS)
trSOAC SubExp
i
              }

-- | Given a (parallel) map nesting and an inner withacc, move the
-- maps inside the branch.  The result is the resulting withacc
-- expression, which will then contain statements with @map@
-- expressions.
interchangeWithAcc ::
  (MonadFreshNames m, LocalScope SOACS m) =>
  KernelNest ->
  WithAccStm ->
  m (Stm SOACS)
interchangeWithAcc :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
KernelNest -> WithAccStm -> m (Stm SOACS)
interchangeWithAcc KernelNest
nest WithAccStm
withacc =
  WithAccStm -> Stm SOACS
withAccStm forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall (m :: * -> *).
(MonadFreshNames m, LocalScope SOACS m) =>
WithAccStm -> LoopNesting -> m WithAccStm
interchangeWithAcc1 WithAccStm
withacc (forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
nest)