{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}

-- | This module defines a collection of simplification rules, as per
-- "Futhark.Optimise.Simplify.Rule".  They are used in the
-- simplifier.
--
-- For performance reasons, many sufficiently simple logically
-- separate rules are merged into single "super-rules", like ruleIf
-- and ruleBasicOp.  This is because it is relatively expensive to
-- activate a rule just to determine that it does not apply.  Thus, it
-- is more efficient to have a few very fat rules than a lot of small
-- rules.  This does not affect the compiler result in any way; it is
-- purely an optimisation to speed up compilation.
module Futhark.Optimise.Simplify.Rules
  ( standardRules,
    removeUnnecessaryCopy,
  )
where

import Control.Monad
import Control.Monad.State
import Data.Either
import Data.List (find, insert, unzip4, zip4)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.BasicOp
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Util

topDownRules :: BuilderOps rep => [TopDownRule rep]
topDownRules :: [TopDownRule rep]
topDownRules =
  [ RuleGeneric rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun,
    RuleIf rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleIf rep
ruleIf,
    RuleIf rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleIf rep
hoistBranchInvariant,
    RuleGeneric rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleGeneric rep
withAccTopDown
  ]

bottomUpRules :: (BuilderOps rep, TraverseOpStms rep) => [BottomUpRule rep]
bottomUpRules :: [BottomUpRule rep]
bottomUpRules =
  [ RuleIf rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleIf rep
removeDeadBranchResult,
    RuleGeneric rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (BottomUp rep)
forall rep.
(TraverseOpStms rep, BuilderOps rep) =>
BottomUpRuleGeneric rep
withAccBottomUp,
    RuleBasicOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex
  ]

-- | A set of standard simplification rules.  These assume pure
-- functional semantics, and so probably should not be applied after
-- memory block merging.
standardRules :: (BuilderOps rep, TraverseOpStms rep, Aliased rep) => RuleBook rep
standardRules :: RuleBook rep
standardRules = [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule rep]
forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules [BottomUpRule rep]
forall rep.
(BuilderOps rep, TraverseOpStms rep) =>
[BottomUpRule rep]
bottomUpRules RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
loopRules RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
basicOpRules

-- | Turn @copy(x)@ into @x@ iff @x@ is not used after this copy
-- statement and it can be consumed.
--
-- This simplistic rule is only valid before we introduce memory.
removeUnnecessaryCopy :: (BuilderOps rep, Aliased rep) => BottomUpRuleBasicOp rep
removeUnnecessaryCopy :: BottomUpRuleBasicOp rep
removeUnnecessaryCopy (SymbolTable rep
vtable, UsageTable
used) (Pat [PatElem (LetDec rep)
d]) StmAux (ExpDec rep)
_ (Copy VName
v)
  | Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
    -- This next line is too conservative, but the problem is that 'v'
    -- might not look like it has been consumed if it is consumed in
    -- an outer scope.  This is because the simplifier applies
    -- bottom-up rules in a kind of deepest-first order.
    Bool -> Bool
