{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}

-- | This module defines a collection of simplification rules, as per
-- "Futhark.Optimise.Simplify.Rule".  They are used in the
-- simplifier.
--
-- For performance reasons, many sufficiently simple logically
-- separate rules are merged into single "super-rules", like ruleIf
-- and ruleBasicOp.  This is because it is relatively expensive to
-- activate a rule just to determine that it does not apply.  Thus, it
-- is more efficient to have a few very fat rules than a lot of small
-- rules.  This does not affect the compiler result in any way; it is
-- purely an optimisation to speed up compilation.
module Futhark.Optimise.Simplify.Rules
  ( standardRules,
    removeUnnecessaryCopy,
  )
where

import Control.Monad
import Data.Either
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.BasicOp
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Util

topDownRules :: BinderOps lore => [TopDownRule lore]
topDownRules :: [TopDownRule lore]
topDownRules =
  [ RuleGeneric lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleGeneric lore a -> SimplificationRule lore a
RuleGeneric RuleGeneric lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun,
    RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
ruleIf,
    RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant
  ]

bottomUpRules :: BinderOps lore => [BottomUpRule lore]
bottomUpRules :: [BottomUpRule lore]
bottomUpRules =
  [ RuleIf lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult,
    RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex
  ]

-- | A set of standard simplification rules.  These assume pure
-- functional semantics, and so probably should not be applied after
-- memory block merging.
standardRules :: (BinderOps lore, Aliased lore) => RuleBook lore
standardRules :: RuleBook lore
standardRules = [TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule lore]
forall lore. BinderOps lore => [TopDownRule lore]
topDownRules [BottomUpRule lore]
forall lore. BinderOps lore => [BottomUpRule lore]
bottomUpRules RuleBook lore -> RuleBook lore -> RuleBook lore
forall a. Semigroup a => a -> a -> a
<> RuleBook lore
forall lore. (BinderOps lore, Aliased lore) => RuleBook lore
loopRules RuleBook lore -> RuleBook lore -> RuleBook lore
forall a. Semigroup a => a -> a -> a
<> RuleBook lore
forall lore. (BinderOps lore, Aliased lore) => RuleBook lore
basicOpRules

-- | Turn @copy(x)@ into @x@ iff @x@ is not used after this copy
-- statement and it can be consumed.
--
-- This simplistic rule is only valid before we introduce memory.
removeUnnecessaryCopy :: BinderOps lore => BottomUpRuleBasicOp lore
removeUnnecessaryCopy :: BottomUpRuleBasicOp lore
removeUnnecessaryCopy (SymbolTable lore
vtable, UsageTable
used) (Pattern [] [PatElemT (LetDec lore)
d]) StmAux (ExpDec lore)
_ (Copy VName
v)
  | Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
    (Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used) Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  where
    -- We need to make sure we can even consume the original.
    -- This is currently a hacky check, much too conservative,
    -- because we don't have the information conveniently
    -- available.
    consumable :: Bool
consumable = case VName -> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName (NameInfo lore) -> Maybe (NameInfo lore))
-> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (NameInfo lore)
forall lore. SymbolTable lore -> Scope lore
ST.toScope SymbolTable lore
vtable of
      Just (FParamName FParamInfo lore
info) -> TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> TypeBase Shape Uniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ FParamInfo lore -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf FParamInfo lore
info
      Maybe (NameInfo lore)
_ -> Bool
False
removeUnnecessaryCopy (SymbolTable lore, UsageTable)
_ PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip

constantFoldPrimFun :: BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun :: TopDownRuleGeneric lore
constantFoldPrimFun TopDown lore
_ (Let Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_))
  | Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
    Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- String
-> Map
     String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
    Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
result
  where
    isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown lore
_ Stm lore
_ = Rule lore
forall lore. Rule lore
Skip

