{-# LANGUAGE TypeFamilies #-}

-- | Alias analysis of a full Futhark program.  Takes as input a
-- program with an arbitrary rep and produces one with aliases.  This
-- module does not implement the aliasing logic itself, and derives
-- its information from definitions in
-- "Futhark.IR.Prop.Aliases" and
-- "Futhark.IR.Aliases".  The alias information computed
-- here will include transitive aliases (note that this is not what
-- the building blocks do).
module Futhark.Analysis.Alias
  ( aliasAnalysis,
    AliasableRep,

    -- * Ad-hoc utilities
    analyseFun,
    analyseStms,
    analyseStm,
    analyseExp,
    analyseBody,
    analyseLambda,
  )
where

import Data.List (foldl')
import Data.Map qualified as M
import Futhark.IR.Aliases

-- | Perform alias analysis on a Futhark program.
aliasAnalysis ::
  (AliasableRep rep) =>
  Prog rep ->
  Prog (Aliases rep)
aliasAnalysis :: forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
aliasAnalysis Prog rep
prog =
  Prog rep
prog
    { progConsts = fst (analyseStms mempty (progConsts prog)),
      progFuns = map analyseFun (progFuns prog)
    }

-- | Perform alias analysis on function.
analyseFun ::
  (AliasableRep rep) =>
  FunDef rep ->
  FunDef (Aliases rep)
analyseFun :: forall rep. AliasableRep rep => FunDef rep -> FunDef (Aliases rep)
analyseFun (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [(RetType rep, RetAls)]
restype [FParam rep]
params Body rep
body) =
  Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType (Aliases rep), RetAls)]
-> [FParam (Aliases rep)]
-> Body (Aliases rep)
-> FunDef (Aliases rep)
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType rep, RetAls)]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [(RetType rep, RetAls)]
[(RetType (Aliases rep), RetAls)]
restype [FParam rep]
[FParam (Aliases rep)]
params Body (Aliases rep)
body'
  where
    body' :: Body (Aliases rep)
body' = AliasTable -> Body rep -> Body (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody AliasTable
forall a. Monoid a => a
mempty Body rep
body

-- | Perform alias analysis on Body.
analyseBody ::
  (AliasableRep rep) =>
  AliasTable ->
  Body rep ->
  Body (Aliases rep)
analyseBody :: forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody AliasTable
atable (Body BodyDec rep
rep Stms rep
stms Result
result) =
  let (Stms (Aliases rep)
stms', AliasesAndConsumed
_atable') = AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
analyseStms AliasTable
atable Stms rep
stms
   in BodyDec rep -> Stms (Aliases rep) -> Result -> Body (Aliases rep)
forall rep.
(ASTRep rep, AliasedOp (OpC rep),
 ASTConstraints (OpC rep (Aliases rep))) =>
BodyDec rep -> Stms (Aliases rep) -> Result -> Body (Aliases rep)
mkAliasedBody BodyDec rep
rep Stms (Aliases rep)
stms' Result
result

-- | Perform alias analysis on statements.
analyseStms ::
  (AliasableRep rep) =>
  AliasTable ->
  Stms rep ->
  (Stms (Aliases rep), AliasesAndConsumed)
analyseStms :: forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
analyseStms AliasTable
orig_aliases =
  (Seq (Stm (Aliases rep)), AliasesAndConsumed)
-> (Seq (Stm (Aliases rep)), AliasesAndConsumed)
forall {t :: * -> *} {rep} {a}.
Foldable t =>
(t (Stm rep), (a, Names)) -> (t (Stm rep), (a, Names))
withoutBound ((Seq (Stm (Aliases rep)), AliasesAndConsumed)
 -> (Seq (Stm (Aliases rep)), AliasesAndConsumed))
-> (Stms rep -> (Seq (Stm (Aliases rep)), AliasesAndConsumed))
-> Stms rep
-> (Seq (Stm (Aliases rep)), AliasesAndConsumed)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Seq (Stm (Aliases rep)), AliasesAndConsumed)
 -> Stm rep -> (Seq (Stm (Aliases rep)), AliasesAndConsumed))
-> (Seq (Stm (Aliases rep)), AliasesAndConsumed)
-> [Stm rep]
-> (Seq (Stm (Aliases rep)), AliasesAndConsumed)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Seq (Stm (Aliases rep)), AliasesAndConsumed)
-> Stm rep -> (Seq (Stm (Aliases rep)), AliasesAndConsumed)
forall {rep}.
(ASTRep rep, CanBeAliased (OpC rep), AliasedOp (OpC rep),
 Ord (OpC rep (Aliases rep)), Show (OpC rep (Aliases rep)),
 Rename (OpC rep (Aliases rep)), Substitute (OpC rep (Aliases rep)),
 FreeIn (OpC rep (Aliases rep)), Pretty (OpC rep (Aliases rep))) =>
