{-# LANGUAGE TypeFamilies #-}

-- | Match simplification rules.
module Futhark.Optimise.Simplify.Rules.Match (matchRules) where

import Control.Monad
import Data.Either
import Data.List (partition, transpose, unzip4, zip5)
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rule
import Futhark.Util

-- Does this case always match the scrutinees?
caseAlwaysMatches :: [SubExp] -> Case a -> Bool
caseAlwaysMatches :: forall a. [SubExp] -> Case a -> Bool
caseAlwaysMatches [SubExp]
ses = forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Maybe PrimValue -> Bool
match [SubExp]
ses forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> [Maybe PrimValue]
casePat
  where
    match :: SubExp -> Maybe PrimValue -> Bool
match SubExp
se (Just PrimValue
v) = SubExp
se forall a. Eq a => a -> a -> Bool
== PrimValue -> SubExp
Constant PrimValue
v
    match SubExp
_ Maybe PrimValue
Nothing = Bool
True

-- Can this case never match the scrutinees?
caseNeverMatches :: [SubExp] -> Case a -> Bool
caseNeverMatches :: forall a. [SubExp] -> Case a -> Bool
caseNeverMatches [SubExp]
ses = forall (t :: * -> *). Foldable t => t Bool -> Bool
or forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Maybe PrimValue -> Bool
impossible [SubExp]
ses forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> [Maybe PrimValue]
casePat
  where
    impossible :: SubExp -> Maybe PrimValue -> Bool
impossible (Constant PrimValue
v1) (Just PrimValue
v2) = PrimValue
v1 forall a. Eq a => a -> a -> Bool
/= PrimValue
v2
    impossible SubExp
_ Maybe PrimValue
_ = Bool
False

ruleMatch :: BuilderOps rep => TopDownRuleMatch rep
-- Remove impossible cases.
ruleMatch :: forall rep. BuilderOps rep => TopDownRuleMatch rep
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
defbody, MatchDec (BranchType rep)
ifdec)
  | ([Case (Body rep)]
impossible, [Case (Body rep)]
cases') <- forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall a. [SubExp] -> Case a -> Bool
caseNeverMatches [SubExp]
cond) [Case (Body rep)]
cases,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Case (Body rep)]
impossible =
      forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases' Body rep
defbody MatchDec (BranchType rep)
ifdec
-- Find new default case.
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
_, MatchDec (BranchType rep)
ifdec)
  | ([Case (Body rep)]
always_matches, [Case (Body rep)]
cases') <- forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall a. [SubExp] -> Case a -> Bool
caseAlwaysMatches [SubExp]
cond) [Case (Body rep)]
cases,
    Case (Body rep)
new_default : [Case (Body rep)]
_ <- forall a. [a] -> [a]
reverse [Case (Body rep)]
always_matches =
      forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases' (forall body. Case body -> body
caseBody Case (Body rep)
new_default) MatchDec (BranchType rep)
ifdec
-- Remove caseless match.
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
_ ExpDec rep
_) ([SubExp]
_, [], Body rep
defbody, MatchDec (BranchType rep)
_) = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
  Result
defbody_res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body rep
defbody
  forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) Result
defbody_res) forall a b. (a -> b) -> a -> b
$ \(PatElem (LetDec rep)
pe, SubExpRes
res) ->
    forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (SubExpRes -> Certs
resCerts SubExpRes
res) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
pe]) forall a b. (a -> b) -> a -> b
$
      forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
res)
-- IMPROVE: the following two rules can be generalised to work in more
-- cases, especially when the branches have bindings, or return more
-- than one value.
--
-- if c then True else v == c || v
ruleMatch
  TopDown rep
_
  Pat (LetDec rep)
pat
  StmAux (ExpDec rep)