simplifyIndex :: BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex :: BottomUpRuleBasicOp lore
simplifyIndex (SymbolTable lore
vtable, UsageTable
used) pat :: Pattern lore
pat@(Pattern [] [PatElemT (LetDec lore)
pe]) (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Index VName
idd Slice SubExp
inds)
  | Just RuleM lore IndexResult
m <- SymbolTable (Lore (RuleM lore))
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable lore
SymbolTable (Lore (RuleM lore))
vtable TypeLookup
seType VName
idd Slice SubExp
inds Bool
consumed = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    IndexResult
res <- RuleM lore IndexResult
m
    Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
      SubExpResult Certificates
cs' SubExp
se ->
        Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs') (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
      IndexResult Certificates
extra_cs VName
idd' Slice SubExp
inds' ->
        Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
extra_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
  where
    consumed :: Bool
consumed = PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
    seType :: TypeLookup
seType (Var VName
v) = VName -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
v SymbolTable lore
vtable
    seType (Constant PrimValue
v) = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip

ruleIf :: BinderOps lore => TopDownRuleIf lore
ruleIf :: TopDownRuleIf lore
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
_ IfSort
ifsort)
  | Just BodyT lore
branch <- Maybe (BodyT lore)
checkBranch,
    IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    let ses :: Result
ses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
branch
    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
branch
    [RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
      [ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        | (PatElemT (LetDec lore)
p, SubExp
se) <- [PatElemT (LetDec lore)]
-> Result -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) Result
ses
      ]
  where
    checkBranch :: Maybe (BodyT lore)
checkBranch
      | SubExp -> Bool
isCt1 SubExp
e1 = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
tb
      | SubExp -> Bool
isCt0 SubExp
e1 = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
fb
      | Bool
otherwise = Maybe (BodyT lore)
forall a. Maybe a
Nothing

-- IMPROVE: the following two rules can be generalised to work in more
-- cases, especially when the branches have bindings, or return more
-- than one value.
--
-- if c then True else v == c || v
ruleIf
  TopDown lore
_
  Pattern lore
pat
  StmAux (ExpDec lore)
_
  ( SubExp
cond,
    Body BodyDec lore
_ Stms lore
tstms [Constant (BoolValue Bool
True)],
    Body BodyDec lore
_ Stms lore
fstms [SubExp
se],
    IfDec [BranchType lore]
ts IfSort
_
    )
    | Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
tstms,
      Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
fstms,
      [Prim PrimType
Bool] <- (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ts =
      RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
-- When type(x)==bool, if c then x else y == (c && x) || (!c && y)
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ts IfSort
_)
  | Body BodyDec lore
_ Stms lore
tstms [SubExp
tres] <- BodyT lore
tb,
    Body BodyDec lore
_ Stms lore
fstms [SubExp
fres] <- BodyT lore
fb,
    (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ Stms lore
tstms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
fstms,
    (BranchType lore -> Bool) -> [BranchType lore] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((ExtType -> ExtType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) (ExtType -> Bool)
-> (BranchType lore -> ExtType) -> BranchType lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf) [BranchType lore]
ts = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
tstms
    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
fstms
    ExpT lore
e <-
      BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
        BinOp
LogOr
        (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
        ( BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
            BinOp
LogAnd
            (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
            (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres)
        )
    Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat ExpT lore
Exp (Lore (RuleM lore))
e
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
_, BodyT lore
tbranch, BodyT lore
_, IfDec [BranchType lore]
_ IfSort
IfFallback)
  | [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames Pattern lore
pat,
    (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    let ses :: Result
ses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
tbranch
    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch
    [RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
      [ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        | (PatElemT (LetDec lore)
p, SubExp
se) <- [PatElemT (LetDec lore)]
-> Result -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) Result
ses
      ]
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec (BranchType lore)
_)
  | Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
t)] <- BodyT lore
tb,
    Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
f)] <- BodyT lore
fb =
    if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f
      then
        RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
      else
        if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
          then RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
            SubExp
cond_neg <- String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
            Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
          else Rule lore
