{-# LANGUAGE FlexibleContexts #-}
-- | Facilities for inspecting the data dependencies of a program.
module Futhark.Analysis.DataDependencies
  ( Dependencies
  , dataDependencies
  , findNecessaryForReturned
  )
  where

import qualified Data.Map.Strict as M

import Futhark.IR

-- | A mapping from a variable name @v@, to those variables on which
-- the value of @v@ is dependent.  The intuition is that we could
-- remove all other variables, and @v@ would still be computable.
-- This also includes names bound in loops or by lambdas.
type Dependencies = M.Map VName Names

-- | Compute the data dependencies for an entire body.
dataDependencies :: ASTLore lore => Body lore -> Dependencies
dataDependencies :: Body lore -> Dependencies
dataDependencies = Dependencies -> Body lore -> Dependencies
forall lore.
ASTLore lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
forall k a. Map k a
M.empty

dataDependencies' :: ASTLore lore =>
                     Dependencies -> Body lore -> Dependencies
dataDependencies' :: Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
startdeps = (Dependencies -> Stm lore -> Dependencies)
-> Dependencies -> Seq (Stm lore) -> Dependencies
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Dependencies -> Stm lore -> Dependencies
forall lore.
ASTLore lore =>
Dependencies -> Stm lore -> Dependencies
grow Dependencies
startdeps (Seq (Stm lore) -> Dependencies)
-> (Body lore -> Seq (Stm lore)) -> Body lore -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms
  where grow :: Dependencies -> Stm lore -> Dependencies
grow Dependencies
deps (Let Pattern lore
pat StmAux (ExpDec lore)
_ (If SubExp
c BodyT lore
tb BodyT lore
fb IfDec (BranchType lore)
_)) =
          let tdeps :: Dependencies
tdeps = Dependencies -> BodyT lore -> Dependencies
forall lore.
ASTLore lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
deps BodyT lore
tb
              fdeps :: Dependencies
fdeps = Dependencies -> BodyT lore -> Dependencies
forall lore.
ASTLore lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
deps BodyT lore
fb
              cdeps :: Names
cdeps = Dependencies -> SubExp -> Names
depsOf Dependencies
deps SubExp
c
              comb :: (PatElemT dec, SubExp, SubExp) -> (VName, Names)
comb (PatElemT dec
pe, SubExp
tres, SubExp
fres) =
                (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe,
                 [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ [PatElemT dec -> Names
forall a. FreeIn a => a -> Names
freeIn PatElemT dec
pe, Names
cdeps, Dependencies -> SubExp -> Names
depsOf Dependencies
tdeps SubExp
tres, Dependencies -> SubExp -> Names
depsOf Dependencies
fdeps SubExp
fres] [Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++
                 (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> Names
forall a. FreeIn a => a -> Names
freeIn PatElemT dec
pe))
              branchdeps :: Dependencies
branchdeps =
                [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Names)] -> Dependencies)
-> [(VName, Names)] -> Dependencies
forall a b. (a -> b) -> a -> b
$ ((PatElemT (LetDec lore), SubExp, SubExp) -> (VName, Names))
-> [(PatElemT (LetDec lore), SubExp, SubExp)] -> [(VName, Names)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), SubExp, SubExp) -> (VName, Names)
forall dec.
FreeIn dec =>
(PatElemT dec, SubExp, SubExp) -> (VName, Names)
comb ([(PatElemT (LetDec lore), SubExp, SubExp)] -> [(VName, Names)])
-> [(PatElemT (LetDec lore), SubExp, SubExp)] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec lore)]
-> [SubExp]
-> [SubExp]
-> [(PatElemT (LetDec lore), SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat)
                (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb)
                (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb)
          in [Dependencies] -> Dependencies
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
f (Map k a) -> Map k a
M.unions [Dependencies
branchdeps, Dependencies
deps, Dependencies
tdeps, Dependencies
fdeps]

        grow Dependencies
deps (Let Pattern lore
pat StmAux (ExpDec lore)
_ ExpT lore
e) =
          let free :: Names
free = Pattern lore -> Names
forall a. FreeIn a => a -> Names
freeIn Pattern lore
pat Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> ExpT lore -> Names
forall a. FreeIn a => a -> Names
freeIn ExpT lore
e
              freeDeps :: Names
freeDeps = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
free
          in [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (VName
name, Names
freeDeps) | VName
name <- Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat ] Dependencies -> Dependencies -> Dependencies
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Dependencies
deps

depsOf :: Dependencies -> SubExp -> Names
depsOf :: Dependencies -> SubExp -> Names
depsOf Dependencies
_ (Constant PrimValue
_) = Names
forall a. Monoid a => a
mempty
depsOf Dependencies
deps (Var VName
v)   = Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
v

depsOfVar :: Dependencies -> VName -> Names
depsOfVar :: Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
name = VName -> Names
oneName VName
name Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> Dependencies -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
name Dependencies
deps

-- | @findNecessaryForReturned p merge deps@ computes which of the
-- loop parameters (@merge@) are necessary for the result of the loop,
-- where @p@ given a loop parameter indicates whether the final value
-- of that parameter is live after the loop.  @deps@ is the data
-- dependencies of the loop body.  This is computed by straightforward
-- fixpoint iteration.
findNecessaryForReturned :: (Param dec -> Bool) -> [(Param dec, SubExp)]
                         -> M.Map VName Names
                         -> Names
findNecessaryForReturned :: (Param dec -> Bool)
-> [(Param dec, SubExp)] -> Dependencies -> Names
findNecessaryForReturned Param dec -> Bool
usedAfterLoop [(Param dec, SubExp)]
merge_and_res Dependencies
allDependencies =
  Names -> Names
iterateNecessary Names
forall a. Monoid a => a
mempty Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>
  [VName] -> Names
namesFromList ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName ([Param dec] -> [VName]) -> [Param dec] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param dec -> Bool) -> [Param dec] -> [Param dec]
forall a. (a -> Bool) -> [a] -> [a]
filter Param dec -> Bool
usedAfterLoop ([Param dec] -> [Param dec]) -> [Param dec] -> [Param dec]
forall a b. (a -> b) -> a -> b
$ ((Param dec, SubExp) -> Param dec)
-> [(Param dec, SubExp)] -> [Param dec]
forall a b. (a -> b) -> [a] -> [b]
map (Param dec, SubExp) -> Param dec
forall a b. (a, b) -> a
fst [(Param dec, SubExp)]
merge_and_res)
  where iterateNecessary :: Names -> Names