_
  ( [SubExp
cond],
    [ Case
        [Just (BoolValue Bool
True)]
        (Body BodyDec rep
_ Seq (Stm rep)
tstms [SubExpRes Certs
tcs (Constant (BoolValue Bool
True))])
      ],
    Body BodyDec rep
_ Seq (Stm rep)
fstms [SubExpRes Certs
fcs SubExp
se],
    MatchDec [BranchType rep]
ts MatchSort
_
    )
    | forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
tstms,
      forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
fstms,
      [Prim PrimType
Bool] <- forall a b. (a -> b) -> [a] -> [b]
map forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ts =
        forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs forall a. Semigroup a => a -> a -> a
<> Certs
fcs) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
-- When type(x)==bool, if c then x else y == (c && x) || (!c && y)
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp
cond], [Case [Just (BoolValue Bool
True)] Body rep
tb], Body rep
fb, MatchDec [BranchType rep]
ts MatchSort
_)
  | Body BodyDec rep
_ Seq (Stm rep)
tstms [SubExpRes Certs
tcs SubExp
tres] <- Body rep
tb,
    Body BodyDec rep
_ Seq (Stm rep)
fstms [SubExpRes Certs
fcs SubExp
fres] <- Body rep
fb,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Exp rep
stmExp) forall a b. (a -> b) -> a -> b
$ Seq (Stm rep)
tstms forall a. Semigroup a => a -> a -> a
<> Seq (Stm rep)
fstms,
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf) [BranchType rep]
ts = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm rep)
tstms
      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm rep)
fstms
      Exp rep
e <-
        forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
          BinOp
LogOr
          (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
          ( forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              BinOp
LogAnd
              (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
              (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres)
          )
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs forall a. Semigroup a => a -> a -> a
<> Certs
fcs) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat Exp rep
e
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
_, [Case [Maybe PrimValue]
_ Body rep
tbranch], Body rep
_, MatchDec [BranchType rep]
_ MatchSort
MatchFallback)
  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Exp rep
stmExp) forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body rep
tbranch = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let ses :: Result
ses = forall rep. Body rep -> Result
bodyResult Body rep
tbranch
      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body rep
tbranch
      forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
        [ forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
p] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
          | (PatElem (LetDec rep)
p, SubExpRes Certs
cs SubExp
se) <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) Result
ses
        ]
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp
cond], [Case [Just (BoolValue Bool
True)] Body rep
tb], Body rep
fb, MatchDec (BranchType rep)
_)
  | Body BodyDec rep
_ Seq (Stm rep)
_ [SubExpRes Certs
tcs (Constant (IntValue IntValue
t))] <- Body rep
tb,
    Body BodyDec rep
_ Seq (Stm rep)
_ [SubExpRes Certs
fcs (Constant (IntValue IntValue
f))] <- Body rep
fb =
      if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f Bool -> Bool -> Bool
&& Certs
tcs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty Bool -> Bool -> Bool
&& Certs
fcs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
        then
          forall rep. RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
        else
          if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
            then forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
              SubExp
cond_neg <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond_neg" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
              forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
            else forall rep. Rule rep
Skip
-- Simplify
--
--   let z = if c then x else y
--
-- to
--
--   let z = y
--
-- in the case where 'x' is a loop parameter with initial value 'y'
-- and the new value of the loop parameter is 'z'.  ('x' and 'y' can
-- be flipped.)
ruleMatch TopDown rep
vtable (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
aux ([SubExp]
_c, [Case [Maybe PrimValue]
_ Body rep
tb], Body rep
fb, MatchDec [BranchType rep
_] MatchSort
_)
  | Body BodyDec rep
_ Seq (Stm rep)
tstms [SubExpRes Certs
xcs SubExp
x] <- Body rep
tb,
    forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
tstms,
    Body BodyDec rep
_ Seq (Stm rep)
fstms [SubExpRes Certs
ycs SubExp
y] <- Body rep
fb,
    forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
fstms,
    SubExp -> SubExp -> Bool
matches SubExp
x SubExp
y Bool -> Bool -> Bool
|| SubExp -> SubExp -> Bool
matches SubExp
y SubExp
x =
      forall rep. RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux forall a. Semigroup a => a -> a -> a
<> Certs
xcs forall a. Semigroup a => a -> a -> a
<> Certs
ycs) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
pe]) (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
y)
  where
    z :: VName
z = forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
    matches :: SubExp -> SubExp -> Bool
matches (Var VName
x) SubExp
y
      | Just (SubExp
initial, SubExp
res) <- forall rep. VName -> SymbolTable rep -> Maybe (SubExp, SubExp)
ST.lookupLoopParam VName
x TopDown rep
vtable =
          SubExp
initial forall a. Eq a => a -> a -> Bool
== SubExp
y Bool -> Bool -> Bool
&& SubExp
res forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
z
    matches SubExp
_ SubExp
_ = Bool
False
ruleMatch TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ ([SubExp], [Case (Body rep)], Body rep, MatchDec (BranchType rep))
_ = forall rep. Rule rep
Skip

-- | Move out results of a conditional expression whose computation is
-- either invariant to the branches (only done for results used for
-- existentials), or the same in both branches.
hoistBranchInvariant :: BuilderOps rep => TopDownRuleMatch rep
hoistBranchInvariant :: forall rep. BuilderOps rep => TopDownRuleMatch rep
hoistBranchInvariant TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
defbody, MatchDec [BranchType rep]
ret MatchSort
ifsort) =
  let case_reses :: [Result]
case_reses = forall a b. (a -> b) -> [a] -> [b]
map (forall rep. Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
      defbody_res :: Result
defbody_res = forall rep. Body rep -> Result
bodyResult Body rep
defbody
      ([RuleM rep (Int, SubExp)]
hoistings, ([PatElem (LetDec rep)]
pes, [BranchType rep]
ts, [Result]
case_reses_tr, Result
defbody_res')) =
        (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [Either a b] -> ([a], [b])
partitionEithers) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
-> Either
     (RuleM rep (Int, SubExp))
     (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
branchInvariant forall a b. (a -> b) -> a -> b
$
          forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
zip5 [Int
0 ..] (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) [BranchType rep]
ret (forall a. [[a]] -> [[a]]
transpose [Result]
case_reses) Result
defbody_res
   in if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [RuleM rep (Int, SubExp)]
hoistings
        then forall rep. Rule rep
Skip
        else forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
          [(Int, SubExp)]
ctx_fixes <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [RuleM rep (Int, SubExp)]
hoistings
          let onCase :: Case (Body rep) -> Result -> Case (Body rep)
onCase (Case [Maybe PrimValue]
vs Body rep
body) Result
case_res = forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyResult :: Result
bodyResult = Result
case_res}
              cases' :: [Case (Body rep)]
cases' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {rep}. Case (Body rep) -> Result -> Case (Body rep)
onCase [Case (Body rep)]
cases forall a b. (a -> b) -> a -> b
$ forall a. [[a]] -> [[a]]
transpose [Result]
case_reses_tr
              defbody' :: Body rep
defbody' = Body rep
defbody {bodyResult :: Result
bodyResult = Result
defbody_res'}
              ret' :: [BranchType rep]
ret' = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchType rep]
ts [(Int, SubExp)]
ctx_fixes
          -- We may have to add some reshapes if we made the type
          -- less existential.
          [Case (Body rep)]
cases'' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall {m :: * -> *}.
MonadBuilder m =>
[TypeBase ExtShape NoUniqueness]
-> Body (Rep m) -> m (Body (Rep m))
reshapeBodyResults forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ret') [Case (Body rep)]
cases'
          Body rep
defbody'' <- forall {m :: * -> *}.
MonadBuilder m =>
[TypeBase ExtShape NoUniqueness]
-> Body (Rep m) -> m (Body (Rep m))
reshapeBodyResults (forall a b. (a -> b) -> [a] -> [b]
map forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ret') Body rep
defbody'
          forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes) forall a b. (a -> b) -> a -> b
$ forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases'' Body rep
defbody'' (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType rep]
ret' MatchSort
ifsort)
  where
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Pat (LetDec rep)
stmPat) forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall rep. Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Stms rep
bodyStms Body rep
defbody

    branchInvariant :: (Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
-> Either
     (RuleM rep (Int, SubExp))
     (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
branchInvariant (Int
i, PatElem (LetDec rep)
pe, BranchType rep
t, Result
case_reses, SubExpRes
defres)
      -- If just one branch has a variant result, then we give up.
      | Names -> Names -> Bool
namesIntersect Names
bound_in_branches forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ SubExpRes
defres forall a. a -> [a] -> [a]
: Result
case_reses =
          Either
  (RuleM rep (Int, SubExp))
  (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
noHoisting
      -- Do all branches return the same value?
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
defres) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ do
          forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
case_reses forall a. Semigroup a => a -> a -> a
<> SubExpRes -> Certs
resCerts SubExpRes
defres) forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$
              SubExpRes -> SubExp
resSubExp SubExpRes
defres
          forall {f :: * -> *} {a} {dec}.
Applicative f =>
a -> PatElem dec -> f (a, SubExp)
hoisted Int
i PatElem (LetDec rep)
pe

      -- Do all branches return values that are free in the
      -- branch, and are we not the only pattern element?  The
      -- latter is to avoid infinite application of this rule.
      | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ Names -> Names -> Bool
namesIntersect Names
bound_in_branches forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ SubExpRes
defres forall a. a -> [a] -> [a]
: Result
case_reses,
        forall dec. Pat dec -> Int
patSize Pat (LetDec rep)
pat forall a. Ord a => a -> a -> Bool
> Int
1,
        Prim PrimType
_ <- forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem (LetDec rep)
pe = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ do
          [BranchType rep]
bt <- forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m, Monad m) =>
Pat (LetDec rep) -> m [BranchType rep]
expTypesFromPat forall a b. (a -> b) -> a -> b
$ forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
pe]
          forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe]
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond
                    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall body. [Maybe PrimValue] -> body -> Case body
