{-# LANGUAGE TypeFamilies #-}

-- | The IR tracks aliases, mostly to ensure the soundness of in-place
-- updates, but it can also be used for other things (such as memory
-- optimisations).  This module contains the raw building blocks for
-- determining the aliases of the values produced by expressions.  It
-- also contains some building blocks for inspecting consumption.
--
-- One important caveat is that all aliases computed here are /local/.
-- Thus, they do not take aliases-of-aliases into account.  See
-- "Futhark.Analysis.Alias" if this is not what you want.
module Futhark.IR.Prop.Aliases
  ( subExpAliases,
    expAliases,
    patAliases,
    lookupAliases,
    Aliased (..),
    AliasesOf (..),

    -- * Consumption
    consumedInStm,
    consumedInExp,
    consumedByLambda,

    -- * Extensibility
    AliasTable,
    AliasedOp (..),
  )
where

import Data.Bifunctor (first, second)
import Data.List (find, transpose)
import Data.Map qualified as M
import Futhark.IR.Prop (ASTRep, IsOp, NameInfo (..), Scope)
import Futhark.IR.Prop.Names
import Futhark.IR.Prop.Pat
import Futhark.IR.Prop.Types
import Futhark.IR.Syntax

-- | The class of representations that contain aliasing information.
class (ASTRep rep, AliasedOp (Op rep), AliasesOf (LetDec rep)) => Aliased rep where
  -- | The aliases of the body results.  Note that this includes names
  -- bound in the body!
  bodyAliases :: Body rep -> [Names]

  -- | The variables consumed in the body.
  consumedInBody :: Body rep -> Names

vnameAliases :: VName -> Names
vnameAliases :: VName -> Names
vnameAliases = VName -> Names
oneName

-- | The aliases of a subexpression.
subExpAliases :: SubExp -> Names
subExpAliases :: SubExp -> Names
subExpAliases Constant {} = Names
forall a. Monoid a => a
mempty
subExpAliases (Var VName
v) = VName -> Names
vnameAliases VName
v

basicOpAliases :: BasicOp -> [Names]
basicOpAliases :: BasicOp -> [Names]
basicOpAliases (SubExp SubExp
se) = [SubExp -> Names
subExpAliases SubExp
se]
basicOpAliases (Opaque OpaqueOp
_ SubExp
se) = [SubExp -> Names
subExpAliases SubExp
se]
basicOpAliases (ArrayLit [SubExp]
_ Type
_) = [Names
forall a. Monoid a => a
mempty]
basicOpAliases BinOp {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases ConvOp {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases CmpOp {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases UnOp {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases (Index VName
ident Slice SubExp
_) = [VName -> Names
vnameAliases VName
ident]
basicOpAliases Update {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases (FlatIndex VName
ident FlatSlice SubExp
_) = [VName -> Names
vnameAliases VName
ident]
basicOpAliases FlatUpdate {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases Iota {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases Replicate {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases Scratch {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases (Reshape ReshapeKind
_ Shape
_ VName
e) = [VName -> Names
vnameAliases VName
e]
basicOpAliases (Rearrange [Int]
_ VName
e) = [VName -> Names
vnameAliases VName
e]
basicOpAliases Concat {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases Manifest {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases Assert {} = [Names
forall a. Monoid a => a
mempty]
basicOpAliases UpdateAcc {} = [Names
forall a. Monoid a => a
mempty]

matchAliases :: [([Names], Names)] -> [Names]
matchAliases :: [([Names], Names)] -> [Names]
matchAliases [([Names], Names)]
l =
  ([Names] -> Names) -> [[Names]] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((Names -> Names -> Names
`namesSubtract` [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat [Names]
conses) (Names -> Names) -> ([Names] -> Names) -> [Names] -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat) ([[Names]] -> [Names]) -> [[Names]] -> [Names]
forall a b. (a -> b) -> a -> b
$ [[Names]] -> [[Names]]
forall a. [[a]] -> [[a]]
transpose [[Names]]
alses
  where
    ([[Names]]
alses, [Names]
conses) = [([Names], Names)] -> ([[Names]], [Names])
forall a b. [(a, b)] -> ([a], [b])
unzip [([Names], Names)]
l

funcallAliases ::
  [PatElem dec] ->
  [(SubExp, Diet)] ->
  [(TypeBase shape Uniqueness, RetAls)] ->
  [Names]
funcallAliases :: forall dec shape.
[PatElem dec]
-> [(SubExp, Diet)]
-> [(TypeBase shape Uniqueness, RetAls)]
-> [Names]
funcallAliases [PatElem dec]
pes [(SubExp, Diet)]
args = ((TypeBase shape Uniqueness, RetAls) -> Names)
-> [(TypeBase shape Uniqueness, RetAls)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase shape Uniqueness, RetAls) -> Names
onType
  where
    getAls :: [a] -> t b -> a
getAls [a]
als t b
is = [a] -> a
forall a. Monoid a => [a] -> a
mconcat ([a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ ((a, b) -> a) -> [(a, b)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, b) -> a
forall a b. (a, b) -> a
fst ([(a, b)] -> [a]) -> [(a, b)] -> [a]
forall a b. (a -> b) -> a -> b
$ ((a, b) -> Bool) -> [(a, b)] -> [(a, b)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((b -> t b -> Bool
forall a. Eq a => a -> t a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` t b
is) (b -> Bool) -> ((a, b) -> b) -> (a, b) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, b) -> b
forall a b. (a, b) -> b
snd) ([(a, b)] -> [(a, b)]) -> [(a, b)] -> [(a, b)]
forall a b. (a -> b) -> a -> b
$ [a] -> [b] -> [(a, b)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
als [b
0 ..]
    arg_als :: [Names]
arg_als = ((SubExp, Diet) -> Names) -> [(SubExp, Diet)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Names
subExpAliases (SubExp -> Names)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args
    res_als :: [Names]
res_als = (PatElem dec -> Names) -> [PatElem dec] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Names
oneName (VName -> Names) -> (PatElem dec -> VName) -> PatElem dec -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem dec]
pes
    onType :: (TypeBase shape Uniqueness, RetAls) -> Names
onType (TypeBase shape Uniqueness
_t, RetAls [Int]
pals [Int]
rals) = [Names] -> [Int] -> Names
forall {a} {t :: * -> *} {b}.
(Monoid a, Foldable t, Eq b, Num b, Enum b) =>
[a] -> t b -> a
getAls [Names]
arg_als [Int]
pals Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> [Int] -> Names
forall {a} {t :: * -> *} {b}.
(Monoid a, Foldable t, Eq b, Num b, Enum b) =>
[a] -> t b -> a
getAls [Names]
res_als [Int]
rals

mutualAliases :: Names -> [PatElem dec] -> [Names] -> [Names]
mutualAliases :: forall dec. Names -> [PatElem dec] -> [Names] -> [Names]
mutualAliases Names
bound [PatElem dec]
pes [Names]
als = (VName -> Names -> Names) -> [VName] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Names -> Names
grow ((PatElem dec -> VName) -> [PatElem dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem dec]
pes) [Names]
als
  where
    bound_als :: [Names]
bound_als = (Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Names -> Names
`namesIntersection` Names
bound) [Names]
als
    grow :: VName -> Names -> Names
grow VName
v Names
names = (Names
names Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
pe_names) Names -> Names -> Names
`namesSubtract` Names
bound
      where
        pe_names :: Names
pe_names =
          [VName] -> Names
namesFromList
            ([VName] -> Names)
-> ([(PatElem dec, Names)] -> [VName])
-> [(PatElem dec, Names)]
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
v)
            ([VName] -> [VName])
-> ([(PatElem dec, Names)] -> [VName])
-> [(PatElem dec, Names)]
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((PatElem dec, Names) -> VName)
-> [(PatElem dec, Names)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem dec -> VName)
-> ((PatElem dec, Names) -> PatElem dec)
-> (PatElem dec, Names)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem dec, Names) -> PatElem dec
forall a b. (a, b) -> a
fst)
            ([(PatElem dec, Names)] -> [VName])
-> ([(PatElem dec, Names)] -> [(PatElem dec, Names)])
-> [(PatElem dec, Names)]
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((PatElem dec, Names) -> Bool)
-> [(PatElem dec, Names)] -> [(PatElem dec, Names)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Names -> Names -> Bool
namesIntersect Names
names (Names -> Bool)
-> ((PatElem dec, Names) -> Names) -> (PatElem dec, Names) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem dec, Names) -> Names
forall a b. (a, b) -> b
snd)
            ([(PatElem dec, Names)] -> Names)
-> [(PatElem dec, Names)] -> Names
forall a b. (a -> b) -> a -> b
$ [PatElem dec] -> [Names] -> [(PatElem dec, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem dec]
pes [Names]
bound_als

-- | The aliases of an expression, one for each pattern element.
--
-- The pattern is important because some aliasing might be through
-- variables that are no longer in scope (consider the aliases for a
-- body that returns the same value multiple times).
expAliases :: (Aliased rep) => [PatElem dec] -> Exp rep -> [Names]
expAliases :: forall rep dec. Aliased rep => [PatElem dec] -> Exp rep -> [Names]
expAliases [PatElem dec]
pes (Match [SubExp]
_ [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_) =
  -- Repeat mempty in case the pattern has more elements (this
  -- implies a type error).
  Names -> [PatElem dec] -> [Names] -> [Names]
forall dec. Names -> [PatElem dec] -> [Names] -> [Names]
mutualAliases Names
bound [PatElem dec]
pes ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$ [Names]
als [Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++ Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty
  where
    als :: [Names]
als = [([Names], Names)] -> [Names]
matchAliases ([([Names], Names)] -> [Names]) -> [([Names], Names)] -> [Names]
forall a b. (a -> b) -> a -> b
$ Body rep -> ([Names], Names)
forall {rep}. Aliased rep => Body rep -> ([Names], Names)
onBody Body rep
defbody ([Names], Names) -> [([Names], Names)] -> [([Names], Names)]
forall a. a -> [a] -> [a]
: (Case (Body rep) -> ([Names], Names))
-> [Case (Body rep)] -> [([Names], Names)]
forall a b. (a -> b) -> [a] -> [b]
map (Body rep -> ([Names], Names)
forall {rep}. Aliased rep => Body rep -> ([Names], Names)
onBody (Body rep -> ([Names], Names))
-> (Case (Body rep) -> Body rep)
-> Case (Body rep)
-> ([Names], Names)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
    onBody :: Body rep -> ([Names], Names)
onBody Body rep
body = (Body rep -> [Names]
forall rep. Aliased rep => Body rep -> [Names]
bodyAliases Body rep
body, Body rep -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody Body rep
body)
    bound :: Names
bound = (Body rep -> Names) -> [Body rep] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Body rep -> Names
forall rep. Body rep -> Names
boundInBody ([Body rep] -> Names) -> [Body rep] -> Names
forall a b. (a -> b) -> a -> b
$ Body rep
defbody Body rep -> [Body rep] -> [Body rep]
forall a. a -> [a] -> [a]
: (Case (Body rep) -> Body rep) -> [Case (Body rep)] -> [Body rep]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody [Case (Body rep)]
cases
expAliases [PatElem dec]
_ (BasicOp BasicOp
op) = BasicOp -> [Names]
basicOpAliases BasicOp
op
expAliases [PatElem dec]
pes (Loop [(FParam rep, SubExp)]
merge LoopForm
_ Body rep
loopbody) =
  Names -> [PatElem dec] -> [Names] -> [Names]
forall dec. Names -> [PatElem dec] -> [Names] -> [Names]
mutualAliases (Names
bound Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
param_names) [PatElem dec]
pes ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$ do
    (FParam rep
p, Names
als) <-
      [(FParam rep, Names)] -> [(FParam rep, Names)]
forall {dec}.
Eq dec =>
[(Param dec, Names)] -> [(Param dec, Names)]
transitive ([(FParam rep, Names)] -> [(FParam rep, Names)])
-> ([Names] -> [(FParam rep, Names)])
-> [Names]
-> [(FParam rep, Names)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [FParam rep] -> [Names] -> [(FParam rep, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam rep]
params ([Names] -> [(FParam rep, Names)])
-> [Names] -> [(FParam rep, Names)]
forall a b. (a -> b) -> a -> b
$ (Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) [Names]
arg_aliases (Body rep -> [Names]
forall rep. Aliased rep => Body rep -> [Names]
bodyAliases Body rep
loopbody)
    if TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> TypeBase Shape Uniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ FParam rep -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType FParam rep
p
      then Names -> [Names]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Names
forall a. Monoid a => a
mempty
      else Names -> [Names]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Names
als
  where
    bound :: Names
bound = Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
loopbody
    arg_aliases :: [Names]
arg_aliases = ((FParam rep, SubExp) -> Names)
-> [(FParam rep, SubExp)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Names
subExpAliases (SubExp -> Names)
-> ((FParam rep, SubExp) -> SubExp)
-> (FParam rep, SubExp)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam rep, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(FParam rep, SubExp)]
merge
    params :: [FParam rep]
params = ((FParam rep, SubExp) -> FParam rep)
-> [(FParam rep, SubExp)] -> [FParam rep]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge
    param_names :: Names
param_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (FParam rep -> VName) -> [FParam rep] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map FParam rep -> VName
forall dec. Param dec -> VName
paramName [FParam rep]
params
    transitive :: [(Param dec, Names)] -> [(Param dec, Names)]
transitive [(Param dec, Names)]
merge_and_als =
      let merge_and_als' :: [(Param dec, Names)]
merge_and_als' = ((Param dec, Names) -> (Param dec, Names))
-> [(Param dec, Names)] -> [(Param dec, Names)]
forall a b. (a -> b) -> [a] -> [b]
map ((Names -> Names) -> (Param dec, Names) -> (Param dec, Names)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Names -> Names
expand) [(Param dec, Names)]
merge_and_als
       in if [(Param dec, Names)]
merge_and_als' [(Param dec, Names)] -> [(Param dec, Names)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(Param dec, Names)]
merge_and_als
            then [(Param dec, Names)]
merge_and_als
            else [(Param dec, Names)] -> [(Param dec, Names)]
transitive [(Param dec, Names)]
merge_and_als'
      where
        look :: VName -> Names
look VName
v = Names
-> ((Param dec, Names) -> Names)
-> Maybe (Param dec, Names)
-> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty (Param dec, Names) -> Names
forall a b. (a, b) -> b
snd (Maybe (Param dec, Names) -> Names)
-> Maybe (Param dec, Names) -> Names
forall a b. (a -> b) -> a -> b
$ ((Param dec, Names) -> Bool)
-> [(Param dec, Names)] -> Maybe (Param dec, Names)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param dec, Names) -> VName) -> (Param dec, Names) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName (Param dec -> VName)
-> ((Param dec, Names) -> Param dec) -> (Param dec, Names) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param dec, Names) -> Param dec
forall a b. (a, b) -> a
fst) [(Param dec, Names)]
merge_and_als
        expand :: Names -> Names
expand Names
als = Names
als Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> (VName -> Names) -> [VName] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap VName -> Names
look (Names -> [VName]
namesToList Names
als)
expAliases [PatElem dec]
pes (Apply Name
_ [(SubExp, Diet)]
args [(RetType rep, RetAls)]
t (Safety, SrcLoc, [SrcLoc])
_) =
  [PatElem dec]
-> [(SubExp, Diet)]
-> [(TypeBase ExtShape Uniqueness, RetAls)]
-> [Names]
forall dec shape.
[PatElem dec]
-> [(SubExp, Diet)]
-> [(TypeBase shape Uniqueness, RetAls)]
-> [Names]
funcallAliases [PatElem dec]
pes [(SubExp, Diet)]
args ([(TypeBase ExtShape Uniqueness, RetAls)] -> [Names])
-> [(TypeBase ExtShape Uniqueness, RetAls)] -> [Names]
forall a b. (a -> b) -> a -> b
$ ((RetType rep, RetAls) -> (TypeBase ExtShape Uniqueness, RetAls))
-> [(RetType rep, RetAls)]
-> [(TypeBase ExtShape Uniqueness, RetAls)]
forall a b. (a -> b) -> [a] -> [b]
map ((RetType rep -> TypeBase ExtShape Uniqueness)
-> (RetType rep, RetAls) -> (TypeBase ExtShape Uniqueness, RetAls)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first RetType rep -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf) [(RetType rep, RetAls)]
t
expAliases [PatElem dec]
_ (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) =
  (WithAccInput rep -> [Names]) -> [WithAccInput rep] -> [Names]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap WithAccInput rep -> [Names]
forall {t :: * -> *} {a} {a} {a} {c}.
(Foldable t, Monoid a) =>
(a, t a, c) -> [a]
inputAliases [WithAccInput rep]
inputs
    [Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++ Int -> [Names] -> [Names]
forall a. Int -> [a] -> [a]
drop Int
num_accs ((Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Names -> Names
`namesSubtract` Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
body) ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$ Body rep -> [Names]
forall rep. Aliased rep => Body rep -> [Names]
bodyAliases Body rep
body)
  where
    body :: Body rep
body = Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    inputAliases :: (a, t a, c) -> [a]
inputAliases (a
_, t a
arrs, c
_) = Int -> a -> [a]
forall a. Int -> a -> [a]
replicate (t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs) a
forall a. Monoid a => a
mempty
    num_accs :: Int
num_accs = [WithAccInput rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
expAliases [PatElem dec]
_ (Op OpC rep rep
op) = OpC rep rep -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases OpC rep rep
op

-- | The variables consumed in this statement.
consumedInStm :: (Aliased rep) => Stm rep -> Names
consumedInStm :: forall rep. Aliased rep => Stm rep -> Names
consumedInStm = Exp rep -> Names
forall rep. Aliased rep => Exp rep -> Names
consumedInExp (Exp rep -> Names) -> (Stm rep -> Exp rep) -> Stm rep -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp

-- | The variables consumed in this expression.
consumedInExp :: (Aliased rep) => Exp rep -> Names
consumedInExp :: forall rep. Aliased rep => Exp rep -> Names
consumedInExp (Apply Name
_ [(SubExp, Diet)]
args [(RetType rep, RetAls)]
_ (Safety, SrcLoc, [SrcLoc])
_) =
  [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat (((SubExp, Diet) -> Names) -> [(SubExp, Diet)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((Names, Diet) -> Names
forall {a}. Monoid a => (a, Diet) -> a
consumeArg ((Names, Diet) -> Names)
-> ((SubExp, Diet) -> (Names, Diet)) -> (SubExp, Diet) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> Names) -> (SubExp, Diet) -> (Names, Diet)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first SubExp -> Names
subExpAliases) [(SubExp, Diet)]
args)
  where
    consumeArg :: (a, Diet) -> a
consumeArg (a
als, Diet
Consume) = a
als
    consumeArg (a, Diet)
_ = a
forall a. Monoid a => a
mempty
consumedInExp (Match [SubExp]
_ [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_) =
  (Case (Body rep) -> Names) -> [Case (Body rep)] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Body rep -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody (Body rep -> Names)
-> (Case (Body rep) -> Body rep) -> Case (Body rep) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Body rep -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody Body rep
defbody
consumedInExp (Loop [(FParam rep, SubExp)]
merge LoopForm
_ Body rep
_) =
  [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat
    ( ((FParam rep, SubExp) -> Names)
-> [(FParam rep, SubExp)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Names
subExpAliases (SubExp -> Names)
-> ((FParam rep, SubExp) -> SubExp)
-> (FParam rep, SubExp)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam rep, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(FParam rep, SubExp)] -> [Names])
-> [(FParam rep, SubExp)] -> [Names]
forall a b. (a -> b) -> a -> b
$
        ((FParam rep, SubExp) -> Bool)
-> [(FParam rep, SubExp)] -> [(FParam rep, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> ((FParam rep, SubExp) -> TypeBase Shape Uniqueness)
-> (FParam rep, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam rep -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType (FParam rep -> TypeBase Shape Uniqueness)
-> ((FParam rep, SubExp) -> FParam rep)
-> (FParam rep, SubExp)
-> TypeBase Shape Uniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge
    )
consumedInExp (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) =
  [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((WithAccInput rep -> Names) -> [WithAccInput rep] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map WithAccInput rep -> Names
forall {a} {c}. (a, [VName], c) -> Names
inputConsumed [WithAccInput rep]
inputs)
    Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> ( Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
lam
           Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam))
       )
  where
    inputConsumed :: (a, [VName], c) -> Names
inputConsumed (a
_, [VName]
arrs, c
_) = [VName] -> Names
namesFromList [VName]
arrs
consumedInExp (BasicOp (Update Safety
_ VName
src Slice SubExp
_ SubExp
_)) = VName -> Names
oneName VName
src
consumedInExp (BasicOp (FlatUpdate VName
src FlatSlice SubExp
_ VName
_)) = VName -> Names
oneName VName
src
consumedInExp (BasicOp (UpdateAcc VName
acc [SubExp]
_ [SubExp]
_)) = VName -> Names
oneName VName
acc
consumedInExp (BasicOp BasicOp
_) = Names
forall a. Monoid a => a
mempty
consumedInExp (Op OpC rep rep
op) = OpC rep rep -> Names
forall op. AliasedOp op => op -> Names
consumedInOp OpC rep rep
op

-- | The variables consumed by this lambda.
consumedByLambda :: (Aliased rep) => Lambda rep -> Names
consumedByLambda :: forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda = Body rep -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody (Body rep -> Names)
-> (Lambda rep -> Body rep) -> Lambda rep -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody

-- | The aliases of each pattern element.
patAliases :: (AliasesOf dec) => Pat dec -> [Names]
patAliases :: forall dec. AliasesOf dec => Pat dec -> [Names]
patAliases = (PatElem dec -> Names) -> [PatElem dec] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map PatElem dec -> Names
forall a. AliasesOf a => a -> Names
aliasesOf ([PatElem dec] -> [Names])
-> (Pat dec -> [PatElem dec]) -> Pat dec -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat dec -> [PatElem dec]
forall dec. Pat dec -> [PatElem dec]
patElems

-- | Something that contains alias information.
class AliasesOf a where
  -- | The alias of the argument element.
  aliasesOf :: a -> Names

instance AliasesOf Names where
  aliasesOf :: Names -> Names
aliasesOf = Names -> Names
forall a. a -> a
id

instance (AliasesOf dec) => AliasesOf (PatElem dec) where
  aliasesOf :: PatElem dec -> Names
aliasesOf = dec -> Names
forall a. AliasesOf a => a -> Names
aliasesOf (dec -> Names) -> (PatElem dec -> dec) -> PatElem dec -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem dec -> dec
forall dec. PatElem dec -> dec
patElemDec

-- | Also includes the name itself.
lookupAliases :: (AliasesOf (LetDec rep)) => VName -> Scope rep -> Names
lookupAliases :: forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
lookupAliases VName
root Scope rep
scope =
  -- We must be careful to handle circular aliasing properly (this
  -- can happen due to Match and Loop).
  Names -> VName -> Names
expand Names
forall a. Monoid a => a
mempty VName
root
  where
    expand :: Names -> VName -> Names
expand Names
prev VName
v =
      case VName -> Scope rep -> Maybe (NameInfo rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Scope rep
scope of
        Just (LetName LetDec rep
dec) ->
          VName -> Names
oneName VName
v
            Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> (VName -> Names) -> [VName] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
              (Names -> VName -> Names
expand (VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
prev))
              ((VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`notNameIn` Names
prev) (Names -> [VName]
namesToList (LetDec rep -> Names
forall a. AliasesOf a => a -> Names
aliasesOf LetDec rep
dec)))
        Maybe (NameInfo rep)
_ -> VName -> Names
oneName VName
v

-- | The class of operations that can produce aliasing and consumption
-- information.
class (IsOp op) => AliasedOp op where
  opAliases :: op -> [Names]
  consumedInOp :: op -> Names

instance AliasedOp (NoOp rep) where
  opAliases :: NoOp rep -> [Names]
opAliases NoOp rep
NoOp = []
  consumedInOp :: NoOp rep -> Names
consumedInOp NoOp rep
NoOp = Names
forall a. Monoid a => a
mempty

-- | Pre-existing aliases for variables.  Used to add transitive
-- aliases.
type AliasTable = M.Map VName Names