{-# LANGUAGE FlexibleContexts #-}
-- | Facilities for inspecting the data dependencies of a program.
module Futhark.Analysis.DataDependencies
  ( Dependencies
  , dataDependencies
  , findNecessaryForReturned
  )
  where

import qualified Data.Map.Strict as M

import Futhark.Representation.AST

-- | A mapping from a variable name @v@, to those variables on which
-- the value of @v@ is dependent.  The intuition is that we could
-- remove all other variables, and @v@ would still be computable.
-- This also includes names bound in loops or by lambdas.
type Dependencies = M.Map VName Names

-- | Compute the data dependencies for an entire body.
dataDependencies :: Attributes lore => Body lore -> Dependencies
dataDependencies :: Body lore -> Dependencies
dataDependencies = Dependencies -> Body lore -> Dependencies
forall lore.
Attributes lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
forall k a. Map k a
M.empty

dataDependencies' :: Attributes lore =>
                     Dependencies -> Body lore -> Dependencies
dataDependencies' :: Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
startdeps = (Dependencies -> Stm lore -> Dependencies)
-> Dependencies -> Seq (Stm lore) -> Dependencies
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Dependencies -> Stm lore -> Dependencies
forall lore.
Attributes lore =>
Dependencies -> Stm lore -> Dependencies
grow Dependencies
startdeps (Seq (Stm lore) -> Dependencies)
-> (Body lore -> Seq (Stm lore)) -> Body lore -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms
  where grow :: Dependencies -> Stm lore -> Dependencies
grow Dependencies
deps (Let Pattern lore
pat StmAux (ExpAttr lore)
_ (If SubExp
c BodyT lore
tb BodyT lore
fb IfAttr (BranchType lore)
_)) =
          let tdeps :: Dependencies
tdeps = Dependencies -> BodyT lore -> Dependencies
forall lore.
Attributes lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
deps BodyT lore
tb
              fdeps :: Dependencies
fdeps = Dependencies -> BodyT lore -> Dependencies
forall lore.
Attributes lore =>
Dependencies -> Body lore -> Dependencies
dataDependencies' Dependencies
deps BodyT lore
fb
              cdeps :: Names
cdeps = Dependencies -> SubExp -> Names
depsOf Dependencies
deps SubExp
c
              comb :: (PatElemT attr, SubExp, SubExp) -> (VName, Names)
comb (PatElemT attr
pe, SubExp
tres, SubExp
fres) =
                (PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
pe,
                 [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ [PatElemT attr -> Names
forall a. FreeIn a => a -> Names
freeIn PatElemT attr
pe, Names
cdeps, Dependencies -> SubExp -> Names
depsOf Dependencies
tdeps SubExp
tres, Dependencies -> SubExp -> Names
depsOf Dependencies
fdeps SubExp
fres] [Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++
                 (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ PatElemT attr -> Names
forall a. FreeIn a => a -> Names
freeIn PatElemT attr
pe))
              branchdeps :: Dependencies
branchdeps =
                [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Names)] -> Dependencies)
-> [(VName, Names)] -> Dependencies
forall a b. (a -> b) -> a -> b
$ ((PatElemT (LetAttr lore), SubExp, SubExp) -> (VName, Names))
-> [(PatElemT (LetAttr lore), SubExp, SubExp)] -> [(VName, Names)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetAttr lore), SubExp, SubExp) -> (VName, Names)
forall attr.
FreeIn attr =>
(PatElemT attr, SubExp, SubExp) -> (VName, Names)
comb ([(PatElemT (LetAttr lore), SubExp, SubExp)] -> [(VName, Names)])
-> [(PatElemT (LetAttr lore), SubExp, SubExp)] -> [(VName, Names)]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetAttr lore)]
-> [SubExp]
-> [SubExp]
-> [(PatElemT (LetAttr lore), SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pattern lore -> [PatElemT (LetAttr lore)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern lore
pat)
                (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb)
                (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb)
          in [Dependencies] -> Dependencies
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
f (Map k a) -> Map k a
M.unions [Dependencies
branchdeps, Dependencies
deps, Dependencies
tdeps, Dependencies
fdeps]

        grow Dependencies
deps (Let Pattern lore
pat StmAux (ExpAttr lore)
_ ExpT lore
e) =
          let free :: Names
free = Pattern lore -> Names
forall a. FreeIn a => a -> Names
freeIn Pattern lore
pat Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> ExpT lore -> Names
forall a. FreeIn a => a -> Names
freeIn ExpT lore
e
              freeDeps :: Names
freeDeps = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Dependencies -> VName -> Names
depsOfVar Dependencies
deps) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
free
          in [(VName, Names)] -> Dependencies
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (VName
name, Names
freeDeps) | VName
name <- Pattern lore -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern lore
pat ] Dependencies -> Dependencies -> Dependencies
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Dependencies
deps