Case (forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> [Maybe PrimValue]
casePat [Case (Body rep)]
cases)
                            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses
                        )
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExpRes -> SubExp
resSubExp SubExpRes
defres]
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType rep]
bt MatchSort
ifsort)
                )
          forall {f :: * -> *} {a} {dec}.
Applicative f =>
a -> PatElem dec -> f (a, SubExp)
hoisted Int
i PatElem (LetDec rep)
pe
      | Bool
otherwise = Either
  (RuleM rep (Int, SubExp))
  (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
noHoisting
      where
        noHoisting :: Either
  (RuleM rep (Int, SubExp))
  (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
noHoisting = forall a b. b -> Either a b
Right (PatElem (LetDec rep)
pe, BranchType rep
t, Result
case_reses, SubExpRes
defres)

    hoisted :: a -> PatElem dec -> f (a, SubExp)
hoisted a
i PatElem dec
pe = forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
i, VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem dec
pe)

    reshapeBodyResults :: [TypeBase ExtShape NoUniqueness]
-> Body (Rep m) -> m (Body (Rep m))
reshapeBodyResults [TypeBase ExtShape NoUniqueness]
rets Body (Rep m)
body = forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
      Result
ses <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep m)
body
      let (Result
ctx_ses, Result
val_ses) = forall a. Int -> [a] -> ([a], [a])
splitFromEnd (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase ExtShape NoUniqueness]
rets) Result
ses
      (Result
ctx_ses ++) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {m :: * -> *}.
MonadBuilder m =>
SubExpRes -> TypeBase ExtShape NoUniqueness -> m SubExpRes
reshapeResult Result
val_ses [TypeBase ExtShape NoUniqueness]
rets
    reshapeResult :: SubExpRes -> TypeBase ExtShape NoUniqueness -> m SubExpRes