forall lore. Rule lore
Skip
ruleIf TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ (SubExp, BodyT lore, BodyT lore, IfDec (BranchType lore))
_ = Rule lore
forall lore. Rule lore
Skip

-- | Move out results of a conditional expression whose computation is
-- either invariant to the branches (only done for results in the
-- context), or the same in both branches.
hoistBranchInvariant :: BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant :: TopDownRuleIf lore
hoistBranchInvariant TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ret IfSort
ifsort) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
  let tses :: Result
tses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
tb
      fses :: Result
fses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
fb
  ([Maybe (Int, SubExp)]
hoistings, ([PatElemT (LetDec lore)]
pes, [Either Int (BranchType lore)]
ts, [(SubExp, SubExp)]
res)) <-
    ([Either
    (Maybe (Int, SubExp))
    (PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
      [(SubExp, SubExp)])))
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
-> RuleM
     lore
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
       [(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(PatElemT (LetDec lore), Either Int (BranchType lore),
   (SubExp, SubExp))]
 -> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))])
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElemT (LetDec lore), Either Int (BranchType lore),
  (SubExp, SubExp))]
-> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
    [(SubExp, SubExp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (([Maybe (Int, SubExp)],
  [(PatElemT (LetDec lore), Either Int (BranchType lore),
    (SubExp, SubExp))])
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
      [(SubExp, SubExp)])))
-> ([Either
       (Maybe (Int, SubExp))
       (PatElemT (LetDec lore), Either Int (BranchType lore),
        (SubExp, SubExp))]
    -> ([Maybe (Int, SubExp)],
        [(PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))]))
-> [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetDec lore), Either Int (BranchType lore),
       (SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
     [(SubExp, SubExp)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
   (Maybe (Int, SubExp))
   (PatElemT (LetDec lore), Either Int (BranchType lore),
    (SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) (RuleM
   lore
   [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetDec lore), Either Int (BranchType lore),
       (SubExp, SubExp))]
 -> RuleM
      lore
      ([Maybe (Int, SubExp)],
       ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
        [(SubExp, SubExp)])))
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
-> RuleM
     lore
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
       [(SubExp, SubExp)]))
forall a b. (a -> b) -> a -> b
$
      ((PatElemT (LetDec lore), Either Int (BranchType lore),
  (SubExp, SubExp))
 -> RuleM
      lore
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))))
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElemT (LetDec lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
branchInvariant ([(PatElemT (LetDec lore), Either Int (BranchType lore),
   (SubExp, SubExp))]
 -> RuleM
      lore
      [Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))])
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
-> RuleM
     lore
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
        [PatElemT (LetDec lore)]
-> [Either Int (BranchType lore)]
-> [(SubExp, SubExp)]
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
     (SubExp, SubExp))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
          (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat)
          ((Int -> Either Int (BranchType lore))
-> [Int] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Either Int (BranchType lore)
forall a b. a -> Either a b
Left [Int
0 .. Int
num_ctx Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Either Int (BranchType lore)]
-> [Either Int (BranchType lore)] -> [Either Int (BranchType lore)]
forall a. [a] -> [a] -> [a]
++ (BranchType lore -> Either Int (BranchType lore))
-> [BranchType lore] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> Either Int (BranchType lore)
forall a b. b -> Either a b
Right [BranchType lore]
ret)
          (Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
tses Result
fses)
  let ctx_fixes :: [(Int, SubExp)]
