{-# LANGUAGE TypeFamilies #-}

-- | Alias analysis of a full Futhark program.  Takes as input a
-- program with an arbitrary rep and produces one with aliases.  This
-- module does not implement the aliasing logic itself, and derives
-- its information from definitions in
-- "Futhark.IR.Prop.Aliases" and
-- "Futhark.IR.Aliases".  The alias information computed
-- here will include transitive aliases (note that this is not what
-- the building blocks do).
module Futhark.Analysis.Alias
  ( aliasAnalysis,
    AliasableRep,

    -- * Ad-hoc utilities
    analyseFun,
    analyseStms,
    analyseStm,
    analyseExp,
    analyseBody,
    analyseLambda,
  )
where

import Data.List (foldl')
import Data.Map qualified as M
import Futhark.IR.Aliases

-- | Perform alias analysis on a Futhark program.
aliasAnalysis ::
  AliasableRep rep =>
  Prog rep ->
  Prog (Aliases rep)
aliasAnalysis :: forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
aliasAnalysis Prog rep
prog =
  Prog rep
prog
    { progConsts :: Stms (Aliases rep)
progConsts = forall a b. (a, b) -> a
fst (forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
analyseStms forall a. Monoid a => a
mempty (forall rep. Prog rep -> Stms rep
progConsts Prog rep
prog)),
      progFuns :: [FunDef (Aliases rep)]
progFuns = forall a b. (a -> b) -> [a] -> [b]
map forall rep. AliasableRep rep => FunDef rep -> FunDef (Aliases rep)
analyseFun (forall rep. Prog rep -> [FunDef rep]
progFuns Prog rep
prog)
    }

-- | Perform alias analysis on function.
analyseFun ::
  AliasableRep rep =>
  FunDef rep ->
  FunDef (Aliases rep)
analyseFun :: forall rep. AliasableRep rep => FunDef rep -> FunDef (Aliases rep)
analyseFun (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType rep]
restype [FParam rep]
params Body rep
body) =
  forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType rep]
restype [FParam rep]
params Body (Aliases rep)
body'
  where
    body' :: Body (Aliases rep)
body' = forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody forall a. Monoid a => a
mempty Body rep
body

-- | Perform alias analysis on Body.
analyseBody ::
  AliasableRep rep =>
  AliasTable ->
  Body rep ->
  Body (Aliases rep)
