{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Simplify.Rules
( standardRules,
removeUnnecessaryCopy,
)
where
import Control.Monad
import Data.Either
import Data.List (find, foldl', isSuffixOf, partition, sort)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.ClosedForm
import Futhark.Optimise.Simplify.Rule
import Futhark.Transform.Rename
import Futhark.Util
topDownRules :: (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules :: [TopDownRule lore]
topDownRules =
[ RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables,
RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop,
RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop,
RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore.
(BinderOps lore, Aliased lore) =>
TopDownRuleDoLoop lore
simplifyLoopVariables,
RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
narrowLoopType,
RuleGeneric lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleGeneric lore a -> SimplificationRule lore a
RuleGeneric RuleGeneric lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun,
RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
ruleIf,
RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant,
RuleBasicOp lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleBasicOp lore
ruleBasicOp
]
bottomUpRules :: BinderOps lore => [BottomUpRule lore]
bottomUpRules :: [BottomUpRule lore]
bottomUpRules =
[ RuleDoLoop lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables,
RuleIf lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult,
RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex,
RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyConcat
]
standardRules :: (BinderOps lore, Aliased lore) => RuleBook lore
standardRules :: RuleBook lore
standardRules = [TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule lore]
forall lore. (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules [BottomUpRule lore]
forall lore. BinderOps lore => [BottomUpRule lore]
bottomUpRules
removeRedundantMergeVariables :: BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables :: BottomUpRuleDoLoop lore
removeRedundantMergeVariables (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
body)
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> Bool) -> [(FParam lore, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (FParam lore -> Bool
usedAfterLoop (FParam lore -> Bool)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val,
[(FParam lore, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(FParam lore, SubExp)]
ctx
=
let ([SubExp]
ctx_es, [SubExp]
val_es) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(FParam lore, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FParam lore, SubExp)]
ctx) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
body
necessaryForReturned :: Names
necessaryForReturned =
(FParam lore -> Bool)
-> [(FParam lore, SubExp)] -> Map VName Names -> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Map VName Names -> Names
findNecessaryForReturned
FParam lore -> Bool
usedAfterLoopOrInForm
([FParam lore] -> [SubExp] -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst ([(FParam lore, SubExp)] -> [FParam lore])
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val) ([SubExp] -> [(FParam lore, SubExp)])
-> [SubExp] -> [(FParam lore, SubExp)]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_es [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es)
(BodyT lore -> Map VName Names
forall lore. ASTLore lore => Body lore -> Map VName Names
dataDependencies BodyT lore
body)
resIsNecessary :: ((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary ((FParam lore
v, SubExp
_), SubExp
_) =
FParam lore -> Bool
usedAfterLoop FParam lore
v
Bool -> Bool -> Bool
|| FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
v VName -> Names -> Bool
`nameIn` Names
necessaryForReturned
Bool -> Bool -> Bool
|| FParam lore -> Bool
referencedInPat FParam lore
v
Bool -> Bool -> Bool
|| FParam lore -> Bool
referencedInForm FParam lore
v
([((FParam lore, SubExp), SubExp)]
keep_ctx, [((FParam lore, SubExp), SubExp)]
discard_ctx) =
(((FParam lore, SubExp), SubExp) -> Bool)
-> [((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
[((FParam lore, SubExp), SubExp)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary ([((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
[((FParam lore, SubExp), SubExp)]))
-> [((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
[((FParam lore, SubExp), SubExp)])
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
ctx [SubExp]
ctx_es
([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
keep_valpart, [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
discard_valpart) =
((PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp)) -> Bool)
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
[(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary (((FParam lore, SubExp), SubExp) -> Bool)
-> ((PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
-> ((FParam lore, SubExp), SubExp))
-> (PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
-> ((FParam lore, SubExp), SubExp)
forall a b. (a, b) -> b
snd) ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
[(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]))
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
[(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
forall a b. (a -> b) -> a -> b
$
[PatElemT (LetDec lore)]
-> [((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
-> [((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
val [SubExp]
val_es
([PatElemT (LetDec lore)]
keep_valpatelems, [((FParam lore, SubExp), SubExp)]
keep_val) = [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([PatElemT (LetDec lore)], [((FParam lore, SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
keep_valpart
([PatElemT (LetDec lore)]
_discard_valpatelems, [((FParam lore, SubExp), SubExp)]
discard_val) = [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([PatElemT (LetDec lore)], [((FParam lore, SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
discard_valpart
([(FParam lore, SubExp)]
ctx', [SubExp]
ctx_es') = [((FParam lore, SubExp), SubExp)]
-> ([(FParam lore, SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((FParam lore, SubExp), SubExp)]
keep_ctx
([(FParam lore, SubExp)]
val', [SubExp]
val_es') = [((FParam lore, SubExp), SubExp)]
-> ([(FParam lore, SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((FParam lore, SubExp), SubExp)]
keep_val
body' :: BodyT lore
body' = BodyT lore
body {bodyResult :: [SubExp]
bodyResult = [SubExp]
ctx_es' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es'}
free_in_keeps :: Names
free_in_keeps = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn [PatElemT (LetDec lore)]
keep_valpatelems
stillUsedContext :: PatElemT (LetDec lore) -> Bool
stillUsedContext PatElemT (LetDec lore)
pat_elem =
PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pat_elem
VName -> Names -> Bool
`nameIn` ( Names
free_in_keeps
Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ((PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (LetDec lore) -> PatElemT (LetDec lore) -> Bool
forall a. Eq a => a -> a -> Bool
/= PatElemT (LetDec lore)
pat_elem) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat)
)
pat' :: Pattern lore
pat' =
Pattern lore
pat
{ patternValueElements :: [PatElemT (LetDec lore)]
patternValueElements = [PatElemT (LetDec lore)]
keep_valpatelems,
patternContextElements :: [PatElemT (LetDec lore)]
patternContextElements =
(PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter PatElemT (LetDec lore) -> Bool
stillUsedContext ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat
}
in if [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val' [(FParam lore, SubExp)] -> [(FParam lore, SubExp)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
then Rule lore
forall lore. Rule lore
Skip
else RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
BodyT lore
body'' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
(([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((FParam lore, SubExp), SubExp)] -> [([VName], ExpT lore)]
forall b lore.
[((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((FParam lore, SubExp), SubExp)]
discard_ctx
(([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((FParam lore, SubExp), SubExp)] -> [([VName], ExpT lore)]
forall b lore.
[((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((FParam lore, SubExp), SubExp)]
discard_val
BodyT lore -> RuleM lore (BodyT lore)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT lore
body'
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat' (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
body''
where
pat_used :: [Bool]
pat_used = (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternValueNames Pattern lore
pat
used_vals :: [VName]
used_vals = ((VName, Bool) -> VName) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Bool) -> VName
forall a b. (a, b) -> a
fst ([(VName, Bool)] -> [VName]) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Bool) -> Bool) -> [(VName, Bool)] -> [(VName, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, Bool) -> Bool
forall a b. (a, b) -> b
snd ([(VName, Bool)] -> [(VName, Bool)])
-> [(VName, Bool)] -> [(VName, Bool)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Bool] -> [(VName, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val) [Bool]
pat_used
usedAfterLoop :: FParam lore -> Bool
usedAfterLoop = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [VName]
used_vals (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName
usedAfterLoopOrInForm :: FParam lore -> Bool
usedAfterLoopOrInForm FParam lore
p =
FParam lore -> Bool
usedAfterLoop FParam lore
p Bool -> Bool -> Bool
|| FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
p VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
patAnnotNames :: Names
patAnnotNames = [FParam lore] -> Names
forall a. FreeIn a => a -> Names
freeIn ([FParam lore] -> Names) -> [FParam lore] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst ([(FParam lore, SubExp)] -> [FParam lore])
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
referencedInPat :: FParam lore -> Bool
referencedInPat = (VName -> Names -> Bool
`nameIn` Names
patAnnotNames) (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName
referencedInForm :: FParam lore -> Bool
referencedInForm = (VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form) (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName
dummyStms :: [((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms = (((FParam lore, SubExp), b) -> ([VName], ExpT lore))
-> [((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
forall a b. (a -> b) -> [a] -> [b]
map ((FParam lore, SubExp), b) -> ([VName], ExpT lore)
forall dec b lore.
DeclTyped dec =>
((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm
dummyStm :: ((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm ((Param dec
p, SubExp
e), b
_)
| TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
p),
Var VName
v <- SubExp
e =
([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
| Bool
otherwise = ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e)
removeRedundantMergeVariables (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ =
Rule lore
forall lore. Rule lore
Skip
hoistLoopInvariantMergeVariables :: BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables :: TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
loopbody) =
case ((VName, (FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp]))
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
-> [(VName, (FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VName, (FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
checkInvariance ([], [(PatElemT (LetDec lore), VName)]
explpat, [], []) ([(VName, (FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp]))
-> [(VName, (FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
forall a b. (a -> b) -> a -> b
$
[VName]
-> [(FParam lore, SubExp)]
-> [SubExp]
-> [(VName, (FParam lore, SubExp), SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) [(FParam lore, SubExp)]
merge [SubExp]
res of
([], [(PatElemT (LetDec lore), VName)]
_, [(FParam lore, SubExp)]
_, [SubExp]
_) ->
Rule lore
forall lore. Rule lore
Skip
([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
res') -> RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let loopbody' :: BodyT lore
loopbody' = BodyT lore
loopbody {bodyResult :: [SubExp]
bodyResult = [SubExp]
res'}
invariantShape :: (a, VName) -> Bool
invariantShape :: (a, VName) -> Bool
invariantShape (a
_, VName
shapemerge) =
VName
shapemerge
VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
merge'
([(PatElemT (LetDec lore), VName)]
implpat', [(PatElemT (LetDec lore), VName)]
implinvariant) = ((PatElemT (LetDec lore), VName) -> Bool)
-> [(PatElemT (LetDec lore), VName)]
-> ([(PatElemT (LetDec lore), VName)],
[(PatElemT (LetDec lore), VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElemT (LetDec lore), VName) -> Bool
forall a. (a, VName) -> Bool
invariantShape [(PatElemT (LetDec lore), VName)]
implpat
implinvariant' :: [(Ident, SubExp)]
implinvariant' = [(PatElemT (LetDec lore) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT (LetDec lore)
p, VName -> SubExp
Var VName
v) | (PatElemT (LetDec lore)
p, VName
v) <- [(PatElemT (LetDec lore), VName)]
implinvariant]
implpat'' :: [PatElemT (LetDec lore)]
implpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
implpat'
explpat'' :: [PatElemT (LetDec lore)]
explpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
explpat'
([(FParam lore, SubExp)]
ctx', [(FParam lore, SubExp)]
val') = Int
-> [(FParam lore, SubExp)]
-> ([(FParam lore, SubExp)], [(FParam lore, SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(PatElemT (LetDec lore), VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PatElemT (LetDec lore), VName)]
implpat') [(FParam lore, SubExp)]
merge'
[(Ident, SubExp)]
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Ident, SubExp)]
invariant [(Ident, SubExp)] -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Ident, SubExp)]
implinvariant') (((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(Ident
v1, SubExp
v2) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Ident -> VName
identName Ident
v1] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v2
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
implpat'' [PatElemT (LetDec lore)]
explpat'') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
loopbody'
where
merge :: [(FParam lore, SubExp)]
merge = [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
res :: [SubExp]
res = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
loopbody
implpat :: [(PatElemT (LetDec lore), VName)]
implpat =
[PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
ctx
explpat :: [(PatElemT (LetDec lore), VName)]
explpat =
[PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val
namesOfMergeParams :: Names
namesOfMergeParams = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
removeFromResult :: (Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (Param dec
mergeParam, b
mergeInit) [(PatElemT dec, VName)]
explpat' =
case ((PatElemT dec, VName) -> Bool)
-> [(PatElemT dec, VName)]
-> ([(PatElemT dec, VName)], [(PatElemT dec, VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam) (VName -> Bool)
-> ((PatElemT dec, VName) -> VName)
-> (PatElemT dec, VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, VName) -> VName
forall a b. (a, b) -> b
snd) [(PatElemT dec, VName)]
explpat' of
([(PatElemT dec
patelem, VName
_)], [(PatElemT dec, VName)]
rest) ->
((Ident, b) -> Maybe (Ident, b)
forall a. a -> Maybe a
Just (PatElemT dec -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT dec
patelem, b
mergeInit), [(PatElemT dec, VName)]
rest)
([(PatElemT dec, VName)]
_, [(PatElemT dec, VName)]
_) ->
(Maybe (Ident, b)
forall a. Maybe a
Nothing, [(PatElemT dec, VName)]
explpat')
checkInvariance :: (VName, (FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
checkInvariance
(VName
pat_name, (FParam lore
mergeParam, SubExp
mergeInit), SubExp
resExp)
([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
resExps)
| Bool -> Bool
not (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType FParam lore
mergeParam))
Bool -> Bool -> Bool
|| TypeBase Shape Uniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType FParam lore
mergeParam) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
Bool
isInvariant,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form =
let (Maybe (Ident, SubExp)
bnd, [(PatElemT (LetDec lore), VName)]
explpat'') =
(FParam lore, SubExp)
-> [(PatElemT (LetDec lore), VName)]
-> (Maybe (Ident, SubExp), [(PatElemT (LetDec lore), VName)])
forall dec dec b.
Typed dec =>
(Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (FParam lore
mergeParam, SubExp
mergeInit) [(PatElemT (LetDec lore), VName)]
explpat'
in ( ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> ((Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)])
-> Maybe (Ident, SubExp)
-> [(Ident, SubExp)]
-> [(Ident, SubExp)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> a
id (:) Maybe (Ident, SubExp)
bnd ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a b. (a -> b) -> a -> b
$ (FParam lore -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent FParam lore
mergeParam, SubExp
mergeInit) (Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> [a] -> [a]
: [(Ident, SubExp)]
invariant,
[(PatElemT (LetDec lore), VName)]
explpat'',
[(FParam lore, SubExp)]
merge',
[SubExp]
resExps
)
where
isInvariant :: Bool
isInvariant
| Var VName
v2 <- SubExp
resExp,
FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v2 =
Names -> FParam lore -> Bool
allExistentialInvariant
([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Ident, SubExp) -> VName) -> [(Ident, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> VName
identName (Ident -> VName)
-> ((Ident, SubExp) -> Ident) -> (Ident, SubExp) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ident, SubExp) -> Ident
forall a b. (a, b) -> a
fst) [(Ident, SubExp)]
invariant)
FParam lore
mergeParam
| SubExp
mergeInit SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp = Bool
True
| Var VName
init_v <- SubExp
mergeInit,
Just (SubExp
p_init, SubExp
p_res) <- VName -> TopDown lore -> Maybe (SubExp, SubExp)
forall lore. VName -> SymbolTable lore -> Maybe (SubExp, SubExp)
ST.lookupLoopParam VName
init_v TopDown lore
vtable,
SubExp
p_init SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp,
SubExp
p_res SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
pat_name =
Bool
True
| Bool
otherwise = Bool
False
checkInvariance
(VName
_pat_name, (FParam lore
mergeParam, SubExp
mergeInit), SubExp
resExp)
([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
resExps) =
([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', (FParam lore
mergeParam, SubExp
mergeInit) (FParam lore, SubExp)
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. a -> [a] -> [a]
: [(FParam lore, SubExp)]
merge', SubExp
resExp SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
resExps)
allExistentialInvariant :: Names -> FParam lore -> Bool
allExistentialInvariant Names
namesOfInvariant FParam lore
mergeParam =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
FParam lore -> Names
forall a. FreeIn a => a -> Names
freeIn FParam lore
mergeParam Names -> Names -> Names
`namesSubtract` VName -> Names
oneName (FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam)
invariantOrNotMergeParam :: Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant VName
name =
Bool -> Bool
not (VName
name VName -> Names -> Bool
`nameIn` Names
namesOfMergeParams)
Bool -> Bool -> Bool
|| VName
name VName -> Names -> Bool
`nameIn` Names
namesOfInvariant
type TypeLookup = SubExp -> Maybe Type
type SimpleRule lore = VarLookup lore -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certificates)
simpleRules :: [SimpleRule lore]
simpleRules :: [SimpleRule lore]
simpleRules =
[ SimpleRule lore
forall lore. SimpleRule lore
simplifyBinOp,
SimpleRule lore
forall lore. SimpleRule lore
simplifyCmpOp,
SimpleRule lore
forall lore. SimpleRule lore
simplifyUnOp,
SimpleRule lore
forall lore. SimpleRule lore
simplifyConvOp,
SimpleRule lore
forall lore. SimpleRule lore
simplifyAssert,
SimpleRule lore
forall lore. SimpleRule lore
copyScratchToScratch,
SimpleRule lore
forall lore. SimpleRule lore
simplifyIdentityReshape,
SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReshape,
SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeScratch,
SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReplicate,
SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeIota,
SimpleRule lore
forall lore. SimpleRule lore
improveReshape
]
simplifyClosedFormLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop :: TopDownRuleDoLoop lore
simplifyClosedFormLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ ([], [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
it SubExp
bound [], BodyT lore
body) =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> IntType
-> SubExp
-> BodyT lore
-> RuleM lore ()
forall lore.
(ASTLore lore, BinderOps lore) =>
Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body lore
-> RuleM lore ()
loopClosedForm Pattern lore
pat [(FParam lore, SubExp)]
val (VName -> Names
oneName VName
i) IntType
it SubExp
bound BodyT lore
body
simplifyClosedFormLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip
simplifyLoopVariables :: (BinderOps lore, Aliased lore) => TopDownRuleDoLoop lore
simplifyLoopVariables :: TopDownRuleDoLoop lore
simplifyLoopVariables TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, form :: LoopForm lore
form@(ForLoop VName
i IntType
it SubExp
num_iters [(LParam lore, VName)]
loop_vars), BodyT lore
body)
| [Maybe (RuleM lore IndexResult)]
simplifiable <- ((LParam lore, VName) -> Maybe (RuleM lore IndexResult))
-> [(LParam lore, VName)] -> [Maybe (RuleM lore IndexResult)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam lore, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable [(LParam lore, VName)]
loop_vars,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Maybe (RuleM lore IndexResult) -> Bool)
-> [Maybe (RuleM lore IndexResult)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe (RuleM lore IndexResult) -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe (RuleM lore IndexResult)]
simplifiable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
([Maybe (LParam lore, VName)]
maybe_loop_vars, [Stms lore]
body_prefix_stms) <-
Scope lore
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) (RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore]))
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall a b. (a -> b) -> a -> b
$
[(Maybe (LParam lore, VName), Stms lore)]
-> ([Maybe (LParam lore, VName)], [Stms lore])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (LParam lore, VName), Stms lore)]
-> ([Maybe (LParam lore, VName)], [Stms lore]))
-> RuleM lore [(Maybe (LParam lore, VName), Stms lore)]
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore))
-> [(LParam lore, VName)]
-> [Maybe (RuleM lore IndexResult)]
-> RuleM lore [(Maybe (LParam lore, VName), Stms lore)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
onLoopVar [(LParam lore, VName)]
loop_vars [Maybe (RuleM lore IndexResult)]
simplifiable
if [Maybe (LParam lore, VName)]
maybe_loop_vars [Maybe (LParam lore, VName)]
-> [Maybe (LParam lore, VName)] -> Bool
forall a. Eq a => a -> a -> Bool
== ((LParam lore, VName) -> Maybe (LParam lore, VName))
-> [(LParam lore, VName)] -> [Maybe (LParam lore, VName)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just [(LParam lore, VName)]
loop_vars
then RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
else do
BodyT lore
body' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [Stms lore] -> Stms lore
forall a. Monoid a => [a] -> a
mconcat [Stms lore]
body_prefix_stms
[SubExp] -> RuleM lore (BodyT lore)
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> RuleM lore (BodyT lore))
-> RuleM lore [SubExp] -> RuleM lore (BodyT lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body (Lore (RuleM lore)) -> RuleM lore [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT lore
Body (Lore (RuleM lore))
body
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop
[(FParam lore, SubExp)]
ctx
[(FParam lore, SubExp)]
val
(VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
num_iters ([(LParam lore, VName)] -> LoopForm lore)
-> [(LParam lore, VName)] -> LoopForm lore
forall a b. (a -> b) -> a -> b
$ [Maybe (LParam lore, VName)] -> [(LParam lore, VName)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (LParam lore, VName)]
maybe_loop_vars)
BodyT lore
body'
where
seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v)
| VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
| Bool
otherwise = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
consumed_in_body :: Names
consumed_in_body = BodyT lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody BodyT lore
body
vtable' :: TopDown lore
vtable' = Scope lore -> TopDown lore
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) TopDown lore -> TopDown lore -> TopDown lore
forall a. Semigroup a => a -> a -> a
<> TopDown lore
vtable
checkIfSimplifiable :: (LParam lore, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable (LParam lore
p, VName
arr) =
SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing
TopDown lore
SymbolTable (Lore (RuleM lore))
vtable'
SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType
VName
arr
(SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam lore
p) [])
(Bool -> Maybe (RuleM lore IndexResult))
-> Bool -> Maybe (RuleM lore IndexResult)
forall a b. (a -> b) -> a -> b
$ LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p VName -> Names -> Bool
`nameIn` Names
consumed_in_body
onLoopVar :: (LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
onLoopVar (LParam lore
p, VName
arr) Maybe (RuleM lore IndexResult)
Nothing =
(Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p, VName
arr), Stms lore
forall a. Monoid a => a
mempty)
onLoopVar (LParam lore
p, VName
arr) (Just RuleM lore IndexResult
m) = do
(IndexResult
x, Stms lore
x_stms) <- RuleM lore IndexResult
-> RuleM lore (IndexResult, Stms (Lore (RuleM lore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms RuleM lore IndexResult
m
case IndexResult
x of
IndexResult Certificates
cs VName
arr' Slice SubExp
slice
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
i VName -> Names -> Bool
`nameIn`) (Names -> Bool) -> (Stm lore -> Names) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn) Stms lore
x_stms,
DimFix (Var VName
j) : Slice SubExp
slice' <- Slice SubExp
slice,
VName
j VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
i VName -> Names -> Bool
`nameIn` Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice -> do
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
x_stms
SubExp
w <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> SubExp)
-> RuleM lore (TypeBase Shape NoUniqueness) -> RuleM lore SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr'
VName
for_in_partial <-
Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"for_in_partial" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
slice'
(Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p, VName
for_in_partial), Stms lore
forall a. Monoid a => a
mempty)
SubExpResult Certificates
cs SubExp
se
| (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. ExpT lore -> Bool
notIndex (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) Stms lore
x_stms -> do
Stms lore
x_stms' <- RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore))))
-> RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
x_stms
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
(Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (LParam lore, VName)
forall a. Maybe a
Nothing, Stms lore
x_stms')
IndexResult
_ -> (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p, VName
arr), Stms lore
forall a. Monoid a => a
mempty)
notIndex :: ExpT lore -> Bool
notIndex (BasicOp Index {}) = Bool
False
notIndex ExpT lore
_ = Bool
True
simplifyLoopVariables TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip
narrowLoopType :: (BinderOps lore) => TopDownRuleDoLoop lore
narrowLoopType :: TopDownRuleDoLoop lore
narrowLoopType TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
Int64 SubExp
n [], BodyT lore
body)
| Just (SubExp
n', IntType
it', Certificates
cs) <- Maybe (SubExp, IntType, Certificates)
smallerType =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
VName
i' <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM lore VName) -> String -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
i
let form' :: LoopForm lore
form' = VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i' IntType
it' SubExp
n' []
BodyT lore
body' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$
LoopForm lore -> RuleM lore (BodyT lore) -> RuleM lore (BodyT lore)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm lore
form' (RuleM lore (BodyT lore) -> RuleM lore (BodyT lore))
-> RuleM lore (BodyT lore) -> RuleM lore (BodyT lore)
forall a b. (a -> b) -> a -> b
$ do
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
i] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
it' IntType
Int64) (VName -> SubExp
Var VName
i')
BodyT lore -> RuleM lore (BodyT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure BodyT lore
body
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form' BodyT lore
body'
where
smallerType :: Maybe (SubExp, IntType, Certificates)
smallerType
| Var VName
n' <- SubExp
n,
Just (ConvOp (SExt IntType
it' IntType
_) SubExp
n'', Certificates
cs) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
n' TopDown lore
vtable =
(SubExp, IntType, Certificates)
-> Maybe (SubExp, IntType, Certificates)
forall a. a -> Maybe a
Just (SubExp
n'', IntType
it', Certificates
cs)
| Constant (IntValue (Int64Value Int64
n')) <- SubExp
n,
Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
n' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Int32 -> Integer
forall a. Integral a => a -> Integer
toInteger (Int32
forall a. Bounded a => a
maxBound :: Int32) =
(SubExp, IntType, Certificates)
-> Maybe (SubExp, IntType, Certificates)
forall a. a -> Maybe a
Just (IntType -> Integer -> SubExp
intConst IntType
Int32 (Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
n'), IntType
Int32, Certificates
forall a. Monoid a => a
mempty)
| Bool
otherwise =
Maybe (SubExp, IntType, Certificates)
forall a. Maybe a
Nothing
narrowLoopType TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip
unroll ::
BinderOps lore =>
Integer ->
[(FParam lore, SubExp)] ->
(VName, IntType, Integer) ->
[(LParam lore, VName)] ->
Body lore ->
RuleM lore [SubExp]
unroll :: Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll Integer
n [(FParam lore, SubExp)]
merge (VName
iv, IntType
it, Integer
i) [(LParam lore, VName)]
loop_vars Body lore
body
| Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
n =
[SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp] -> RuleM lore [SubExp])
-> [SubExp] -> RuleM lore [SubExp]
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam lore, SubExp)]
merge
| Bool
otherwise = do
Body lore
iter_body <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
[(FParam lore, SubExp)]
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam lore, SubExp)]
merge (((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(FParam lore
mergevar, SubExp
mergeinit) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergevar] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
mergeinit
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
iv] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
it Integer
i
[(LParam lore, VName)]
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(LParam lore, VName)]
loop_vars (((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ())
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(LParam lore
p, VName
arr) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam lore
p) []
Body lore -> RuleM lore (Body lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body lore
body
Body lore
iter_body' <- Body lore -> RuleM lore (Body lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody Body lore
iter_body
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
iter_body'
let merge' :: [(FParam lore, SubExp)]
merge' = [FParam lore] -> [SubExp] -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst [(FParam lore, SubExp)]
merge) ([SubExp] -> [(FParam lore, SubExp)])
-> [SubExp] -> [(FParam lore, SubExp)]
forall a b. (a -> b) -> a -> b
$ Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
iter_body'
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll Integer
n [(FParam lore, SubExp)]
merge' (VName
iv, IntType
it, Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) [(LParam lore, VName)]
loop_vars Body lore
body
simplifyKnownIterationLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop :: TopDownRuleDoLoop lore
simplifyKnownIterationLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
it (Constant PrimValue
iters) [(LParam lore, VName)]
loop_vars, BodyT lore
body)
| IntValue IntValue
n <- PrimValue
iters,
IntValue -> Bool
zeroIshInt IntValue
n Bool -> Bool -> Bool
|| IntValue -> Bool
oneIshInt IntValue
n Bool -> Bool -> Bool
|| Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec lore)
aux = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
res <- Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> BodyT lore
-> RuleM lore [SubExp]
forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll (IntValue -> Integer
forall int. Integral int => IntValue -> int
valueIntegral IntValue
n) ([(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val) (VName
i, IntType
it, Integer
0) [(LParam lore, VName)]
loop_vars BodyT lore
body
[(VName, SubExp)]
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) [SubExp]
res) (((VName, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ =
Rule lore
forall lore. Rule lore
Skip
removeUnnecessaryCopy :: BinderOps lore => BottomUpRuleBasicOp lore
removeUnnecessaryCopy :: BottomUpRuleBasicOp lore
removeUnnecessaryCopy (SymbolTable lore
vtable, UsageTable
used) (Pattern [] [PatElemT (LetDec lore)
d]) StmAux (ExpDec lore)
_ (Copy VName
v)
| Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
(Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used) Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
where
consumable :: Bool
consumable = case VName -> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName (NameInfo lore) -> Maybe (NameInfo lore))
-> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (NameInfo lore)
forall lore. SymbolTable lore -> Scope lore
ST.toScope SymbolTable lore
vtable of
Just (FParamName FParamInfo lore
info) -> TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> TypeBase Shape Uniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ FParamInfo lore -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf FParamInfo lore
info
Maybe (NameInfo lore)
_ -> Bool
False
removeUnnecessaryCopy (SymbolTable lore, UsageTable)
_ PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip
simplifyCmpOp :: SimpleRule lore
simplifyCmpOp :: SimpleRule lore
simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpOp
cmp SubExp
e1 SubExp
e2)
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$
Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Bool -> PrimValue
forall a b. (a -> b) -> a -> b
$
case CmpOp
cmp of
CmpEq {} -> Bool
True
CmpSlt {} -> Bool
False
CmpUlt {} -> Bool
False
CmpSle {} -> Bool
True
CmpUle {} -> Bool
True
FCmpLt {} -> Bool
False
FCmpLe {} -> Bool
True
CmpOp
CmpLlt -> Bool
False
CmpOp
CmpLle -> Bool
True
simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpOp
cmp (Constant PrimValue
v1) (Constant PrimValue
v2)) =
PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> (Bool -> PrimValue) -> Bool -> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> PrimValue
BoolValue (Bool -> Maybe (BasicOp, Certificates))
-> Maybe Bool -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
cmp PrimValue
v1 PrimValue
v2
simplifyCmpOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpEq {} (Constant (IntValue IntValue
x)) (Var VName
v))
| Just (BasicOp (ConvOp BToI {} SubExp
b), Certificates
cs) <- VarLookup lore
look VName
v =
case IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
x :: Int of
Int
1 -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
b, Certificates
cs)
Int
0 -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
b, Certificates
cs)
Int
_ -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False)), Certificates
cs)
simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp :: SimpleRule lore
simplifyBinOp :: SimpleRule lore
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
op (Constant PrimValue
v1) (Constant PrimValue
v2))
| Just PrimValue
res <- BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op PrimValue
v1 PrimValue
v2 =
PrimValue -> Maybe (BasicOp, Certificates)
constRes PrimValue
res
simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Add {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| Var VName
v2 <- SubExp
e2,
Just (BasicOp (BinOp Sub {} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_a, Certificates
cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FAdd {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Sub {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| Var VName
v1 <- SubExp
e1,
Just (BasicOp (BinOp Add {} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
SubExp
e1_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_b, Certificates
cs)
| Var VName
v1 <- SubExp
e1,
Just (BasicOp (BinOp Add {} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
SubExp
e1_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_a, Certificates
cs)
| Var VName
v2 <- SubExp
e2,
Just (BasicOp (BinOp Add {} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
SubExp
e2_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_b, Certificates
cs)
| Var VName
v2 <- SubExp
e1,
Just (BasicOp (BinOp Add {} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_a, Certificates
cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FSub {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Mul {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FMul {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SMod IntType
t Safety
_) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
| Var VName
v1 <- SubExp
e1,
Just (BasicOp (BinOp SMod {} SubExp
_ SubExp
e4), Certificates
v1_cs) <- VarLookup lore
look VName
v1,
SubExp
e4 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1, Certificates
v1_cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SDiv {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SDivUp {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FDiv {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SRem IntType
t Safety
_) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
1 :: Int)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SQuot {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (FPow FloatType
t) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> SubExp
floatConst FloatType
t Double
1
| SubExp -> Bool
isCt0 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (Shl IntType
t) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp AShr {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (And IntType
t) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Or {} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (Xor IntType
t) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
LogAnd SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
| SubExp -> Bool
isCt0 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
| SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| Var VName
v <- SubExp
e1,
Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)
| Var VName
v <- SubExp
e2,
Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)
simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
LogOr SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt1 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
| SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
| Var VName
v <- SubExp
e1,
Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)
| Var VName
v <- SubExp
e2,
Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)
simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SMax IntType
it) SubExp
e1 SubExp
e2)
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| Var VName
v1 <- SubExp
e1,
Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
SubExp
e1_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_2 SubExp
e2, Certificates
v1_cs)
| Var VName
v1 <- SubExp
e1,
Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
SubExp
e1_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_1 SubExp
e2, Certificates
v1_cs)
| Var VName
v2 <- SubExp
e2,
Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
SubExp
e2_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_2 SubExp
e1, Certificates
v2_cs)
| Var VName
v2 <- SubExp
e2,
Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
SubExp
e2_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_1 SubExp
e1, Certificates
v2_cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
constRes :: PrimValue -> Maybe (BasicOp, Certificates)
constRes :: PrimValue -> Maybe (BasicOp, Certificates)
constRes = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just ((BasicOp, Certificates) -> Maybe (BasicOp, Certificates))
-> (PrimValue -> (BasicOp, Certificates))
-> PrimValue
-> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certificates))
-> (PrimValue -> BasicOp) -> PrimValue -> (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (PrimValue -> SubExp) -> PrimValue -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimValue -> SubExp
Constant
subExpRes :: SubExp -> Maybe (BasicOp, Certificates)
subExpRes :: SubExp -> Maybe (BasicOp, Certificates)
subExpRes = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just ((BasicOp, Certificates) -> Maybe (BasicOp, Certificates))
-> (SubExp -> (BasicOp, Certificates))
-> SubExp
-> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certificates))
-> (SubExp -> BasicOp) -> SubExp -> (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp
simplifyUnOp :: SimpleRule lore
simplifyUnOp :: SimpleRule lore
simplifyUnOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (UnOp UnOp
op (Constant PrimValue
v)) =
PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op PrimValue
v
simplifyUnOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (UnOp UnOp
Not (Var VName
v))
| Just (BasicOp (UnOp UnOp
Not SubExp
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
v2, Certificates
v_cs)
simplifyUnOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyConvOp :: SimpleRule lore
simplifyConvOp :: SimpleRule lore
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp ConvOp
op (Constant PrimValue
v)) =
PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op PrimValue
v
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp ConvOp
op SubExp
se)
| (PrimType
from, PrimType
to) <- ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op,
PrimType
from PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
to =
SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
se
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (SExt IntType
t2 IntType
t1) (Var VName
v))
| Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (ZExt IntType
t2 IntType
t1) (Var VName
v))
| Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
ZExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (SIToFP IntType
t2 FloatType
t1) (Var VName
v))
| Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (UIToFP IntType
t2 FloatType
t1) (Var VName
v))
| Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (FPConv FloatType
t2 FloatType
t1) (Var VName
v))
| Just (BasicOp (ConvOp (FPConv FloatType
t3 FloatType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
FloatType
t2 FloatType -> FloatType -> Bool
forall a. Ord a => a -> a -> Bool
>= FloatType
t3 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (FloatType -> FloatType -> ConvOp
FPConv FloatType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyAssert :: SimpleRule lore
simplifyAssert :: SimpleRule lore
simplifyAssert VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Assert (Constant (BoolValue Bool
True)) ErrorMsg SubExp
_ (SrcLoc, [SrcLoc])
_) =
PrimValue -> Maybe (BasicOp, Certificates)
constRes PrimValue
Checked
simplifyAssert VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
constantFoldPrimFun :: BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun :: TopDownRuleGeneric lore
constantFoldPrimFun TopDown lore
_ (Let Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_))
| Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- String
-> Map
String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
result
where
isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown lore
_ Stm lore
_ = Rule lore
forall lore. Rule lore
Skip
simplifyIndex :: BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex :: BottomUpRuleBasicOp lore
simplifyIndex (SymbolTable lore
vtable, UsageTable
used) pat :: Pattern lore
pat@(Pattern [] [PatElemT (LetDec lore)
pe]) (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Index VName
idd Slice SubExp
inds)
| Just RuleM lore IndexResult
m <- SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable lore
SymbolTable (Lore (RuleM lore))
vtable SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
idd Slice SubExp
inds Bool
consumed = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
IndexResult
res <- RuleM lore IndexResult
m
Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
SubExpResult Certificates
cs' SubExp
se ->
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs') (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
IndexResult Certificates
extra_cs VName
idd' Slice SubExp
inds' ->
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
extra_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
where
consumed :: Bool
consumed = PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v) = VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v SymbolTable lore
vtable
seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip
data IndexResult
= IndexResult Certificates VName (Slice SubExp)
| SubExpResult Certificates SubExp
simplifyIndexing ::
MonadBinder m =>
ST.SymbolTable (Lore m) ->
TypeLookup ->
VName ->
Slice SubExp ->
Bool ->
Maybe (m IndexResult)
simplifyIndexing :: SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable (Lore m)
vtable SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
idd Slice SubExp
inds Bool
consuming =
case VName -> Maybe (BasicOp, Certificates)
defOf VName
idd of
Maybe (BasicOp, Certificates)
_
| Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
idd),
Slice SubExp
inds Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
t [] ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
forall a. Monoid a => a
mempty (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd
| Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining PrimExp VName
e,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp" PrimExp VName
e
| Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
Just (ST.IndexedArray Certificates
cs VName
arr [TPrimExp Int64 VName]
inds'') <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
(TPrimExp Int64 VName -> Bool) -> [TPrimExp Int64 VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining (PrimExp VName -> Bool)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
inds'',
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult)
-> ([SubExp] -> Slice SubExp) -> [SubExp] -> IndexResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix
([SubExp] -> IndexResult) -> m [SubExp] -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TPrimExp Int64 VName -> m SubExp)
-> [TPrimExp Int64 VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> TPrimExp Int64 VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp") [TPrimExp Int64 VName]
inds''
Maybe (BasicOp, Certificates)
Nothing -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
Just (SubExp (Var VName
v), Certificates
cs) -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v Slice SubExp
inds
Just (Iota SubExp
_ SubExp
x SubExp
s IntType
to_it, Certificates
cs)
| [DimFix SubExp
ii] <- Slice SubExp
inds,
Just (Prim (IntType IntType
from_it)) <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType SubExp
ii ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
let mul :: PrimExp VName -> PrimExp VName -> PrimExp VName
mul = BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName)
-> BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
to_it Overflow
OverflowWrap
add :: PrimExp VName -> PrimExp VName -> PrimExp VName
add = BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName)
-> BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
to_it Overflow
OverflowWrap
in (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$
String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_iota" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
( IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
to_it (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
from_it) SubExp
ii)
PrimExp VName -> PrimExp VName -> PrimExp VName
`mul` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
)
PrimExp VName -> PrimExp VName -> PrimExp VName
`add` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
| [DimSlice SubExp
i_offset SubExp
i_n SubExp
i_stride] <- Slice SubExp
inds ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
SubExp
i_offset' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_offset
SubExp
i_stride' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_stride
let mul :: PrimExp VName -> PrimExp VName -> PrimExp VName
mul = BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName)
-> BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
to_it Overflow
OverflowWrap
add :: PrimExp VName -> PrimExp VName -> PrimExp VName
add = BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName)
-> BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
to_it Overflow
OverflowWrap
SubExp
i_offset'' <-
String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"iota_offset" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
( PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
PrimExp VName -> PrimExp VName -> PrimExp VName
`mul` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
)
PrimExp VName -> PrimExp VName -> PrimExp VName
`add` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
i_offset'
SubExp
i_stride'' <-
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"iota_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowWrap) SubExp
s SubExp
i_stride'
(SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_iota" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
i_n SubExp
i_offset'' SubExp
i_stride'' IntType
to_it
Just (Rotate [SubExp]
offsets VName
a, Certificates
cs)
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp -> Bool)
-> [SubExp] -> Slice SubExp -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> DimIndex SubExp -> Bool
forall d. SubExp -> DimIndex d -> Bool
rotateAndSlice [SubExp]
offsets Slice SubExp
inds -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> m (TypeBase Shape NoUniqueness) -> m [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
a
let adjustI :: SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d = do
SubExp
i_p_o <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"i_p_o" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
i SubExp
o
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rot_i" (BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SMod IntType
Int64 Safety
Unsafe) SubExp
i_p_o SubExp
d)
adjust :: (DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (DimFix SubExp
i, SubExp
o, SubExp
d) =
SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d
adjust (DimSlice SubExp
i SubExp
n SubExp
s, SubExp
o, SubExp
d) =
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (SubExp -> SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d f (SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
n f (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
s
Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
a (Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp))
-> [(DimIndex SubExp, SubExp, SubExp)] -> m (Slice SubExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp)
forall (f :: * -> *).
MonadBinder f =>
(DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (Slice SubExp
-> [SubExp] -> [SubExp] -> [(DimIndex SubExp, SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Slice SubExp
inds [SubExp]
offsets [SubExp]
dims)
where
rotateAndSlice :: SubExp -> DimIndex d -> Bool
rotateAndSlice SubExp
r DimSlice {} = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
r
rotateAndSlice SubExp
_ DimIndex d
_ = Bool
False
Just (Index VName
aa Slice SubExp
ais, Certificates
cs) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
aa
(Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
ais) (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
inds))
Just (Replicate (Shape [SubExp
_]) (Var VName
vv), Certificates
cs)
| [DimFix {}] <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vv
| DimFix {} : Slice SubExp
is' <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
vv Slice SubExp
is'
Just (Replicate (Shape [SubExp
_]) val :: SubExp
val@(Constant PrimValue
_), Certificates
cs)
| [DimFix {}] <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
val
Just (Replicate (Shape [SubExp]
ds) SubExp
v, Certificates
cs)
| (Slice SubExp
ds_inds, Slice SubExp
rest_inds) <- Int -> Slice SubExp -> (Slice SubExp, Slice SubExp)
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) Slice SubExp
inds,
([SubExp]
ds', Slice SubExp
ds_inds') <- [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp))
-> [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp))
-> Slice SubExp -> [(SubExp, DimIndex SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index Slice SubExp
ds_inds,
[SubExp]
ds' [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [SubExp]
ds ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
VName
arr <- String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"smaller_replicate" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds') SubExp
v
IndexResult -> m IndexResult
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ds_inds' Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ Slice SubExp
rest_inds
where
index :: DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index DimFix {} = Maybe (SubExp, DimIndex SubExp)
forall a. Maybe a
Nothing
index (DimSlice SubExp
_ SubExp
n SubExp
s) = (SubExp, DimIndex SubExp) -> Maybe (SubExp, DimIndex SubExp)
forall a. a -> Maybe a
Just (SubExp
n, SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) SubExp
n SubExp
s)
Just (Rearrange [Int]
perm VName
src, Certificates
cs)
| [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((DimIndex SubExp -> Bool) -> Slice SubExp -> Slice SubExp
forall a. (a -> Bool) -> [a] -> [a]
takeWhile DimIndex SubExp -> Bool
forall d. DimIndex d -> Bool
isIndex Slice SubExp
inds) ->
let inds' :: Slice SubExp
inds' = [Int] -> Slice SubExp -> Slice SubExp
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) Slice SubExp
inds
in m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds'
where
isIndex :: DimIndex d -> Bool
isIndex DimFix {} = Bool
True
isIndex DimIndex d
_ = Bool
False
Just (Copy VName
src, Certificates
cs)
| Just [SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims,
Bool -> Bool
not Bool
consuming,
VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
ST.available VName
src SymbolTable (Lore m)
vtable ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
Just (Reshape ShapeChange SubExp
newshape VName
src, Certificates
cs)
| Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
Just [SubExp]
olddims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
[Bool]
changed_dims <- (SubExp -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(/=) [SubExp]
newdims [SubExp]
olddims,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
drop (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds) [Bool]
changed_dims ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
| Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
Just [SubExp]
olddims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds,
[SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
olddims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
newdims ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
Just (Reshape [DimChange SubExp
_] VName
v2, Certificates
cs)
| Just [SubExp
_] <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
v2) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds
Just (Concat Int
d VName
x [VName]
xs SubExp
_, Certificates
cs)
|
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isConcat ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs,
Just (Slice SubExp
ibef, DimFix SubExp
i, Slice SubExp
iaft) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d Slice SubExp
inds,
Just (Prim PrimType
res_t) <-
(TypeBase Shape NoUniqueness
-> [SubExp] -> TypeBase Shape NoUniqueness
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
inds)
(TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> Maybe (TypeBase Shape NoUniqueness)
-> Maybe (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> SymbolTable (Lore m) -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
x SymbolTable (Lore m)
vtable -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
SubExp
x_len <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d (TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
[SubExp]
xs_lens <- (VName -> m SubExp) -> [VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d) (m (TypeBase Shape NoUniqueness) -> m SubExp)
-> (VName -> m (TypeBase Shape NoUniqueness)) -> VName -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType) [VName]
xs
let add :: SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
n SubExp
m = do
SubExp
added <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_add" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
n SubExp
m
(SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
added, SubExp
n)
(SubExp
_, [SubExp]
starts) <- (SubExp -> SubExp -> m (SubExp, SubExp))
-> SubExp -> [SubExp] -> m (SubExp, [SubExp])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
x_len [SubExp]
xs_lens
let xs_and_starts :: [(VName, SubExp)]
xs_and_starts = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [SubExp]
starts
let mkBranch :: [(VName, SubExp)] -> m SubExp
mkBranch [] =
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
mkBranch ((VName
x', SubExp
start) : [(VName, SubExp)]
xs_and_starts') = do
SubExp
cmp <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_cmp" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
start SubExp
i
(SubExp
thisres, Stms (Lore m)
thisbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ do
SubExp
i' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_i" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowWrap) SubExp
i SubExp
start
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i' DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
BodyT (Lore m)
thisbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
thisbnds [SubExp
thisres]
(SubExp
altres, Stms (Lore m)
altbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts'
BodyT (Lore m)
altbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
altbnds [SubExp
altres]
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_branch" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cmp BodyT (Lore m)
thisbody BodyT (Lore m)
altbody (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
[BranchType (Lore m)] -> IfSort -> IfDec (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType (Lore m)
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
res_t] IfSort
IfNormal
Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts
Just (ArrayLit [SubExp]
ses TypeBase Shape NoUniqueness
_, Certificates
cs)
| DimFix (Constant (IntValue (Int64Value Int64
i))) : Slice SubExp
inds' <- Slice SubExp
inds,
Just SubExp
se <- Int64 -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int64
i [SubExp]
ses ->
case Slice SubExp
inds' of
[] -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
se
Slice SubExp
_ | Var VName
v2 <- SubExp
se -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds'
Slice SubExp
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
Maybe (BasicOp, Certificates)
_
| Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd,
SubExp -> Bool
isCt1 (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
t,
DimFix SubExp
i : Slice SubExp
inds' <- Slice SubExp
inds,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
i ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$
Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
forall a. Monoid a => a
mempty VName
idd (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$
SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
inds'
Maybe (BasicOp, Certificates)
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
where
defOf :: VName -> Maybe (BasicOp, Certificates)
defOf VName
v = do
(BasicOp BasicOp
op, Certificates
def_cs) <- VName -> SymbolTable (Lore m) -> Maybe (Exp (Lore m), Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable (Lore m)
vtable
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a. Monad m => a -> m a
return (BasicOp
op, Certificates
def_cs)
worthInlining :: PrimExp v -> Bool
worthInlining PrimExp v
e
| Int -> PrimExp v -> Bool
forall v. Int -> PrimExp v -> Bool
primExpSizeAtLeast Int
20 PrimExp v
e = Bool
False
| Bool
otherwise = PrimExp v -> Bool
forall v. PrimExp v -> Bool
worthInlining' PrimExp v
e
worthInlining' :: PrimExp v -> Bool
worthInlining' (BinOpExp Pow {} PrimExp v
_ PrimExp v
_) = Bool
False
worthInlining' (BinOpExp FPow {} PrimExp v
_ PrimExp v
_) = Bool
False
worthInlining' (BinOpExp BinOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
worthInlining' (CmpOpExp CmpOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
worthInlining' (ConvOpExp ConvOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
worthInlining' (UnOpExp UnOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
worthInlining' FunExp {} = Bool
False
worthInlining' PrimExp v
_ = Bool
True
isConcat :: VName -> Bool
isConcat VName
v
| Just (Concat {}, Certificates
_) <- VName -> Maybe (BasicOp, Certificates)
defOf VName
v =
Bool
True
| Bool
otherwise =
Bool
False
data ConcatArg
= ArgArrayLit [SubExp]
| ArgReplicate [SubExp] SubExp
| ArgVar VName
toConcatArg :: ST.SymbolTable lore -> VName -> (ConcatArg, Certificates)
toConcatArg :: SymbolTable lore -> VName -> (ConcatArg, Certificates)
toConcatArg SymbolTable lore
vtable VName
v =
case VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable of
Just (ArrayLit [SubExp]
ses TypeBase Shape NoUniqueness
_, Certificates
cs) ->
([SubExp] -> ConcatArg
ArgArrayLit [SubExp]
ses, Certificates
cs)
Just (Replicate Shape
shape SubExp
se, Certificates
cs) ->
([SubExp] -> SubExp -> ConcatArg
ArgReplicate [Int -> Shape -> SubExp
shapeSize Int
0 Shape
shape] SubExp
se, Certificates
cs)
Maybe (BasicOp, Certificates)
_ ->
(VName -> ConcatArg
ArgVar VName
v, Certificates
forall a. Monoid a => a
mempty)
fromConcatArg ::
MonadBinder m =>
Type ->
(ConcatArg, Certificates) ->
m VName
fromConcatArg :: TypeBase Shape NoUniqueness -> (ConcatArg, Certificates) -> m VName
fromConcatArg TypeBase Shape NoUniqueness
t (ArgArrayLit [SubExp]
ses, Certificates
cs) =
Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"concat_lit" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit [SubExp]
ses (TypeBase Shape NoUniqueness -> BasicOp)
-> TypeBase Shape NoUniqueness -> BasicOp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType TypeBase Shape NoUniqueness
t
fromConcatArg TypeBase Shape NoUniqueness
elem_type (ArgReplicate [SubExp]
ws SubExp
se, Certificates
cs) = do
let elem_shape :: Shape
elem_shape = TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
elem_type
Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ do
SubExp
w <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"concat_rep_w" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> m (Exp (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ws)
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"concat_rep" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Int -> Shape -> SubExp -> Shape
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
0 Shape
elem_shape SubExp
w) SubExp
se
fromConcatArg TypeBase Shape NoUniqueness
_ (ArgVar VName
v, Certificates
_) =
VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
fuseConcatArg ::
[(ConcatArg, Certificates)] ->
(ConcatArg, Certificates) ->
[(ConcatArg, Certificates)]
fuseConcatArg :: [(ConcatArg, Certificates)]
-> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)]
fuseConcatArg [(ConcatArg, Certificates)]
xs (ArgArrayLit [], Certificates
_) =
[(ConcatArg, Certificates)]
xs
fuseConcatArg [(ConcatArg, Certificates)]
xs (ArgReplicate [SubExp
w] SubExp
se, Certificates
cs)
| SubExp -> Bool
isCt0 SubExp
w =
[(ConcatArg, Certificates)]
xs
| SubExp -> Bool
isCt1 SubExp
w =
[(ConcatArg, Certificates)]
-> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)]
fuseConcatArg [(ConcatArg, Certificates)]
xs ([SubExp] -> ConcatArg
ArgArrayLit [SubExp
se], Certificates
cs)
fuseConcatArg ((ArgArrayLit [SubExp]
x_ses, Certificates
x_cs) : [(ConcatArg, Certificates)]
xs) (ArgArrayLit [SubExp]
y_ses, Certificates
y_cs) =
([SubExp] -> ConcatArg
ArgArrayLit ([SubExp]
x_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ses), Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
y_cs) (ConcatArg, Certificates)
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certificates)]
xs
fuseConcatArg ((ArgReplicate [SubExp]
x_ws SubExp
x_se, Certificates
x_cs) : [(ConcatArg, Certificates)]
xs) (ArgReplicate [SubExp]
y_ws SubExp
y_se, Certificates
y_cs)
| SubExp
x_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y_se =
([SubExp] -> SubExp -> ConcatArg
ArgReplicate ([SubExp]
x_ws [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ws) SubExp
x_se, Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
y_cs) (ConcatArg, Certificates)
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certificates)]
xs
fuseConcatArg [(ConcatArg, Certificates)]
xs (ConcatArg, Certificates)
y =
(ConcatArg, Certificates)
y (ConcatArg, Certificates)
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certificates)]
xs
simplifyConcat :: BinderOps lore => BottomUpRuleBasicOp lore
simplifyConcat :: BottomUpRuleBasicOp lore
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat StmAux (ExpDec lore)
_ (Concat Int
i VName
x [VName]
xs SubExp
new_d)
| Just Int
r <- TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (TypeBase Shape NoUniqueness -> Int)
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
x SymbolTable lore
vtable,
let perm :: [Int]
perm = [Int
i] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1],
Just (VName
x', Certificates
x_cs) <- [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm VName
x,
Just ([VName]
xs', [Certificates]
xs_cs) <- [(VName, Certificates)] -> ([VName], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Certificates)] -> ([VName], [Certificates]))
-> Maybe [(VName, Certificates)] -> Maybe ([VName], [Certificates])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe (VName, Certificates))
-> [VName] -> Maybe [(VName, Certificates)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm) [VName]
xs = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
VName
concat_rearrange <-
Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> [Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"concat_rearrange" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
0 VName
x' [VName]
xs' SubExp
new_d
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
concat_rearrange
where
transposedBy :: [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm1 VName
v =
case VName -> SymbolTable lore -> Maybe (ExpT lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable lore
vtable of
Just (BasicOp (Rearrange [Int]
perm2 VName
v'), Certificates
vcs)
| [Int]
perm1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm2 -> (VName, Certificates) -> Maybe (VName, Certificates)
forall a. a -> Maybe a
Just (VName
v', Certificates
vcs)
Maybe (ExpT lore, Certificates)
_ -> Maybe (VName, Certificates)
forall a. Maybe a
Nothing
simplifyConcat (SymbolTable lore, UsageTable)
_ Pattern lore
pat StmAux (ExpDec lore)
aux (Concat Int
_ VName
x [] SubExp
_) =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
x
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Concat Int
i VName
x [VName]
xs SubExp
new_d)
| VName
x' VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
|| [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs' [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= [VName]
xs =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> [Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
i VName
x' ([VName]
zs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs') SubExp
new_d
where
(VName
x' : [VName]
zs, Certificates
x_cs) = VName -> ([VName], Certificates)
isConcat VName
x
([[VName]]
xs', [Certificates]
xs_cs) = [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([VName], Certificates)] -> ([[VName]], [Certificates]))
-> [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. (a -> b) -> a -> b
$ (VName -> ([VName], Certificates))
-> [VName] -> [([VName], Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> ([VName], Certificates)
isConcat [VName]
xs
isConcat :: VName -> ([VName], Certificates)
isConcat VName
v = case VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable of
Just (Concat Int
j VName
y [VName]
ys SubExp
_, Certificates
v_cs) | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i -> (VName
y VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certificates
v_cs)
Maybe (BasicOp, Certificates)
_ -> ([VName
v], Certificates
forall a. Monoid a => a
mempty)
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat StmAux (ExpDec lore)
aux (Concat Int
0 VName
x [VName]
xs SubExp
outer_w)
|
(ConcatArg, Certificates)
y : [(ConcatArg, Certificates)]
ys <-
[(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. [a] -> [a]
reverse ([(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)])
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> a -> b
$
([(ConcatArg, Certificates)]
-> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)])
-> [(ConcatArg, Certificates)]
-> [(ConcatArg, Certificates)]
-> [(ConcatArg, Certificates)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [(ConcatArg, Certificates)]
-> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)]
fuseConcatArg [(ConcatArg, Certificates)]
forall a. Monoid a => a
mempty ([(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)])
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> a -> b
$
(VName -> (ConcatArg, Certificates))
-> [VName] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map (SymbolTable lore -> VName -> (ConcatArg, Certificates)
forall lore. SymbolTable lore -> VName -> (ConcatArg, Certificates)
toConcatArg SymbolTable lore
vtable) ([VName] -> [(ConcatArg, Certificates)])
-> [VName] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> a -> b
$ VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs,
[VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [(ConcatArg, Certificates)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ConcatArg, Certificates)]
ys =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
TypeBase Shape NoUniqueness
elem_type <- VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
VName
y' <- TypeBase Shape NoUniqueness
-> (ConcatArg, Certificates) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
TypeBase Shape NoUniqueness -> (ConcatArg, Certificates) -> m VName
fromConcatArg TypeBase Shape NoUniqueness
elem_type (ConcatArg, Certificates)
y
[VName]
ys' <- ((ConcatArg, Certificates) -> RuleM lore VName)
-> [(ConcatArg, Certificates)] -> RuleM lore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TypeBase Shape NoUniqueness
-> (ConcatArg, Certificates) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
TypeBase Shape NoUniqueness -> (ConcatArg, Certificates) -> m VName
fromConcatArg TypeBase Shape NoUniqueness
elem_type) [(ConcatArg, Certificates)]
ys
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
0 VName
y' [VName]
ys' SubExp
outer_w
simplifyConcat (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip
ruleIf :: BinderOps lore => TopDownRuleIf lore
ruleIf :: TopDownRuleIf lore
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
_ IfSort
ifsort)
| Just BodyT lore
branch <- Maybe (BodyT lore)
checkBranch,
IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let ses :: [SubExp]
ses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
branch
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
branch
[RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
[ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT (LetDec lore)
p, SubExp
se) <- [PatElemT (LetDec lore)]
-> [SubExp] -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) [SubExp]
ses
]
where
checkBranch :: Maybe (BodyT lore)
checkBranch
| SubExp -> Bool
isCt1 SubExp
e1 = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
tb
| SubExp -> Bool
isCt0 SubExp
e1 = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
fb
| Bool
otherwise = Maybe (BodyT lore)
forall a. Maybe a
Nothing
ruleIf
TopDown lore
_
Pattern lore
pat
StmAux (ExpDec lore)
_
( SubExp
cond,
Body BodyDec lore
_ Stms lore
tstms [Constant (BoolValue Bool
True)],
Body BodyDec lore
_ Stms lore
fstms [SubExp
se],
IfDec [BranchType lore]
ts IfSort
_
)
| Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
tstms,
Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
fstms,
[Prim PrimType
Bool] <- (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ts =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ts IfSort
_)
| Body BodyDec lore
_ Stms lore
tstms [SubExp
tres] <- BodyT lore
tb,
Body BodyDec lore
_ Stms lore
fstms [SubExp
fres] <- BodyT lore
fb,
(Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ Stms lore
tstms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
fstms,
(BranchType lore -> Bool) -> [BranchType lore] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((ExtType -> ExtType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) (ExtType -> Bool)
-> (BranchType lore -> ExtType) -> BranchType lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf) [BranchType lore]
ts = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
tstms
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
fstms
ExpT lore
e <-
BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
BinOp
LogOr
(ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
( BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
BinOp
LogAnd
(ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
(ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres)
)
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat ExpT lore
Exp (Lore (RuleM lore))
e
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
_, BodyT lore
tbranch, BodyT lore
_, IfDec [BranchType lore]
_ IfSort
IfFallback)
| [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames Pattern lore
pat,
(Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let ses :: [SubExp]
ses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tbranch
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch
[RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
[ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT (LetDec lore)
p, SubExp
se) <- [PatElemT (LetDec lore)]
-> [SubExp] -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) [SubExp]
ses
]
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec (BranchType lore)
_)
| Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
t)] <- BodyT lore
tb,
Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
f)] <- BodyT lore
fb =
if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f
then
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
else
if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
then RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
SubExp
cond_neg <- String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
else Rule lore
forall lore. Rule lore
Skip
ruleIf TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ (SubExp, BodyT lore, BodyT lore, IfDec (BranchType lore))
_ = Rule lore
forall lore. Rule lore
Skip
hoistBranchInvariant :: BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant :: TopDownRuleIf lore
hoistBranchInvariant TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ret IfSort
ifsort) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let tses :: [SubExp]
tses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb
fses :: [SubExp]
fses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb
([Maybe (Int, SubExp)]
hoistings, ([PatElemT (LetDec lore)]
pes, [Either Int (BranchType lore)]
ts, [(SubExp, SubExp)]
res)) <-
([Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])))
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])))
-> ([Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]))
-> [Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) (RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])))
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall a b. (a -> b) -> a -> b
$
((PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))))
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
branchInvariant ([(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
[PatElemT (LetDec lore)]
-> [Either Int (BranchType lore)]
-> [(SubExp, SubExp)]
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
(Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat)
((Int -> Either Int (BranchType lore))
-> [Int] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Either Int (BranchType lore)
forall a b. a -> Either a b
Left [Int
0 .. Int
num_ctx Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Either Int (BranchType lore)]
-> [Either Int (BranchType lore)] -> [Either Int (BranchType lore)]
forall a. [a] -> [a] -> [a]
++ (BranchType lore -> Either Int (BranchType lore))
-> [BranchType lore] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> Either Int (BranchType lore)
forall a b. b -> Either a b
Right [BranchType lore]
ret)
([SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
tses [SubExp]
fses)
let ctx_fixes :: [(Int, SubExp)]
ctx_fixes = [Maybe (Int, SubExp)] -> [(Int, SubExp)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, SubExp)]
hoistings
([SubExp]
tses', [SubExp]
fses') = [(SubExp, SubExp)] -> ([SubExp], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, SubExp)]
res
tb' :: BodyT lore
tb' = BodyT lore
tb {bodyResult :: [SubExp]
bodyResult = [SubExp]
tses'}
fb' :: BodyT lore
fb' = BodyT lore
fb {bodyResult :: [SubExp]
bodyResult = [SubExp]
fses'}
ret' :: [BranchType lore]
ret' = ((Int, SubExp) -> [BranchType lore] -> [BranchType lore])
-> [BranchType lore] -> [(Int, SubExp)] -> [BranchType lore]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType lore] -> [BranchType lore])
-> (Int, SubExp) -> [BranchType lore] -> [BranchType lore]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType lore] -> [BranchType lore]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) ([Either Int (BranchType lore)] -> [BranchType lore]
forall a b. [Either a b] -> [b]
rights [Either Int (BranchType lore)]
ts) [(Int, SubExp)]
ctx_fixes
([PatElemT (LetDec lore)]
ctx_pes, [PatElemT (LetDec lore)]
val_pes) = Int
-> [PatElemT (LetDec lore)]
-> ([PatElemT (LetDec lore)], [PatElemT (LetDec lore)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
ret') [PatElemT (LetDec lore)]
pes
if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe (Int, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (Int, SubExp)]
hoistings
then do
BodyT lore
tb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
tb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
BodyT lore
fb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
fb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
ctx_pes [PatElemT (LetDec lore)]
val_pes) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT lore
tb'' BodyT lore
fb'' ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
ret' IfSort
ifsort)
else RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
where
num_ctx :: Int
num_ctx = [PatElemT (LetDec lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PatElemT (LetDec lore)] -> Int)
-> [PatElemT (LetDec lore)] -> Int
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat
bound_in_branches :: Names
bound_in_branches =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
(Stm lore -> [VName]) -> Seq (Stm lore) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Pattern lore -> [VName])
-> (Stm lore -> Pattern lore) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern) (Seq (Stm lore) -> [VName]) -> Seq (Stm lore) -> [VName]
forall a b. (a -> b) -> a -> b
$
BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tb Seq (Stm lore) -> Seq (Stm lore) -> Seq (Stm lore)
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fb
mem_sizes :: Names
mem_sizes = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([PatElemT (LetDec lore)] -> Names)
-> [PatElemT (LetDec lore)] -> Names
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
isMem (TypeBase Shape NoUniqueness -> Bool)
-> (PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness)
-> PatElemT (LetDec lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_branches
isMem :: TypeBase shape u -> Bool
isMem Mem {} = Bool
True
isMem TypeBase shape u
_ = Bool
False
sizeOfMem :: VName -> Bool
sizeOfMem VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
mem_sizes
branchInvariant :: (PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
branchInvariant (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse, SubExp
fse))
| SubExp
tse SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
fse = do
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
tse
PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t
| SubExp -> Bool
invariant SubExp
tse,
SubExp -> Bool
invariant SubExp
fse,
Pattern lore -> Int
forall dec. PatternT dec -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
Prim PrimType
_ <- PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType PatElemT (LetDec lore)
pe,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
sizeOfMem (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe = do
[BranchType lore]
bt <- Pattern lore -> RuleM lore [BranchType lore]
forall lore (m :: * -> *).
(ASTLore lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern (Pattern lore -> RuleM lore [BranchType lore])
-> Pattern lore -> RuleM lore [BranchType lore]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
pe]
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe]
(ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond (BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp
tse]
RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp
fse]
RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
bt IfSort
ifsort)
)
PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t
| Bool
otherwise =
Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))))
-> Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
forall a b. b -> Either a b
Right (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse, SubExp
fse))
hoisted :: PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT dec
pe (Left a
i) = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left (Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b)
-> Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. (a -> b) -> a -> b
$ (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe)
hoisted PatElemT dec
_ Right {} = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left Maybe (a, SubExp)
forall a. Maybe a
Nothing
reshapeBodyResults :: BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT (Lore m)
body [ExtType]
rets = m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (BodyT (Lore m)) -> m (BodyT (Lore m)))
-> m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
ses <- BodyT (Lore m) -> m [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT (Lore m)
body
let ([SubExp]
ctx_ses, [SubExp]
val_ses) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
rets) [SubExp]
ses
[SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> m (BodyT (Lore m)))
-> ([SubExp] -> [SubExp]) -> [SubExp] -> m (BodyT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([SubExp]
ctx_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++) ([SubExp] -> m (BodyT (Lore m)))
-> m [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SubExp -> ExtType -> m SubExp)
-> [SubExp] -> [ExtType] -> m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> ExtType -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> ExtType -> m SubExp
reshapeResult [SubExp]
val_ses [ExtType]
rets
reshapeResult :: SubExp -> ExtType -> m SubExp
reshapeResult (Var VName
v) t :: ExtType
t@Array {} = do
TypeBase Shape NoUniqueness
v_t <- VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
let newshape :: [SubExp]
newshape = TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> TypeBase Shape NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ExtType
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
removeExistentials ExtType
t TypeBase Shape NoUniqueness
v_t
if [SubExp]
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
v_t
then String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> Exp (Lore m)
forall lore. [SubExp] -> VName -> Exp lore
shapeCoerce [SubExp]
newshape VName
v
else SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
reshapeResult SubExp
se ExtType
_ =
SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
simplifyIdentityReshape :: SimpleRule lore
simplifyIdentityReshape :: SimpleRule lore
simplifyIdentityReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
| Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t
=
SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
simplifyIdentityReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyReshapeReshape :: SimpleRule lore
simplifyReshapeReshape :: SimpleRule lore
simplifyReshapeReshape VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
| Just (BasicOp (Reshape ShapeChange SubExp
oldshape VName
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp
Reshape (ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => ShapeChange d -> ShapeChange d -> ShapeChange d
fuseReshape ShapeChange SubExp
oldshape ShapeChange SubExp
newshape) VName
v2, Certificates
v_cs)
simplifyReshapeReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyReshapeScratch :: SimpleRule lore
simplifyReshapeScratch :: SimpleRule lore
simplifyReshapeScratch VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
| Just (BasicOp (Scratch PrimType
bt [SubExp]
_), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch PrimType
bt ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape, Certificates
v_cs)
simplifyReshapeScratch VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyReshapeReplicate :: SimpleRule lore
simplifyReshapeReplicate :: SimpleRule lore
simplifyReshapeReplicate VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
| Just (BasicOp (Replicate Shape
_ SubExp
se), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
Just Shape
oldshape <- TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (TypeBase Shape NoUniqueness -> Shape)
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType SubExp
se,
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
let new :: [SubExp]
new =
Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take (ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
oldshape) ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape
in (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
new) SubExp
se, Certificates
v_cs)
simplifyReshapeReplicate VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyReshapeIota :: SimpleRule lore
simplifyReshapeIota :: SimpleRule lore
simplifyReshapeIota VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
| Just (BasicOp (Iota SubExp
_ SubExp
offset SubExp
stride IntType
it), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
[SubExp
n] <- ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n SubExp
offset SubExp
stride IntType
it, Certificates
v_cs)
simplifyReshapeIota VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
improveReshape :: SimpleRule lore
improveReshape :: SimpleRule lore
improveReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
| Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
ShapeChange SubExp
newshape' <- [SubExp] -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => [d] -> ShapeChange d -> ShapeChange d
informReshape (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t) ShapeChange SubExp
newshape,
ShapeChange SubExp
newshape' ShapeChange SubExp -> ShapeChange SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= ShapeChange SubExp
newshape =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
newshape' VName
v, Certificates
forall a. Monoid a => a
mempty)
improveReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
copyScratchToScratch :: SimpleRule lore
copyScratchToScratch :: SimpleRule lore
copyScratchToScratch VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Copy VName
src) = do
TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
if VName -> Bool
isActuallyScratch VName
src
then (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t), Certificates
forall a. Monoid a => a
mempty)
else Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
where
isActuallyScratch :: VName -> Bool
isActuallyScratch VName
v =
case Exp lore -> Maybe BasicOp
forall lore. Exp lore -> Maybe BasicOp
asBasicOp (Exp lore -> Maybe BasicOp)
-> ((Exp lore, Certificates) -> Exp lore)
-> (Exp lore, Certificates)
-> Maybe BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp lore, Certificates) -> Exp lore
forall a b. (a, b) -> a
fst ((Exp lore, Certificates) -> Maybe BasicOp)
-> Maybe (Exp lore, Certificates) -> Maybe BasicOp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarLookup lore
defOf VName
v of
Just Scratch {} -> Bool
True
Just (Rearrange [Int]
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
Just (Reshape ShapeChange SubExp
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
Maybe BasicOp
_ -> Bool
False
copyScratchToScratch VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
ruleBasicOp :: BinderOps lore => TopDownRuleBasicOp lore
ruleBasicOp :: TopDownRuleBasicOp lore
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux BasicOp
op
| Just (BasicOp
op', Certificates
cs) <- [Maybe (BasicOp, Certificates)] -> Maybe (BasicOp, Certificates)
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [SimpleRule lore
rule VName -> Maybe (Exp lore, Certificates)
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType BasicOp
op | SimpleRule lore
rule <- [SimpleRule lore]
forall lore. [SimpleRule lore]
simpleRules] =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpDec lore)
aux) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp BasicOp
op'
where
defOf :: VName -> Maybe (Exp lore, Certificates)
defOf = (VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
`ST.lookupExp` TopDown lore
vtable)
seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v) = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
src Slice SubExp
_ (Var VName
v))
| Just (BasicOp Scratch {}, Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Update VName
src [DimSlice SubExp
i SubExp
n SubExp
s] (Var VName
v))
| SubExp -> Bool
isCt1 SubExp
n,
SubExp -> Bool
isCt1 SubExp
s,
Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> [SubExp] -> TopDown lore -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
v [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0] TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
SubExp
e' <- String -> PrimExp VName -> RuleM lore SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"update_elem" PrimExp VName
e
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
src [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i] SubExp
e'
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
dest Slice SubExp
destis (Var VName
v))
| Just (Exp lore
e, Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
Exp lore -> Bool
arrayFrom Exp lore
e =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
where
arrayFrom :: Exp lore -> Bool
arrayFrom (BasicOp (Copy VName
copy_v))
| Just (Exp lore
e', Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
copy_v TopDown lore
vtable =
Exp lore -> Bool
arrayFrom Exp lore
e'
arrayFrom (BasicOp (Index VName
src Slice SubExp
srcis)) =
VName
src VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest Bool -> Bool -> Bool
&& Slice SubExp
destis Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
srcis
arrayFrom (BasicOp (Replicate Shape
v_shape SubExp
v_se))
| Just (Replicate Shape
dest_shape SubExp
dest_se, Certificates
_) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
dest TopDown lore
vtable,
SubExp
v_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
dest_se,
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
v_shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape =
Bool
True
arrayFrom Exp lore
_ =
Bool
False
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
dest Slice SubExp
is SubExp
se)
| Just TypeBase Shape NoUniqueness
dest_t <- VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
dest TopDown lore
vtable,
Shape -> Slice SubExp -> Bool
isFullSlice (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
dest_t) Slice SubExp
is = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
case SubExp
se of
Var VName
v | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
is -> do
VName
v_reshaped <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_reshaped") (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew ([SubExp] -> ShapeChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
dest_t) VName
v
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v_reshaped
SubExp
_ -> Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit [SubExp
se] (TypeBase Shape NoUniqueness -> BasicOp)
-> TypeBase Shape NoUniqueness -> BasicOp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType TypeBase Shape NoUniqueness
dest_t
ruleBasicOp TopDown lore
vtable Pattern lore
pat (StmAux Certificates
cs1 Attrs
attrs ExpDec lore
_) (Update VName
dest1 Slice SubExp
is1 (Var VName
v1))
| Just (Update VName
dest2 Slice SubExp
is2 SubExp
se2, Certificates
cs2) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v1 TopDown lore
vtable,
Just (Copy VName
v3, Certificates
cs3) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
dest2 TopDown lore
vtable,
Just (Index VName
v4 Slice SubExp
is4, Certificates
cs4) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v3 TopDown lore
vtable,
Slice SubExp
is4 Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
is1,
VName
v4 VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest1 =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs1 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs2 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs3 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs4) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
Slice SubExp
is5 <- Slice (TPrimExp Int64 VName) -> RuleM lore (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName) -> RuleM lore (Slice SubExp))
-> Slice (TPrimExp Int64 VName) -> RuleM lore (Slice SubExp)
forall a b. (a -> b) -> a -> b
$ Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is1) (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is2)
Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
dest1 Slice SubExp
is5 SubExp
se2
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (CmpOp (CmpEq PrimType
t) SubExp
se1 SubExp
se2)
| Just RuleM lore ()
m <- SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith SubExp
se1 SubExp
se2 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify RuleM lore ()
m
| Just RuleM lore ()
m <- SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith SubExp
se2 SubExp
se1 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify RuleM lore ()
m
where
simplifyWith :: SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith (Var VName
v) SubExp
x
| Just Stm lore
bnd <- VName -> TopDown lore -> Maybe (Stm lore)
forall lore. VName -> SymbolTable lore -> Maybe (Stm lore)
ST.lookupStm VName
v TopDown lore
vtable,
If SubExp
p BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_ <- Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
bnd,
Just (SubExp
y, SubExp
z) <-
VName
-> Pattern lore
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
forall dec lore lore.
VName
-> PatternT dec
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
returns VName
v (Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
bnd) BodyT lore
tbranch BodyT lore
fbranch,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Names
forall lore. Body lore -> Names
boundInBody BodyT lore
tbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
y,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Names
forall lore. Body lore -> Names
boundInBody BodyT lore
fbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
z = RuleM lore () -> Maybe (RuleM lore ())
forall a. a -> Maybe a
Just (RuleM lore () -> Maybe (RuleM lore ()))
-> RuleM lore () -> Maybe (RuleM lore ())
forall a b. (a -> b) -> a -> b
$ do
SubExp
eq_x_y <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
y
SubExp
eq_x_z <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"eq_x_z" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
z
SubExp
p_and_eq_x_y <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
p SubExp
eq_x_y
SubExp
not_p <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"not_p" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
p
SubExp
not_p_and_eq_x_z <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
not_p SubExp
eq_x_z
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
p_and_eq_x_y SubExp
not_p_and_eq_x_z
simplifyWith SubExp
_ SubExp
_ =
Maybe (RuleM lore ())
forall a. Maybe a
Nothing
returns :: VName
-> PatternT dec
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
returns VName
v PatternT dec
ifpat BodyT lore
tbranch BodyT lore
fbranch =
((PatElemT dec, (SubExp, SubExp)) -> (SubExp, SubExp))
-> Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PatElemT dec, (SubExp, SubExp)) -> (SubExp, SubExp)
forall a b. (a, b) -> b
snd (Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp))
-> Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
((PatElemT dec, (SubExp, SubExp)) -> Bool)
-> [(PatElemT dec, (SubExp, SubExp))]
-> Maybe (PatElemT dec, (SubExp, SubExp))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((PatElemT dec, (SubExp, SubExp)) -> VName)
-> (PatElemT dec, (SubExp, SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT dec -> VName)
-> ((PatElemT dec, (SubExp, SubExp)) -> PatElemT dec)
-> (PatElemT dec, (SubExp, SubExp))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, (SubExp, SubExp)) -> PatElemT dec
forall a b. (a, b) -> a
fst) ([(PatElemT dec, (SubExp, SubExp))]
-> Maybe (PatElemT dec, (SubExp, SubExp)))
-> [(PatElemT dec, (SubExp, SubExp))]
-> Maybe (PatElemT dec, (SubExp, SubExp))
forall a b. (a -> b) -> a -> b
$
[PatElemT dec]
-> [(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT dec
ifpat) ([(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))])
-> [(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
[SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tbranch) (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fbranch)
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Replicate (Shape []) se :: SubExp
se@Constant {}) =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Replicate (Shape []) (Var VName
v)) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
TypeBase Shape NoUniqueness
v_t <- VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$
if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
v_t
then SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
else VName -> BasicOp
Copy VName
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Replicate Shape
shape (Var VName
v))
| Just (BasicOp (Replicate Shape
shape2 SubExp
se), Certificates
cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape2) SubExp
se
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (ArrayLit (SubExp
se : [SubExp]
ses) TypeBase Shape NoUniqueness
_)
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
se) [SubExp]
ses =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
let n :: SubExp
n = Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ses) Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
1 :: Int64)
in Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) SubExp
se
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Index VName
idd Slice SubExp
slice)
| Just [SubExp]
inds <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
Just (BasicOp (Reshape ShapeChange SubExp
newshape VName
idd2), Certificates
idd_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
idd TopDown lore
vtable,
ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
inds =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
case ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape of
Just [SubExp]
_ ->
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
idd_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd2 Slice SubExp
slice
Maybe [SubExp]
Nothing -> do
[SubExp]
oldshape <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> RuleM lore (TypeBase Shape NoUniqueness) -> RuleM lore [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
idd2
let new_inds :: [TPrimExp Int64 VName]
new_inds =
[TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex
((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
oldshape)
((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape)
((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
inds)
[SubExp]
new_inds' <-
(TPrimExp Int64 VName -> RuleM lore SubExp)
-> [TPrimExp Int64 VName] -> RuleM lore [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> TPrimExp Int64 VName -> RuleM lore SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"new_index") [TPrimExp Int64 VName]
new_inds
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
idd_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd2 (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
new_inds'
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (BinOp (Pow IntType
t) SubExp
e1 SubExp
e2)
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
t Integer
2 =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
Shl IntType
t) (IntType -> Integer -> SubExp
intConst IntType
t Integer
1) SubExp
e2
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Rearrange [Int]
perm VName
v)
| [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rearrange [Int]
perm VName
v)
| Just (BasicOp (Rearrange [Int]
perm2 VName
e), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm2) VName
e
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rearrange [Int]
perm VName
v)
| Just (BasicOp (Rotate [SubExp]
offsets VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
Just (BasicOp (Rearrange [Int]
perm3 VName
v3), Certificates
v2_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v2 TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let offsets' :: [SubExp]
offsets' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm3) [SubExp]
offsets
VName
rearrange_rotate <- String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"rearrange_rotate" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
v3
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
v_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
v2_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm3) VName
rearrange_rotate
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rearrange [Int]
perm VName
v1)
| Just (BasicOp (Replicate Shape
dims (Var VName
v2)), Certificates
v1_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v1 TopDown lore
vtable,
Int
num_dims <- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
dims,
([Int]
rep_perm, [Int]
rest_perm) <- Int -> [Int] -> ([Int], [Int])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_dims [Int]
perm,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
rest_perm,
[Int]
rep_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0 .. [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
rep_perm Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v1_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
SubExp
v <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rearrange_replicate" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
num_dims) [Int]
rest_perm) VName
v2
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
dims SubExp
v
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Rotate [SubExp]
offsets VName
v)
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
isCt0 [SubExp]
offsets = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rotate [SubExp]
offsets VName
v)
| Just (BasicOp (Rearrange [Int]
perm VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
Just (BasicOp (Rotate [SubExp]
offsets2 VName
v3), Certificates
v2_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v2 TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let offsets2' :: [SubExp]
offsets2' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) [SubExp]
offsets2
addOffsets :: SubExp -> SubExp -> m SubExp
addOffsets SubExp
x SubExp
y = String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"summed_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
x SubExp
y
[SubExp]
offsets' <- (SubExp -> SubExp -> RuleM lore SubExp)
-> [SubExp] -> [SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM lore SubExp
forall (m :: * -> *). MonadBinder m => SubExp -> SubExp -> m SubExp
addOffsets [SubExp]
offsets [SubExp]
offsets2'
VName
rotate_rearrange <-
StmAux (ExpDec lore) -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"rotate_rearrange" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
v3
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
v_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
v2_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
rotate_rearrange
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rotate [SubExp]
offsets1 VName
v)
| Just (BasicOp (Rotate [SubExp]
offsets2 VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
offsets <- (SubExp -> SubExp -> RuleM lore SubExp)
-> [SubExp] -> [SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM lore SubExp
forall (m :: * -> *). MonadBinder m => SubExp -> SubExp -> m SubExp
add [SubExp]
offsets1 [SubExp]
offsets2
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets VName
v2
where
add :: SubExp -> SubExp -> m SubExp
add SubExp
x SubExp
y = String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
x SubExp
y
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Update VName
arr_x Slice SubExp
slice_x (Var VName
v))
| Just [SubExp]
_ <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice_x,
Just (Index VName
arr_y Slice SubExp
slice_y, Certificates
cs_y) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v TopDown lore
vtable,
VName -> TopDown lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
ST.available VName
arr_y TopDown lore
vtable,
VName
arr_y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
arr_x,
Just (Slice SubExp
slice_x_bef, DimFix SubExp
i, []) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice_x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Slice SubExp
slice_x,
Just (Slice SubExp
slice_y_bef, DimFix SubExp
j, []) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice_y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Slice SubExp
slice_y = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let slice_x' :: Slice SubExp
slice_x' = Slice SubExp
slice_x_bef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
slice_y' :: Slice SubExp
slice_y' = Slice SubExp
slice_y_bef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
j (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
VName
v' <- String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_slice") (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_y Slice SubExp
slice_y'
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs_y (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
arr_x Slice SubExp
slice_x' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (CmpOp CmpSle {} SubExp
x SubExp
y)
| Constant (IntValue (Int64Value Int64
0)) <- SubExp
x,
Var VName
v <- SubExp
y,
Just SubExp
_ <- VName -> TopDown lore -> Maybe SubExp
forall lore. VName -> SymbolTable lore -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (CmpOp CmpSlt {} SubExp
x SubExp
y)
| Var VName
v <- SubExp
x,
Just SubExp
n <- VName -> TopDown lore -> Maybe SubExp
forall lore. VName -> SymbolTable lore -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown lore
vtable,
SubExp
n SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (CmpOp CmpSlt {} (Var VName
x) SubExp
y)
| SubExp -> Bool
isCt0 SubExp
y,
Bool -> (Entry lore -> Bool) -> Maybe (Entry lore) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Entry lore -> Bool
forall lore. Entry lore -> Bool
ST.entryIsSize (Maybe (Entry lore) -> Bool) -> Maybe (Entry lore) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> TopDown lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
x TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False
ruleBasicOp TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ =
Rule lore
forall lore. Rule lore
Skip
removeDeadBranchResult :: BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult :: BottomUpRuleIf lore
removeDeadBranchResult (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
rettype IfSort
ifsort)
|
Pattern lore -> Int
forall dec. PatternT dec -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
rettype,
[Bool]
patused <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat,
Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) =
let tses :: [SubExp]
tses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb
fses :: [SubExp]
fses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb
pick :: [a] -> [a]
pick :: [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
tb' :: BodyT lore
tb' = BodyT lore
tb {bodyResult :: [SubExp]
bodyResult = [SubExp] -> [SubExp]
forall a. [a] -> [a]
pick [SubExp]
tses}
fb' :: BodyT lore
fb' = BodyT lore
fb {bodyResult :: [SubExp]
bodyResult = [SubExp] -> [SubExp]
forall a. [a] -> [a]
pick [SubExp]
fses}
pat' :: [PatElemT (LetDec lore)]
pat' = [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. [a] -> [a]
pick ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
rettype' :: [BranchType lore]
rettype' = [BranchType lore] -> [BranchType lore]
forall a. [a] -> [a]
pick [BranchType lore]
rettype
in RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
pat') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
e1 BodyT lore
tb' BodyT lore
fb' (IfDec (BranchType lore) -> ExpT lore)
-> IfDec (BranchType lore) -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
rettype' IfSort
ifsort
| Bool
otherwise = Rule lore
forall lore. Rule lore
Skip
isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False
isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False