ctx_fixes = [Maybe (Int, SubExp)] -> [(Int, SubExp)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, SubExp)]
hoistings
      (Result
tses', Result
fses') = [(SubExp, SubExp)] -> (Result, Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, SubExp)]
res
      tb' :: BodyT lore
tb' = BodyT lore
tb {bodyResult :: Result
bodyResult = Result
tses'}
      fb' :: BodyT lore
fb' = BodyT lore
fb {bodyResult :: Result
bodyResult = Result
fses'}
      ret' :: [BranchType lore]
ret' = ((Int, SubExp) -> [BranchType lore] -> [BranchType lore])
-> [BranchType lore] -> [(Int, SubExp)] -> [BranchType lore]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType lore] -> [BranchType lore])
-> (Int, SubExp) -> [BranchType lore] -> [BranchType lore]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType lore] -> [BranchType lore]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) ([Either Int (BranchType lore)] -> [BranchType lore]
forall a b. [Either a b] -> [b]
rights [Either Int (BranchType lore)]
ts) [(Int, SubExp)]
ctx_fixes
      ([PatElemT (LetDec lore)]
ctx_pes, [PatElemT (LetDec lore)]
val_pes) = Int
-> [PatElemT (LetDec lore)]
-> ([PatElemT (LetDec lore)], [PatElemT (LetDec lore)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
ret') [PatElemT (LetDec lore)]
pes
  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe (Int, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (Int, SubExp)]
hoistings -- Was something hoisted?
    then do
      -- We may have to add some reshapes if we made the type
      -- less existential.
      BodyT lore
tb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
tb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
      BodyT lore
fb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
fb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
      Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
ctx_pes [PatElemT (LetDec lore)]
val_pes) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT lore
tb'' BodyT lore
fb'' ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
ret' IfSort
ifsort)
    else RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
  where
    num_ctx :: Int
num_ctx = [PatElemT (LetDec lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PatElemT (LetDec lore)] -> Int)
-> [PatElemT (LetDec lore)] -> Int
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
        (Stm lore -> [VName]) -> Seq (Stm lore) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Pattern lore -> [VName])
-> (Stm lore -> Pattern lore) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern) (Seq (Stm lore) -> [VName]) -> Seq (Stm lore) -> [VName]
forall a b. (a -> b) -> a -> b
$
          BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tb Seq (Stm lore) -> Seq (Stm lore) -> Seq (Stm lore)
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fb
    mem_sizes :: Names
mem_sizes = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([PatElemT (LetDec lore)] -> Names)
-> [PatElemT (LetDec lore)] -> Names
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Bool
forall shape u. TypeBase shape u -> Bool
isMem (Type -> Bool)
-> (PatElemT (LetDec lore) -> Type)
-> PatElemT (LetDec lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec lore) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
    invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
    invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_branches

    isMem :: TypeBase shape u -> Bool
isMem Mem {} = Bool
True
    isMem TypeBase shape u
_ = Bool
False
    sizeOfMem :: VName -> Bool
sizeOfMem VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
mem_sizes

    branchInvariant :: (PatElemT (LetDec lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
branchInvariant (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse, SubExp
fse))
      -- Do both branches return the same value?
      | SubExp
tse SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
fse = do
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
tse
        PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t

      -- Do both branches return values that are free in the
      -- branch, and are we not the only pattern element?  The
      -- latter is to avoid infinite application of this rule.
      | SubExp -> Bool
invariant SubExp
tse,
        SubExp -> Bool
invariant SubExp
fse,
        Pattern lore -> Int
forall dec. PatternT dec -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
        Prim PrimType
_ <- PatElemT (LetDec lore) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec lore)
pe,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
sizeOfMem (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe = do
        [BranchType lore]
bt <- Pattern lore -> RuleM lore [BranchType lore]
forall lore (m :: * -> *).
(ASTLore lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern (Pattern lore -> RuleM lore [BranchType lore])
-> Pattern lore -> RuleM lore [BranchType lore]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
pe]
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe]
          (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond (BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM [SubExp
tse]
                  RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM [SubExp
fse]
                  RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
bt IfSort
ifsort)
              )
        PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t
      | Bool
otherwise =
        Either
  (Maybe (Int, SubExp))
  (PatElemT (LetDec lore), Either Int (BranchType lore),
   (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either
   (Maybe (Int, SubExp))
   (PatElemT (LetDec lore), Either Int (BranchType lore),
    (SubExp, SubExp))
 -> RuleM
      lore
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec lore), Either Int (BranchType lore),
          (SubExp, SubExp))))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))
-> RuleM
     lore
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec lore), Either Int (BranchType lore),
         (SubExp, SubExp)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore), Either Int (BranchType lore),
 (SubExp, SubExp))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetDec lore), Either Int (BranchType lore),
      (SubExp, SubExp))
forall a b. b -> Either a b
Right (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse, SubExp
fse))

    hoisted :: PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT dec
pe (Left a
i) = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left (Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b)
-> Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. (a -> b) -> a -> b
$ (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe)
    hoisted PatElemT dec
_ Right {} = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left Maybe (a, SubExp)
forall a. Maybe a
Nothing

    reshapeBodyResults :: BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT (Lore m)
body [ExtType]
rets = m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (BodyT (Lore m)) -> m (BodyT (Lore m)))
-> m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ do
      Result
ses <- BodyT (Lore m) -> m Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind BodyT (Lore m)
body
      let (Result
ctx_ses, Result
val_ses) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
rets) Result
ses
      Result -> m (BodyT (Lore m))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM (Result -> m (BodyT (Lore m)))
-> (Result -> Result) -> Result -> m (BodyT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result
ctx_ses Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++) (Result -> m (BodyT (Lore m))) -> m Result -> m (BodyT (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SubExp -> ExtType -> m SubExp) -> Result -> [ExtType] -> m Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> ExtType -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> ExtType -> m SubExp
reshapeResult Result
val_ses [ExtType]
rets
    reshapeResult :: SubExp -> ExtType -> m SubExp
reshapeResult (Var VName
v) t :: ExtType
t@Array {} = do
      Type
v_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
      let newshape :: Result
newshape = Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims (Type -> Result) -> Type -> Result
forall a b. (a -> b) -> a -> b
$ ExtType -> Type -> Type
removeExistentials ExtType
t Type
v_t
      if Result
newshape Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
v_t
        then String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Result -> VName -> Exp (Lore m)
forall lore. Result -> VName -> Exp lore
shapeCoerce Result
newshape VName
v
        else SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
    reshapeResult SubExp
se ExtType
_ =
      SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

-- | Remove the return values of a branch, that are not actually used
-- after a branch.  Standard dead code removal can remove the branch
-- if *none* of the return values are used, but this rule is more
-- precise.
removeDeadBranchResult :: BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult :: BottomUpRuleIf lore
removeDeadBranchResult (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
rettype IfSort
ifsort)
  | -- Only if there is no existential context...
    Pattern lore -> Int
forall dec. PatternT dec -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
rettype,
    -- Figure out which of the names in 'pat' are used...
    [Bool]
patused <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat,
    -- If they are not all used, then this rule applies.
    Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) =
    -- Remove the parts of the branch-results that correspond to dead
    -- return value bindings.  Note that this leaves dead code in the
    -- branch bodies, but that will be removed later.
    let tses :: Result
tses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
tb
        fses :: Result
fses = BodyT lore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT lore
fb
        pick :: [a] -> [a]
        pick :: [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
        tb' :: BodyT lore
tb' = BodyT lore
tb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
tses}
        fb' :: BodyT lore
fb' = BodyT lore
fb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
fses}
        pat' :: [PatElemT (LetDec lore)]
pat' = [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. [a] -> [a]
pick ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
        rettype' :: [BranchType lore]
rettype' = [BranchType lore] -> [BranchType lore]
forall a. [a] -> [a]
pick [BranchType lore]
rettype
     in RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
pat') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
e1 BodyT lore
tb' BodyT lore
fb' (IfDec (BranchType lore) -> ExpT lore)
-> IfDec (BranchType lore) -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
rettype' IfSort
ifsort
  | Bool
otherwise = Rule lore
forall lore. Rule lore
Skip

-- Some helper functions

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False