{-# LANGUAGE TypeFamilies #-}
module Futhark.IR.Prop.Aliases
( subExpAliases,
expAliases,
patAliases,
lookupAliases,
Aliased (..),
AliasesOf (..),
consumedInStm,
consumedInExp,
consumedByLambda,
AliasTable,
AliasedOp (..),
)
where
import Data.Bifunctor (first, second)
import Data.List (find, transpose)
import Data.Map qualified as M
import Futhark.IR.Prop (ASTRep, IsOp, NameInfo (..), Scope)
import Futhark.IR.Prop.Names
import Futhark.IR.Prop.Patterns
import Futhark.IR.Prop.Types
import Futhark.IR.Syntax
class (ASTRep rep, AliasedOp (Op rep), AliasesOf (LetDec rep)) => Aliased rep where
bodyAliases :: Body rep -> [Names]
consumedInBody :: Body rep -> Names
vnameAliases :: VName -> Names
vnameAliases :: VName -> Names
vnameAliases = VName -> Names
oneName
subExpAliases :: SubExp -> Names
subExpAliases :: SubExp -> Names
subExpAliases Constant {} = forall a. Monoid a => a
mempty
subExpAliases (Var VName
v) = VName -> Names
vnameAliases VName
v
basicOpAliases :: BasicOp -> [Names]
basicOpAliases :: BasicOp -> [Names]
basicOpAliases (SubExp SubExp
se) = [SubExp -> Names
subExpAliases SubExp
se]
basicOpAliases (Opaque OpaqueOp
_ SubExp
se) = [SubExp -> Names
subExpAliases SubExp
se]
basicOpAliases (ArrayLit [SubExp]
_ Type
_) = [forall a. Monoid a => a
mempty]
basicOpAliases BinOp {} = [forall a. Monoid a => a
mempty]
basicOpAliases ConvOp {} = [forall a. Monoid a => a
mempty]
basicOpAliases CmpOp {} = [forall a. Monoid a => a
mempty]
basicOpAliases UnOp {} = [forall a. Monoid a => a
mempty]
basicOpAliases (Index VName
ident Slice SubExp
_) = [VName -> Names
vnameAliases VName
ident]
basicOpAliases Update {} = [forall a. Monoid a => a
mempty]
basicOpAliases (FlatIndex VName
ident FlatSlice SubExp
_) = [VName -> Names
vnameAliases VName
ident]
basicOpAliases FlatUpdate {} = [forall a. Monoid a => a
mempty]
basicOpAliases Iota {} = [forall a. Monoid a => a
mempty]
basicOpAliases Replicate {} = [forall a. Monoid a => a
mempty]
basicOpAliases Scratch {} = [forall a. Monoid a => a
mempty]
basicOpAliases (Reshape ReshapeKind
_ Shape
_ VName
e) = [VName -> Names
vnameAliases VName
e]
basicOpAliases (Rearrange [Int]
_ VName
e) = [VName -> Names
vnameAliases VName
e]
basicOpAliases (Rotate [SubExp]
_ VName
e) = [VName -> Names
vnameAliases VName
e]
basicOpAliases Concat {} = [forall a. Monoid a => a
mempty]
basicOpAliases Copy {} = [forall a. Monoid a => a
mempty]
basicOpAliases Manifest {} = [forall a. Monoid a => a
mempty]
basicOpAliases Assert {} = [forall a. Monoid a => a
mempty]
basicOpAliases UpdateAcc {} = [forall a. Monoid a => a
mempty]
matchAliases :: [([Names], Names)] -> [Names]
matchAliases :: [([Names], Names)] -> [Names]
matchAliases [([Names], Names)]
l =
forall a b. (a -> b) -> [a] -> [b]
map ((Names -> Names -> Names
`namesSubtract` forall a. Monoid a => [a] -> a
mconcat [Names]
conses) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Monoid a => [a] -> a
mconcat) forall a b. (a -> b) -> a -> b
$ forall a. [[a]] -> [[a]]
transpose [[Names]]
alses
where
([[Names]]
alses, [Names]
conses) = forall a b. [(a, b)] -> ([a], [b])
unzip [([Names], Names)]
l
returnAliases :: [TypeBase shape Uniqueness] -> [(Names, Diet)] -> [Names]
returnAliases :: forall shape.
[TypeBase shape Uniqueness] -> [(Names, Diet)] -> [Names]
returnAliases [TypeBase shape Uniqueness]
rts [(Names, Diet)]
args = forall a b. (a -> b) -> [a] -> [b]
map TypeBase shape Uniqueness -> Names
returnType' [TypeBase shape Uniqueness]
rts
where
returnType' :: TypeBase shape Uniqueness -> Names
returnType' (Array PrimType
_ shape
_ Uniqueness
Nonunique) =
forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Names -> Diet -> Names
maskAliases) [(Names, Diet)]
args
returnType' (Array PrimType
_ shape
_ Uniqueness
Unique) =
forall a. Monoid a => a
mempty
returnType' (Prim PrimType
_) =
forall a. Monoid a => a
mempty
returnType' Acc {} =
forall a. HasCallStack => [Char] -> a
error [Char]
"returnAliases Acc"
returnType' Mem {} =
forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Names -> Diet -> Names
maskAliases) [(Names, Diet)]
args
funcallAliases :: [(SubExp, Diet)] -> [TypeBase shape Uniqueness] -> [Names]
funcallAliases :: forall shape.
[(SubExp, Diet)] -> [TypeBase shape Uniqueness] -> [Names]
funcallAliases [(SubExp, Diet)]
args [TypeBase shape Uniqueness]
t =
forall shape.
[TypeBase shape Uniqueness] -> [(Names, Diet)] -> [Names]
returnAliases [TypeBase shape Uniqueness]
t [(SubExp -> Names
subExpAliases SubExp
se, Diet
d) | (SubExp
se, Diet
d) <- [(SubExp, Diet)]
args]
expAliases :: (Aliased rep) => [PatElem dec] -> Exp rep -> [Names]
expAliases :: forall rep dec. Aliased rep => [PatElem dec] -> Exp rep -> [Names]
expAliases [PatElem dec]
pes (Match [SubExp]
_ [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_) =
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Names -> Names
grow (forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem dec]
pes) forall a b. (a -> b) -> a -> b
$ [Names]
als forall a. [a] -> [a] -> [a]
++ forall a. a -> [a]
repeat forall a. Monoid a => a
mempty
where
als :: [Names]
als = [([Names], Names)] -> [Names]
matchAliases forall a b. (a -> b) -> a -> b
$ forall {rep}. Aliased rep => Body rep -> ([Names], Names)
onBody Body rep
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall {rep}. Aliased rep => Body rep -> ([Names], Names)
onBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
onBody :: Body rep -> ([Names], Names)
onBody Body rep
body = (forall rep. Aliased rep => Body rep -> [Names]
bodyAliases Body rep
body, forall rep. Aliased rep => Body rep -> Names
consumedInBody Body rep
body)
bound :: Names
bound = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall rep. Body rep -> Names
boundInBody forall a b. (a -> b) -> a -> b
$ Body rep
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body rep)]
cases
grow :: VName -> Names -> Names
grow VName
v Names
names = (Names
names forall a. Semigroup a => a -> a -> a
<> Names
pe_names) Names -> Names -> Names
`namesSubtract` Names
bound
where
pe_names :: Names
pe_names =
[VName] -> Names
namesFromList
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
/= VName
v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall dec. PatElem dec -> VName
patElemName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (Names -> Names -> Bool
namesIntersect Names
names forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd)
forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem dec]
pes [Names]
als
expAliases [PatElem dec]
_ (BasicOp BasicOp
op) = BasicOp -> [Names]
basicOpAliases BasicOp
op
expAliases [PatElem dec]
_ (DoLoop [(FParam rep, SubExp)]
merge LoopForm rep
_ Body rep
loopbody) = do
(FParam rep
p, Names
als) <-
forall {dec}.
Eq dec =>
[(Param dec, Names)] -> [(Param dec, Names)]
transitive forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [FParam rep]
params forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Monoid a => a -> a -> a
mappend [Names]
arg_aliases (forall rep. Aliased rep => Body rep -> [Names]
bodyAliases Body rep
loopbody)
let als' :: Names
als' = Names
als Names -> Names -> Names
`namesSubtract` Names
param_names
if forall shape. TypeBase shape Uniqueness -> Bool
unique forall a b. (a -> b) -> a -> b
$ forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType FParam rep
p
then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Names
als' Names -> Names -> Names
`namesSubtract` Names
bound
where
bound :: Names
bound = forall rep. Body rep -> Names
boundInBody Body rep
loopbody
arg_aliases :: [Names]
arg_aliases = forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Names
subExpAliases forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(FParam rep, SubExp)]
merge
params :: [FParam rep]
params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge
param_names :: Names
param_names = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [FParam rep]
params
transitive :: [(Param dec, Names)] -> [(Param dec, Names)]
transitive [(Param dec, Names)]
merge_and_als =
let merge_and_als' :: [(Param dec, Names)]
merge_and_als' = forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Names -> Names
expand) [(Param dec, Names)]
merge_and_als
in if [(Param dec, Names)]
merge_and_als' forall a. Eq a => a -> a -> Bool
== [(Param dec, Names)]
merge_and_als
then [(Param dec, Names)]
merge_and_als
else [(Param dec, Names)] -> [(Param dec, Names)]
transitive [(Param dec, Names)]
merge_and_als'
where
look :: VName -> Names
look VName
v = forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param dec, Names)]
merge_and_als
expand :: Names -> Names
expand Names
als = Names
als forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap VName -> Names
look (Names -> [VName]
namesToList Names
als)
expAliases [PatElem dec]
_ (Apply Name
_ [(SubExp, Diet)]
args [RetType rep]
t (Safety, SrcLoc, [SrcLoc])
_) =
forall shape.
[(SubExp, Diet)] -> [TypeBase shape Uniqueness] -> [Names]
funcallAliases [(SubExp, Diet)]
args forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf [RetType rep]
t
expAliases [PatElem dec]
_ (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) =
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {t :: * -> *} {a} {a} {a} {c}.
(Foldable t, Monoid a) =>
(a, t a, c) -> [a]
inputAliases [WithAccInput rep]
inputs
forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
num_accs (forall a b. (a -> b) -> [a] -> [b]
map (Names -> Names -> Names
`namesSubtract` forall rep. Body rep -> Names
boundInBody Body rep
body) forall a b. (a -> b) -> a -> b
$ forall rep. Aliased rep => Body rep -> [Names]
bodyAliases Body rep
body)
where
body :: Body rep
body = forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
inputAliases :: (a, t a, c) -> [a]
inputAliases (a
_, t a
arrs, c
_) = forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs) forall a. Monoid a => a
mempty
num_accs :: Int
num_accs = forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
expAliases [PatElem dec]
_ (Op OpC rep rep
op) = forall op. AliasedOp op => op -> [Names]
opAliases OpC rep rep
op
maskAliases :: Names -> Diet -> Names
maskAliases :: Names -> Diet -> Names
maskAliases Names
_ Diet
Consume = forall a. Monoid a => a
mempty
maskAliases Names
_ Diet
ObservePrim = forall a. Monoid a => a
mempty
maskAliases Names
als Diet
Observe = Names
als
consumedInStm :: Aliased rep => Stm rep -> Names
consumedInStm :: forall rep. Aliased rep => Stm rep -> Names
consumedInStm = forall rep. Aliased rep => Exp rep -> Names
consumedInExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Exp rep
stmExp
consumedInExp :: (Aliased rep) => Exp rep -> Names
consumedInExp :: forall rep. Aliased rep => Exp rep -> Names
consumedInExp (Apply Name
_ [(SubExp, Diet)]
args [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_) =
forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map (forall {a}. Monoid a => (a, Diet) -> a
consumeArg forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first SubExp -> Names
subExpAliases) [(SubExp, Diet)]
args)
where
consumeArg :: (a, Diet) -> a
consumeArg (a
als, Diet
Consume) = a
als
consumeArg (a, Diet)
_ = forall a. Monoid a => a
mempty
consumedInExp (Match [SubExp]
_ [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_) =
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall rep. Aliased rep => Body rep -> Names
consumedInBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases forall a. Semigroup a => a -> a -> a
<> forall rep. Aliased rep => Body rep -> Names
consumedInBody Body rep
defbody
consumedInExp (DoLoop [(FParam rep, SubExp)]
merge LoopForm rep
form Body rep
body) =
forall a. Monoid a => [a] -> a
mconcat
( forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Names
subExpAliases forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> [a]
filter (forall shape. TypeBase shape Uniqueness -> Bool
unique forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge
)
forall a. Semigroup a => a -> a -> a
<> LoopForm rep -> Names
consumedInForm LoopForm rep
form
where
body_consumed :: Names
body_consumed = forall rep. Aliased rep => Body rep -> Names
consumedInBody Body rep
body
varConsumed :: (Param (LParamInfo rep), VName) -> Bool
varConsumed = (VName -> Names -> Bool
`nameIn` Names
body_consumed) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst
consumedInForm :: LoopForm rep -> Names
consumedInForm (ForLoop VName
_ IntType
_ SubExp
_ [(Param (LParamInfo rep), VName)]
loopvars) =
[VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (Param (LParamInfo rep), VName) -> Bool
varConsumed [(Param (LParamInfo rep), VName)]
loopvars
consumedInForm WhileLoop {} =
forall a. Monoid a => a
mempty
consumedInExp (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) =
forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {c}. (a, [VName], c) -> Names
inputConsumed [WithAccInput rep]
inputs)
forall a. Semigroup a => a -> a -> a
<> ( forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
lam
Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam))
)
where
inputConsumed :: (a, [VName], c) -> Names
inputConsumed (a
_, [VName]
arrs, c
_) = [VName] -> Names
namesFromList [VName]
arrs
consumedInExp (BasicOp (Update Safety
_ VName
src Slice SubExp
_ SubExp
_)) = VName -> Names
oneName VName
src
consumedInExp (BasicOp (FlatUpdate VName
src FlatSlice SubExp
_ VName
_)) = VName -> Names
oneName VName
src
consumedInExp (BasicOp (UpdateAcc VName
acc [SubExp]
_ [SubExp]
_)) = VName -> Names
oneName VName
acc
consumedInExp (BasicOp BasicOp
_) = forall a. Monoid a => a
mempty
consumedInExp (Op OpC rep rep
op) = forall op. AliasedOp op => op -> Names
consumedInOp OpC rep rep
op
consumedByLambda :: Aliased rep => Lambda rep -> Names
consumedByLambda :: forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda = forall rep. Aliased rep => Body rep -> Names
consumedInBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody
patAliases :: AliasesOf dec => Pat dec -> [Names]
patAliases :: forall dec. AliasesOf dec => Pat dec -> [Names]
patAliases = forall a b. (a -> b) -> [a] -> [b]
map forall a. AliasesOf a => a -> Names
aliasesOf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [PatElem dec]
patElems
class AliasesOf a where
aliasesOf :: a -> Names
instance AliasesOf Names where
aliasesOf :: Names -> Names
aliasesOf = forall a. a -> a
id
instance AliasesOf dec => AliasesOf (PatElem dec) where
aliasesOf :: PatElem dec -> Names
aliasesOf = forall a. AliasesOf a => a -> Names
aliasesOf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> dec
patElemDec
lookupAliases :: AliasesOf (LetDec rep) => VName -> Scope rep -> Names
lookupAliases :: forall rep. AliasesOf (LetDec rep) => VName -> Scope rep -> Names
lookupAliases VName
root Scope rep
scope =
Names -> VName -> Names
expand forall a. Monoid a => a
mempty VName
root
where
expand :: Names -> VName -> Names
expand Names
prev VName
v =
case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Scope rep
scope of
Just (LetName LetDec rep
dec) ->
VName -> Names
oneName VName
v
forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
(Names -> VName -> Names
expand (VName -> Names
oneName VName
v forall a. Semigroup a => a -> a -> a
<> Names
prev))
(forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`notNameIn` Names
prev) (Names -> [VName]
namesToList (forall a. AliasesOf a => a -> Names
aliasesOf LetDec rep
dec)))
Maybe (NameInfo rep)
_ -> VName -> Names
oneName VName
v
class IsOp op => AliasedOp op where
opAliases :: op -> [Names]
consumedInOp :: op -> Names
instance AliasedOp (NoOp rep) where
opAliases :: NoOp rep -> [Names]
opAliases NoOp rep
NoOp = []
consumedInOp :: NoOp rep -> Names
consumedInOp NoOp rep
NoOp = forall a. Monoid a => a
mempty
type AliasTable = M.Map VName Names