iterateNecessary Names
prev_necessary
          | Names
necessary Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
prev_necessary = Names
necessary
          | Bool
otherwise                   = Names -> Names
iterateNecessary Names
necessary
          where necessary :: Names
necessary = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
dependencies [SubExp]
returnedResultSubExps
                usedAfterLoopOrNecessary :: Param dec -> Bool
usedAfterLoopOrNecessary Param dec
param =
                  Param dec -> Bool
usedAfterLoop Param dec
param Bool -> Bool -> Bool
|| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param VName -> Names -> Bool
`nameIn` Names
prev_necessary
                returnedResultSubExps :: [SubExp]
returnedResultSubExps =
                  ((Param dec, SubExp) -> SubExp)
-> [(Param dec, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param dec, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(Param dec, SubExp)] -> [SubExp])
-> [(Param dec, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ((Param dec, SubExp) -> Bool)
-> [(Param dec, SubExp)] -> [(Param dec, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Param dec -> Bool
usedAfterLoopOrNecessary (Param dec -> Bool)
-> ((Param dec, SubExp) -> Param dec)
-> (Param dec, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param dec, SubExp) -> Param dec
forall a b. (a, b) -> a
fst) [(Param dec, SubExp)]
merge_and_res
                dependencies :: SubExp -> Names
dependencies (Constant PrimValue
_) =
                  Names
forall a. Monoid a => a
mempty
                dependencies (Var VName
v)      =
                  Names -> VName -> Dependencies -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v Dependencies
allDependencies