{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Simplify.Rules
( standardRules,
removeUnnecessaryCopy,
)
where
import Control.Monad
import Control.Monad.State
import Data.List (insert, unzip4, zip4)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.BasicOp
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Optimise.Simplify.Rules.Match
import Futhark.Util
topDownRules :: BuilderOps rep => [TopDownRule rep]
topDownRules :: forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules =
[ forall {k} (rep :: k) a.
RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric forall rep. BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun,
forall {k} (rep :: k) a.
RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric forall rep. BuilderOps rep => TopDownRuleGeneric rep
withAccTopDown
]
bottomUpRules :: (BuilderOps rep, TraverseOpStms rep) => [BottomUpRule rep]
bottomUpRules :: forall rep.
(BuilderOps rep, TraverseOpStms rep) =>
[BottomUpRule rep]
bottomUpRules =
[ forall {k} (rep :: k) a.
RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric forall rep.
(TraverseOpStms rep, BuilderOps rep) =>
BottomUpRuleGeneric rep
withAccBottomUp,
forall {k} (rep :: k) a.
RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex
]
standardRules :: (BuilderOps rep, TraverseOpStms rep, Aliased rep) => RuleBook rep
standardRules :: forall rep.
(BuilderOps rep, TraverseOpStms rep, Aliased rep) =>
RuleBook rep
standardRules =
forall {k} (m :: k).
[TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules forall rep.
(BuilderOps rep, TraverseOpStms rep) =>
[BottomUpRule rep]
bottomUpRules
forall a. Semigroup a => a -> a -> a
<> forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
loopRules
forall a. Semigroup a => a -> a -> a
<> forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
basicOpRules
forall a. Semigroup a => a -> a -> a
<> forall rep. BuilderOps rep => RuleBook rep
matchRules
removeUnnecessaryCopy :: BuilderOps rep => BottomUpRuleBasicOp rep
removeUnnecessaryCopy :: forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
removeUnnecessaryCopy (SymbolTable rep
vtable, UsageTable
used) (Pat [PatElem (LetDec rep)
d]) StmAux (ExpDec rep)
aux (Copy VName
v)
| Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
Bool -> Bool
not (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isInResult` UsageTable
used)
Bool -> Bool -> Bool
|| forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d
VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
Bool -> Bool -> Bool
|| (Bool
v_is_fresh Bool -> Bool -> Bool
&& Bool
v_not_used_again),
(Bool
v_not_used_again Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
d] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
where
v_not_used_again :: Bool
v_not_used_again = Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used)
v_is_fresh :: Bool
v_is_fresh = VName
v forall {k} (rep :: k). VName -> SymbolTable rep -> Names
`ST.lookupAliases` SymbolTable rep
vtable forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
consumable :: Bool
consumable = forall a. a -> Maybe a -> a
fromMaybe Bool
False forall a b. (a -> b) -> a -> b
$ do
Entry rep
e <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Entry rep -> Int
ST.entryDepth Entry rep
e forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). SymbolTable rep -> Int
ST.loopDepth SymbolTable rep
vtable
Entry rep -> Maybe Bool
consumableStm Entry rep
e forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Entry rep -> Maybe Bool
consumableFParam Entry rep
e
consumableFParam :: Entry rep -> Maybe Bool
consumableFParam =
forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (forall shape. TypeBase shape Uniqueness -> Bool
unique forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. DeclTyped t => t -> DeclType
declTypeOf) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Entry rep -> Maybe (FParamInfo rep)
ST.entryFParam
consumableStm :: Entry rep -> Maybe Bool
consumableStm Entry rep
e = do
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Entry rep -> Maybe (Stm rep)
ST.entryStm Entry rep
e
forall (f :: * -> *). Alternative f => Bool -> f ()
guard Bool
v_is_fresh
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
removeUnnecessaryCopy (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = forall {k} (rep :: k). Rule rep
Skip
constantFoldPrimFun :: BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun :: forall rep. BuilderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun TopDown rep
_ (Let Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_))
| Just [PrimValue]
args' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$
PrimValue -> SubExp
Constant PrimValue
result
where
isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = forall a. a -> Maybe a
Just PrimValue
v
isConst SubExp
_ = forall a. Maybe a
Nothing
constantFoldPrimFun TopDown rep
_ Stm rep
_ = forall {k} (rep :: k). Rule rep
Skip
simplifyIndex :: BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex :: forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyIndex (SymbolTable rep
vtable, UsageTable
used) pat :: Pat (LetDec rep)
pat@(Pat [PatElem (LetDec rep)
pe]) (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Index VName
idd Slice SubExp
inds)
| Just RuleM rep IndexResult
m <- forall (m :: * -> *).
MonadBuilder m =>
SymbolTable (Rep m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable rep
vtable TypeLookup
seType VName
idd Slice SubExp
inds Bool
consumed = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
IndexResult
res <- RuleM rep IndexResult
m
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
SubExpResult Certs
cs' SubExp
se ->
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs forall a. Semigroup a => a -> a -> a
<> Certs
cs') forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> BasicOp
SubExp SubExp
se
IndexResult Certs
extra_cs VName
idd' Slice SubExp
inds' ->
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs forall a. Semigroup a => a -> a -> a
<> Certs
extra_cs) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
where
consumed :: Bool
consumed = forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
seType :: TypeLookup
seType (Var VName
v) = forall {k} (rep :: k).
ASTRep rep =>
VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable
seType (Constant PrimValue
v) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = forall {k} (rep :: k). Rule rep
Skip
withAccTopDown :: BuilderOps rep => TopDownRuleGeneric rep
withAccTopDown :: forall rep. BuilderOps rep => TopDownRuleGeneric rep
withAccTopDown TopDown rep
_ (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [] Lambda rep
lam)) = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ do
Result
lam_res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) Result
lam_res) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
withAccTopDown TopDown rep
vtable (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [WithAccInput rep]
inputs Lambda rep
lam)) = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ do
let ([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
(Result
acc_res, Result
nonacc_res) =
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
([PatElem (LetDec rep)]
acc_pes, [PatElem (LetDec rep)]
nonacc_pes) =
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
([[PatElem (LetDec rep)]]
acc_pes', [WithAccInput rep]
inputs', [(Param (LParamInfo rep), Param (LParamInfo rep))]
params', Result
acc_res') <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *} {dec} {a} {c} {a} {dec}.
MonadBuilder m =>
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc forall a b. (a -> b) -> a -> b
$
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
(forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map forall {t :: * -> *} {a} {a} {c}. Foldable t => (a, t a, c) -> Int
inputArrs [WithAccInput rep]
inputs) [PatElem (LetDec rep)]
acc_pes)
[WithAccInput rep]
inputs
(forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
acc_params)
Result
acc_res
let ([Param (LParamInfo rep)]
cert_params', [Param (LParamInfo rep)]
acc_params') = forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), Param (LParamInfo rep))]
params'
([PatElem (LetDec rep)]
nonacc_pes', Result
nonacc_res') <-
forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
tryMoveNonAcc (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (LetDec rep)]
nonacc_pes Result
nonacc_res)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElem (LetDec rep)]]
acc_pes' forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
acc_pes Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) forall {k} (rep :: k) a. RuleM rep a
cannotSimplify
Lambda rep
lam' <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params' forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params') forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$
(forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam) {bodyResult :: Result
bodyResult = Result
acc_res' forall a. Semigroup a => a -> a -> a
<> Result
nonacc_res'}
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElem (LetDec rep)]]
acc_pes' forall a. Semigroup a => a -> a -> a
<> [PatElem (LetDec rep)]
nonacc_pes')) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput rep]
inputs' Lambda rep
lam'
where
num_nonaccs :: Int
num_nonaccs = forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs
tryMoveAcc :: ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
-> m (Maybe
([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes))
tryMoveAcc ([PatElem dec]
pes, (a
_, [VName]
arrs, c
_), (a
_, Param dec
acc_p), SubExpRes Certs
cs (Var VName
v))
| forall dec. Param dec -> VName
paramName Param dec
acc_p forall a. Eq a => a -> a -> Bool
== VName
v,
Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem dec]
pes [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(PatElem dec
pe, VName
arr) ->
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem dec
pe] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
tryMoveAcc ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
x =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just ([PatElem dec], (a, [VName], c), (a, Param dec), SubExpRes)
x
tryMoveNonAcc :: (PatElem (LetDec rep), SubExpRes)
-> RuleM rep (Maybe (PatElem (LetDec rep), SubExpRes))
tryMoveNonAcc (PatElem (LetDec rep)
pe, SubExpRes Certs
cs (Var VName
v))
| VName
v forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable,
Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = do
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
tryMoveNonAcc (PatElem (LetDec rep)
pe, SubExpRes Certs
cs (Constant PrimValue
v))
| Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = do
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
tryMoveNonAcc (PatElem (LetDec rep), SubExpRes)
x =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (PatElem (LetDec rep), SubExpRes)
x
withAccTopDown TopDown rep
_ Stm rep
_ = forall {k} (rep :: k). Rule rep
Skip
elimUpdates :: (ASTRep rep, TraverseOpStms rep) => [VName] -> Body rep -> (Body rep, [VName])
elimUpdates :: forall {k} (rep :: k).
(ASTRep rep, TraverseOpStms rep) =>
[VName] -> Body rep -> (Body rep, [VName])
elimUpdates [VName]
get_rid_of = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> (a, s)
runState forall a. Monoid a => a
mempty forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> StateT [VName] Identity (Body rep)
onBody
where
onBody :: Body rep -> StateT [VName] Identity (Body rep)
onBody Body rep
body = do
Stms rep
stms' <- Stms rep -> StateT [VName] Identity (Stms rep)
onStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body {bodyStms :: Stms rep
bodyStms = Stms rep
stms'}
onStms :: Stms rep -> StateT [VName] Identity (Stms rep)
onStms = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Stm rep -> StateT [VName] Identity (Stm rep)
onStm
onStm :: Stm rep -> StateT [VName] Identity (Stm rep)
onStm (Let pat :: Pat (LetDec rep)
pat@(Pat [PatElem VName
_ LetDec rep
dec]) StmAux (ExpDec rep)
aux (BasicOp (UpdateAcc VName
acc [SubExp]
_ [SubExp]
_)))
| Acc VName
c Shape
_ [Type]
_ NoUniqueness
_ <- forall t. Typed t => t -> Type
typeOf LetDec rep
dec,
VName
c forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
get_rid_of = do
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Ord a => a -> [a] -> [a]
insert VName
c)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
acc
onStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp rep -> StateT [VName] Identity (Exp rep)
onExp Exp rep
e
onExp :: Exp rep -> StateT [VName] Identity (Exp rep)
onExp = forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper rep rep (StateT [VName] Identity)
mapper
where
mapper :: Mapper rep rep (StateT [VName] Identity)
mapper =
forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
{ mapOnOp :: Op rep -> StateT [VName] Identity (Op rep)
mapOnOp = forall {k} (rep :: k) (m :: * -> *).
(TraverseOpStms rep, Monad m) =>
OpStmsTraverser m (Op rep) rep
traverseOpStms (\Scope rep
_ Stms rep
stms -> Stms rep -> StateT [VName] Identity (Stms rep)
onStms Stms rep
stms),
mapOnBody :: Scope rep -> Body rep -> StateT [VName] Identity (Body rep)
mapOnBody = \Scope rep
_ Body rep
body -> Body rep -> StateT [VName] Identity (Body rep)
onBody Body rep
body
}
withAccBottomUp :: (TraverseOpStms rep, BuilderOps rep) => BottomUpRuleGeneric rep
withAccBottomUp :: forall rep.
(TraverseOpStms rep, BuilderOps rep) =>
BottomUpRuleGeneric rep
withAccBottomUp (SymbolTable rep
_, UsageTable
utable) (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (WithAcc [WithAccInput rep]
inputs Lambda rep
lam))
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
let (Result
acc_res, Result
nonacc_res) =
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
([PatElem (LetDec rep)]
acc_pes, [PatElem (LetDec rep)]
nonacc_pes) =
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
let get_rid_of :: [VName]
get_rid_of =
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter ([PatElem (LetDec rep)], VName) -> Bool
getRidOf
forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip
(forall a. [Int] -> [a] -> [[a]]
chunks (forall a b. (a -> b) -> [a] -> [b]
map forall {t :: * -> *} {a} {a} {c}. Foldable t => (a, t a, c) -> Int
inputArrs [WithAccInput rep]
inputs) [PatElem (LetDec rep)]
acc_pes)
forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
cert_params
let ([PatElem (LetDec rep)]
nonacc_pes', Result
nonacc_res') =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem (LetDec rep), SubExpRes) -> Bool
keepNonAccRes forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (LetDec rep)]
nonacc_pes Result
nonacc_res
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
get_rid_of Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) forall {k} (rep :: k) a. RuleM rep a
cannotSimplify
let (Body rep
body', [VName]
eliminated) = forall {k} (rep :: k).
(ASTRep rep, TraverseOpStms rep) =>
[VName] -> Body rep -> (Body rep, [VName])
elimUpdates [VName]
get_rid_of forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
eliminated Bool -> Bool -> Bool
&& [PatElem (LetDec rep)]
nonacc_pes' forall a. Eq a => a -> a -> Bool
== [PatElem (LetDec rep)]
nonacc_pes) forall {k} (rep :: k) a. RuleM rep a
cannotSimplify
let pes' :: [PatElem (LetDec rep)]
pes' = [PatElem (LetDec rep)]
acc_pes forall a. [a] -> [a] -> [a]
++ [PatElem (LetDec rep)]
nonacc_pes'
Lambda rep
lam' <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params) forall a b. (a -> b) -> a -> b
$ do
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body rep
body'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Result
acc_res forall a. [a] -> [a] -> [a]
++ Result
nonacc_res'
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput rep]
inputs Lambda rep
lam'
where
num_nonaccs :: Int
num_nonaccs = forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs
getRidOf :: ([PatElem (LetDec rep)], VName) -> Bool
getRidOf ([PatElem (LetDec rep)]
pes, VName
_) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec rep)]
pes
keepNonAccRes :: (PatElem (LetDec rep), SubExpRes) -> Bool
keepNonAccRes (PatElem (LetDec rep)
pe, SubExpRes
_) = forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.used` UsageTable
utable
withAccBottomUp (SymbolTable rep, UsageTable)
_ Stm rep
_ = forall {k} (rep :: k). Rule rep
Skip