reshapeResult (SubExpRes Certs
cs (Var VName
v)) t :: TypeBase ExtShape NoUniqueness
t@Array {} = do
      TypeBase Shape NoUniqueness
v_t <- forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
      let newshape :: [SubExp]
newshape = forall u. TypeBase Shape u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape NoUniqueness
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
removeExistentials TypeBase ExtShape NoUniqueness
t TypeBase Shape NoUniqueness
v_t
      Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> if [SubExp]
newshape forall a. Eq a => a -> a -> Bool
/= forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
v_t
          then forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" (forall rep. [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newshape VName
v)
          else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
    reshapeResult SubExpRes
se TypeBase ExtShape NoUniqueness
_ =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExpRes
se

-- | Remove the return values of a branch, that are not actually used
-- after a branch.  Standard dead code removal can remove the branch
-- if *none* of the return values are used, but this rule is more
-- precise.
removeDeadBranchResult :: BuilderOps rep => BottomUpRuleMatch rep
removeDeadBranchResult :: forall rep. BuilderOps rep => BottomUpRuleMatch rep
removeDeadBranchResult (SymbolTable rep
_, UsageTable
used) Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
defbody, MatchDec [BranchType rep]
rettype MatchSort
ifsort)
  | -- Only if there is no existential binding...
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`notNameIn` forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> Names
freeIn (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)) (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat),
    -- Figure out which of the names in 'pat' are used...
    [Bool]
patused <- forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat,
    -- If they are not all used, then this rule applies.
    Bool -> Bool
not (forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) = do
      -- Remove the parts of the branch-results that correspond to dead
      -- return value bindings.  Note that this leaves dead code in the
      -- branch bodies, but that will be removed later.
      let pick :: [a] -> [a]
          pick :: forall a. [a] -> [a]
pick = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
          pat' :: [PatElem (LetDec rep)]
pat' = forall a. [a] -> [a]
pick forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
          rettype' :: [BranchType rep]
rettype' = forall a. [a] -> [a]
pick [BranchType rep]
rettype
      forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
        [Case (Body rep)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall {m :: * -> *}.
MonadBuilder m =>
(Result -> Result) -> Body (Rep m) -> m (Body (Rep m))
onBody forall a. [a] -> [a]
pick) [Case (Body rep)]
cases
        Body rep
defbody' <- forall {m :: * -> *}.
MonadBuilder m =>
(Result -> Result) -> Body (Rep m) -> m (Body (Rep m))
onBody forall a. [a] -> [a]
pick Body rep
defbody
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pat') forall a b. (a -> b) -> a -> b
$ forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases' Body rep
defbody' forall a b. (a -> b) -> a -> b
$ forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType rep]
rettype' MatchSort
ifsort
  | Bool
otherwise = forall rep. Rule rep
Skip
  where
    onBody :: (Result -> Result) -> Body (Rep m) -> m (Body (Rep m))
onBody Result -> Result
pick (Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
res) = forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms forall a b. (a -> b) -> a -> b
$ Result -> Result
pick Result
res

topDownRules :: BuilderOps rep => [TopDownRule rep]
topDownRules :: forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules =
  [ forall rep a. RuleMatch rep a -> SimplificationRule rep a
RuleMatch forall rep. BuilderOps rep => TopDownRuleMatch rep
ruleMatch,
    forall rep a. RuleMatch rep a -> SimplificationRule rep a
RuleMatch forall rep. BuilderOps rep => TopDownRuleMatch rep
hoistBranchInvariant
  ]

bottomUpRules :: (BuilderOps rep) => [BottomUpRule rep]
bottomUpRules :: forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules =
  [ forall rep a. RuleMatch rep a -> SimplificationRule rep a
RuleMatch forall rep. BuilderOps rep => BottomUpRuleMatch rep
removeDeadBranchResult
  ]

matchRules :: (BuilderOps rep) => RuleBook rep
matchRules :: forall rep. BuilderOps rep => RuleBook rep
matchRules = forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules