-- | Facilities for inspecting the data dependencies of a program.
module Futhark.Analysis.DataDependencies
  ( Dependencies,
    dataDependencies,
    depsOf,
    depsOf',
    depsOfArrays,
    depsOfShape,
    lambdaDependencies,
    reductionDependencies,
    findNecessaryForReturned,
  )
where

import Data.List qualified as L
import Data.Map.Strict qualified as M
import Futhark.IR

-- | A mapping from a variable name @v@, to those variables on which
-- the value of @v@ is dependent.  The intuition is that we could
-- remove all other variables, and @v@ would still be computable.
-- This also includes names bound in loops or by lambdas.
type Dependencies = M.Map VName Names

-- | Compute the data dependencies for an entire body.
dataDependencies :: (ASTRep rep) => Body rep -> Dependencies
dataDependencies :: forall rep. ASTRep rep => Body rep -> Dependencies
dataDependencies = Dependencies -> Body rep -> Dependencies
forall rep. ASTRep rep => Dependencies -> Body rep -> Dependencies
dataDependencies' Dependencies
forall k a. Map k a
M.empty

dataDependencies' ::
  (ASTRep rep) =>
  Dependencies ->
  Body rep ->
  Dependencies
dataDependencies' :: forall rep. ASTRep rep => Dependencies -> Body rep -> Dependencies
dataDependencies' Dependencies
startdeps = (Dependencies -> Stm rep -> Dependencies)
-> Dependencies -> Seq (Stm rep) -> Dependencies
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Dependencies -> Stm rep -> Dependencies
forall {rep}. ASTRep rep => Dependencies -> Stm rep -> Dependencies
grow Dependencies
startdeps (Seq (Stm rep) -> Dependencies)
-> (Body rep -> Seq (Stm rep)) -> Body rep -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms
  where
    grow :: Dependencies -> Stm rep -> Dependencies
grow Dependencies
deps (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (WithAcc [WithAccInput rep]
inputs Lambda rep
lam)) =
      let input_deps :: [Names]
input_deps = (WithAccInput rep -> [Names]) -> [WithAccInput rep] -> [Names]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap WithAccInput rep -> [Names]
forall {rep}.
ASTRep rep =>
(Shape, [VName], Maybe (Lambda rep, [SubExp])) -> [Names]
depsOfWithAccInput [WithAccInput rep]
inputs
          -- Dependencies of each input reduction are concatenated.
          -- Input to lam is cert_1, ..., cert_n, acc_1, ..., acc_n.
          lam_deps :: [Names]
lam_deps = Dependencies -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
Dependencies -> Lambda rep -> [Names] -> [Names]
lambdaDependencies Dependencies
deps Lambda rep
lam ([Names]
input_deps [Names] -> [Names] -> [Names]
forall a. Semigroup a => a -> a -> a
<> [Names]
input_deps)
          transitive :: [Names]
transitive = (Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> Names -> Names
depsOfNames Dependencies
deps) [Names]
lam_deps
       in [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) [Names]
transitive) Dependencies -> Dependencies -> Dependencies
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Dependencies
deps
      where
        depsOfArrays' :: Shape -> [VName] -> [Names]
depsOfArrays' Shape
shape =
          (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (\VName
arr -> VName -> Names
oneName VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Shape -> Names
depsOfShape Shape
shape)
        depsOfWithAccInput :: (Shape, [VName], Maybe (Lambda rep, [SubExp])) -> [Names]
depsOfWithAccInput (Shape
shape, [VName]
arrs, Maybe (Lambda rep, [SubExp])
Nothing) =
          Shape -> [VName] -> [Names]
depsOfArrays' Shape
shape [VName]
arrs
        depsOfWithAccInput (Shape
shape, [VName]
arrs, Just (Lambda rep
lam', [SubExp]
nes)) =
          Dependencies -> Lambda rep -> [SubExp] -> [Names] -> [Names]
forall rep.
ASTRep rep =>
Dependencies -> Lambda rep -> [SubExp] -> [Names] -> [Names]
reductionDependencies Dependencies
deps Lambda rep
lam' [SubExp]
nes (Shape -> [VName] -> [Names]
depsOfArrays' Shape
shape [VName]
arrs)
    grow Dependencies
deps (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Op OpC rep rep
op)) =
      let op_deps :: [Names]
op_deps = (Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> Names -> Names
depsOfNames Dependencies
deps) (OpC rep rep -> [Names]
forall rep. ASTRep rep => OpC rep rep -> [Names]
forall (op :: * -> *) rep.
(IsOp op, ASTRep rep) =>
op rep -> [Names]
opDependencies OpC rep rep
op)
          pat_deps :: [Names]
pat_deps = (PatElem (LetDec rep) -> Names)
-> [PatElem (LetDec rep)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> Names -> Names
depsOfNames Dependencies
deps (Names -> Names)
-> (PatElem (LetDec rep) -> Names) -> PatElem (LetDec rep) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (LetDec rep) -> Names
forall a. FreeIn a => a -> Names
freeIn) (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
       in if [Names] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Names]
op_deps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Names] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Names]
pat_deps
            then
              String -> Dependencies
forall a. HasCallStack => String -> a
error (String -> Dependencies)
-> ([String] -> String) -> [String] -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unlines ([String] -> Dependencies) -> [String] -> Dependencies
forall a b. (a -> b) -> a -> b
$
                [ String
"dataDependencies':",
                  String
"Pattern size: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([Names] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Names]
pat_deps),
                  String
"Op deps size: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([Names] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Names]
op_deps),
                  String
"Expression:",
                  OpC rep rep -> String
forall a. Pretty a => a -> String
prettyString OpC rep rep
op
                ]
            else
              [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) ([Names] -> [(VName, Names)]) -> [Names] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$ (Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) [Names]
pat_deps [Names]
op_deps)
                Dependencies -> Dependencies -> Dependencies
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Dependencies
deps
    grow Dependencies
deps (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Match [SubExp]
c [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_)) =
      let cases_deps :: [Dependencies]
cases_deps = (Case (Body rep) -> Dependencies)
-> [Case (Body rep)] -> [Dependencies]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> Body rep -> Dependencies
forall rep. ASTRep rep => Dependencies -> Body rep -> Dependencies
dataDependencies' Dependencies
deps (Body rep -> Dependencies)
-> (Case (Body rep) -> Body rep) -> Case (Body rep) -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
          defbody_deps :: Dependencies
defbody_deps = Dependencies -> Body rep -> Dependencies
forall rep. ASTRep rep => Dependencies -> Body rep -> Dependencies
dataDependencies' Dependencies
deps Body rep
defbody
          cdeps :: Names
cdeps = (SubExp -> Names) -> [SubExp] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Dependencies -> SubExp -> Names
depsOf Dependencies
deps) [SubExp]
c
          comb :: (PatElem dec, [Names], Names) -> (VName, Names)
comb (PatElem dec
pe, [Names]
se_cases_deps, Names
se_defbody_deps) =
            ( PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe,
              [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$
                [Names]
se_cases_deps
                  [Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++ [PatElem dec -> Names
forall a. FreeIn a => a -> Names
freeIn PatElem dec
pe, Names
cdeps, Names
se_defbody_deps]
                  [Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ PatElem dec -> Names
forall a. FreeIn a => a -> Names
freeIn PatElem dec
pe)
            )
          branchdeps :: Dependencies
branchdeps =
            [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Names)] -> Dependencies)
-> [(VName, Names)] -> Dependencies
forall a b. (a -> b) -> a -> b
$
              ((PatElem (LetDec rep), [Names], Names) -> (VName, Names))
-> [(PatElem (LetDec rep), [Names], Names)] -> [(VName, Names)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElem (LetDec rep), [Names], Names) -> (VName, Names)
forall {dec}.
FreeIn dec =>
(PatElem dec, [Names], Names) -> (VName, Names)
comb ([(PatElem (LetDec rep), [Names], Names)] -> [(VName, Names)])
-> [(PatElem (LetDec rep), [Names], Names)] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$
                [PatElem (LetDec rep)]
-> [[Names]] -> [Names] -> [(PatElem (LetDec rep), [Names], Names)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
                  (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
                  ( [[Names]] -> [[Names]]
forall a. [[a]] -> [[a]]
L.transpose ([[Names]] -> [[Names]])
-> ([[SubExp]] -> [[Names]]) -> [[SubExp]] -> [[Names]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Dependencies -> [SubExp] -> [Names])
-> [Dependencies] -> [[SubExp]] -> [[Names]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> Names) -> [SubExp] -> [Names])
-> (Dependencies -> SubExp -> Names)
-> Dependencies
-> [SubExp]
-> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dependencies -> SubExp -> Names
depsOf) [Dependencies]
cases_deps ([[SubExp]] -> [[Names]]) -> [[SubExp]] -> [[Names]]
forall a b. (a -> b) -> a -> b
$
                      (Case (Body rep) -> [SubExp]) -> [Case (Body rep)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp])
-> (Case (Body rep) -> [SubExpRes]) -> Case (Body rep) -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body rep -> [SubExpRes])
-> (Case (Body rep) -> Body rep) -> Case (Body rep) -> [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
                  )
                  ((SubExpRes -> Names) -> [SubExpRes] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> SubExp -> Names
depsOf Dependencies
defbody_deps (SubExp -> Names) -> (SubExpRes -> SubExp) -> SubExpRes -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
defbody))
       in [Dependencies] -> Dependencies
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
f (Map k a) -> Map k a
M.unions ([Dependencies] -> Dependencies) -> [Dependencies] -> Dependencies
forall a b. (a -> b) -> a -> b
$ [Dependencies
branchdeps, Dependencies
deps, Dependencies
defbody_deps] [Dependencies] -> [Dependencies] -> [Dependencies]
forall a. [a] -> [a] -> [a]
++ [Dependencies]
cases_deps
    grow Dependencies
deps (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e) =
      let free :: Names
free = Pat (LetDec rep) -> Names
forall a. FreeIn a => a -> Names
freeIn Pat (LetDec rep)
pat Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Exp rep -> Names
forall a. FreeIn a => a -> Names
freeIn Exp rep
e
          free_deps :: Names
free_deps = Dependencies -> Names -> Names
depsOfNames Dependencies
deps Names
free
       in [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Names
free_deps) | VName
name <- Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat] Dependencies -> Dependencies -> Dependencies
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Dependencies
deps

depsOf :: Dependencies -> SubExp -> Names
depsOf :: Dependencies -> SubExp -> Names
depsOf Dependencies
_ (Constant PrimValue
_) = Names
forall a. Monoid a => a
mempty
depsOf Dependencies
deps (Var VName
v) = Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
v

depsOf' :: SubExp -> Names
depsOf' :: SubExp -> Names
depsOf' (Constant PrimValue
_) = Names
forall a. Monoid a => a
mempty
depsOf' (Var VName
v) = Dependencies -> VName -> Names
depsOfVar Dependencies
forall a. Monoid a => a
mempty VName
v

depsOfVar :: Dependencies -> VName -> Names
depsOfVar :: Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
name = VName -> Names
oneName VName
name Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> Dependencies -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
name Dependencies
deps

depsOfRes :: Dependencies -> SubExpRes -> Names
depsOfRes :: Dependencies -> SubExpRes -> Names
depsOfRes Dependencies
deps (SubExpRes Certs
_ SubExp
se) = Dependencies -> SubExp -> Names
depsOf Dependencies
deps SubExp
se

-- | Extend @names@ with direct dependencies in @deps@.
depsOfNames :: Dependencies -> Names -> Names
depsOfNames :: Dependencies -> Names -> Names
depsOfNames Dependencies
deps Names
names = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
names

depsOfArrays :: SubExp -> [VName] -> [Names]
depsOfArrays :: SubExp -> [VName] -> [Names]
depsOfArrays SubExp
size = (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (\VName
arr -> VName -> Names
oneName VName
arr Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Dependencies -> SubExp -> Names
depsOf Dependencies
forall a. Monoid a => a
mempty SubExp
size)

depsOfShape :: Shape -> Names
depsOfShape :: Shape -> Names
depsOfShape Shape
shape = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> SubExp -> Names
depsOf Dependencies
forall a. Monoid a => a
mempty) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape)

-- | Determine the variables on which the results of applying
-- anonymous function @lam@ to @inputs@ depend.
lambdaDependencies ::
  (ASTRep rep) =>
  Dependencies ->
  Lambda rep ->
  [Names] ->
  [Names]
lambdaDependencies :: forall rep.
ASTRep rep =>
Dependencies -> Lambda rep -> [Names] -> [Names]
lambdaDependencies Dependencies
deps Lambda rep
lam [Names]
inputs =
  let names_in_scope :: Names
names_in_scope = Lambda rep -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda rep
lam Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat [Names]
inputs
      deps_in :: Dependencies
deps_in = [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Names)] -> Dependencies)
-> [(VName, Names)] -> Dependencies
forall a b. (a -> b) -> a -> b
$ [VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda rep -> [VName]
forall rep. Lambda rep -> [VName]
boundByLambda Lambda rep
lam) [Names]
inputs
      deps' :: Dependencies
deps' = Dependencies -> Body rep -> Dependencies
forall rep. ASTRep rep => Dependencies -> Body rep -> Dependencies
dataDependencies' (Dependencies
deps_in Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Dependencies
deps) (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
   in (SubExpRes -> Names) -> [SubExpRes] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map
        (Names -> Names -> Names
namesIntersection Names
names_in_scope (Names -> Names) -> (SubExpRes -> Names) -> SubExpRes -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dependencies -> SubExpRes -> Names
depsOfRes Dependencies
deps')
        (Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body rep -> [SubExpRes]) -> Body rep -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)

-- | Like 'lambdaDependencies', but @lam@ is a binary operation
-- with a neutral element.
reductionDependencies ::
  (ASTRep rep) =>
  Dependencies ->
  Lambda rep ->
  [SubExp] ->
  [Names] ->
  [Names]
reductionDependencies :: forall rep.
ASTRep rep =>
Dependencies -> Lambda rep -> [SubExp] -> [Names] -> [Names]
reductionDependencies Dependencies
deps Lambda rep
lam [SubExp]
nes [Names]
inputs =
  let nes' :: [Names]
nes' = (SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> SubExp -> Names
depsOf Dependencies
deps) [SubExp]
nes
   in Dependencies -> Lambda rep -> [Names] -> [Names]
forall rep.
ASTRep rep =>
Dependencies -> Lambda rep -> [Names] -> [Names]
lambdaDependencies Dependencies
deps Lambda rep
lam ((Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) [Names]
nes' [Names]
inputs)

-- | @findNecessaryForReturned p merge deps@ computes which of the
-- loop parameters (@merge@) are necessary for the result of the loop,
-- where @p@ given a loop parameter indicates whether the final value
-- of that parameter is live after the loop.  @deps@ is the data
-- dependencies of the loop body.  This is computed by straightforward
-- fixpoint iteration.
findNecessaryForReturned ::
  (Param dec -> Bool) ->
  [(Param dec, SubExp)] ->
  M.Map VName Names ->
  Names
findNecessaryForReturned :: forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Dependencies -> Names
findNecessaryForReturned Param dec -> Bool
usedAfterLoop [(Param dec, SubExp)]
merge_and_res Dependencies
allDependencies =
  Names -> Names
iterateNecessary Names
forall a. Monoid a => a
mempty
    Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName ([Param dec] -> [VName]) -> [Param dec] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param dec -> Bool) -> [Param dec] -> [Param dec]
forall a. (a -> Bool) -> [a] -> [a]
filter Param dec -> Bool
usedAfterLoop ([Param dec] -> [Param dec]) -> [Param dec] -> [Param dec]
forall a b. (a -> b) -> a -> b
$ ((Param dec, SubExp) -> Param dec)
-> [(Param dec, SubExp)] -> [Param dec]
forall a b. (a -> b) -> [a] -> [b]
map (Param dec, SubExp) -> Param dec
forall a b. (a, b) -> a
fst [(Param dec, SubExp)]
merge_and_res)
  where
    iterateNecessary :: Names -> Names
iterateNecessary Names
prev_necessary
      | Names
necessary Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
prev_necessary = Names
necessary
      | Bool
otherwise = Names -> Names
iterateNecessary Names
necessary
      where
        necessary :: Names
necessary = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
dependencies [SubExp]
returnedResultSubExps
        usedAfterLoopOrNecessary :: Param dec -> Bool
usedAfterLoopOrNecessary Param dec
param =
          Param dec -> Bool
usedAfterLoop Param dec
param Bool -> Bool -> Bool
|| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param VName -> Names -> Bool
`nameIn` Names
prev_necessary
        returnedResultSubExps :: [SubExp]
returnedResultSubExps =
          ((Param dec, SubExp) -> SubExp)
-> [(Param dec, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param dec, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(Param dec, SubExp)] -> [SubExp])
-> [(Param dec, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ((Param dec, SubExp) -> Bool)
-> [(Param dec, SubExp)] -> [(Param dec, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Param dec -> Bool
usedAfterLoopOrNecessary (Param dec -> Bool)
-> ((Param dec, SubExp) -> Param dec)
-> (Param dec, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param dec, SubExp) -> Param dec
forall a b. (a, b) -> a
fst) [(Param dec, SubExp)]
merge_and_res
        dependencies :: SubExp -> Names
dependencies (Constant PrimValue
_) =
          Names
forall a. Monoid a => a
mempty
        dependencies (Var VName
v) =
          Names -> VName -> Dependencies -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v Dependencies
allDependencies