{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Simplify.Rules.Match (matchRules) where
import Control.Monad
import Data.Either
import Data.List (partition, transpose, unzip4, zip5)
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rule
import Futhark.Util
caseAlwaysMatches :: [SubExp] -> Case a -> Bool
caseAlwaysMatches :: forall a. [SubExp] -> Case a -> Bool
caseAlwaysMatches [SubExp]
ses = forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Maybe PrimValue -> Bool
match [SubExp]
ses forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> [Maybe PrimValue]
casePat
where
match :: SubExp -> Maybe PrimValue -> Bool
match SubExp
se (Just PrimValue
v) = SubExp
se forall a. Eq a => a -> a -> Bool
== PrimValue -> SubExp
Constant PrimValue
v
match SubExp
_ Maybe PrimValue
Nothing = Bool
True
caseNeverMatches :: [SubExp] -> Case a -> Bool
caseNeverMatches :: forall a. [SubExp] -> Case a -> Bool
caseNeverMatches [SubExp]
ses = forall (t :: * -> *). Foldable t => t Bool -> Bool
or forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Maybe PrimValue -> Bool
impossible [SubExp]
ses forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> [Maybe PrimValue]
casePat
where
impossible :: SubExp -> Maybe PrimValue -> Bool
impossible (Constant PrimValue
v1) (Just PrimValue
v2) = PrimValue
v1 forall a. Eq a => a -> a -> Bool
/= PrimValue
v2
impossible SubExp
_ Maybe PrimValue
_ = Bool
False
ruleMatch :: BuilderOps rep => TopDownRuleMatch rep
ruleMatch :: forall rep. BuilderOps rep => TopDownRuleMatch rep
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
defbody, MatchDec (BranchType rep)
ifdec)
| ([Case (Body rep)]
impossible, [Case (Body rep)]
cases') <- forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall a. [SubExp] -> Case a -> Bool
caseNeverMatches [SubExp]
cond) [Case (Body rep)]
cases,
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Case (Body rep)]
impossible =
forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases' Body rep
defbody MatchDec (BranchType rep)
ifdec
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
_, MatchDec (BranchType rep)
ifdec)
| ([Case (Body rep)]
always_matches, [Case (Body rep)]
cases') <- forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall a. [SubExp] -> Case a -> Bool
caseAlwaysMatches [SubExp]
cond) [Case (Body rep)]
cases,
Case (Body rep)
new_default : [Case (Body rep)]
_ <- forall a. [a] -> [a]
reverse [Case (Body rep)]
always_matches =
forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases' (forall body. Case body -> body
caseBody Case (Body rep)
new_default) MatchDec (BranchType rep)
ifdec
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
_ ExpDec rep
_) ([SubExp]
_, [], Body rep
defbody, MatchDec (BranchType rep)
_) = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
Result
defbody_res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body rep
defbody
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) Result
defbody_res) forall a b. (a -> b) -> a -> b
$ \(PatElem (LetDec rep)
pe, SubExpRes
res) ->
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (SubExpRes -> Certs
resCerts SubExpRes
res) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
pe]) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
res)
ruleMatch
TopDown rep
_
Pat (LetDec rep)
pat
StmAux (ExpDec rep)
_
( [SubExp
cond],
[ Case
[Just (BoolValue Bool
True)]
(Body BodyDec rep
_ Seq (Stm rep)
tstms [SubExpRes Certs
tcs (Constant (BoolValue Bool
True))])
],
Body BodyDec rep
_ Seq (Stm rep)
fstms [SubExpRes Certs
fcs SubExp
se],
MatchDec [BranchType rep]
ts MatchSort
_
)
| forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
tstms,
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
fstms,
[Prim PrimType
Bool] <- forall a b. (a -> b) -> [a] -> [b]
map forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ts =
forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs forall a. Semigroup a => a -> a -> a
<> Certs
fcs) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp
cond], [Case [Just (BoolValue Bool
True)] Body rep
tb], Body rep
fb, MatchDec [BranchType rep]
ts MatchSort
_)
| Body BodyDec rep
_ Seq (Stm rep)
tstms [SubExpRes Certs
tcs SubExp
tres] <- Body rep
tb,
Body BodyDec rep
_ Seq (Stm rep)
fstms [SubExpRes Certs
fcs SubExp
fres] <- Body rep
fb,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). IsOp (Op rep) => Exp rep -> Bool
safeExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Exp rep
stmExp) forall a b. (a -> b) -> a -> b
$ Seq (Stm rep)
tstms forall a. Semigroup a => a -> a -> a
<> Seq (Stm rep)
fstms,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf) [BranchType rep]
ts = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm rep)
tstms
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm rep)
fstms
Exp rep
e <-
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
BinOp
LogOr
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
( forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
BinOp
LogAnd
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres)
)
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs forall a. Semigroup a => a -> a -> a
<> Certs
fcs) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat Exp rep
e
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
_, [Case [Maybe PrimValue]
_ Body rep
tbranch], Body rep
_, MatchDec [BranchType rep]
_ MatchSort
MatchFallback)
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). IsOp (Op rep) => Exp rep -> Bool
safeExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Exp rep
stmExp) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
tbranch = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
let ses :: Result
ses = forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
tbranch
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
tbranch
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
[ forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
p] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElem (LetDec rep)
p, SubExpRes Certs
cs SubExp
se) <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) Result
ses
]
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp
cond], [Case [Just (BoolValue Bool
True)] Body rep
tb], Body rep
fb, MatchDec (BranchType rep)
_)
| Body BodyDec rep
_ Seq (Stm rep)
_ [SubExpRes Certs
tcs (Constant (IntValue IntValue
t))] <- Body rep
tb,
Body BodyDec rep
_ Seq (Stm rep)
_ [SubExpRes Certs
fcs (Constant (IntValue IntValue
f))] <- Body rep
fb =
if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f Bool -> Bool -> Bool
&& Certs
tcs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty Bool -> Bool -> Bool
&& Certs
fcs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
then
forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
else
if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
then forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
SubExp
cond_neg <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond_neg" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
else forall {k} (rep :: k). Rule rep
Skip
ruleMatch TopDown rep
vtable (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
aux ([SubExp]
_c, [Case [Maybe PrimValue]
_ Body rep
tb], Body rep
fb, MatchDec [BranchType rep
_] MatchSort
_)
| Body BodyDec rep
_ Seq (Stm rep)
tstms [SubExpRes Certs
xcs SubExp
x] <- Body rep
tb,
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
tstms,
Body BodyDec rep
_ Seq (Stm rep)
fstms [SubExpRes Certs
ycs SubExp
y] <- Body rep
fb,
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
fstms,
SubExp -> SubExp -> Bool
matches SubExp
x SubExp
y Bool -> Bool -> Bool
|| SubExp -> SubExp -> Bool
matches SubExp
y SubExp
x =
forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux forall a. Semigroup a => a -> a -> a
<> Certs
xcs forall a. Semigroup a => a -> a -> a
<> Certs
ycs) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
pe]) (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
y)
where
z :: VName
z = forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
matches :: SubExp -> SubExp -> Bool
matches (Var VName
x) SubExp
y
| Just (SubExp
initial, SubExp
res) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (SubExp, SubExp)
ST.lookupLoopParam VName
x TopDown rep
vtable =
SubExp
initial forall a. Eq a => a -> a -> Bool
== SubExp
y Bool -> Bool -> Bool
&& SubExp
res forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
z
matches SubExp
_ SubExp
_ = Bool
False
ruleMatch TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ ([SubExp], [Case (Body rep)], Body rep, MatchDec (BranchType rep))
_ = forall {k} (rep :: k). Rule rep
Skip
hoistBranchInvariant :: BuilderOps rep => TopDownRuleMatch rep
hoistBranchInvariant :: forall rep. BuilderOps rep => TopDownRuleMatch rep
hoistBranchInvariant TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
defbody, MatchDec [BranchType rep]
ret MatchSort
ifsort) =
let case_reses :: [Result]
case_reses = forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
defbody_res :: Result
defbody_res = forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
defbody
([RuleM rep (Int, SubExp)]
hoistings, ([PatElem (LetDec rep)]
pes, [BranchType rep]
ts, [Result]
case_reses_tr, Result
defbody_res')) =
(forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [Either a b] -> ([a], [b])
partitionEithers) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
-> Either
(RuleM rep (Int, SubExp))
(PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
branchInvariant forall a b. (a -> b) -> a -> b
$
forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
zip5 [Int
0 ..] (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) [BranchType rep]
ret (forall a. [[a]] -> [[a]]
transpose [Result]
case_reses) Result
defbody_res
in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [RuleM rep (Int, SubExp)]
hoistings
then forall {k} (rep :: k). Rule rep
Skip
else forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
[(Int, SubExp)]
ctx_fixes <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [RuleM rep (Int, SubExp)]
hoistings
let onCase :: Case (Body rep) -> Result -> Case (Body rep)
onCase (Case [Maybe PrimValue]
vs Body rep
body) Result
case_res = forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyResult :: Result
bodyResult = Result
case_res}
cases' :: [Case (Body rep)]
cases' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} {rep :: k}. Case (Body rep) -> Result -> Case (Body rep)
onCase [Case (Body rep)]
cases forall a b. (a -> b) -> a -> b
$ forall a. [[a]] -> [[a]]
transpose [Result]
case_reses_tr
defbody' :: Body rep
defbody' = Body rep
defbody {bodyResult :: Result
bodyResult = Result
defbody_res'}
ret' :: [BranchType rep]
ret' = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchType rep]
ts [(Int, SubExp)]
ctx_fixes
[Case (Body rep)]
cases'' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall {m :: * -> *}.
MonadBuilder m =>
[TypeBase ExtShape NoUniqueness]
-> Body (Rep m) -> m (Body (Rep m))
reshapeBodyResults forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ret') [Case (Body rep)]
cases'
Body rep
defbody'' <- forall {m :: * -> *}.
MonadBuilder m =>
[TypeBase ExtShape NoUniqueness]
-> Body (Rep m) -> m (Body (Rep m))
reshapeBodyResults (forall a b. (a -> b) -> [a] -> [b]
map forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ret') Body rep
defbody'
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases'' Body rep
defbody'' (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType rep]
ret' MatchSort
ifsort)
where
bound_in_branches :: Names
bound_in_branches =
[VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
defbody
branchInvariant :: (Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
-> Either
(RuleM rep (Int, SubExp))
(PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
branchInvariant (Int
i, PatElem (LetDec rep)
pe, BranchType rep
t, Result
case_reses, SubExpRes
defres)
| Names -> Names -> Bool
namesIntersect Names
bound_in_branches forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ SubExpRes
defres forall a. a -> [a] -> [a]
: Result
case_reses =
Either
(RuleM rep (Int, SubExp))
(PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
noHoisting
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
defres) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
case_reses forall a. Semigroup a => a -> a -> a
<> SubExpRes -> Certs
resCerts SubExpRes
defres) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$
SubExpRes -> SubExp
resSubExp SubExpRes
defres
forall {f :: * -> *} {a} {dec}.
Applicative f =>
a -> PatElem dec -> f (a, SubExp)
hoisted Int
i PatElem (LetDec rep)
pe
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ Names -> Names -> Bool
namesIntersect Names
bound_in_branches forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ SubExpRes
defres forall a. a -> [a] -> [a]
: Result
case_reses,
forall dec. Pat dec -> Int
patSize Pat (LetDec rep)
pat forall a. Ord a => a -> a -> Bool
> Int
1,
Prim PrimType
_ <- forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem (LetDec rep)
pe = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ do
[BranchType rep]
bt <- forall {k} (rep :: k) (m :: * -> *).
(ASTRep rep, HasScope rep m, Monad m) =>
Pat (LetDec rep) -> m [BranchType rep]
expTypesFromPat forall a b. (a -> b) -> a -> b
$ forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
pe]
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall body. [Maybe PrimValue] -> body -> Case body
Case (forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> [Maybe PrimValue]
casePat [Case (Body rep)]
cases)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses
)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExpRes -> SubExp
resSubExp SubExpRes
defres]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType rep]
bt MatchSort
ifsort)
)
forall {f :: * -> *} {a} {dec}.
Applicative f =>
a -> PatElem dec -> f (a, SubExp)
hoisted Int
i PatElem (LetDec rep)
pe
| Bool
otherwise = Either
(RuleM rep (Int, SubExp))
(PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
noHoisting
where
noHoisting :: Either
(RuleM rep (Int, SubExp))
(PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
noHoisting = forall a b. b -> Either a b
Right (PatElem (LetDec rep)
pe, BranchType rep
t, Result
case_reses, SubExpRes
defres)
hoisted :: a -> PatElem dec -> f (a, SubExp)
hoisted a
i PatElem dec
pe = forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
i, VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem dec
pe)
reshapeBodyResults :: [TypeBase ExtShape NoUniqueness]
-> Body (Rep m) -> m (Body (Rep m))
reshapeBodyResults [TypeBase ExtShape NoUniqueness]
rets Body (Rep m)
body = forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
Result
ses <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep m)
body
let (Result
ctx_ses, Result
val_ses) = forall a. Int -> [a] -> ([a], [a])
splitFromEnd (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase ExtShape NoUniqueness]
rets) Result
ses
(Result
ctx_ses ++) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {m :: * -> *}.
MonadBuilder m =>
SubExpRes -> TypeBase ExtShape NoUniqueness -> m SubExpRes
reshapeResult Result
val_ses [TypeBase ExtShape NoUniqueness]
rets
reshapeResult :: SubExpRes -> TypeBase ExtShape NoUniqueness -> m SubExpRes
reshapeResult (SubExpRes Certs
cs (Var VName
v)) t :: TypeBase ExtShape NoUniqueness
t@Array {} = do
TypeBase Shape NoUniqueness
v_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
let newshape :: [SubExp]
newshape = forall u. TypeBase Shape u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape NoUniqueness
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
removeExistentials TypeBase ExtShape NoUniqueness
t TypeBase Shape NoUniqueness
v_t
Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> if [SubExp]
newshape forall a. Eq a => a -> a -> Bool
/= forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
v_t
then forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" (forall {k} (rep :: k). [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newshape VName
v)
else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
reshapeResult SubExpRes
se TypeBase ExtShape NoUniqueness
_ =
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExpRes
se
removeDeadBranchResult :: BuilderOps rep => BottomUpRuleMatch rep
removeDeadBranchResult :: forall rep. BuilderOps rep => BottomUpRuleMatch rep
removeDeadBranchResult (SymbolTable rep
_, UsageTable
used) Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
defbody, MatchDec [BranchType rep]
rettype MatchSort
ifsort)
|
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> Names
freeIn (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)) (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat),
[Bool]
patused <- forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat,
Bool -> Bool
not (forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) = do
let pick :: [a] -> [a]
pick :: forall a. [a] -> [a]
pick = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
pat' :: [PatElem (LetDec rep)]
pat' = forall a. [a] -> [a]
pick forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
rettype' :: [BranchType rep]
rettype' = forall a. [a] -> [a]
pick [BranchType rep]
rettype
forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
[Case (Body rep)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall {m :: * -> *}.
MonadBuilder m =>
(Result -> Result) -> Body (Rep m) -> m (Body (Rep m))
onBody forall a. [a] -> [a]
pick) [Case (Body rep)]
cases
Body rep
defbody' <- forall {m :: * -> *}.
MonadBuilder m =>
(Result -> Result) -> Body (Rep m) -> m (Body (Rep m))
onBody forall a. [a] -> [a]
pick Body rep
defbody
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pat') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases' Body rep
defbody' forall a b. (a -> b) -> a -> b
$ forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType rep]
rettype' MatchSort
ifsort
| Bool
otherwise = forall {k} (rep :: k). Rule rep
Skip
where
onBody :: (Result -> Result) -> Body (Rep m) -> m (Body (Rep m))
onBody Result -> Result
pick (Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
res) = forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms forall a b. (a -> b) -> a -> b
$ Result -> Result
pick Result
res
topDownRules :: BuilderOps rep => [TopDownRule rep]
topDownRules :: forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules =
[ forall {k} (rep :: k) a.
RuleMatch rep a -> SimplificationRule rep a
RuleMatch forall rep. BuilderOps rep => TopDownRuleMatch rep
ruleMatch,
forall {k} (rep :: k) a.
RuleMatch rep a -> SimplificationRule rep a
RuleMatch forall rep. BuilderOps rep => TopDownRuleMatch rep
hoistBranchInvariant
]
bottomUpRules :: (BuilderOps rep) => [BottomUpRule rep]
bottomUpRules :: forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules =
[ forall {k} (rep :: k) a.
RuleMatch rep a -> SimplificationRule rep a
RuleMatch forall rep. BuilderOps rep => BottomUpRuleMatch rep
removeDeadBranchResult
]
matchRules :: (BuilderOps rep) => RuleBook rep
matchRules :: forall rep. BuilderOps rep => RuleBook rep
matchRules = forall {k} (m :: k).
[TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules