{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}

-- | This module defines a collection of simplification rules, as per
-- "Futhark.Optimise.Simplify.Rule".  They are used in the
-- simplifier.
--
-- For performance reasons, many sufficiently simple logically
-- separate rules are merged into single "super-rules", like ruleIf
-- and ruleBasicOp.  This is because it is relatively expensive to
-- activate a rule just to determine that it does not apply.  Thus, it
-- is more efficient to have a few very fat rules than a lot of small
-- rules.  This does not affect the compiler result in any way; it is
-- purely an optimisation to speed up compilation.
module Futhark.Optimise.Simplify.Rules
  ( standardRules,
    removeUnnecessaryCopy,
  )
where

import Control.Monad
import Data.Either
import Data.List (find, unzip4, zip4)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.BasicOp
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Util

topDownRules :: BuilderOps rep => [TopDownRule rep]
topDownRules :: [TopDownRule rep]
topDownRules =
  [ RuleGeneric rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun,
    RuleIf rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleIf rep
ruleIf,
    RuleIf rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleIf rep
hoistBranchInvariant,
    RuleGeneric rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleGeneric rep
withAccTopDown
  ]

bottomUpRules :: BuilderOps rep => [BottomUpRule rep]
bottomUpRules :: [BottomUpRule rep]
bottomUpRules =
  [ RuleIf rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleIf rep
removeDeadBranchResult,
    RuleGeneric rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleGeneric rep
withAccBottomUp,
    RuleBasicOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex
  ]

-- | A set of standard simplification rules.  These assume pure
-- functional semantics, and so probably should not be applied after
-- memory block merging.
standardRules :: (BuilderOps rep, Aliased rep) => RuleBook rep
standardRules :: RuleBook rep
standardRules = [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule rep]
forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules [BottomUpRule rep]
forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
loopRules RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
basicOpRules

-- | Turn @copy(x)@ into @x@ iff @x@ is not used after this copy
-- statement and it can be consumed.
--
-- This simplistic rule is only valid before we introduce memory.
removeUnnecessaryCopy :: (BuilderOps rep, Aliased rep) => BottomUpRuleBasicOp rep
removeUnnecessaryCopy :: BottomUpRuleBasicOp rep
removeUnnecessaryCopy (SymbolTable rep
vtable, UsageTable
used) (Pat [PatElemT (LetDec rep)
d]) StmAux (ExpDec rep)
_ (Copy VName
v)
  | Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
    -- This next line is too conservative, but the problem is that 'v'
    -- might not look like it has been consumed if it is consumed in
    -- an outer scope.  This is because the simplifier applies
    -- bottom-up rules in a kind of deepest-first order.
    Bool -> Bool
not (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isInResult` UsageTable
used) Bool -> Bool -> Bool
|| (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
    (Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used) Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
d] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  where
    -- We need to make sure we can even consume the original.  The big
    -- missing piece here is that we cannot do copy removal inside of
    -- 'map' and other SOACs, but that is handled by SOAC-specific rules.
    consumable :: Bool
consumable = Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
fromMaybe Bool
False (Maybe Bool -> Bool) -> Maybe Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
      Entry rep
e <- VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Entry rep -> Int
forall rep. Entry rep -> Int
ST.entryDepth Entry rep
e Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== SymbolTable rep -> Int
forall rep. SymbolTable rep -> Int
ST.loopDepth SymbolTable rep
vtable
      Entry rep -> Maybe Bool
consumableStm Entry rep
e Maybe Bool -> Maybe Bool -> Maybe Bool
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Entry rep -> Maybe Bool
consumableFParam Entry rep
e
    consumableFParam :: Entry rep -> Maybe Bool
consumableFParam =
      Bool -> Maybe Bool
forall a. a -> Maybe a
Just (Bool -> Maybe Bool)
-> (Entry rep -> Bool) -> Entry rep -> Maybe Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> (FParamInfo rep -> Bool) -> Maybe (FParamInfo rep) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> (FParamInfo rep -> TypeBase Shape Uniqueness)
-> FParamInfo rep
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParamInfo rep -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf) (Maybe (FParamInfo rep) -> Bool)
-> (Entry rep -> Maybe (FParamInfo rep)) -> Entry rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry rep -> Maybe (FParamInfo rep)
forall rep. Entry rep -> Maybe (FParamInfo rep)
ST.entryFParam
    consumableStm :: Entry rep -> Maybe Bool
consumableStm Entry rep
e = do
      PatT (LetDec rep)
pat <- Stm rep -> PatT (LetDec rep)
forall rep. Stm rep -> Pat rep
stmPat (Stm rep -> PatT (LetDec rep))
-> Maybe (Stm rep) -> Maybe (PatT (LetDec rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Entry rep -> Maybe (Stm rep)
forall rep. Entry rep -> Maybe (Stm rep)
ST.entryStm Entry rep
e
      PatElemT (LetDec rep)
pe <- (PatElemT (LetDec rep) -> Bool)
-> [PatElemT (LetDec rep)] -> Maybe (PatElemT (LetDec rep))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> (PatElemT (LetDec rep) -> VName)
-> PatElemT (LetDec rep)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName) (PatT (LetDec rep) -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems PatT (LetDec rep)
pat)
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec rep) -> Names
forall a. AliasesOf a => a -> Names
aliasesOf PatElemT (LetDec rep)
pe Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
      Bool -> Maybe Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
removeUnnecessaryCopy (SymbolTable rep, UsageTable)
_ PatT (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip

constantFoldPrimFun :: BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun :: TopDownRuleGeneric rep
constantFoldPrimFun TopDown rep
_ (Let Pat rep
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_))
  | Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
    Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- String
-> Map
     String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
    Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
result
  where
    isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown rep
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip

simplifyIndex :: BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex :: BottomUpRuleBasicOp rep
simplifyIndex (SymbolTable rep
vtable, UsageTable
used) pat :: Pat rep
pat@(Pat [PatElemT (LetDec rep)
pe]) (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Index VName
idd Slice SubExp
inds)
  | Just RuleM rep IndexResult
m <- SymbolTable (Rep (RuleM rep))
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM rep IndexResult)
forall (m :: * -> *).
MonadBuilder m =>
SymbolTable (Rep m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable rep
SymbolTable (Rep (RuleM rep))
vtable TypeLookup
seType VName
idd Slice SubExp
inds Bool
consumed = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    IndexResult
res <- RuleM rep IndexResult
m
    Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
      SubExpResult Certs
cs' SubExp
se ->
        Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs') (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
      IndexResult Certs
extra_cs VName
idd' Slice SubExp
inds' ->
        Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
extra_cs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
  where
    consumed :: Bool
consumed = PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
    seType :: TypeLookup
seType (Var VName
v) = VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable
    seType (Constant PrimValue
v) = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable rep, UsageTable)
_ Pat rep
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip

ruleIf :: BuilderOps rep => TopDownRuleIf rep
ruleIf :: TopDownRuleIf rep
ruleIf TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
e1, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
_ IfSort
ifsort)
  | Just BodyT rep
branch <- Maybe (BodyT rep)
checkBranch,
    IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let ses :: Result
ses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
branch
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
branch
    [RuleM rep ()] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
      [ Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        | (PatElemT (LetDec rep)
p, SubExpRes Certs
cs SubExp
se) <- [PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat) Result
ses
      ]
  where
    checkBranch :: Maybe (BodyT rep)
checkBranch
      | SubExp -> Bool
isCt1 SubExp
e1 = BodyT rep -> Maybe (BodyT rep)
forall a. a -> Maybe a
Just BodyT rep
tb
      | SubExp -> Bool
isCt0 SubExp
e1 = BodyT rep -> Maybe (BodyT rep)
forall a. a -> Maybe a
Just BodyT rep
fb
      | Bool
otherwise = Maybe (BodyT rep)
forall a. Maybe a
Nothing

-- IMPROVE: the following two rules can be generalised to work in more
-- cases, especially when the branches have bindings, or return more
-- than one value.
--
-- if c then True else v == c || v
ruleIf
  TopDown rep
_
  Pat rep
pat
  StmAux (ExpDec rep)
_
  ( SubExp
cond,
    Body BodyDec rep
_ Stms rep
tstms [SubExpRes Certs
tcs (Constant (BoolValue Bool
True))],
    Body BodyDec rep
_ Stms rep
fstms [SubExpRes Certs
fcs SubExp
se],
    IfDec [BranchType rep]
ts IfSort
_
    )
    | Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms rep
tstms,
      Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms rep
fstms,
      [Prim PrimType
Bool] <- (BranchType rep -> ExtType) -> [BranchType rep] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType rep]
ts =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
fcs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
-- When type(x)==bool, if c then x else y == (c && x) || (!c && y)
ruleIf TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
cond, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
ts IfSort
_)
  | Body BodyDec rep
_ Stms rep
tstms [SubExpRes Certs
tcs SubExp
tres] <- BodyT rep
tb,
    Body BodyDec rep
_ Stms rep
fstms [SubExpRes Certs
fcs SubExp
fres] <- BodyT rep
fb,
    (Stm rep -> Bool) -> Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (ExpT rep -> Bool) -> (Stm rep -> ExpT rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp) (Stms rep -> Bool) -> Stms rep -> Bool
forall a b. (a -> b) -> a -> b
$ Stms rep
tstms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
fstms,
    (BranchType rep -> Bool) -> [BranchType rep] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((ExtType -> ExtType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) (ExtType -> Bool)
-> (BranchType rep -> ExtType) -> BranchType rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf) [BranchType rep]
ts = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (RuleM rep))
tstms
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (RuleM rep))
fstms
    ExpT rep
e <-
      BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
        BinOp
LogOr
        (ExpT rep -> RuleM rep (ExpT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT rep -> RuleM rep (ExpT rep))
-> ExpT rep -> RuleM rep (ExpT rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
        ( BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
            BinOp
LogAnd
            (ExpT rep -> RuleM rep (ExpT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT rep -> RuleM rep (ExpT rep))
-> ExpT rep -> RuleM rep (ExpT rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
            (ExpT rep -> RuleM rep (ExpT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT rep -> RuleM rep (ExpT rep))
-> ExpT rep -> RuleM rep (ExpT rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres)
        )
    Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
fcs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat ExpT rep
Exp (Rep (RuleM rep))
e
ruleIf TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
_, BodyT rep
tbranch, BodyT rep
_, IfDec [BranchType rep]
_ IfSort
IfFallback)
  | (Stm rep -> Bool) -> Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (ExpT rep -> Bool) -> (Stm rep -> ExpT rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp) (Stms rep -> Bool) -> Stms rep -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
tbranch = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let ses :: Result
ses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
tbranch
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
tbranch
    [RuleM rep ()] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
      [ Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        | (PatElemT (LetDec rep)
p, SubExpRes Certs
cs SubExp
se) <- [PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat) Result
ses
      ]
ruleIf TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
cond, BodyT rep
tb, BodyT rep
fb, IfDec (BranchType rep)
_)
  | Body BodyDec rep
_ Stms rep
_ [SubExpRes Certs
tcs (Constant (IntValue IntValue
t))] <- BodyT rep
tb,
    Body BodyDec rep
_ Stms rep
_ [SubExpRes Certs
fcs (Constant (IntValue IntValue
f))] <- BodyT rep
fb =
    if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f Bool -> Bool -> Bool
&& Certs
tcs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty Bool -> Bool -> Bool
&& Certs
fcs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty
      then
        RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
          Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
      else
        if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
          then RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
            SubExp
cond_neg <- String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
            Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
          else Rule rep
forall rep. Rule rep
Skip
ruleIf TopDown rep
_ Pat rep
_ StmAux (ExpDec rep)
_ (SubExp, BodyT rep, BodyT rep, IfDec (BranchType rep))
_ = Rule rep
forall rep. Rule rep
Skip

-- | Move out results of a conditional expression whose computation is
-- either invariant to the branches (only done for results used for
-- existentials), or the same in both branches.
hoistBranchInvariant :: BuilderOps rep => TopDownRuleIf rep
hoistBranchInvariant :: TopDownRuleIf rep
hoistBranchInvariant TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
cond, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
ret IfSort
ifsort) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  let tses :: Result
tses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
tb
      fses :: Result
fses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
fb
  ([Maybe (Int, SubExp)]
hoistings, ([PatElemT (LetDec rep)]
pes, [BranchType rep]
ts, [(SubExpRes, SubExpRes)]
res)) <-
    ([Either
    (Maybe (Int, SubExp))
    (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetDec rep)], [BranchType rep],
      [(SubExpRes, SubExpRes)])))
-> RuleM
     rep
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> RuleM
     rep
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetDec rep)], [BranchType rep],
       [(SubExpRes, SubExpRes)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
 -> ([PatElemT (LetDec rep)], [BranchType rep],
     [(SubExpRes, SubExpRes)]))
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetDec rep)], [BranchType rep],
     [(SubExpRes, SubExpRes)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([PatElemT (LetDec rep)], [BranchType rep],
    [(SubExpRes, SubExpRes)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (([Maybe (Int, SubExp)],
  [(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetDec rep)], [BranchType rep],
      [(SubExpRes, SubExpRes)])))
-> ([Either
       (Maybe (Int, SubExp))
       (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
    -> ([Maybe (Int, SubExp)],
        [(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]))
-> [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetDec rep)], [BranchType rep],
     [(SubExpRes, SubExpRes)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
   (Maybe (Int, SubExp))
   (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) (RuleM
   rep
   [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
 -> RuleM
      rep
      ([Maybe (Int, SubExp)],
       ([PatElemT (LetDec rep)], [BranchType rep],
        [(SubExpRes, SubExpRes)])))
-> ([(Int, PatElemT (LetDec rep), BranchType rep,
      (SubExpRes, SubExpRes))]
    -> RuleM
         rep
         [Either
            (Maybe (Int, SubExp))
            (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
-> [(Int, PatElemT (LetDec rep), BranchType rep,
     (SubExpRes, SubExpRes))]
-> RuleM
     rep
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetDec rep)], [BranchType rep],
       [(SubExpRes, SubExpRes)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, PatElemT (LetDec rep), BranchType rep,
  (SubExpRes, SubExpRes))
 -> RuleM
      rep
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))))
-> [(Int, PatElemT (LetDec rep), BranchType rep,
     (SubExpRes, SubExpRes))]
-> RuleM
     rep
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int, PatElemT (LetDec rep), BranchType rep,
 (SubExpRes, SubExpRes))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
branchInvariant ([(Int, PatElemT (LetDec rep), BranchType rep,
   (SubExpRes, SubExpRes))]
 -> RuleM
      rep
      ([Maybe (Int, SubExp)],
       ([PatElemT (LetDec rep)], [BranchType rep],
        [(SubExpRes, SubExpRes)])))
-> [(Int, PatElemT (LetDec rep), BranchType rep,
     (SubExpRes, SubExpRes))]
-> RuleM
     rep
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetDec rep)], [BranchType rep],
       [(SubExpRes, SubExpRes)]))
forall a b. (a -> b) -> a -> b
$
      [Int]
-> [PatElemT (LetDec rep)]
-> [BranchType rep]
-> [(SubExpRes, SubExpRes)]
-> [(Int, PatElemT (LetDec rep), BranchType rep,
     (SubExpRes, SubExpRes))]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int
0 ..] (Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat) [BranchType rep]
ret (Result -> Result -> [(SubExpRes, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
tses Result
fses)
  let ctx_fixes :: [(Int, SubExp)]
ctx_fixes = [Maybe (Int, SubExp)] -> [(Int, SubExp)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, SubExp)]
hoistings
      (Result
tses', Result
fses') = [(SubExpRes, SubExpRes)] -> (Result, Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExpRes, SubExpRes)]
res
      tb' :: BodyT rep
tb' = BodyT rep
tb {bodyResult :: Result
bodyResult = Result
tses'}
      fb' :: BodyT rep
fb' = BodyT rep
fb {bodyResult :: Result
bodyResult = Result
fses'}
      ret' :: [BranchType rep]
ret' = ((Int, SubExp) -> [BranchType rep] -> [BranchType rep])
-> [BranchType rep] -> [(Int, SubExp)] -> [BranchType rep]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType rep] -> [BranchType rep])
-> (Int, SubExp) -> [BranchType rep] -> [BranchType rep]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType rep] -> [BranchType rep]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchType rep]
ts [(Int, SubExp)]
ctx_fixes
  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe (Int, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (Int, SubExp)]
hoistings -- Was something hoisted?
    then do
      -- We may have to add some reshapes if we made the type
      -- less existential.
      BodyT rep
tb'' <- BodyT (Rep (RuleM rep))
-> [ExtType] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BodyT (Rep m) -> [ExtType] -> m (BodyT (Rep m))
reshapeBodyResults BodyT rep
BodyT (Rep (RuleM rep))
tb' ([ExtType] -> RuleM rep (BodyT (Rep (RuleM rep))))
-> [ExtType] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ (BranchType rep -> ExtType) -> [BranchType rep] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType rep]
ret'
      BodyT rep
fb'' <- BodyT (Rep (RuleM rep))
-> [ExtType] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BodyT (Rep m) -> [ExtType] -> m (BodyT (Rep m))
reshapeBodyResults BodyT rep
BodyT (Rep (RuleM rep))
fb' ([ExtType] -> RuleM rep (BodyT (Rep (RuleM rep))))
-> [ExtType] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ (BranchType rep -> ExtType) -> [BranchType rep] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType rep]
ret'
      Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> Pat rep
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (LetDec rep)]
pes) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond BodyT rep
tb'' BodyT rep
fb'' ([BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
ret' IfSort
ifsort)
    else RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
  where
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList ([VName] -> Names)
-> (Seq (Stm rep) -> [VName]) -> Seq (Stm rep) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> [VName]) -> Seq (Stm rep) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames (Pat rep -> [VName]) -> (Stm rep -> Pat rep) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat rep
forall rep. Stm rep -> Pat rep
stmPat) (Seq (Stm rep) -> Names) -> Seq (Stm rep) -> Names
forall a b. (a -> b) -> a -> b
$
        BodyT rep -> Seq (Stm rep)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
tb Seq (Stm rep) -> Seq (Stm rep) -> Seq (Stm rep)
forall a. Semigroup a => a -> a -> a
<> BodyT rep -> Seq (Stm rep)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
fb
    invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
    invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_branches

    branchInvariant :: (Int, PatElemT (LetDec rep), BranchType rep,
 (SubExpRes, SubExpRes))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
branchInvariant (Int
i, PatElemT (LetDec rep)
pe, BranchType rep
t, (SubExpRes
tse, SubExpRes
fse))
      -- Do both branches return the same value?
      | SubExpRes
tse SubExpRes -> SubExpRes -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes
fse = do
        Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (SubExpRes -> Certs
resCerts SubExpRes
tse Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> SubExpRes -> Certs
resCerts SubExpRes
fse) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
tse
        Int
-> PatElemT (LetDec rep)
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall (f :: * -> *) a dec b.
Applicative f =>
a -> PatElemT dec -> f (Either (Maybe (a, SubExp)) b)
hoisted Int
i PatElemT (LetDec rep)
pe

      -- Do both branches return values that are free in the
      -- branch, and are we not the only pattern element?  The
      -- latter is to avoid infinite application of this rule.
      | SubExp -> Bool
invariant (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
tse,
        SubExp -> Bool
invariant (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
fse,
        Pat rep -> Int
forall dec. PatT dec -> Int
patSize Pat rep
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
        Prim PrimType
_ <- PatElemT (LetDec rep) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec rep)
pe = do
        [BranchType rep]
bt <- Pat rep -> RuleM rep [BranchType rep]
forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m, Monad m) =>
Pat rep -> m [BranchType rep]
expTypesFromPat (Pat rep -> RuleM rep [BranchType rep])
-> Pat rep -> RuleM rep [BranchType rep]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec rep)] -> Pat rep
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (LetDec rep)
pe]
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe]
          (ExpT rep -> RuleM rep ()) -> RuleM rep (ExpT rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond (BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (BodyT rep)
-> RuleM rep (BodyT rep -> IfDec (BranchType rep) -> ExpT rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExpRes -> SubExp
resSubExp SubExpRes
tse]
                  RuleM rep (BodyT rep -> IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (BodyT rep)
-> RuleM rep (IfDec (BranchType rep) -> ExpT rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExpRes -> SubExp
resSubExp SubExpRes
fse]
                  RuleM rep (IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (IfDec (BranchType rep)) -> RuleM rep (ExpT rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType rep) -> RuleM rep (IfDec (BranchType rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
bt IfSort
ifsort)
              )
        Int
-> PatElemT (LetDec rep)
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall (f :: * -> *) a dec b.
Applicative f =>
a -> PatElemT dec -> f (Either (Maybe (a, SubExp)) b)
hoisted Int
i PatElemT (LetDec rep)
pe
      | Bool
otherwise =
        Either
  (Maybe (Int, SubExp))
  (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either
   (Maybe (Int, SubExp))
   (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
 -> RuleM
      rep
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
forall a b. b -> Either a b
Right (PatElemT (LetDec rep)
pe, BranchType rep
t, (SubExpRes
tse, SubExpRes
fse))

    hoisted :: a -> PatElemT dec -> f (Either (Maybe (a, SubExp)) b)
hoisted a
i PatElemT dec
pe = Either (Maybe (a, SubExp)) b -> f (Either (Maybe (a, SubExp)) b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (Maybe (a, SubExp)) b -> f (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> f (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left (Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b)
-> Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. (a -> b) -> a -> b
$ (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe)

    reshapeBodyResults :: BodyT (Rep m) -> [ExtType] -> m (BodyT (Rep m))
reshapeBodyResults BodyT (Rep m)
body [ExtType]
rets = m Result -> m (BodyT (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (m Result -> m (BodyT (Rep m))) -> m Result -> m (BodyT (Rep m))
forall a b. (a -> b) -> a -> b
$ do
      Result
ses <- BodyT (Rep m) -> m Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind BodyT (Rep m)
body
      let (Result
ctx_ses, Result
val_ses) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
rets) Result
ses
      (Result
ctx_ses Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++) (Result -> Result) -> m Result -> m Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> ExtType -> m SubExpRes)
-> Result -> [ExtType] -> m Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExpRes -> ExtType -> m SubExpRes
forall (m :: * -> *).
MonadBuilder m =>
SubExpRes -> ExtType -> m SubExpRes
reshapeResult Result
val_ses [ExtType]
rets
    reshapeResult :: SubExpRes -> ExtType -> m SubExpRes
reshapeResult (SubExpRes Certs
cs (Var VName
v)) t :: ExtType
t@Array {} = do
      Type
v_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      let newshape :: [SubExp]
newshape = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ExtType -> Type -> Type
removeExistentials ExtType
t Type
v_t
      Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs
        (SubExp -> SubExpRes) -> m SubExp -> m SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> if [SubExp]
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
v_t
          then String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" ([SubExp] -> VName -> Exp (Rep m)
forall rep. [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newshape VName
v)
          else SubExp -> m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
    reshapeResult SubExpRes
se ExtType
_ =
      SubExpRes -> m SubExpRes
forall (m :: * -> *) a. Monad m => a -> m a
return SubExpRes
se

-- | Remove the return values of a branch, that are not actually used
-- after a branch.  Standard dead code removal can remove the branch
-- if *none* of the return values are used, but this rule is more
-- precise.
removeDeadBranchResult :: BuilderOps rep => BottomUpRuleIf rep
removeDeadBranchResult :: BottomUpRuleIf rep
removeDeadBranchResult (SymbolTable rep
_, UsageTable
used) Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
e1, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
rettype IfSort
ifsort)
  | -- Only if there is no existential binding...
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` (PatElemT (LetDec rep) -> Names)
-> [PatElemT (LetDec rep)] -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap PatElemT (LetDec rep) -> Names
forall a. FreeIn a => a -> Names
freeIn (Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat)) (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat),
    -- Figure out which of the names in 'pat' are used...
    [Bool]
patused <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat,
    -- If they are not all used, then this rule applies.
    Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) =
    -- Remove the parts of the branch-results that correspond to dead
    -- return value bindings.  Note that this leaves dead code in the
    -- branch bodies, but that will be removed later.
    let tses :: Result
tses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
tb
        fses :: Result
fses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
fb
        pick :: [a] -> [a]
        pick :: [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
        tb' :: BodyT rep
tb' = BodyT rep
tb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
tses}
        fb' :: BodyT rep
fb' = BodyT rep
fb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
fses}
        pat' :: [PatElemT (LetDec rep)]
pat' = [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. [a] -> [a]
pick ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)])
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a b. (a -> b) -> a -> b
$ Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat
        rettype' :: [BranchType rep]
rettype' = [BranchType rep] -> [BranchType rep]
forall a. [a] -> [a]
pick [BranchType rep]
rettype
     in RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> Pat rep
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (LetDec rep)]
pat') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
e1 BodyT rep
tb' BodyT rep
fb' (IfDec (BranchType rep) -> ExpT rep)
-> IfDec (BranchType rep) -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
rettype' IfSort
ifsort
  | Bool
otherwise = Rule rep
forall rep. Rule rep
Skip

withAccTopDown :: BuilderOps rep => TopDownRuleGeneric rep
-- A WithAcc with no accumulators is sent to Valhalla.
withAccTopDown :: TopDownRuleGeneric rep
withAccTopDown TopDown rep
_ (Let Pat rep
pat StmAux (ExpDec rep)
aux (WithAcc [] Lambda rep
lam)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  Result
lam_res <- Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
  [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat) Result
lam_res) (((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ())
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
    Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
-- Identify those results in 'lam' that are free and move them out.
withAccTopDown TopDown rep
vtable (Let Pat rep
pat StmAux (ExpDec rep)
aux (WithAcc [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs Lambda rep
lam)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  let ([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
        Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam
      (Result
acc_res, Result
nonacc_res) =
        Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT rep -> Result) -> BodyT rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
      ([PatElemT (LetDec rep)]
acc_pes, [PatElemT (LetDec rep)]
nonacc_pes) =
        Int
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs ([PatElemT (LetDec rep)]
 -> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)]))
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a b. (a -> b) -> a -> b
$ Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat

  -- Look at accumulator results.
  ([[PatElemT (LetDec rep)]]
acc_pes', [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs', [(Param (LParamInfo rep), Param (LParamInfo rep))]
params', Result
acc_res') <-
    ([Maybe
    ([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, [SubExp])),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
 -> ([[PatElemT (LetDec rep)]],
     [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
     [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> RuleM
     rep
     [Maybe
        ([PatElemT (LetDec rep)],
         (Shape, [VName], Maybe (Lambda rep, [SubExp])),
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
     rep
     ([[PatElemT (LetDec rep)]],
      [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
      [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([([PatElemT (LetDec rep)],
  (Shape, [VName], Maybe (Lambda rep, [SubExp])),
  (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
    [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([([PatElemT (LetDec rep)],
   (Shape, [VName], Maybe (Lambda rep, [SubExp])),
   (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
 -> ([[PatElemT (LetDec rep)]],
     [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
     [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([Maybe
       ([PatElemT (LetDec rep)],
        (Shape, [VName], Maybe (Lambda rep, [SubExp])),
        (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
    -> [([PatElemT (LetDec rep)],
         (Shape, [VName], Maybe (Lambda rep, [SubExp])),
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)])
-> [Maybe
      ([PatElemT (LetDec rep)],
       (Shape, [VName], Maybe (Lambda rep, [SubExp])),
       (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
    [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe
   ([PatElemT (LetDec rep)],
    (Shape, [VName], Maybe (Lambda rep, [SubExp])),
    (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, [SubExp])),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a. [Maybe a] -> [a]
catMaybes) (RuleM
   rep
   [Maybe
      ([PatElemT (LetDec rep)],
       (Shape, [VName], Maybe (Lambda rep, [SubExp])),
       (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
 -> RuleM
      rep
      ([[PatElemT (LetDec rep)]],
       [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
       [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([([PatElemT (LetDec rep)],
      (Shape, [VName], Maybe (Lambda rep, [SubExp])),
      (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
    -> RuleM
         rep
         [Maybe
            ([PatElemT (LetDec rep)],
             (Shape, [VName], Maybe (Lambda rep, [SubExp])),
             (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)])
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, [SubExp])),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
     rep
     ([[PatElemT (LetDec rep)]],
      [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
      [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElemT (LetDec rep)],
  (Shape, [VName], Maybe (Lambda rep, [SubExp])),
  (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
 -> RuleM
      rep
      (Maybe
         ([PatElemT (LetDec rep)],
          (Shape, [VName], Maybe (Lambda rep, [SubExp])),
          (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)))
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, [SubExp])),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
     rep
     [Maybe
        ([PatElemT (LetDec rep)],
         (Shape, [VName], Maybe (Lambda rep, [SubExp])),
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([PatElemT (LetDec rep)],
 (Shape, [VName], Maybe (Lambda rep, [SubExp])),
 (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
-> RuleM
     rep
     (Maybe
        ([PatElemT (LetDec rep)],
         (Shape, [VName], Maybe (Lambda rep, [SubExp])),
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes))
forall (m :: * -> *) dec a c a dec.
MonadBuilder m =>
([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc ([([PatElemT (LetDec rep)],
   (Shape, [VName], Maybe (Lambda rep, [SubExp])),
   (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
 -> RuleM
      rep
      ([[PatElemT (LetDec rep)]],
       [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
       [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, [SubExp])),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
     rep
     ([[PatElemT (LetDec rep)]],
      [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
      [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b. (a -> b) -> a -> b
$
      [[PatElemT (LetDec rep)]]
-> [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> Result
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, [SubExp])),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
        ([Int] -> [PatElemT (LetDec rep)] -> [[PatElemT (LetDec rep)]]
forall a. [Int] -> [a] -> [[a]]
chunks (((Shape, [VName], Maybe (Lambda rep, [SubExp])) -> Int)
-> [(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, [VName], Maybe (Lambda rep, [SubExp])) -> Int
forall (t :: * -> *) a a c. Foldable t => (a, t a, c) -> Int
inputArrs [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs) [PatElemT (LetDec rep)]
acc_pes)
        [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs
        ([Param (LParamInfo rep)]
-> [Param (LParamInfo rep)]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
acc_params)
        Result
acc_res
  let ([Param (LParamInfo rep)]
cert_params', [Param (LParamInfo rep)]
acc_params') = [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), Param (LParamInfo rep))]
params'

  -- Look at non-accumulator results.
  ([PatElemT (LetDec rep)]
nonacc_pes', Result
nonacc_res') <-
    [(PatElemT (LetDec rep), SubExpRes)]
-> ([PatElemT (LetDec rep)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElemT (LetDec rep), SubExpRes)]
 -> ([PatElemT (LetDec rep)], Result))
-> ([Maybe (PatElemT (LetDec rep), SubExpRes)]
    -> [(PatElemT (LetDec rep), SubExpRes)])
-> [Maybe (PatElemT (LetDec rep), SubExpRes)]
-> ([PatElemT (LetDec rep)], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (PatElemT (LetDec rep), SubExpRes)]
-> [(PatElemT (LetDec rep), SubExpRes)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (PatElemT (LetDec rep), SubExpRes)]
 -> ([PatElemT (LetDec rep)], Result))
-> RuleM rep [Maybe (PatElemT (LetDec rep), SubExpRes)]
-> RuleM rep ([PatElemT (LetDec rep)], Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((PatElemT (LetDec rep), SubExpRes)
 -> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes)))
-> [(PatElemT (LetDec rep), SubExpRes)]
-> RuleM rep [Maybe (PatElemT (LetDec rep), SubExpRes)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
tryMoveNonAcc ([PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (LetDec rep)]
nonacc_pes Result
nonacc_res)

  Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
acc_pes Bool -> Bool -> Bool
&& [PatElemT (LetDec rep)]
nonacc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

  Lambda rep
lam' <-
    [LParam (Rep (RuleM rep))]
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params' [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params') (RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep))))
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$
      Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam) {bodyResult :: Result
bodyResult = Result
acc_res' Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
nonacc_res'}

  Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> Pat rep
forall dec. [PatElemT dec] -> PatT dec
Pat ([[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)]
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. Semigroup a => a -> a -> a
<> [PatElemT (LetDec rep)]
nonacc_pes')) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
forall rep.
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
WithAcc [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs' Lambda rep
lam'
  where
    num_nonaccs :: Int
num_nonaccs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
lam) Int -> Int -> Int
forall a. Num a => a -> a -> a
- [(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs
    inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs

    tryMoveAcc :: ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc ([PatElemT dec]
pes, (a
_, [VName]
arrs, c
_), (a
_, Param dec
acc_p), SubExpRes Certs
cs (Var VName
v))
      | Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
acc_p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v,
        Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
        [(PatElemT dec, VName)] -> ((PatElemT dec, VName) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT dec] -> [VName] -> [(PatElemT dec, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT dec]
pes [VName]
arrs) (((PatElemT dec, VName) -> m ()) -> m ())
-> ((PatElemT dec, VName) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT dec
pe, VName
arr) ->
          [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
        Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
forall a. Maybe a
Nothing
    tryMoveAcc ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
x =
      Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
 -> m (Maybe
         ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)))
-> Maybe
     ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall a b. (a -> b) -> a -> b
$ ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> Maybe
     ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
forall a. a -> Maybe a
Just ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
x

    tryMoveNonAcc :: (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
tryMoveNonAcc (PatElemT (LetDec rep)
pe, SubExpRes Certs
cs (Var VName
v))
      | VName
v VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable,
        Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
        Maybe (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PatElemT (LetDec rep), SubExpRes)
forall a. Maybe a
Nothing
    tryMoveNonAcc (PatElemT (LetDec rep)
pe, SubExpRes Certs
cs (Constant PrimValue
v))
      | Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
        Maybe (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PatElemT (LetDec rep), SubExpRes)
forall a. Maybe a
Nothing
    tryMoveNonAcc (PatElemT (LetDec rep), SubExpRes)
x =
      Maybe (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (PatElemT (LetDec rep), SubExpRes)
 -> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes)))
-> Maybe (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec rep), SubExpRes)
-> Maybe (PatElemT (LetDec rep), SubExpRes)
forall a. a -> Maybe a
Just (PatElemT (LetDec rep), SubExpRes)
x
withAccTopDown TopDown rep
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip

withAccBottomUp :: BuilderOps rep => BottomUpRuleGeneric rep
-- Eliminate dead results.
withAccBottomUp :: BottomUpRuleGeneric rep
withAccBottomUp (SymbolTable rep
_, UsageTable
utable) (Let Pat rep
pat StmAux (ExpDec rep)
aux (WithAcc [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs Lambda rep
lam))
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let (Result
acc_res, Result
nonacc_res) =
          Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT rep -> Result) -> BodyT rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
        ([PatElemT (LetDec rep)]
acc_pes, [PatElemT (LetDec rep)]
nonacc_pes) =
          Int
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs ([PatElemT (LetDec rep)]
 -> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)]))
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a b. (a -> b) -> a -> b
$ Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat
        ([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
          Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam

    -- Eliminate unused accumulator results
    let ([[PatElemT (LetDec rep)]]
acc_pes', [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs', [(Param (LParamInfo rep), Param (LParamInfo rep))]
param_pairs, Result
acc_res') =
          [([PatElemT (LetDec rep)],
  (Shape, [VName], Maybe (Lambda rep, [SubExp])),
  (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
    [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([([PatElemT (LetDec rep)],
   (Shape, [VName], Maybe (Lambda rep, [SubExp])),
   (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
 -> ([[PatElemT (LetDec rep)]],
     [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
     [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([([PatElemT (LetDec rep)],
      (Shape, [VName], Maybe (Lambda rep, [SubExp])),
      (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
    -> [([PatElemT (LetDec rep)],
         (Shape, [VName], Maybe (Lambda rep, [SubExp])),
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)])
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, [SubExp])),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
    [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElemT (LetDec rep)],
  (Shape, [VName], Maybe (Lambda rep, [SubExp])),
  (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
 -> Bool)
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, [SubExp])),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, [SubExp])),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a. (a -> Bool) -> [a] -> [a]
filter ([PatElemT (LetDec rep)],
 (Shape, [VName], Maybe (Lambda rep, [SubExp])),
 (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
-> Bool
keepAccRes ([([PatElemT (LetDec rep)],
   (Shape, [VName], Maybe (Lambda rep, [SubExp])),
   (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
 -> ([[PatElemT (LetDec rep)]],
     [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
     [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, [SubExp])),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
    [(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b. (a -> b) -> a -> b
$
            [[PatElemT (LetDec rep)]]
-> [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> Result
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, [SubExp])),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
              ([Int] -> [PatElemT (LetDec rep)] -> [[PatElemT (LetDec rep)]]
forall a. [Int] -> [a] -> [[a]]
chunks (((Shape, [VName], Maybe (Lambda rep, [SubExp])) -> Int)
-> [(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, [VName], Maybe (Lambda rep, [SubExp])) -> Int
forall (t :: * -> *) a a c. Foldable t => (a, t a, c) -> Int
inputArrs [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs) [PatElemT (LetDec rep)]
acc_pes)
              [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs
              ([Param (LParamInfo rep)]
-> [Param (LParamInfo rep)]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
acc_params)
              Result
acc_res
        ([Param (LParamInfo rep)]
cert_params', [Param (LParamInfo rep)]
acc_params') = [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), Param (LParamInfo rep))]
param_pairs

    -- Eliminate unused non-accumulator results
    let ([PatElemT (LetDec rep)]
nonacc_pes', Result
nonacc_res') =
          [(PatElemT (LetDec rep), SubExpRes)]
-> ([PatElemT (LetDec rep)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElemT (LetDec rep), SubExpRes)]
 -> ([PatElemT (LetDec rep)], Result))
-> [(PatElemT (LetDec rep), SubExpRes)]
-> ([PatElemT (LetDec rep)], Result)
forall a b. (a -> b) -> a -> b
$ ((PatElemT (LetDec rep), SubExpRes) -> Bool)
-> [(PatElemT (LetDec rep), SubExpRes)]
-> [(PatElemT (LetDec rep), SubExpRes)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (LetDec rep), SubExpRes) -> Bool
keepNonAccRes ([(PatElemT (LetDec rep), SubExpRes)]
 -> [(PatElemT (LetDec rep), SubExpRes)])
-> [(PatElemT (LetDec rep), SubExpRes)]
-> [(PatElemT (LetDec rep), SubExpRes)]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (LetDec rep)]
nonacc_pes Result
nonacc_res

    Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
acc_pes Bool -> Bool -> Bool
&& [PatElemT (LetDec rep)]
nonacc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

    let pes' :: [PatElemT (LetDec rep)]
pes' = [[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)]
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (LetDec rep)]
nonacc_pes'

    Lambda rep
lam' <- [LParam (Rep (RuleM rep))]
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params' [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params') (RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep))))
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ do
      RuleM rep Result -> RuleM rep ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (RuleM rep Result -> RuleM rep ())
-> RuleM rep Result -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
      Result -> RuleM rep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> RuleM rep Result) -> Result -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Result
acc_res' Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
nonacc_res'

    StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> Pat rep
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (LetDec rep)]
pes') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
forall rep.
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
WithAcc [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs' Lambda rep
lam'
  where
    num_nonaccs :: Int
num_nonaccs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
lam) Int -> Int -> Int
forall a. Num a => a -> a -> a
- [(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs
    inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs
    keepAccRes :: ([PatElemT (LetDec rep)],
 (Shape, [VName], Maybe (Lambda rep, [SubExp])),
 (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
-> Bool
keepAccRes ([PatElemT (LetDec rep)]
pes, (Shape, [VName], Maybe (Lambda rep, [SubExp]))
_, (Param (LParamInfo rep), Param (LParamInfo rep))
_, SubExpRes
_) = (PatElemT (LetDec rep) -> Bool) -> [PatElemT (LetDec rep)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) (VName -> Bool)
-> (PatElemT (LetDec rep) -> VName)
-> PatElemT (LetDec rep)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT (LetDec rep)]
pes
    keepNonAccRes :: (PatElemT (LetDec rep), SubExpRes) -> Bool
keepNonAccRes (PatElemT (LetDec rep)
pe, SubExpRes
_) = PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.used` UsageTable
utable
withAccBottomUp (SymbolTable rep, UsageTable)
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip

-- Some helper functions

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False