{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Simplify.Rules
( standardRules,
removeUnnecessaryCopy,
)
where
import Control.Monad
import Data.Either
import Data.List (find, unzip4, zip4)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.BasicOp
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Util
topDownRules :: BuilderOps rep => [TopDownRule rep]
topDownRules :: [TopDownRule rep]
topDownRules =
[ RuleGeneric rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun,
RuleIf rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleIf rep
ruleIf,
RuleIf rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleIf rep
hoistBranchInvariant,
RuleGeneric rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleGeneric rep
withAccTopDown
]
bottomUpRules :: BuilderOps rep => [BottomUpRule rep]
bottomUpRules :: [BottomUpRule rep]
bottomUpRules =
[ RuleIf rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleIf rep
removeDeadBranchResult,
RuleGeneric rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleGeneric rep
withAccBottomUp,
RuleBasicOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex
]
standardRules :: (BuilderOps rep, Aliased rep) => RuleBook rep
standardRules :: RuleBook rep
standardRules = [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule rep]
forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules [BottomUpRule rep]
forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
loopRules RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
basicOpRules
removeUnnecessaryCopy :: (BuilderOps rep, Aliased rep) => BottomUpRuleBasicOp rep
removeUnnecessaryCopy :: BottomUpRuleBasicOp rep
removeUnnecessaryCopy (SymbolTable rep
vtable, UsageTable
used) (Pat [PatElemT (LetDec rep)
d]) StmAux (ExpDec rep)
_ (Copy VName
v)
| Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
Bool -> Bool
not (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isInResult` UsageTable
used) Bool -> Bool -> Bool
|| (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
(Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used) Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
d] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
where
consumable :: Bool
consumable = Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
fromMaybe Bool
False (Maybe Bool -> Bool) -> Maybe Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
Entry rep
e <- VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Entry rep -> Int
forall rep. Entry rep -> Int
ST.entryDepth Entry rep
e Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== SymbolTable rep -> Int
forall rep. SymbolTable rep -> Int
ST.loopDepth SymbolTable rep
vtable
Entry rep -> Maybe Bool
consumableStm Entry rep
e Maybe Bool -> Maybe Bool -> Maybe Bool
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Entry rep -> Maybe Bool
consumableFParam Entry rep
e
consumableFParam :: Entry rep -> Maybe Bool
consumableFParam =
Bool -> Maybe Bool
forall a. a -> Maybe a
Just (Bool -> Maybe Bool)
-> (Entry rep -> Bool) -> Entry rep -> Maybe Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> (FParamInfo rep -> Bool) -> Maybe (FParamInfo rep) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> (FParamInfo rep -> TypeBase Shape Uniqueness)
-> FParamInfo rep
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParamInfo rep -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf) (Maybe (FParamInfo rep) -> Bool)
-> (Entry rep -> Maybe (FParamInfo rep)) -> Entry rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry rep -> Maybe (FParamInfo rep)
forall rep. Entry rep -> Maybe (FParamInfo rep)
ST.entryFParam
consumableStm :: Entry rep -> Maybe Bool
consumableStm Entry rep
e = do
PatT (LetDec rep)
pat <- Stm rep -> PatT (LetDec rep)
forall rep. Stm rep -> Pat rep
stmPat (Stm rep -> PatT (LetDec rep))
-> Maybe (Stm rep) -> Maybe (PatT (LetDec rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Entry rep -> Maybe (Stm rep)
forall rep. Entry rep -> Maybe (Stm rep)
ST.entryStm Entry rep
e
PatElemT (LetDec rep)
pe <- (PatElemT (LetDec rep) -> Bool)
-> [PatElemT (LetDec rep)] -> Maybe (PatElemT (LetDec rep))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> (PatElemT (LetDec rep) -> VName)
-> PatElemT (LetDec rep)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName) (PatT (LetDec rep) -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems PatT (LetDec rep)
pat)
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec rep) -> Names
forall a. AliasesOf a => a -> Names
aliasesOf PatElemT (LetDec rep)
pe Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
Bool -> Maybe Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
removeUnnecessaryCopy (SymbolTable rep, UsageTable)
_ PatT (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip
constantFoldPrimFun :: BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun :: TopDownRuleGeneric rep
constantFoldPrimFun TopDown rep
_ (Let Pat rep
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_))
| Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- String
-> Map
String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
result
where
isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown rep
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip
simplifyIndex :: BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex :: BottomUpRuleBasicOp rep
simplifyIndex (SymbolTable rep
vtable, UsageTable
used) pat :: Pat rep
pat@(Pat [PatElemT (LetDec rep)
pe]) (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Index VName
idd Slice SubExp
inds)
| Just RuleM rep IndexResult
m <- SymbolTable (Rep (RuleM rep))
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM rep IndexResult)
forall (m :: * -> *).
MonadBuilder m =>
SymbolTable (Rep m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable rep
SymbolTable (Rep (RuleM rep))
vtable TypeLookup
seType VName
idd Slice SubExp
inds Bool
consumed = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
IndexResult
res <- RuleM rep IndexResult
m
Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
SubExpResult Certs
cs' SubExp
se ->
Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs') (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
IndexResult Certs
extra_cs VName
idd' Slice SubExp
inds' ->
Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
extra_cs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
where
consumed :: Bool
consumed = PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
seType :: TypeLookup
seType (Var VName
v) = VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable
seType (Constant PrimValue
v) = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable rep, UsageTable)
_ Pat rep
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip
ruleIf :: BuilderOps rep => TopDownRuleIf rep
ruleIf :: TopDownRuleIf rep
ruleIf TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
e1, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
_ IfSort
ifsort)
| Just BodyT rep
branch <- Maybe (BodyT rep)
checkBranch,
IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let ses :: Result
ses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
branch
Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
branch
[RuleM rep ()] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
[ Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT (LetDec rep)
p, SubExpRes Certs
cs SubExp
se) <- [PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat) Result
ses
]
where
checkBranch :: Maybe (BodyT rep)
checkBranch
| SubExp -> Bool
isCt1 SubExp
e1 = BodyT rep -> Maybe (BodyT rep)
forall a. a -> Maybe a
Just BodyT rep
tb
| SubExp -> Bool
isCt0 SubExp
e1 = BodyT rep -> Maybe (BodyT rep)
forall a. a -> Maybe a
Just BodyT rep
fb
| Bool
otherwise = Maybe (BodyT rep)
forall a. Maybe a
Nothing
ruleIf
TopDown rep
_
Pat rep
pat
StmAux (ExpDec rep)
_
( SubExp
cond,
Body BodyDec rep
_ Stms rep
tstms [SubExpRes Certs
tcs (Constant (BoolValue Bool
True))],
Body BodyDec rep
_ Stms rep
fstms [SubExpRes Certs
fcs SubExp
se],
IfDec [BranchType rep]
ts IfSort
_
)
| Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms rep
tstms,
Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms rep
fstms,
[Prim PrimType
Bool] <- (BranchType rep -> ExtType) -> [BranchType rep] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType rep]
ts =
RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
fcs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
ruleIf TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
cond, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
ts IfSort
_)
| Body BodyDec rep
_ Stms rep
tstms [SubExpRes Certs
tcs SubExp
tres] <- BodyT rep
tb,
Body BodyDec rep
_ Stms rep
fstms [SubExpRes Certs
fcs SubExp
fres] <- BodyT rep
fb,
(Stm rep -> Bool) -> Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (ExpT rep -> Bool) -> (Stm rep -> ExpT rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp) (Stms rep -> Bool) -> Stms rep -> Bool
forall a b. (a -> b) -> a -> b
$ Stms rep
tstms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
fstms,
(BranchType rep -> Bool) -> [BranchType rep] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((ExtType -> ExtType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) (ExtType -> Bool)
-> (BranchType rep -> ExtType) -> BranchType rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf) [BranchType rep]
ts = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (RuleM rep))
tstms
Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (RuleM rep))
fstms
ExpT rep
e <-
BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
BinOp
LogOr
(ExpT rep -> RuleM rep (ExpT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT rep -> RuleM rep (ExpT rep))
-> ExpT rep -> RuleM rep (ExpT rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
( BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
BinOp
LogAnd
(ExpT rep -> RuleM rep (ExpT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT rep -> RuleM rep (ExpT rep))
-> ExpT rep -> RuleM rep (ExpT rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
(ExpT rep -> RuleM rep (ExpT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT rep -> RuleM rep (ExpT rep))
-> ExpT rep -> RuleM rep (ExpT rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres)
)
Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
fcs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat ExpT rep
Exp (Rep (RuleM rep))
e
ruleIf TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
_, BodyT rep
tbranch, BodyT rep
_, IfDec [BranchType rep]
_ IfSort
IfFallback)
| (Stm rep -> Bool) -> Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (ExpT rep -> Bool) -> (Stm rep -> ExpT rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp) (Stms rep -> Bool) -> Stms rep -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
tbranch = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let ses :: Result
ses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
tbranch
Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
tbranch
[RuleM rep ()] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
[ Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT (LetDec rep)
p, SubExpRes Certs
cs SubExp
se) <- [PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat) Result
ses
]
ruleIf TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
cond, BodyT rep
tb, BodyT rep
fb, IfDec (BranchType rep)
_)
| Body BodyDec rep
_ Stms rep
_ [SubExpRes Certs
tcs (Constant (IntValue IntValue
t))] <- BodyT rep
tb,
Body BodyDec rep
_ Stms rep
_ [SubExpRes Certs
fcs (Constant (IntValue IntValue
f))] <- BodyT rep
fb =
if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f Bool -> Bool -> Bool
&& Certs
tcs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty Bool -> Bool -> Bool
&& Certs
fcs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty
then
RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
else
if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
then RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
SubExp
cond_neg <- String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
else Rule rep
forall rep. Rule rep
Skip
ruleIf TopDown rep
_ Pat rep
_ StmAux (ExpDec rep)
_ (SubExp, BodyT rep, BodyT rep, IfDec (BranchType rep))
_ = Rule rep
forall rep. Rule rep
Skip
hoistBranchInvariant :: BuilderOps rep => TopDownRuleIf rep
hoistBranchInvariant :: TopDownRuleIf rep
hoistBranchInvariant TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
cond, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
ret IfSort
ifsort) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let tses :: Result
tses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
tb
fses :: Result
fses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
fb
([Maybe (Int, SubExp)]
hoistings, ([PatElemT (LetDec rep)]
pes, [BranchType rep]
ts, [(SubExpRes, SubExpRes)]
res)) <-
([Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec rep)], [BranchType rep],
[(SubExpRes, SubExpRes)])))
-> RuleM
rep
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> RuleM
rep
([Maybe (Int, SubExp)],
([PatElemT (LetDec rep)], [BranchType rep],
[(SubExpRes, SubExpRes)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([PatElemT (LetDec rep)], [BranchType rep],
[(SubExpRes, SubExpRes)]))
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec rep)], [BranchType rep],
[(SubExpRes, SubExpRes)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([PatElemT (LetDec rep)], [BranchType rep],
[(SubExpRes, SubExpRes)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (([Maybe (Int, SubExp)],
[(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec rep)], [BranchType rep],
[(SubExpRes, SubExpRes)])))
-> ([Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]))
-> [Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec rep)], [BranchType rep],
[(SubExpRes, SubExpRes)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) (RuleM
rep
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
-> RuleM
rep
([Maybe (Int, SubExp)],
([PatElemT (LetDec rep)], [BranchType rep],
[(SubExpRes, SubExpRes)])))
-> ([(Int, PatElemT (LetDec rep), BranchType rep,
(SubExpRes, SubExpRes))]
-> RuleM
rep
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))])
-> [(Int, PatElemT (LetDec rep), BranchType rep,
(SubExpRes, SubExpRes))]
-> RuleM
rep
([Maybe (Int, SubExp)],
([PatElemT (LetDec rep)], [BranchType rep],
[(SubExpRes, SubExpRes)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, PatElemT (LetDec rep), BranchType rep,
(SubExpRes, SubExpRes))
-> RuleM
rep
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))))
-> [(Int, PatElemT (LetDec rep), BranchType rep,
(SubExpRes, SubExpRes))]
-> RuleM
rep
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int, PatElemT (LetDec rep), BranchType rep,
(SubExpRes, SubExpRes))
-> RuleM
rep
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
branchInvariant ([(Int, PatElemT (LetDec rep), BranchType rep,
(SubExpRes, SubExpRes))]
-> RuleM
rep
([Maybe (Int, SubExp)],
([PatElemT (LetDec rep)], [BranchType rep],
[(SubExpRes, SubExpRes)])))
-> [(Int, PatElemT (LetDec rep), BranchType rep,
(SubExpRes, SubExpRes))]
-> RuleM
rep
([Maybe (Int, SubExp)],
([PatElemT (LetDec rep)], [BranchType rep],
[(SubExpRes, SubExpRes)]))
forall a b. (a -> b) -> a -> b
$
[Int]
-> [PatElemT (LetDec rep)]
-> [BranchType rep]
-> [(SubExpRes, SubExpRes)]
-> [(Int, PatElemT (LetDec rep), BranchType rep,
(SubExpRes, SubExpRes))]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int
0 ..] (Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat) [BranchType rep]
ret (Result -> Result -> [(SubExpRes, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
tses Result
fses)
let ctx_fixes :: [(Int, SubExp)]
ctx_fixes = [Maybe (Int, SubExp)] -> [(Int, SubExp)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, SubExp)]
hoistings
(Result
tses', Result
fses') = [(SubExpRes, SubExpRes)] -> (Result, Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExpRes, SubExpRes)]
res
tb' :: BodyT rep
tb' = BodyT rep
tb {bodyResult :: Result
bodyResult = Result
tses'}
fb' :: BodyT rep
fb' = BodyT rep
fb {bodyResult :: Result
bodyResult = Result
fses'}
ret' :: [BranchType rep]
ret' = ((Int, SubExp) -> [BranchType rep] -> [BranchType rep])
-> [BranchType rep] -> [(Int, SubExp)] -> [BranchType rep]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType rep] -> [BranchType rep])
-> (Int, SubExp) -> [BranchType rep] -> [BranchType rep]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType rep] -> [BranchType rep]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchType rep]
ts [(Int, SubExp)]
ctx_fixes
if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe (Int, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (Int, SubExp)]
hoistings
then do
BodyT rep
tb'' <- BodyT (Rep (RuleM rep))
-> [ExtType] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BodyT (Rep m) -> [ExtType] -> m (BodyT (Rep m))
reshapeBodyResults BodyT rep
BodyT (Rep (RuleM rep))
tb' ([ExtType] -> RuleM rep (BodyT (Rep (RuleM rep))))
-> [ExtType] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ (BranchType rep -> ExtType) -> [BranchType rep] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType rep]
ret'
BodyT rep
fb'' <- BodyT (Rep (RuleM rep))
-> [ExtType] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BodyT (Rep m) -> [ExtType] -> m (BodyT (Rep m))
reshapeBodyResults BodyT rep
BodyT (Rep (RuleM rep))
fb' ([ExtType] -> RuleM rep (BodyT (Rep (RuleM rep))))
-> [ExtType] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ (BranchType rep -> ExtType) -> [BranchType rep] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType rep]
ret'
Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> Pat rep
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (LetDec rep)]
pes) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond BodyT rep
tb'' BodyT rep
fb'' ([BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
ret' IfSort
ifsort)
else RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
where
bound_in_branches :: Names
bound_in_branches =
[VName] -> Names
namesFromList ([VName] -> Names)
-> (Seq (Stm rep) -> [VName]) -> Seq (Stm rep) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> [VName]) -> Seq (Stm rep) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames (Pat rep -> [VName]) -> (Stm rep -> Pat rep) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat rep
forall rep. Stm rep -> Pat rep
stmPat) (Seq (Stm rep) -> Names) -> Seq (Stm rep) -> Names
forall a b. (a -> b) -> a -> b
$
BodyT rep -> Seq (Stm rep)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
tb Seq (Stm rep) -> Seq (Stm rep) -> Seq (Stm rep)
forall a. Semigroup a => a -> a -> a
<> BodyT rep -> Seq (Stm rep)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
fb
invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_branches
branchInvariant :: (Int, PatElemT (LetDec rep), BranchType rep,
(SubExpRes, SubExpRes))
-> RuleM
rep
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
branchInvariant (Int
i, PatElemT (LetDec rep)
pe, BranchType rep
t, (SubExpRes
tse, SubExpRes
fse))
| SubExpRes
tse SubExpRes -> SubExpRes -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes
fse = do
Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (SubExpRes -> Certs
resCerts SubExpRes
tse Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> SubExpRes -> Certs
resCerts SubExpRes
fse) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
tse
Int
-> PatElemT (LetDec rep)
-> RuleM
rep
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall (f :: * -> *) a dec b.
Applicative f =>
a -> PatElemT dec -> f (Either (Maybe (a, SubExp)) b)
hoisted Int
i PatElemT (LetDec rep)
pe
| SubExp -> Bool
invariant (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
tse,
SubExp -> Bool
invariant (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
fse,
Pat rep -> Int
forall dec. PatT dec -> Int
patSize Pat rep
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
Prim PrimType
_ <- PatElemT (LetDec rep) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec rep)
pe = do
[BranchType rep]
bt <- Pat rep -> RuleM rep [BranchType rep]
forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m, Monad m) =>
Pat rep -> m [BranchType rep]
expTypesFromPat (Pat rep -> RuleM rep [BranchType rep])
-> Pat rep -> RuleM rep [BranchType rep]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec rep)] -> Pat rep
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (LetDec rep)
pe]
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe]
(ExpT rep -> RuleM rep ()) -> RuleM rep (ExpT rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond (BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (BodyT rep)
-> RuleM rep (BodyT rep -> IfDec (BranchType rep) -> ExpT rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExpRes -> SubExp
resSubExp SubExpRes
tse]
RuleM rep (BodyT rep -> IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (BodyT rep)
-> RuleM rep (IfDec (BranchType rep) -> ExpT rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> RuleM rep (BodyT (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExpRes -> SubExp
resSubExp SubExpRes
fse]
RuleM rep (IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (IfDec (BranchType rep)) -> RuleM rep (ExpT rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType rep) -> RuleM rep (IfDec (BranchType rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
bt IfSort
ifsort)
)
Int
-> PatElemT (LetDec rep)
-> RuleM
rep
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall (f :: * -> *) a dec b.
Applicative f =>
a -> PatElemT dec -> f (Either (Maybe (a, SubExp)) b)
hoisted Int
i PatElemT (LetDec rep)
pe
| Bool
otherwise =
Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> RuleM
rep
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> RuleM
rep
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))))
-> Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> RuleM
rep
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
-> Either
(Maybe (Int, SubExp))
(PatElemT (LetDec rep), BranchType rep, (SubExpRes, SubExpRes))
forall a b. b -> Either a b
Right (PatElemT (LetDec rep)
pe, BranchType rep
t, (SubExpRes
tse, SubExpRes
fse))
hoisted :: a -> PatElemT dec -> f (Either (Maybe (a, SubExp)) b)
hoisted a
i PatElemT dec
pe = Either (Maybe (a, SubExp)) b -> f (Either (Maybe (a, SubExp)) b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (Maybe (a, SubExp)) b -> f (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> f (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left (Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b)
-> Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. (a -> b) -> a -> b
$ (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe)
reshapeBodyResults :: BodyT (Rep m) -> [ExtType] -> m (BodyT (Rep m))
reshapeBodyResults BodyT (Rep m)
body [ExtType]
rets = m Result -> m (BodyT (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (m Result -> m (BodyT (Rep m))) -> m Result -> m (BodyT (Rep m))
forall a b. (a -> b) -> a -> b
$ do
Result
ses <- BodyT (Rep m) -> m Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind BodyT (Rep m)
body
let (Result
ctx_ses, Result
val_ses) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
rets) Result
ses
(Result
ctx_ses Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++) (Result -> Result) -> m Result -> m Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> ExtType -> m SubExpRes)
-> Result -> [ExtType] -> m Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExpRes -> ExtType -> m SubExpRes
forall (m :: * -> *).
MonadBuilder m =>
SubExpRes -> ExtType -> m SubExpRes
reshapeResult Result
val_ses [ExtType]
rets
reshapeResult :: SubExpRes -> ExtType -> m SubExpRes
reshapeResult (SubExpRes Certs
cs (Var VName
v)) t :: ExtType
t@Array {} = do
Type
v_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
let newshape :: [SubExp]
newshape = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ExtType -> Type -> Type
removeExistentials ExtType
t Type
v_t
Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs
(SubExp -> SubExpRes) -> m SubExp -> m SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> if [SubExp]
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
v_t
then String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" ([SubExp] -> VName -> Exp (Rep m)
forall rep. [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newshape VName
v)
else SubExp -> m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
reshapeResult SubExpRes
se ExtType
_ =
SubExpRes -> m SubExpRes
forall (m :: * -> *) a. Monad m => a -> m a
return SubExpRes
se
removeDeadBranchResult :: BuilderOps rep => BottomUpRuleIf rep
removeDeadBranchResult :: BottomUpRuleIf rep
removeDeadBranchResult (SymbolTable rep
_, UsageTable
used) Pat rep
pat StmAux (ExpDec rep)
_ (SubExp
e1, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
rettype IfSort
ifsort)
|
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` (PatElemT (LetDec rep) -> Names)
-> [PatElemT (LetDec rep)] -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap PatElemT (LetDec rep) -> Names
forall a. FreeIn a => a -> Names
freeIn (Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat)) (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat),
[Bool]
patused <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat,
Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) =
let tses :: Result
tses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
tb
fses :: Result
fses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
fb
pick :: [a] -> [a]
pick :: [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
tb' :: BodyT rep
tb' = BodyT rep
tb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
tses}
fb' :: BodyT rep
fb' = BodyT rep
fb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
fses}
pat' :: [PatElemT (LetDec rep)]
pat' = [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. [a] -> [a]
pick ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)])
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a b. (a -> b) -> a -> b
$ Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat
rettype' :: [BranchType rep]
rettype' = [BranchType rep] -> [BranchType rep]
forall a. [a] -> [a]
pick [BranchType rep]
rettype
in RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> Pat rep
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (LetDec rep)]
pat') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
e1 BodyT rep
tb' BodyT rep
fb' (IfDec (BranchType rep) -> ExpT rep)
-> IfDec (BranchType rep) -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
rettype' IfSort
ifsort
| Bool
otherwise = Rule rep
forall rep. Rule rep
Skip
withAccTopDown :: BuilderOps rep => TopDownRuleGeneric rep
withAccTopDown :: TopDownRuleGeneric rep
withAccTopDown TopDown rep
_ (Let Pat rep
pat StmAux (ExpDec rep)
aux (WithAcc [] Lambda rep
lam)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
Result
lam_res <- Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
[(VName, SubExpRes)]
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat) Result
lam_res) (((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ())
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
withAccTopDown TopDown rep
vtable (Let Pat rep
pat StmAux (ExpDec rep)
aux (WithAcc [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs Lambda rep
lam)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let ([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs) ([Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam
(Result
acc_res, Result
nonacc_res) =
Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT rep -> Result) -> BodyT rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
([PatElemT (LetDec rep)]
acc_pes, [PatElemT (LetDec rep)]
nonacc_pes) =
Int
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs ([PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)]))
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a b. (a -> b) -> a -> b
$ Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat
([[PatElemT (LetDec rep)]]
acc_pes', [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs', [(Param (LParamInfo rep), Param (LParamInfo rep))]
params', Result
acc_res') <-
([Maybe
([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> RuleM
rep
[Maybe
([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([Maybe
([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)])
-> [Maybe
([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe
([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a. [Maybe a] -> [a]
catMaybes) (RuleM
rep
[Maybe
([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
[Maybe
([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)])
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
-> RuleM
rep
(Maybe
([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)))
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
[Maybe
([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
-> RuleM
rep
(Maybe
([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes))
forall (m :: * -> *) dec a c a dec.
MonadBuilder m =>
([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc ([([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> RuleM
rep
([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b. (a -> b) -> a -> b
$
[[PatElemT (LetDec rep)]]
-> [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> Result
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
([Int] -> [PatElemT (LetDec rep)] -> [[PatElemT (LetDec rep)]]
forall a. [Int] -> [a] -> [[a]]
chunks (((Shape, [VName], Maybe (Lambda rep, [SubExp])) -> Int)
-> [(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, [VName], Maybe (Lambda rep, [SubExp])) -> Int
forall (t :: * -> *) a a c. Foldable t => (a, t a, c) -> Int
inputArrs [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs) [PatElemT (LetDec rep)]
acc_pes)
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs
([Param (LParamInfo rep)]
-> [Param (LParamInfo rep)]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
acc_params)
Result
acc_res
let ([Param (LParamInfo rep)]
cert_params', [Param (LParamInfo rep)]
acc_params') = [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), Param (LParamInfo rep))]
params'
([PatElemT (LetDec rep)]
nonacc_pes', Result
nonacc_res') <-
[(PatElemT (LetDec rep), SubExpRes)]
-> ([PatElemT (LetDec rep)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElemT (LetDec rep), SubExpRes)]
-> ([PatElemT (LetDec rep)], Result))
-> ([Maybe (PatElemT (LetDec rep), SubExpRes)]
-> [(PatElemT (LetDec rep), SubExpRes)])
-> [Maybe (PatElemT (LetDec rep), SubExpRes)]
-> ([PatElemT (LetDec rep)], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (PatElemT (LetDec rep), SubExpRes)]
-> [(PatElemT (LetDec rep), SubExpRes)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (PatElemT (LetDec rep), SubExpRes)]
-> ([PatElemT (LetDec rep)], Result))
-> RuleM rep [Maybe (PatElemT (LetDec rep), SubExpRes)]
-> RuleM rep ([PatElemT (LetDec rep)], Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes)))
-> [(PatElemT (LetDec rep), SubExpRes)]
-> RuleM rep [Maybe (PatElemT (LetDec rep), SubExpRes)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
tryMoveNonAcc ([PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (LetDec rep)]
nonacc_pes Result
nonacc_res)
Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
acc_pes Bool -> Bool -> Bool
&& [PatElemT (LetDec rep)]
nonacc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
Lambda rep
lam' <-
[LParam (Rep (RuleM rep))]
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params' [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params') (RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep))))
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$
Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam) {bodyResult :: Result
bodyResult = Result
acc_res' Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
nonacc_res'}
Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> Pat rep
forall dec. [PatElemT dec] -> PatT dec
Pat ([[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)]
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. Semigroup a => a -> a -> a
<> [PatElemT (LetDec rep)]
nonacc_pes')) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
forall rep.
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
WithAcc [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs' Lambda rep
lam'
where
num_nonaccs :: Int
num_nonaccs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
lam) Int -> Int -> Int
forall a. Num a => a -> a -> a
- [(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs
inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs
tryMoveAcc :: ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc ([PatElemT dec]
pes, (a
_, [VName]
arrs, c
_), (a
_, Param dec
acc_p), SubExpRes Certs
cs (Var VName
v))
| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
acc_p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v,
Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
[(PatElemT dec, VName)] -> ((PatElemT dec, VName) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT dec] -> [VName] -> [(PatElemT dec, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT dec]
pes [VName]
arrs) (((PatElemT dec, VName) -> m ()) -> m ())
-> ((PatElemT dec, VName) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT dec
pe, VName
arr) ->
[VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
forall a. Maybe a
Nothing
tryMoveAcc ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
x =
Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)))
-> Maybe
([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes))
forall a b. (a -> b) -> a -> b
$ ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> Maybe
([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
forall a. a -> Maybe a
Just ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExpRes)
x
tryMoveNonAcc :: (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
tryMoveNonAcc (PatElemT (LetDec rep)
pe, SubExpRes Certs
cs (Var VName
v))
| VName
v VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable,
Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
Maybe (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PatElemT (LetDec rep), SubExpRes)
forall a. Maybe a
Nothing
tryMoveNonAcc (PatElemT (LetDec rep)
pe, SubExpRes Certs
cs (Constant PrimValue
v))
| Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty = do
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
Maybe (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PatElemT (LetDec rep), SubExpRes)
forall a. Maybe a
Nothing
tryMoveNonAcc (PatElemT (LetDec rep), SubExpRes)
x =
Maybe (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes)))
-> Maybe (PatElemT (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExpRes))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec rep), SubExpRes)
-> Maybe (PatElemT (LetDec rep), SubExpRes)
forall a. a -> Maybe a
Just (PatElemT (LetDec rep), SubExpRes)
x
withAccTopDown TopDown rep
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip
withAccBottomUp :: BuilderOps rep => BottomUpRuleGeneric rep
withAccBottomUp :: BottomUpRuleGeneric rep
withAccBottomUp (SymbolTable rep
_, UsageTable
utable) (Let Pat rep
pat StmAux (ExpDec rep)
aux (WithAcc [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs Lambda rep
lam))
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
let (Result
acc_res, Result
nonacc_res) =
Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT rep -> Result) -> BodyT rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
([PatElemT (LetDec rep)]
acc_pes, [PatElemT (LetDec rep)]
nonacc_pes) =
Int
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs ([PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)]))
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a b. (a -> b) -> a -> b
$ Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat
([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs) ([Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam
let ([[PatElemT (LetDec rep)]]
acc_pes', [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs', [(Param (LParamInfo rep), Param (LParamInfo rep))]
param_pairs, Result
acc_res') =
[([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)])
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
-> Bool)
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a. (a -> Bool) -> [a] -> [a]
filter ([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
-> Bool
keepAccRes ([([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
-> ([[PatElemT (LetDec rep)]],
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))],
[(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b. (a -> b) -> a -> b
$
[[PatElemT (LetDec rep)]]
-> [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> Result
-> [([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
([Int] -> [PatElemT (LetDec rep)] -> [[PatElemT (LetDec rep)]]
forall a. [Int] -> [a] -> [[a]]
chunks (((Shape, [VName], Maybe (Lambda rep, [SubExp])) -> Int)
-> [(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, [VName], Maybe (Lambda rep, [SubExp])) -> Int
forall (t :: * -> *) a a c. Foldable t => (a, t a, c) -> Int
inputArrs [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs) [PatElemT (LetDec rep)]
acc_pes)
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs
([Param (LParamInfo rep)]
-> [Param (LParamInfo rep)]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
acc_params)
Result
acc_res
([Param (LParamInfo rep)]
cert_params', [Param (LParamInfo rep)]
acc_params') = [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), Param (LParamInfo rep))]
param_pairs
let ([PatElemT (LetDec rep)]
nonacc_pes', Result
nonacc_res') =
[(PatElemT (LetDec rep), SubExpRes)]
-> ([PatElemT (LetDec rep)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElemT (LetDec rep), SubExpRes)]
-> ([PatElemT (LetDec rep)], Result))
-> [(PatElemT (LetDec rep), SubExpRes)]
-> ([PatElemT (LetDec rep)], Result)
forall a b. (a -> b) -> a -> b
$ ((PatElemT (LetDec rep), SubExpRes) -> Bool)
-> [(PatElemT (LetDec rep), SubExpRes)]
-> [(PatElemT (LetDec rep), SubExpRes)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (LetDec rep), SubExpRes) -> Bool
keepNonAccRes ([(PatElemT (LetDec rep), SubExpRes)]
-> [(PatElemT (LetDec rep), SubExpRes)])
-> [(PatElemT (LetDec rep), SubExpRes)]
-> [(PatElemT (LetDec rep), SubExpRes)]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (LetDec rep)]
nonacc_pes Result
nonacc_res
Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
acc_pes Bool -> Bool -> Bool
&& [PatElemT (LetDec rep)]
nonacc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
let pes' :: [PatElemT (LetDec rep)]
pes' = [[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)]
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (LetDec rep)]
nonacc_pes'
Lambda rep
lam' <- [LParam (Rep (RuleM rep))]
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params' [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params') (RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep))))
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ do
RuleM rep Result -> RuleM rep ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (RuleM rep Result -> RuleM rep ())
-> RuleM rep Result -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
Result -> RuleM rep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> RuleM rep Result) -> Result -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Result
acc_res' Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
nonacc_res'
StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> Pat rep
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (LetDec rep)]
pes') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
forall rep.
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
WithAcc [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs' Lambda rep
lam'
where
num_nonaccs :: Int
num_nonaccs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
lam) Int -> Int -> Int
forall a. Num a => a -> a -> a
- [(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs
inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs
keepAccRes :: ([PatElemT (LetDec rep)],
(Shape, [VName], Maybe (Lambda rep, [SubExp])),
(Param (LParamInfo rep), Param (LParamInfo rep)), SubExpRes)
-> Bool
keepAccRes ([PatElemT (LetDec rep)]
pes, (Shape, [VName], Maybe (Lambda rep, [SubExp]))
_, (Param (LParamInfo rep), Param (LParamInfo rep))
_, SubExpRes
_) = (PatElemT (LetDec rep) -> Bool) -> [PatElemT (LetDec rep)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) (VName -> Bool)
-> (PatElemT (LetDec rep) -> VName)
-> PatElemT (LetDec rep)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT (LetDec rep)]
pes
keepNonAccRes :: (PatElemT (LetDec rep), SubExpRes) -> Bool
keepNonAccRes (PatElemT (LetDec rep)
pe, SubExpRes
_) = PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.used` UsageTable
utable
withAccBottomUp (SymbolTable rep, UsageTable)
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip
isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False
isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False