not (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isInResult` UsageTable
used) Bool -> Bool -> Bool
|| (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
    (Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used) Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  where
    -- We need to make sure we can even consume the original.  The big
    -- missing piece here is that we cannot do copy removal inside of
    -- 'map' and other SOACs, but that is handled by SOAC-specific rules.
    consumable :: Bool
consumable = Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
fromMaybe Bool
False (Maybe Bool -> Bool) -> Maybe Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
      Entry rep
e <- VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Entry rep -> Int
forall rep. Entry rep -> Int
ST.entryDepth Entry rep
e Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== SymbolTable rep -> Int
forall rep. SymbolTable rep -> Int
ST.loopDepth SymbolTable rep
vtable
      Entry rep -> Maybe Bool
consumableStm Entry rep
e Maybe Bool -> Maybe Bool -> Maybe Bool
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Entry rep -> Maybe Bool
consumableFParam Entry rep
e
    consumableFParam :: Entry rep -> Maybe Bool
consumableFParam =
      Bool -> Maybe Bool
forall a. a -> Maybe a
Just (Bool -> Maybe Bool)
-> (Entry rep -> Bool) -> Entry rep -> Maybe Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> (FParamInfo rep -> Bool) -> Maybe (FParamInfo rep) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> (FParamInfo rep -> TypeBase Shape Uniqueness)
-> FParamInfo rep
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParamInfo rep -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf) (Maybe (FParamInfo rep) -> Bool)
-> (Entry rep -> Maybe (FParamInfo rep)) -> Entry rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry rep -> Maybe (FParamInfo rep)
forall rep. Entry rep -> Maybe (FParamInfo rep)
ST.entryFParam
    consumableStm :: Entry rep -> Maybe Bool
consumableStm Entry rep
e = do
      Pat (LetDec rep)
pat <- Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat (Stm rep -> Pat (LetDec rep))
-> Maybe (Stm rep) -> Maybe (Pat (LetDec rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Entry rep -> Maybe (Stm rep)
forall rep. Entry rep -> Maybe (Stm rep)
ST.entryStm Entry rep
e
      PatElem (LetDec rep)
pe <- (PatElem (LetDec rep) -> Bool)
-> [PatElem (LetDec rep)] -> Maybe (PatElem (LetDec rep))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> (PatElem (LetDec rep) -> VName) -> PatElem (LetDec rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName) (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec rep) -> Names
forall a. AliasesOf a => a -> Names
aliasesOf PatElem (LetDec rep)
pe Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
      Bool -> Maybe Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
removeUnnecessaryCopy (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip

constantFoldPrimFun :: BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun :: TopDownRuleGeneric rep
constantFoldPrimFun TopDown rep
_ (Let Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_))
  | Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
    Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- String
-> Map
     String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
    Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
result
  where
    isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown rep
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip

simplifyIndex :: BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex :: BottomUpRuleBasicOp rep
simplifyIndex (SymbolTable rep
vtable, UsageTable
used) pat :: Pat (LetDec rep)
pat@(Pat [PatElem (LetDec rep)
pe]) (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Index VName
idd Slice SubExp
inds)
  | Just RuleM rep IndexResult
m <- SymbolTable (Rep (RuleM rep))
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM rep IndexResult)
forall (m :: * -> *).
MonadBuilder m =>
SymbolTable (Rep m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable rep
SymbolTable (Rep (RuleM rep))
vtable TypeLookup
seType VName
idd Slice SubExp
inds Bool
consumed = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    IndexResult
res <- RuleM rep IndexResult
m
    Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
      SubExpResult Certs
cs' SubExp
se ->
        Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs') (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
      IndexResult Certs
extra_cs VName
idd' Slice SubExp
inds' ->
        Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
extra_cs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
  where
    consumed :: Bool
consumed = PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
    seType :: TypeLookup
seType (Var VName
v) = VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable
    seType (Constant PrimValue
v) = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip

ruleIf :: BuilderOps rep => TopDownRuleIf rep
ruleIf :: TopDownRuleIf rep
ruleIf TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (SubExp
e1, Body rep
tb, Body rep
fb, IfDec [BranchType rep]
_ IfSort
ifsort)
  | Just Body rep
branch <- Maybe (Body rep)
checkBranch,
    IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let ses :: Result
ses = Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
branch
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
branch
    [RuleM rep ()] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
      [ Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        | (PatElem (LetDec rep)
p, SubExpRes Certs
cs SubExp
se) <- [PatElem (LetDec rep)]
-> Result -> [(PatElem (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) Result
ses
      ]
  where
    checkBranch :: Maybe (Body rep)
checkBranch
      | SubExp -> Bool
isCt1 SubExp
e1 = Body rep -> Maybe (Body rep)
forall a. a -> Maybe a
Just Body rep
tb
      | SubExp -> Bool
isCt0 SubExp
e1 = Body rep -> Maybe (Body rep)
forall a. a -> Maybe a
Just Body rep
fb
      | Bool
otherwise = Maybe (Body rep)
forall a. Maybe a
Nothing

-- IMPROVE: the following two rules can be generalised to work in more
-- cases, especially when the branches have bindings, or return more
-- than one value.
--
-- if c then True else v == c || v
ruleIf
  TopDown rep
_
  Pat (LetDec rep)
pat
  StmAux (ExpDec rep)
_
  ( SubExp
cond,
    Body BodyDec rep
_ Stms rep
tstms [SubExpRes Certs
tcs (Constant (BoolValue Bool
True))],
    Body BodyDec rep
_ Stms rep
fstms [SubExpRes Certs
fcs SubExp
se],
    IfDec [BranchType rep]
ts IfSort
_
    )
    | Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms rep
tstms,
      Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms rep
fstms,
      [Prim PrimType
Bool] <- (BranchType rep -> ExtType) -> [BranchType rep] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType rep]
ts =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
fcs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
-- When type(x)==bool, if c then x else y == (c && x) || (!c && y)
ruleIf TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (SubExp
cond, Body rep
tb, Body rep
fb, IfDec [BranchType rep]
ts IfSort
_)
  | Body BodyDec rep
_ Stms rep
tstms [SubExpRes Certs
tcs SubExp
tres] <- Body rep
tb,
    Body BodyDec rep
_ Stms rep
fstms [SubExpRes Certs
fcs SubExp
fres] <- Body rep
fb,
    (Stm rep -> Bool) -> Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp) (Stms rep -> Bool) -> Stms rep -> Bool
forall a b. (a -> b) -> a -> b
$ Stms rep
tstms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
fstms,
    (BranchType rep -> Bool) -> [BranchType rep] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((ExtType -> ExtType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) (ExtType -> Bool)
-> (BranchType rep -> ExtType) -> BranchType rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf) [BranchType rep]
ts = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (RuleM rep))
tstms
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (RuleM rep))
fstms
    Exp rep
e <-
      BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
        BinOp
LogOr
        (Exp rep -> RuleM rep (Exp rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp rep -> RuleM rep (Exp rep)) -> Exp rep -> RuleM rep (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
        ( BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
            BinOp
LogAnd
            (Exp rep -> RuleM rep (Exp rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp rep -> RuleM rep (Exp rep)) -> Exp rep -> RuleM rep (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
            (Exp rep -> RuleM rep (Exp rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp rep -> RuleM rep (Exp rep)) -> Exp rep -> RuleM rep (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres)
        )
    Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
fcs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat Exp rep
Exp (Rep (RuleM rep))
e
ruleIf TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (SubExp
_, Body rep
tbranch, Body rep
_, IfDec [BranchType rep]
_ IfSort
IfFallback)
  | (Stm rep -> Bool) -> Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp) (Stms rep -> Bool) -> Stms rep -> Bool
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
tbranch = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let ses :: Result
ses = Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
tbranch
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
tbranch
    [RuleM rep ()] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
      [ Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        | (PatElem (LetDec rep)
p, SubExpRes Certs
cs SubExp
se) <- [PatElem (LetDec rep)]
-> Result -> [(PatElem (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) Result
ses
      ]
ruleIf TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (SubExp
cond, Body rep
tb, Body rep
fb, IfDec (BranchType rep)
_)
  | Body BodyDec rep
_ Stms rep
_ [SubExpRes Certs
tcs (Constant (IntValue IntValue
t))] <- Body rep
tb,
    Body BodyDec rep
_ Stms rep
_ [SubExpRes Certs
fcs (Constant (IntValue IntValue
f))] <- Body rep
fb =
    if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f Bool -> Bool -> Bool
&& Certs
tcs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty Bool -> Bool -> Bool
&& Certs
fcs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty
      then
        RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
          Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
      else
        if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
          then RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
            SubExp
cond_neg <- String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
            Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
          else Rule rep
forall rep. Rule rep
Skip
ruleIf TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (SubExp, Body rep, Body rep, IfDec (BranchType rep))
_ = Rule rep
forall rep. Rule rep
Skip

-- | Move out results of a conditional expression whose computation is
-- either invariant to the branches (only done for results used for
-- existentials), or the same in both branches.
hoistBranchInvariant :: BuilderOps rep => TopDownRuleIf rep
hoistBranchInvariant :: TopDownRuleIf rep
hoistBranchInvariant TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (SubExp
cond, Body rep
tb, Body rep
fb, IfDec [BranchType rep]
ret IfSort
ifsort) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  let tses :: Result
tses = Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
tb
      fses :: Result
fses = Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
fb
  ([Maybe (Int, SubExp)]
hoistings, ([PatElem (LetDec rep)]
pes, [BranchType rep]
ts, [(SubExpRes, SubExpRes)]
res)) <-
    ([Either
    (Maybe (Int, SubExp))
    (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
 -> ([Maybe (Int, SubExp)],
     ([PatElem (LetDec rep)], [BranchType rep],
      [(SubExpRes, SubExpRes)])))
-> RuleM
     rep
     [Either
        (Maybe (Int, SubExp))
        (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> RuleM
     rep
     ([Maybe (Int, SubExp)],
      ([PatElem (LetDec rep)], [BranchType rep],
       [(SubExpRes, SubExpRes)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
 -> ([PatElem (LetDec rep)], [BranchType rep],
     [(SubExpRes, SubExpRes)]))
-> ([Maybe (Int, SubExp)],
    [(PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
-> ([Maybe (Int, SubExp)],
    ([PatElem (LetDec rep)], [BranchType rep],
     [(SubExpRes, SubExpRes)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([PatElem (LetDec rep)], [BranchType rep],
    [(SubExpRes, SubExpRes)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (([Maybe (Int, SubExp)],
  [(PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
 -> ([Maybe (Int, SubExp)],
     ([PatElem (LetDec rep)], [BranchType rep],
      [(SubExpRes, SubExpRes)])))
-> ([Either
       (Maybe (Int, SubExp))
       (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
    -> ([Maybe (Int, SubExp)],
        [(PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]))
-> [Either
      (Maybe (Int, SubExp))
      (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([Maybe (Int, SubExp)],
    ([PatElem (LetDec rep)], [BranchType rep],
     [(SubExpRes, SubExpRes)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
   (Maybe (Int, SubExp))
   (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([Maybe (Int, SubExp)],
    [(PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) (RuleM
   rep
   [Either
      (Maybe (Int, SubExp))
      (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
 -> RuleM
      rep
      ([Maybe (Int, SubExp)],
       ([PatElem (LetDec rep)], [BranchType rep],
        [(SubExpRes, SubExpRes)])))
-> ([(Int, PatElem (LetDec rep), BranchType rep,
      (SubExpRes, SubExpRes))]
    -> RuleM
         rep
         [Either
            (Maybe (Int, SubExp))
            (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
-> [(Int, PatElem (LetDec rep), BranchType rep,
     (SubExpRes, SubExpRes))]
-> RuleM
     rep
     ([Maybe (Int, SubExp)],
      ([PatElem (LetDec rep)], [BranchType rep],
       [(SubExpRes, SubExpRes)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, PatElem (LetDec rep), BranchType rep,
  (SubExpRes, SubExpRes))
 -> RuleM
      rep
      (Either
         (Maybe (Int, SubExp))
         (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))))
-> [(Int, PatElem (LetDec rep), BranchType rep,
     (SubExpRes, SubExpRes))]
-> RuleM
     rep
     [Either
        (Maybe (Int, SubExp))
        (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int, PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
branchInvariant ([(Int, PatElem (LetDec rep), BranchType rep,
   (SubExpRes, SubExpRes))]
 -> RuleM
      rep
      ([Maybe (Int, SubExp)],
       ([PatElem (LetDec rep)], [BranchType rep],
        [(SubExpRes, SubExpRes)])))
-> [(Int, PatElem (LetDec rep), BranchType rep,
     (SubExpRes, SubExpRes))]
-> RuleM
     rep
     ([Maybe (Int, SubExp)],
      ([PatElem (LetDec rep)], [BranchType rep],
       [(SubExpRes, SubExpRes)]))
forall a b. (a -> b) -> a -> b
$
      [Int]
-> [PatElem (LetDec rep)]
-> [BranchType rep]
-> [(SubExpRes, SubExpRes)]
-> [(Int, PatElem (LetDec rep), BranchType rep,
     (SubExpRes, SubExpRes))]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int
0 ..] (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) [BranchType rep]
ret (Result -> Result -> [(SubExpRes, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
tses Result
fses)
  let ctx_fixes :: [(Int, SubExp)]
ctx_fixes = [Maybe (Int, SubExp)] -> [(Int, SubExp)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, SubExp)]
hoistings
      (Result
tses', Result
fses') = [(SubExpRes, SubExpRes)] -> (Result, Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExpRes, SubExpRes)]
res
      tb' :: Body rep
tb' = Body rep
tb {bodyResult :: Result
bodyResult = Result
tses'}
      fb' :: Body rep
fb' = Body rep
fb {bodyResult :: Result
bodyResult = Result
fses'}
      ret' :: [BranchType rep]
ret' = ((Int, SubExp) -> [BranchType rep] -> [BranchType rep])
-> [BranchType rep] -> [(Int, SubExp)] -> [BranchType rep]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType rep] -> [BranchType rep])
-> (Int, SubExp) -> [BranchType rep] -> [BranchType rep]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType rep] -> [BranchType rep]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchType rep]
ts [(Int, SubExp)]
ctx_fixes
  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe (Int, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (Int, SubExp)]
hoistings -- Was something hoisted?
    then do
      -- We may have to add some reshapes if we made the type
      -- less existential.
      Body rep
tb'' <- Body (Rep (RuleM rep))
-> [ExtType] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> [ExtType] -> m (Body (Rep m))
reshapeBodyResults Body rep
Body (Rep (RuleM rep))
tb' ([ExtType] -> RuleM rep (Body (Rep (RuleM rep))))
-> [ExtType] -> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ (BranchType rep -> ExtType) -> [BranchType rep] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType rep]
ret'
      Body rep
fb'' <- Body (Rep (RuleM rep))
-> [ExtType] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> [ExtType] -> m (Body (Rep m))
reshapeBodyResults Body rep
Body (Rep (RuleM rep))
fb' ([ExtType] -> RuleM rep (Body (Rep (RuleM rep))))
-> [ExtType] -> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ (BranchType rep -> ExtType) -> [BranchType rep] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType rep]
ret'
      Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
forall rep.
SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
If SubExp
cond Body rep
tb'' Body rep
fb'' ([BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
ret' IfSort
ifsort)
    else RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
  where
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList ([VName] -> Names)
-> (Seq (Stm rep) -> [VName]) -> Seq (Stm rep) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> [VName]) -> Seq (Stm rep) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName])
-> (Stm rep -> Pat (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) (Seq (Stm rep) -> Names) -> Seq (Stm rep) -> Names
forall a b. (a -> b) -> a -> b
$
        Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
tb Seq (Stm rep) -> Seq (Stm rep) -> Seq (Stm rep)
forall a. Semigroup a => a -> a -> a
<> Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
fb
    invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
    invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_branches

    branchInvariant :: (Int, PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
branchInvariant (Int
i, PatElem (LetDec rep)
pe, BranchType rep
t, (SubExpRes
tse, SubExpRes
fse))
      -- Do both branches return the same value?
      | SubExpRes
tse SubExpRes -> SubExpRes -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes
fse = do
        Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (SubExpRes -> Certs
resCerts SubExpRes
tse Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> SubExpRes -> Certs
resCerts SubExpRes
fse) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
tse
        Int
-> PatElem (LetDec rep)
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall (f :: * -> *) a dec b.
Applicative f =>
a -> PatElem dec -> f (Either (Maybe (a, SubExp)) b)
hoisted Int
i PatElem (LetDec rep)
pe

      -- Do both branches return values that are free in the
      -- branch, and are we not the only pattern element?  The
      -- latter is to avoid infinite application of this rule.
      | SubExp -> Bool
invariant (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
tse,
        SubExp -> Bool
invariant (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
fse,
        Pat (LetDec rep) -> Int
forall dec. Pat dec -> Int
patSize Pat (LetDec rep)
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
        Prim PrimType
_ <- PatElem (LetDec rep) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe = do
        [BranchType rep]
bt <- Pat (LetDec rep) -> RuleM rep [BranchType rep]
forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m, Monad m) =>
Pat (LetDec rep) -> m [BranchType rep]
expTypesFromPat (Pat (LetDec rep) -> RuleM rep [BranchType rep])
-> Pat (LetDec rep) -> RuleM rep [BranchType rep]
forall a b. (a -> b) -> a -> b
$ [PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
pe]
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe]
          (Exp rep -> RuleM rep ()) -> RuleM rep (Exp rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
forall rep.
SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
If SubExp
cond (Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep)
-> RuleM rep (Body rep)
-> RuleM rep (Body rep -> IfDec (BranchType rep) -> Exp rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExpRes -> SubExp
resSubExp SubExpRes
tse]
                  RuleM rep (Body rep -> IfDec (BranchType rep) -> Exp rep)
-> RuleM rep (Body rep)
-> RuleM rep (IfDec (BranchType rep) -> Exp rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExpRes -> SubExp
resSubExp SubExpRes
fse]
                  RuleM rep (IfDec (BranchType rep) -> Exp rep)
-> RuleM rep (IfDec (BranchType rep)) -> RuleM rep (Exp rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType rep) -> RuleM rep (IfDec (BranchType rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
bt IfSort
ifsort)
              )
        Int
-> PatElem (LetDec rep)
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall (f :: * -> *) a dec b.
Applicative f =>
a -> PatElem dec -> f (Either (Maybe (a, SubExp)) b)
hoisted Int
i PatElem (LetDec rep)
pe
      | Bool
otherwise =
        Either
  (Maybe (Int, SubExp))
  (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either
   (Maybe (Int, SubExp))
   (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
 -> RuleM
      rep
      (Either
         (Maybe (Int, SubExp))
         (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))))
-> Either
     (Maybe (Int, SubExp))
     (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall a b. (a -> b) -> a -> b
$ (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> Either
     (Maybe (Int, SubExp))
     (PatElem (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
forall a b. b -> Either a b
Right (PatElem (LetDec rep)
pe, BranchType rep
t, (SubExpRes
tse, SubExpRes
fse))

    hoisted :: a -> PatElem dec -> f (Either (Maybe (a, SubExp)) b)
hoisted a
i PatElem dec
pe = Either (Maybe (a, SubExp)) b -> f (Either (Maybe (a, SubExp)) b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (Maybe (a, SubExp)) b -> f (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> f (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left (Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b)
-> Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. (a -> b) -> a -> b
$ (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe)

    reshapeBodyResults :: Body (Rep m) -> [ExtType] -> m (Body (Rep m))
reshapeBodyResults Body (Rep m)
body [ExtType]
rets = m Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (m Result -> m (Body (Rep m))) -> m Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
      Result
ses <- Body (Rep m) -> m Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep m)
body
      let (Result
ctx_ses, Result
val_ses) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
rets) Result
ses
      (Result
ctx_ses Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++) (Result -> Result) -> m Result -> m Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> ExtType -> m SubExpRes)
-> Result -> [ExtType] -> m Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExpRes -> ExtType -> m SubExpRes
forall (m :: * -> *).
MonadBuilder m =>
SubExpRes -> ExtType -> m SubExpRes
reshapeResult Result
val_ses [ExtType]
rets
    reshapeResult :: SubExpRes -> ExtType -> m SubExpRes
reshapeResult (SubExpRes Certs
cs (Var VName
v)) t :: ExtType
t@Array {} = do
      Type
v_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      let newshape :: [SubExp]
newshape = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ExtType -> Type -> Type
removeExistentials ExtType
t Type
v_t
      Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs
        (SubExp -> SubExpRes) -> m SubExp -> m SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> if [SubExp]
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
v_t
          then String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" ([SubExp] -> VName -> Exp (Rep m)
forall rep. [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newshape VName
v)
          else SubExp -> m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
    reshapeResult SubExpRes
se ExtType
_ =
      SubExpRes -> m SubExpRes
forall (m :: * -> *) a. Monad m => a -> m a
return SubExpRes
se

-- | Remove the return values of a branch, that are not actually used
-- after a branch.  Standard dead code removal can remove the branch
-- if *none* of the return values are used, but this rule is more
-- precise.
removeDeadBranchResult :: BuilderOps rep => BottomUpRuleIf rep
removeDeadBranchResult :: BottomUpRuleIf rep
removeDeadBranchResult (SymbolTable rep
_, UsageTable
used) Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (SubExp
e1, Body rep
tb, Body rep
fb, IfDec [BranchType rep]
rettype IfSort
ifsort)
  | -- Only if there is no existential binding...
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` (PatElem (LetDec rep) -> Names) -> [PatElem (LetDec rep)] -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap PatElem (LetDec rep) -> Names
forall a. FreeIn a => a -> Names
freeIn (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)) (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat),
    -- Figure out which of the names in 'pat' are used...
    [Bool]
patused <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat,
    -- If they are not all used, then this rule applies.
    Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) =
    -- Remove the parts of the branch-results that correspond to dead
    -- return value bindings.  Note that this leaves dead code in the
    -- branch bodies, but that will be removed later.
    let tses :: Result
tses = Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
tb
        fses :: Result
fses = Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
fb
        pick :: [a] -> [a]
        pick :: [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
        tb' :: Body rep
tb' = Body rep
tb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
tses}
        fb' :: Body rep
fb' = Body rep
fb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
fses}
        pat' :: [PatElem (LetDec rep)]
pat' = [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a. [a] -> [a]
pick ([PatElem (LetDec rep)] -> [PatElem (LetDec rep)])
-> [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
        rettype' :: [BranchType rep]
rettype' = [BranchType rep] -> [BranchType rep]
forall a. [a] -> [a]
pick [BranchType rep]
rettype
     in RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pat') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
forall rep.
SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
If SubExp
e1 Body rep
tb' Body rep
fb' (IfDec (BranchType rep) -> Exp rep)
-> IfDec (BranchType rep) -> Exp rep
forall a b. (a -> b) -> a -> b
$ [BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
rettype' IfSort
ifsort
  | Bool
otherwise = Rule rep
forall rep. Rule rep
Skip

withAccTopDown :: BuilderOps rep => TopDownRuleGeneric rep
-- A WithAcc with no accumulators is sent to Valhalla.
withAccTopDown :: TopDownRuleGeneric rep
withAccTopDown TopDown rep
_ (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [] Lambda rep
lam)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  Result
lam_res <- Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
  [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) Result
lam_res) (((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ())
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
    Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
-- Identify those results in 'lam' that are free and move them out.
withAccTopDown TopDown rep
vtable (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [WithAccInput rep]
inputs Lambda rep
lam)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  let ([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
        Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([WithAccInput rep] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
      (Result
acc_res, Result
nonacc_res) =
        Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result) -> Body rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
      ([PatElem (LetDec rep)]
acc_pes, [PatElem (LetDec rep)]
nonacc_pes) =
        Int
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs ([PatElem (LetDec rep)]
 -> ([PatElem (LetDec rep)], [PatElem (LetDec rep)]))
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat

  -- Look at accumulator results.
  ([[PatElem (LetDec rep)]]
acc_pes', [WithAccInput rep]
inputs', [(Param (LParamInfo rep), Param (LParamInfo rep))]
params', Result
acc_res') <-
    ([Maybe
    ([PatElem (LetDec rep)], WithAccInput rep,
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
 -> ([[PatElem (LetDec rep)]], [WithAccInput rep],
     [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> RuleM
     rep
     [Maybe
        ([PatElem (LetDec rep)], WithAccInput rep,
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
     rep
     ([[PatElem (LetDec rep)]], [WithAccInput rep],
      [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([([PatElem (LetDec rep)], WithAccInput rep,
  (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElem (LetDec rep)]], [WithAccInput rep],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([([PatElem (LetDec rep)], WithAccInput rep,
   (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
 -> ([[PatElem (LetDec rep)]], [WithAccInput rep],
     [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([Maybe
       ([PatElem (LetDec rep)], WithAccInput rep,
        (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
    -> [([PatElem (LetDec rep)], WithAccInput rep,
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)])
-> [Maybe
      ([PatElem (LetDec rep)], WithAccInput rep,
       (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElem (LetDec rep)]], [WithAccInput rep],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe
   ([PatElem (LetDec rep)], WithAccInput rep,
    (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> [([PatElem (LetDec rep)], WithAccInput rep,
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a. [Maybe a] -> [a]
catMaybes) (RuleM
   rep
   [Maybe
      ([PatElem (LetDec rep)], WithAccInput rep,
       (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
 -> RuleM
      rep
      ([[PatElem (LetDec rep)]], [WithAccInput rep],
       [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([([PatElem (LetDec rep)], WithAccInput rep,
      (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
    -> RuleM
         rep
         [Maybe
            ([PatElem (LetDec rep)], WithAccInput rep,
             (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)])
-> [([PatElem (LetDec rep)], WithAccInput rep,
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
     rep
     ([[PatElem (LetDec rep)]], [WithAccInput rep],
      [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem (LetDec rep)], WithAccInput rep,
  (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
 -> RuleM
      rep
      (Maybe
         ([PatElem (LetDec rep)], WithAccInput rep,
          (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)))
-> [([PatElem (LetDec rep)], WithAccInput rep,
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
     rep
     [Maybe
        ([PatElem (LetDec rep)], WithAccInput rep,
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([PatElem (LetDec rep)], WithAccInput rep,
 (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
-> RuleM
     rep
     (Maybe
        ([PatElem (LetDec rep)], WithAccInput rep,
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes))
forall (m :: * -> *) dec a c a dec.
MonadBuilder m =>
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc ([([PatElem (LetDec rep)], WithAccInput rep,
   (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
 -> RuleM
      rep
      ([[PatElem (LetDec rep)]], [WithAccInput rep],
       [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> [([PatElem (LetDec rep)], WithAccInput rep,
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
     rep
     ([[PatElem (LetDec rep)]], [WithAccInput rep],
      [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b. (a -> b) -> a -> b
$
      [[PatElem (LetDec rep)]]
-> [WithAccInput rep]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> Result
-> [([PatElem (LetDec rep)], WithAccInput rep,
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
        ([Int] -> [PatElem (LetDec rep)] -> [[PatElem (LetDec rep)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((WithAccInput rep -> Int) -> [WithAccInput rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map WithAccInput rep -> Int
forall (t :: * -> *) a a c. Foldable t => (a, t a, c) -> Int
inputArrs [WithAccInput rep]
inputs) [PatElem (LetDec rep)]
acc_pes)
        [WithAccInput rep]
inputs
        ([Param (LParamInfo rep)]
-> [Param (LParamInfo rep)]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
acc_params)
        Result
acc_res
  let ([Param (LParamInfo rep)]
cert_params', [Param (LParamInfo rep)]
acc_params') = [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), Param (LParamInfo rep))]
params'

  -- Look at non-accumulator results.
  ([PatElem (LetDec rep)]
nonacc_pes', Result
nonacc_res') <-
    [(PatElem (LetDec rep), SubExpRes)]
-> ([PatElem (LetDec rep)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem (LetDec rep), SubExpRes)]
 -> ([PatElem (LetDec rep)], Result))
-> ([Maybe (PatElem (LetDec rep), SubExpRes)]
    -> [(PatElem (LetDec rep), SubExpRes)])
-> [Maybe (PatElem (LetDec rep), SubExpRes)]
-> ([PatElem (LetDec rep)], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (PatElem (LetDec rep), SubExpRes)]
-> [(PatElem (LetDec rep), SubExpRes)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (PatElem (LetDec rep), SubExpRes)]
 -> ([PatElem (LetDec rep)], Result))
-> RuleM rep [Maybe (PatElem (LetDec rep), SubExpRes)]
-> RuleM rep ([PatElem (LetDec rep)], Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((PatElem (LetDec rep), SubExpRes)
 -> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes)))
-> [(PatElem (LetDec rep), SubExpRes)]
-> RuleM rep [Maybe (PatElem (LetDec rep), SubExpRes)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
tryMoveNonAcc ([PatElem (LetDec rep)]
-> Result -> [(PatElem (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (LetDec rep)]
nonacc_pes Result
nonacc_res)

  Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([[PatElem (LetDec rep)]] -> [PatElem (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElem (LetDec rep)]]
acc_pes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
acc_pes Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

  Lambda rep
lam' <-
    [LParam (Rep (RuleM rep))]
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params' [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params') (RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep))))
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$
      Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) {bodyResult :: Result
bodyResult = Result
acc_res' Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
nonacc_res'}

  Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat ([[PatElem (LetDec rep)]] -> [PatElem (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElem (LetDec rep)]]
acc_pes' [PatElem (LetDec rep)]
-> [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a. Semigroup a => a -> a -> a
<> [PatElem (LetDec rep)]
nonacc_pes')) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [WithAccInput rep] -> Lambda rep -> Exp rep
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput rep]
inputs' Lambda rep
lam'
  where
    num_nonaccs :: Int
num_nonaccs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) Int -> Int -> Int
forall a. Num a => a -> a -> a
- [WithAccInput rep] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
    inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs

    tryMoveAcc :: ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc ([PatElem dec]
pes, (a
_, [VName]
arrs, c
_), (a
_, Param dec
acc_p), SubExpRes Certs
cs (Var VName
v))
      | Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
acc_p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v,
        Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
        [(PatElem dec, VName)] -> ((PatElem dec, VName) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem dec] -> [VName] -> [(PatElem dec, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem dec]
pes [VName]
arrs) (((PatElem dec, VName) -> m ()) -> m ())
-> ((PatElem dec, VName) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(PatElem dec
pe, VName
arr) ->
          [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
        Maybe ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
forall a. Maybe a
Nothing
    tryMoveAcc ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
x =
      Maybe ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
 -> m (Maybe
         ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)))
-> Maybe
     ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
        ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall a b. (a -> b) -> a -> b
$ ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> Maybe
     ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
forall a. a -> Maybe a
Just ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
x

    tryMoveNonAcc :: (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
tryMoveNonAcc (PatElem (LetDec rep)
pe, SubExpRes Certs
cs (Var VName
v))
      | VName
v VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable,
        Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
        Maybe (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PatElem (LetDec rep), SubExpRes)
forall a. Maybe a
Nothing
    tryMoveNonAcc (PatElem (LetDec rep)
pe, SubExpRes Certs
cs (Constant PrimValue
v))
      | Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
        Maybe (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PatElem (LetDec rep), SubExpRes)
forall a. Maybe a
Nothing
    tryMoveNonAcc (PatElem (LetDec rep), SubExpRes)
x =
      Maybe (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (PatElem (LetDec rep), SubExpRes)
 -> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes)))
-> Maybe (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
forall a b. (a -> b) -> a -> b
$ (PatElem (LetDec rep), SubExpRes)
-> Maybe (PatElem (LetDec rep), SubExpRes)
forall a. a -> Maybe a
Just (PatElem (LetDec rep), SubExpRes)
x
withAccTopDown TopDown rep
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip

elimUpdates :: (ASTRep rep, TraverseOpStms rep) => [VName] -> Body rep -> (Body rep, [VName])
elimUpdates :: [VName] -> Body rep -> (Body rep, [VName])
elimUpdates [VName]
get_rid_of = (State [VName] (Body rep) -> [VName] -> (Body rep, [VName]))
-> [VName] -> State [VName] (Body rep) -> (Body rep, [VName])
forall a b c. (a -> b -> c) -> b -> a -> c
flip State [VName] (Body rep) -> [VName] -> (Body rep, [VName])
forall s a. State s a -> s -> (a, s)
runState [VName]
forall a. Monoid a => a
mempty (State [VName] (Body rep) -> (Body rep, [VName]))
-> (Body rep -> State [VName] (Body rep))
-> Body rep
-> (Body rep, [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> State [VName] (Body rep)
onBody
  where
    onBody :: Body rep -> State [VName] (Body rep)
onBody Body rep
body = do
      Stms rep
stms' <- Stms rep -> StateT [VName] Identity (Stms rep)
onStms (Stms rep -> StateT [VName] Identity (Stms rep))
-> Stms rep -> StateT [VName] Identity (Stms rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
      Body rep -> State [VName] (Body rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body {bodyStms :: Stms rep
bodyStms = Stms rep
stms'}
    onStms :: Stms rep -> StateT [VName] Identity (Stms rep)
onStms = (Stm rep -> StateT [VName] Identity (Stm rep))
-> Stms rep -> StateT [VName] Identity (Stms rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm rep -> StateT [VName] Identity (Stm rep)
onStm
    onStm :: Stm rep -> StateT [VName] Identity (Stm rep)
onStm (Let pat :: Pat (LetDec rep)
pat@(Pat [PatElem VName
_ LetDec rep
dec]) StmAux (ExpDec rep)
aux (BasicOp (UpdateAcc VName
acc [SubExp]
_ [SubExp]
_)))
      | Acc VName
c Shape
_ [Type]
_ NoUniqueness
_ <- LetDec rep -> Type
forall t. Typed t => t -> Type
typeOf LetDec rep
dec,
        VName
c VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
get_rid_of = do
        ([VName] -> [VName]) -> StateT [VName] Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (VName -> [VName] -> [VName]
forall a. Ord a => a -> [a] -> [a]
insert VName
c)
        Stm rep -> StateT [VName] Identity (Stm rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm rep -> StateT [VName] Identity (Stm rep))
-> Stm rep -> StateT [VName] Identity (Stm rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
acc
    onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) = Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Exp rep -> Stm rep)
-> StateT [VName] Identity (Exp rep)
-> StateT [VName] Identity (Stm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp rep -> StateT [VName] Identity (Exp rep)
onExp Exp rep
e
    onExp :: Exp rep -> StateT [VName] Identity (Exp rep)
onExp = Mapper rep rep (StateT [VName] Identity)
-> Exp rep -> StateT [VName] Identity (Exp rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper rep rep (StateT [VName] Identity)
mapper
      where
        mapper :: Mapper rep rep (StateT [VName] Identity)
mapper =
          Mapper rep rep (StateT [VName] Identity)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
            { mapOnOp :: Op rep -> StateT [VName] Identity (Op rep)
mapOnOp = OpStmsTraverser (StateT [VName] Identity) (Op rep) rep
forall rep (m :: * -> *).
(TraverseOpStms rep, Monad m) =>
OpStmsTraverser m (Op rep) rep
traverseOpStms (\Scope rep
_ Stms rep
stms -> Stms rep -> StateT [VName] Identity (Stms rep)
onStms Stms rep
stms),
              mapOnBody :: Scope rep -> Body rep -> State [VName] (Body rep)
mapOnBody = \Scope rep
_ Body rep
body -> Body rep -> State [VName] (Body rep)
onBody Body rep
body
            }

withAccBottomUp :: (TraverseOpStms rep, BuilderOps rep) => BottomUpRuleGeneric rep
-- Eliminate dead results.  See Note [Dead Code Elimination for WithAcc]
withAccBottomUp :: BottomUpRuleGeneric rep
withAccBottomUp (SymbolTable rep
_, UsageTable
utable) (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [WithAccInput rep]
inputs Lambda rep
lam))
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let (Result
acc_res, Result
nonacc_res) =
          Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result) -> Body rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
        ([PatElem (LetDec rep)]
acc_pes, [PatElem (LetDec rep)]
nonacc_pes) =
          Int
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs ([PatElem (LetDec rep)]
 -> ([PatElem (LetDec rep)], [PatElem (LetDec rep)]))
-> [PatElem (LetDec rep)]
-> ([PatElem (LetDec rep)], [PatElem (LetDec rep)])
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
        ([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
          Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([WithAccInput rep] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam

    -- Eliminate unused accumulator results
    let get_rid_of :: [VName]
get_rid_of =
          (([PatElem (LetDec rep)], VName) -> VName)
-> [([PatElem (LetDec rep)], VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map ([PatElem (LetDec rep)], VName) -> VName
forall a b. (a, b) -> b
snd ([([PatElem (LetDec rep)], VName)] -> [VName])
-> ([([PatElem (LetDec rep)], VName)]
    -> [([PatElem (LetDec rep)], VName)])
-> [([PatElem (LetDec rep)], VName)]
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElem (LetDec rep)], VName) -> Bool)
-> [([PatElem (LetDec rep)], VName)]
-> [([PatElem (LetDec rep)], VName)]
forall a. (a -> Bool) -> [a] -> [a]
filter ([PatElem (LetDec rep)], VName) -> Bool
getRidOf ([([PatElem (LetDec rep)], VName)] -> [VName])
-> [([PatElem (LetDec rep)], VName)] -> [VName]
forall a b. (a -> b) -> a -> b
$
            [[PatElem (LetDec rep)]]
-> [VName] -> [([PatElem (LetDec rep)], VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip
              ([Int] -> [PatElem (LetDec rep)] -> [[PatElem (LetDec rep)]]
forall a. [Int] -> [a] -> [[a]]
chunks ((WithAccInput rep -> Int) -> [WithAccInput rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map WithAccInput rep -> Int
forall (t :: * -> *) a a c. Foldable t => (a, t a, c) -> Int
inputArrs [WithAccInput rep]
inputs) [PatElem (LetDec rep)]
acc_pes)
              ([VName] -> [([PatElem (LetDec rep)], VName)])
-> [VName] -> [([PatElem (LetDec rep)], VName)]
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
cert_params

    -- Eliminate unused non-accumulator results
    let ([PatElem (LetDec rep)]
nonacc_pes', Result
nonacc_res') =
          [(PatElem (LetDec rep), SubExpRes)]
-> ([PatElem (LetDec rep)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElem (LetDec rep), SubExpRes)]
 -> ([PatElem (LetDec rep)], Result))
-> [(PatElem (LetDec rep), SubExpRes)]
-> ([PatElem (LetDec rep)], Result)
forall a b. (a -> b) -> a -> b
$ ((PatElem (LetDec rep), SubExpRes) -> Bool)
-> [(PatElem (LetDec rep), SubExpRes)]
-> [(PatElem (LetDec rep), SubExpRes)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem (LetDec rep), SubExpRes) -> Bool
keepNonAccRes ([(PatElem (LetDec rep), SubExpRes)]
 -> [(PatElem (LetDec rep), SubExpRes)])
-> [(PatElem (LetDec rep), SubExpRes)]
-> [(PatElem (LetDec rep), SubExpRes)]
forall a b. (a -> b) -> a -> b
$ [PatElem (LetDec rep)]
-> Result -> [(PatElem (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (LetDec rep)]
nonacc_pes Result
nonacc_res

    Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
get_rid_of Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

    let (Body rep
body', [VName]
eliminated) = [VName] -> Body rep -> (Body rep, [VName])
forall rep.
(ASTRep rep, TraverseOpStms rep) =>
[VName] -> Body rep -> (Body rep, [VName])
elimUpdates [VName]
get_rid_of (Body rep -> (Body rep, [VName]))
-> Body rep -> (Body rep, [VName])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam

    Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
eliminated Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' [PatElem (LetDec rep)] -> [PatElem (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

    let pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
acc_pes [PatElem (LetDec rep)]
-> [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
nonacc_pes'

    Lambda rep
lam' <- [LParam (Rep (RuleM rep))]
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params) (RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep))))
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ do
      RuleM rep Result -> RuleM rep ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (RuleM rep Result -> RuleM rep ())
-> RuleM rep Result -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body rep
Body (Rep (RuleM rep))
body'
      Result -> RuleM rep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> RuleM rep Result) -> Result -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Result
acc_res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
nonacc_res'

    StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [WithAccInput rep] -> Lambda rep -> Exp rep
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput rep]
inputs Lambda rep
lam'
  where
    num_nonaccs :: Int
num_nonaccs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) Int -> Int -> Int
forall a. Num a => a -> a -> a
- [WithAccInput rep] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
    inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs
    getRidOf :: ([PatElem (LetDec rep)], VName) -> Bool
getRidOf ([PatElem (LetDec rep)]
pes, VName
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (PatElem (LetDec rep) -> Bool) -> [PatElem (LetDec rep)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) (VName -> Bool)
-> (PatElem (LetDec rep) -> VName) -> PatElem (LetDec rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec rep)]
pes
    keepNonAccRes :: (PatElem (LetDec rep), SubExpRes) -> Bool
keepNonAccRes (PatElem (LetDec rep)
pe, SubExpRes
_) = PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.used` UsageTable
utable
withAccBottomUp (SymbolTable rep, UsageTable)
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip

-- Some helper functions

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False

-- Note [Dead Code Elimination for WithAcc]
--
-- Our static semantics for accumulators are basically those of linear
-- types.  This makes dead code elimination somewhat tricky.  First,
-- what we consider dead code is when we have a WithAcc where at least
-- one of the array results (that internally correspond to an
-- accumulator) are unused.  E.g
--
-- let {X',Y'} =
--   with_acc {X, Y} (\X_p Y_p X_acc Y_acc -> ... {X_acc', Y_acc'})
--
-- where X' is not used later.  Note that Y' is still used later.  If
-- none of the results of the WithAcc are used, then the Stm as a
-- whole is dead and can be removed.  That's the trivial case, done
-- implicitly by the simplifier.  The interesting case is exactly when
-- some of the results are unused.  How do we get rid of them?
--
-- Naively, we might just remove them:
--
-- let Y' =
--   with_acc Y (\Y_p Y_acc -> ... Y_acc')
--
-- This is safe *only* if X_acc is used *only* in the result (i.e. an
-- "identity" WithAcc).  Otherwise we end up with references to X_acc,
-- which no longer exists.  This simple case is actually handled in
-- the withAccTopDown rule, and is easy enough.
--
-- What we actually do when we decide to eliminate X_acc is that we
-- inspect the body of the WithAcc and eliminate all UpdateAcc
-- operations that refer to the same accumulator as X_acc (identified
-- by the X_p token).  I.e. we turn every
--
-- let B = update_acc(A, ...)
--
-- where 'A' is ultimately decided from X_acc into
--
-- let B = A
--
-- That's it!  We then let ordinary dead code elimination eventually
-- simplify the body enough that we have an "identity" WithAcc.  There
-- is no _guarantee_ that this will happen, but our general dead code
-- elimination tends to be pretty good.