(Stms (Aliases rep), AliasesAndConsumed)
-> Stm rep -> (Stms (Aliases rep), AliasesAndConsumed)
f (Seq (Stm (Aliases rep))
forall a. Monoid a => a
mempty, (AliasTable
orig_aliases, Names
forall a. Monoid a => a
mempty)) ([Stm rep] -> (Seq (Stm (Aliases rep)), AliasesAndConsumed))
-> (Stms rep -> [Stm rep])
-> Stms rep
-> (Seq (Stm (Aliases rep)), AliasesAndConsumed)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList
  where
    withoutBound :: (t (Stm rep), (a, Names)) -> (t (Stm rep), (a, Names))
withoutBound (t (Stm rep)
stms, (a
aliases, Names
consumed)) =
      let bound :: Names
bound = (Stm rep -> Names) -> t (Stm rep) -> Names
forall m a. Monoid m => (a -> m) -> t a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ([VName] -> Names
namesFromList ([VName] -> Names) -> (Stm rep -> [VName]) -> Stm rep -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName])
-> (Stm rep -> Pat (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) t (Stm rep)
stms
          consumed' :: Names
consumed' = Names
consumed Names -> Names -> Names
`namesSubtract` Names
bound
       in (t (Stm rep)
stms, (a
aliases, Names
consumed'))

    f :: (Stms (Aliases rep), AliasesAndConsumed)
-> Stm rep -> (Stms (Aliases rep), AliasesAndConsumed)
f (Stms (Aliases rep)
stms, AliasesAndConsumed
aliases) Stm rep
stm =
      let stm' :: Stm (Aliases rep)
stm' = AliasTable -> Stm rep -> Stm (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Stm rep -> Stm (Aliases rep)
analyseStm (AliasesAndConsumed -> AliasTable
forall a b. (a, b) -> a
fst AliasesAndConsumed
aliases) Stm rep
stm
          atable' :: AliasesAndConsumed
atable' = AliasesAndConsumed -> Stm (Aliases rep) -> AliasesAndConsumed
forall rep.
Aliased rep =>
AliasesAndConsumed -> Stm rep -> AliasesAndConsumed
trackAliases AliasesAndConsumed
aliases Stm (Aliases rep)
stm'
       in (Stms (Aliases rep)
stms Stms (Aliases rep) -> Stms (Aliases rep) -> Stms (Aliases rep)
forall a. Semigroup a => a -> a -> a
<> Stm (Aliases rep) -> Stms (Aliases rep)
forall rep. Stm rep -> Stms rep
oneStm Stm (Aliases rep)
stm', AliasesAndConsumed
atable')

-- | Perform alias analysis on statement.
analyseStm ::
  (AliasableRep rep) =>
  AliasTable ->
  Stm rep ->
  Stm (Aliases rep)
analyseStm :: forall rep.
AliasableRep rep =>
AliasTable -> Stm rep -> Stm (Aliases rep)
analyseStm AliasTable
aliases (Let Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
dec) Exp rep
e) =
  let e' :: Exp (Aliases rep)
e' = AliasTable -> Exp rep -> Exp (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Exp rep -> Exp (Aliases rep)
analyseExp AliasTable
aliases Exp rep
e
      pat' :: Pat (VarAliases, LetDec rep)
pat' = Pat (LetDec rep)
-> Exp (Aliases rep) -> Pat (VarAliases, LetDec rep)
forall rep dec.
(Aliased rep, Typed dec) =>
Pat dec -> Exp rep -> Pat (VarAliases, dec)
mkAliasedPat Pat (LetDec rep)
pat Exp (Aliases rep)
e'
      rep' :: (VarAliases, ExpDec rep)
rep' = (Names -> VarAliases
AliasDec (Names -> VarAliases) -> Names -> VarAliases
forall a b. (a -> b) -> a -> b
$ Exp (Aliases rep) -> Names
forall rep. Aliased rep => Exp rep -> Names
consumedInExp Exp (Aliases rep)
e', ExpDec rep
dec)
   in Pat (LetDec (Aliases rep))
-> StmAux (ExpDec (Aliases rep))
-> Exp (Aliases rep)
-> Stm (Aliases rep)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (VarAliases, LetDec rep)
Pat (LetDec (Aliases rep))
pat' (Certs
-> Attrs
-> (VarAliases, ExpDec rep)
-> StmAux (VarAliases, ExpDec rep)
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs (VarAliases, ExpDec rep)
rep') Exp (Aliases rep)
e'

-- | Perform alias analysis on expression.
analyseExp ::
  (AliasableRep rep) =>
  AliasTable ->
  Exp rep ->
  Exp (Aliases rep)
-- Would be better to put this in a BranchType annotation, but that
-- requires a lot of other work.
analyseExp :: forall rep.
AliasableRep rep =>
AliasTable -> Exp rep -> Exp (Aliases rep)
analyseExp AliasTable
aliases (Match [SubExp]
cond [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
matchdec) =
  let cases' :: [Case (Body (Aliases rep))]
cases' = (Case (Body rep) -> Case (Body (Aliases rep)))
-> [Case (Body rep)] -> [Case (Body (Aliases rep))]
forall a b. (a -> b) -> [a] -> [b]
map ((Body rep -> Body (Aliases rep))
-> Case (Body rep) -> Case (Body (Aliases rep))
forall a b. (a -> b) -> Case a -> Case b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Body rep -> Body (Aliases rep))
 -> Case (Body rep) -> Case (Body (Aliases rep)))
-> (Body rep -> Body (Aliases rep))
-> Case (Body rep)
-> Case (Body (Aliases rep))
forall a b. (a -> b) -> a -> b
$ AliasTable -> Body rep -> Body (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody AliasTable
aliases) [Case (Body rep)]
cases
      defbody' :: Body (Aliases rep)
defbody' = AliasTable -> Body rep -> Body (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody AliasTable
aliases Body rep
defbody
      all_cons :: VarAliases
all_cons = (Body (Aliases rep) -> VarAliases)
-> [Body (Aliases rep)] -> VarAliases
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (([VarAliases], VarAliases) -> VarAliases
forall a b. (a, b) -> b
snd (([VarAliases], VarAliases) -> VarAliases)
-> (Body (Aliases rep) -> ([VarAliases], VarAliases))
-> Body (Aliases rep)
-> VarAliases
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([VarAliases], VarAliases), BodyDec rep)
-> ([VarAliases], VarAliases)
forall a b. (a, b) -> a
fst ((([VarAliases], VarAliases), BodyDec rep)
 -> ([VarAliases], VarAliases))
-> (Body (Aliases rep)
    -> (([VarAliases], VarAliases), BodyDec rep))
-> Body (Aliases rep)
-> ([VarAliases], VarAliases)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Aliases rep) -> (([VarAliases], VarAliases), BodyDec rep)
Body (Aliases rep) -> BodyDec (Aliases rep)
forall rep. Body rep -> BodyDec rep
bodyDec) ([Body (Aliases rep)] -> VarAliases)
-> [Body (Aliases rep)] -> VarAliases
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep)
defbody' Body (Aliases rep) -> [Body (Aliases rep)] -> [Body (Aliases rep)]
forall a. a -> [a] -> [a]
: (Case (Body (Aliases rep)) -> Body (Aliases rep))
-> [Case (Body (Aliases rep))] -> [Body (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body (Aliases rep)) -> Body (Aliases rep)
forall body. Case body -> body
caseBody [Case (Body (Aliases rep))]
cases'
      isConsumed :: VName -> Bool
isConsumed VName
v =
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` VarAliases -> Names
unAliases VarAliases
all_cons) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
          VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: Names -> [VName]
namesToList (Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v AliasTable
aliases)
      notConsumed :: VarAliases -> VarAliases
notConsumed =
        Names -> VarAliases
AliasDec
          (Names -> VarAliases)
-> (VarAliases -> Names) -> VarAliases -> VarAliases
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Names
namesFromList
          ([VName] -> Names)
-> (VarAliases -> [VName]) -> VarAliases -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (VName -> Bool) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isConsumed)
          ([VName] -> [VName])
-> (VarAliases -> [VName]) -> VarAliases -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
          (Names -> [VName])
-> (VarAliases -> Names) -> VarAliases -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarAliases -> Names
unAliases
      onBody :: Body (Aliases rep) -> Body (Aliases rep)
onBody (Body (([VarAliases]
als, VarAliases
cons), BodyDec rep
dec) Stms (Aliases rep)
stms Result
res) =
        BodyDec (Aliases rep)
-> Stms (Aliases rep) -> Result -> Body (Aliases rep)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body (((VarAliases -> VarAliases) -> [VarAliases] -> [VarAliases]
forall a b. (a -> b) -> [a] -> [b]
map VarAliases -> VarAliases
notConsumed [VarAliases]
als, VarAliases
cons), BodyDec rep
dec) Stms (Aliases rep)
stms Result
res
      cases'' :: [Case (Body (Aliases rep))]
cases'' = (Case (Body (Aliases rep)) -> Case (Body (Aliases rep)))
-> [Case (Body (Aliases rep))] -> [Case (Body (Aliases rep))]
forall a b. (a -> b) -> [a] -> [b]
map ((Body (Aliases rep) -> Body (Aliases rep))
-> Case (Body (Aliases rep)) -> Case (Body (Aliases rep))
forall a b. (a -> b) -> Case a -> Case b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Body (Aliases rep) -> Body (Aliases rep)
onBody) [Case (Body (Aliases rep))]
cases'
      defbody'' :: Body (Aliases rep)
defbody'' = Body (Aliases rep) -> Body (Aliases rep)
onBody Body (Aliases rep)
defbody'
   in [SubExp]
-> [Case (Body (Aliases rep))]
-> Body (Aliases rep)
-> MatchDec (BranchType (Aliases rep))
-> Exp (Aliases rep)
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body (Aliases rep))]
cases'' Body (Aliases rep)
defbody'' MatchDec (BranchType rep)
MatchDec (BranchType (Aliases rep))
matchdec
analyseExp AliasTable
aliases Exp rep
e = Mapper rep (Aliases rep) Identity -> Exp rep -> Exp (Aliases rep)
forall frep trep. Mapper frep trep Identity -> Exp frep -> Exp trep
mapExp Mapper rep (Aliases rep) Identity
analyse Exp rep
e
  where
    analyse :: Mapper rep (Aliases rep) Identity
analyse =
      Mapper
        { mapOnSubExp :: SubExp -> Identity SubExp
mapOnSubExp = SubExp -> Identity SubExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnVName :: VName -> Identity VName
mapOnVName = VName -> Identity VName
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnBody :: Scope (Aliases rep) -> Body rep -> Identity (Body (Aliases rep))
mapOnBody = (Body rep -> Identity (Body (Aliases rep)))
-> Scope (Aliases rep) -> Body rep -> Identity (Body (Aliases rep))
forall a b. a -> b -> a
const ((Body rep -> Identity (Body (Aliases rep)))
 -> Scope (Aliases rep)
 -> Body rep
 -> Identity (Body (Aliases rep)))
-> (Body rep -> Identity (Body (Aliases rep)))
-> Scope (Aliases rep)
-> Body rep
-> Identity (Body (Aliases rep))
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Identity (Body (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Aliases rep) -> Identity (Body (Aliases rep)))
-> (Body rep -> Body (Aliases rep))
-> Body rep
-> Identity (Body (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AliasTable -> Body rep -> Body (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody AliasTable
aliases,
          mapOnRetType :: RetType rep -> Identity (RetType (Aliases rep))
mapOnRetType = RetType rep -> Identity (RetType rep)
RetType rep -> Identity (RetType (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnBranchType :: BranchType rep -> Identity (BranchType (Aliases rep))
mapOnBranchType = BranchType rep -> Identity (BranchType rep)
BranchType rep -> Identity (BranchType (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnFParam :: FParam rep -> Identity (FParam (Aliases rep))
mapOnFParam = FParam rep -> Identity (FParam rep)
FParam rep -> Identity (FParam (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnLParam :: LParam rep -> Identity (LParam (Aliases rep))
mapOnLParam = LParam rep -> Identity (LParam rep)
LParam rep -> Identity (LParam (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnOp :: Op rep -> Identity (Op (Aliases rep))
mapOnOp = OpC rep (Aliases rep) -> Identity (OpC rep (Aliases rep))
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpC rep (Aliases rep) -> Identity (OpC rep (Aliases rep)))
-> (Op rep -> OpC rep (Aliases rep))
-> Op rep
-> Identity (OpC rep (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AliasTable -> Op rep -> OpC rep (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> OpC rep rep -> OpC rep (Aliases rep)
forall (op :: * -> *) rep.
(CanBeAliased op, AliasableRep rep) =>
AliasTable -> op rep -> op (Aliases rep)
addOpAliases AliasTable
aliases
        }

-- | Perform alias analysis on lambda.
analyseLambda ::
  (AliasableRep rep) =>
  AliasTable ->
  Lambda rep ->
  Lambda (Aliases rep)
analyseLambda :: forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
analyseLambda AliasTable
aliases Lambda rep
lam =
  let body :: Body (Aliases rep)
body = AliasTable -> Body rep -> Body (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody AliasTable
aliases (Body rep -> Body (Aliases rep)) -> Body rep -> Body (Aliases rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
   in Lambda rep
lam
        { lambdaBody = body,
          lambdaParams = lambdaParams lam
        }