depsOf :: Dependencies -> SubExp -> Names
depsOf :: Dependencies -> SubExp -> Names
depsOf Dependencies
_ (Constant PrimValue
_) = Names
forall a. Monoid a => a
mempty
depsOf Dependencies
deps (Var VName
v)   = Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
v

depsOfVar :: Dependencies -> VName -> Names
depsOfVar :: Dependencies -> VName -> Names
depsOfVar Dependencies
deps VName
name = VName -> Names
oneName VName
name Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> Dependencies -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
name Dependencies
deps

findNecessaryForReturned :: (Param attr -> Bool) -> [(Param attr, SubExp)]
                         -> M.Map VName Names
                         -> Names
findNecessaryForReturned :: (Param attr -> Bool)
-> [(Param attr, SubExp)] -> Dependencies -> Names
findNecessaryForReturned Param attr -> Bool
usedAfterLoop [(Param attr, SubExp)]
merge_and_res Dependencies
allDependencies =
  Names -> Names
iterateNecessary Names
forall a. Monoid a => a
mempty Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>
  [VName] -> Names
namesFromList ((Param attr -> VName) -> [Param attr] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param attr -> VName
forall attr. Param attr -> VName
paramName ([Param attr] -> [VName]) -> [Param attr] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param attr -> Bool) -> [Param attr] -> [Param attr]
forall a. (a -> Bool) -> [a] -> [a]
filter Param attr -> Bool
usedAfterLoop ([Param attr] -> [Param attr]) -> [Param attr] -> [Param attr]
forall a b. (a -> b) -> a -> b
$ ((Param attr, SubExp) -> Param attr)
-> [(Param attr, SubExp)] -> [Param attr]
forall a b. (a -> b) -> [a] -> [b]
map (Param attr, SubExp) -> Param attr
forall a b. (a, b) -> a
fst [(Param attr, SubExp)]
merge_and_res)
  where iterateNecessary :: Names -> Names
iterateNecessary Names
prev_necessary
          | Names
necessary Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
prev_necessary = Names
necessary
          | Bool
otherwise                   = Names -> Names
iterateNecessary Names
necessary
          where necessary :: Names
necessary = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (SubExp -> Names) -> [SubExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Names
dependencies [SubExp]
returnedResultSubExps
                usedAfterLoopOrNecessary :: Param attr -> Bool
usedAfterLoopOrNecessary Param attr
param =
                  Param attr -> Bool
usedAfterLoop Param attr
param Bool -> Bool -> Bool
|| Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
param VName -> Names -> Bool
`nameIn` Names
prev_necessary
                returnedResultSubExps :: [SubExp]
returnedResultSubExps =
                  ((Param attr, SubExp) -> SubExp)
-> [(Param attr, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param attr, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(Param attr, SubExp)] -> [SubExp])
-> [(Param attr, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ((Param attr, SubExp) -> Bool)
-> [(Param attr, SubExp)] -> [(Param attr, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Param attr -> Bool
usedAfterLoopOrNecessary (Param attr -> Bool)
-> ((Param attr, SubExp) -> Param attr)
-> (Param attr, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param attr, SubExp) -> Param attr
forall a b. (a, b) -> a
fst) [(Param attr, SubExp)]
merge_and_res
                dependencies :: SubExp -> Names
dependencies (Constant PrimValue
_) =
                  Names
forall a. Monoid a => a
mempty
                dependencies (Var VName
v)      =
                  Names -> VName -> Dependencies -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v Dependencies
allDependencies