analyseBody :: forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody AliasTable
atable (Body BodyDec rep
rep Stms rep
stms Result
result) =
  let (Stms (Aliases rep)
stms', AliasesAndConsumed
_atable') = forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
analyseStms AliasTable
atable Stms rep
stms
   in forall rep.
(ASTRep rep, AliasedOp (OpC rep (Aliases rep))) =>
BodyDec rep -> Stms (Aliases rep) -> Result -> Body (Aliases rep)
mkAliasedBody BodyDec rep
rep Stms (Aliases rep)
stms' Result
result

-- | Perform alias analysis on statements.
analyseStms ::
  AliasableRep rep =>
  AliasTable ->
  Stms rep ->
  (Stms (Aliases rep), AliasesAndConsumed)
analyseStms :: forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
analyseStms AliasTable
orig_aliases =
  forall {t :: * -> *} {rep} {a}.
Foldable t =>
(t (Stm rep), (a, Names)) -> (t (Stm rep), (a, Names))
withoutBound forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {rep}.
(ASTRep rep, CanBeAliased (OpC rep),
 AliasedOp (OpC rep (Aliases rep))) =>
(Stms (Aliases rep), AliasesAndConsumed)
-> Stm rep -> (Stms (Aliases rep), AliasesAndConsumed)
f (forall a. Monoid a => a
mempty, (AliasTable
orig_aliases, forall a. Monoid a => a
mempty)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stms rep -> [Stm rep]
stmsToList
  where
    withoutBound :: (t (Stm rep), (a, Names)) -> (t (Stm rep), (a, Names))
withoutBound (t (Stm rep)
stms, (a
aliases, Names
consumed)) =
      let bound :: Names
bound = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ([VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Pat (LetDec rep)
stmPat) t (Stm rep)
stms
          consumed' :: Names
consumed' = Names
consumed Names -> Names -> Names
`namesSubtract` Names
bound
       in (t (Stm rep)
stms, (a
aliases, Names
consumed'))

    f :: (Stms (Aliases rep), AliasesAndConsumed)
-> Stm rep -> (Stms (Aliases rep), AliasesAndConsumed)
f (Stms (Aliases rep)
stms, AliasesAndConsumed
aliases) Stm rep
stm =
      let stm' :: Stm (Aliases rep)
stm' = forall rep.
AliasableRep rep =>
AliasTable -> Stm rep -> Stm (Aliases rep)
analyseStm (forall a b. (a, b) -> a
fst AliasesAndConsumed
aliases) Stm rep
stm
          atable' :: AliasesAndConsumed
atable' = forall rep.
Aliased rep =>
AliasesAndConsumed -> Stm rep -> AliasesAndConsumed
trackAliases AliasesAndConsumed
aliases Stm (Aliases rep)
stm'
       in (Stms (Aliases rep)
stms forall a. Semigroup a => a -> a -> a
<> forall rep. Stm rep -> Stms rep
oneStm Stm (Aliases rep)
stm', AliasesAndConsumed
atable')

-- | Perform alias analysis on statement.
analyseStm ::
  AliasableRep rep =>
  AliasTable ->
  Stm rep ->
  Stm (Aliases rep)
analyseStm :: forall rep.
AliasableRep rep =>
AliasTable -> Stm rep -> Stm (Aliases rep)
analyseStm AliasTable
aliases (Let Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
dec) Exp rep
e) =
  let e' :: Exp (Aliases rep)
e' = forall rep.
AliasableRep rep =>
AliasTable -> Exp rep -> Exp (Aliases rep)
analyseExp AliasTable
aliases Exp rep
e
      pat' :: Pat (VarAliases, LetDec rep)
pat' = forall rep dec.
(Aliased rep, Typed dec) =>
Pat dec -> Exp rep -> Pat (VarAliases, dec)
mkAliasedPat Pat (LetDec rep)
pat Exp (Aliases rep)
e'
      rep' :: (VarAliases, ExpDec rep)
rep' = (Names -> VarAliases
AliasDec forall a b. (a -> b) -> a -> b
$ forall rep. Aliased rep => Exp rep -> Names
consumedInExp Exp (Aliases rep)
e', ExpDec rep
dec)
   in forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (VarAliases, LetDec rep)
pat' (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs (VarAliases, ExpDec rep)
rep') Exp (Aliases rep)
e'

-- | Perform alias analysis on expression.
analyseExp ::
  AliasableRep rep =>
  AliasTable ->
  Exp rep ->
  Exp (Aliases rep)
-- Would be better to put this in a BranchType annotation, but that
-- requires a lot of other work.
analyseExp :: forall rep.
AliasableRep rep =>
AliasTable -> Exp rep -> Exp (Aliases rep)
analyseExp AliasTable
aliases (Match [SubExp]
cond [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
matchdec) =
  let cases' :: [Case (Body (Aliases rep))]
cases' = forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody AliasTable
aliases) [Case (Body rep)]
cases
      defbody' :: Body (Aliases rep)
defbody' = forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody AliasTable
aliases Body rep
defbody
      all_cons :: VarAliases
all_cons = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> BodyDec rep
bodyDec) forall a b. (a -> b) -> a -> b
$ Body (Aliases rep)
defbody' forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body (Aliases rep))]
cases'
      isConsumed :: VName -> Bool
isConsumed VName
v =
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` VarAliases -> Names
unAliases VarAliases
all_cons) forall a b. (a -> b) -> a -> b
$
          VName
v forall a. a -> [a] -> [a]
: Names -> [VName]
namesToList (forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
v AliasTable
aliases)
      notConsumed :: VarAliases -> VarAliases
notConsumed =
        Names -> VarAliases
AliasDec
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Names
namesFromList
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Bool
isConsumed)
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarAliases -> Names
unAliases
      onBody :: Body (Aliases rep) -> Body (Aliases rep)
onBody (Body (([VarAliases]
als, VarAliases
cons), BodyDec rep
dec) Stms (Aliases rep)
stms Result
res) =
        forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body ((forall a b. (a -> b) -> [a] -> [b]
map VarAliases -> VarAliases
notConsumed [VarAliases]
als, VarAliases
cons), BodyDec rep
dec) Stms (Aliases rep)
stms Result
res
      cases'' :: [Case (Body (Aliases rep))]
cases'' = forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Body (Aliases rep) -> Body (Aliases rep)
onBody) [Case (Body (Aliases rep))]
cases'
      defbody'' :: Body (Aliases rep)
defbody'' = Body (Aliases rep) -> Body (Aliases rep)
onBody Body (Aliases rep)
defbody'
   in forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body (Aliases rep))]
cases'' Body (Aliases rep)
defbody'' MatchDec (BranchType rep)
matchdec
analyseExp AliasTable
aliases Exp rep
e = forall frep trep. Mapper frep trep Identity -> Exp frep -> Exp trep
mapExp Mapper rep (Aliases rep) Identity
analyse Exp rep
e
  where
    analyse :: Mapper rep (Aliases rep) Identity
analyse =
      Mapper
        { mapOnSubExp :: SubExp -> Identity SubExp
mapOnSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnVName :: VName -> Identity VName
mapOnVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnBody :: Scope (Aliases rep) -> Body rep -> Identity (Body (Aliases rep))
mapOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody AliasTable
aliases,
          mapOnRetType :: RetType rep -> Identity (RetType (Aliases rep))
mapOnRetType = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnBranchType :: BranchType rep -> Identity (BranchType (Aliases rep))
mapOnBranchType = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnFParam :: FParam rep -> Identity (FParam (Aliases rep))
mapOnFParam = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnLParam :: LParam rep -> Identity (LParam (Aliases rep))
mapOnLParam = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnOp :: Op rep -> Identity (Op (Aliases rep))
mapOnOp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (op :: * -> *) rep.
(CanBeAliased op, AliasableRep rep) =>
AliasTable -> op rep -> op (Aliases rep)
addOpAliases AliasTable
aliases
        }

-- | Perform alias analysis on lambda.
analyseLambda ::
  AliasableRep rep =>
  AliasTable ->
  Lambda rep ->
  Lambda (Aliases rep)
analyseLambda :: forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
analyseLambda AliasTable
aliases Lambda rep
lam =
  let body :: Body (Aliases rep)
body = forall rep.
AliasableRep rep =>
AliasTable -> Body rep -> Body (Aliases rep)
analyseBody AliasTable
aliases forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
   in Lambda rep
lam
        { lambdaBody :: Body (Aliases rep)
lambdaBody = Body (Aliases rep)
body,
          lambdaParams :: [LParam (Aliases rep)]
lambdaParams = forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
        }