{-# LANGUAGE TupleSections #-}
module Algorithm.EqSat where
import Algorithm.EqSat.Egraph
import Algorithm.EqSat.DB
import Algorithm.EqSat.Info
import Algorithm.EqSat.Build
import Control.Lens (element, makeLenses, over, (&), (+~), (-~), (.~), (^.))
import Control.Monad.State
import Data.Function (on)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.List (intercalate, minimumBy)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (mapMaybe)
import Data.SRTree
import Data.HashSet (HashSet)
import qualified Data.HashSet as Set
import Control.Monad ( zipWithM )
import Debug.Trace
type Scheduler a = State (IntMap Int) a
fromJust :: Maybe a -> a
fromJust :: forall a. Maybe a -> a
fromJust (Just a
x) = a
x
fromJust Maybe a
_ = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"fromJust called with Nothing"
{-# INLINE fromJust #-}
eqSat :: Monad m => Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
eqSat :: forall (m :: * -> *).
Monad m =>
Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
eqSat Fix SRTree
expr [Rule]
rules CostFun
costFun Int
maxIt =
do Int
root <- CostFun -> Fix SRTree -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Fix SRTree -> EGraphST m Int
fromTree CostFun
costFun Fix SRTree
expr
(Bool
end, Int
it) <- CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
forall (m :: * -> *).
Monad m =>
CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
runEqSat CostFun
costFun [Rule]
rules Int
maxIt
Fix SRTree
best <- Int -> EGraphST m (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBest Int
root
if Bool -> Bool
not Bool
end
then do Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
eqSat Fix SRTree
best [Rule]
rules CostFun
costFun Int
it
else Fix SRTree -> EGraphST m (Fix SRTree)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Fix SRTree
best
type CostMap = Map EClassId (Int, Fix SRTree)
recalculateBest :: Monad m => CostFun -> EClassId -> EGraphST m (Fix SRTree)
recalculateBest :: forall (m :: * -> *).
Monad m =>
CostFun -> Int -> EGraphST m (Fix SRTree)
recalculateBest CostFun
costFun Int
eid =
do ClassIdMap EClass
classes <- (EGraph -> ClassIdMap EClass)
-> StateT EGraph m (ClassIdMap EClass)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap EClass
_eClass
let costs :: CostMap
costs = ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
forall k a. Map k a
Map.empty
Int
eid' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eid
Fix SRTree -> EGraphST m (Fix SRTree)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> EGraphST m (Fix SRTree))
-> Fix SRTree -> EGraphST m (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ (Int, Fix SRTree) -> Fix SRTree
forall a b. (a, b) -> b
snd ((Int, Fix SRTree) -> Fix SRTree)
-> (Int, Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ CostMap
costs CostMap -> Int -> (Int, Fix SRTree)
forall k a. Ord k => Map k a -> k -> a
Map.! Int
eid'
where
nodeCost :: CostMap -> ENode -> Maybe (Int, Fix SRTree)
nodeCost :: CostMap -> ENode -> Maybe (Int, Fix SRTree)
nodeCost CostMap
costMap ENode
enode =
do [(Int, Fix SRTree)]
optChildren <- (Int -> Maybe (Int, Fix SRTree))
-> [Int] -> Maybe [(Int, Fix SRTree)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (CostMap
costMap CostMap -> Int -> Maybe (Int, Fix SRTree)
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!?) (ENode -> [Int]
forall a. SRTree a -> [a]
childrenOf ENode
enode)
let cc :: [Int]
cc = ((Int, Fix SRTree) -> Int) -> [(Int, Fix SRTree)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst [(Int, Fix SRTree)]
optChildren
nc :: [Fix SRTree]
nc = ((Int, Fix SRTree) -> Fix SRTree)
-> [(Int, Fix SRTree)] -> [Fix SRTree]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Fix SRTree) -> Fix SRTree
forall a b. (a, b) -> b
snd [(Int, Fix SRTree)]
optChildren
n :: ENode
n = [Int] -> ENode -> ENode
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Int]
cc ENode
enode
c :: Int
c = CostFun
costFun ENode
n
(Int, Fix SRTree) -> Maybe (Int, Fix SRTree)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
cc, SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ [Fix SRTree] -> ENode -> SRTree (Fix SRTree)
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Fix SRTree]
nc ENode
enode)
minimumBy' :: (a -> a -> Ordering) -> [a] -> Maybe a
minimumBy' a -> a -> Ordering
f [] = Maybe a
forall a. Maybe a
Nothing
minimumBy' a -> a -> Ordering
f [a]
xs = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ (a -> a -> Ordering) -> [a] -> a
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy a -> a -> Ordering
f [a]
xs
fillUpCosts :: IntMap EClass -> CostMap -> CostMap
fillUpCosts :: ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
m =
case (Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap))
-> (Bool, CostMap) -> ClassIdMap EClass -> (Bool, CostMap)
forall a b. (Int -> a -> b -> b) -> b -> IntMap a -> b
IntMap.foldrWithKey Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
costOfClass (Bool
False, CostMap
m) ClassIdMap EClass
classes of
(Bool
False, CostMap
_) -> CostMap
m
(Bool
True, CostMap
m') -> ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
m'
costOfClass :: EClassId -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
costOfClass :: Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
costOfClass Int
eid EClass
ecl (Bool
b, CostMap
m) =
let currentCost :: Maybe (Int, Fix SRTree)
currentCost = CostMap
m CostMap -> Int -> Maybe (Int, Fix SRTree)
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Int
eid
minCost :: Maybe (Int, Fix SRTree)
minCost = ((Int, Fix SRTree) -> (Int, Fix SRTree) -> Ordering)
-> [(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree)
forall {a}. (a -> a -> Ordering) -> [a] -> Maybe a
minimumBy' (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, Fix SRTree) -> Int)
-> (Int, Fix SRTree)
-> (Int, Fix SRTree)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst)
([(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree))
-> [(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree)
forall a b. (a -> b) -> a -> b
$ (ENode -> Maybe (Int, Fix SRTree))
-> [ENode] -> [(Int, Fix SRTree)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (CostMap -> ENode -> Maybe (Int, Fix SRTree)
nodeCost CostMap
m)
([ENode] -> [(Int, Fix SRTree)]) -> [ENode] -> [(Int, Fix SRTree)]
forall a b. (a -> b) -> a -> b
$ (ENodeEnc -> ENode) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> [a] -> [b]
map ENodeEnc -> ENode
decodeEnode
([ENodeEnc] -> [ENode]) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> a -> b
$ HashSet ENodeEnc -> [ENodeEnc]
forall a. HashSet a -> [a]
Set.toList (EClass -> HashSet ENodeEnc
_eNodes EClass
ecl)
in case (Maybe (Int, Fix SRTree)
currentCost, Maybe (Int, Fix SRTree)
minCost) of
(Maybe (Int, Fix SRTree)
_, Maybe (Int, Fix SRTree)
Nothing) -> (Bool
b, CostMap
m)
(Maybe (Int, Fix SRTree)
Nothing, Just (Int, Fix SRTree)
new) -> (Bool
True, Int -> (Int, Fix SRTree) -> CostMap -> CostMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int
eid (Int, Fix SRTree)
new CostMap
m)
(Just (Int, Fix SRTree)
old, Just (Int, Fix SRTree)
new) -> if (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst (Int, Fix SRTree)
old Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst (Int, Fix SRTree)
new
then (Bool
b, CostMap
m)
else (Bool
True, Int -> (Int, Fix SRTree) -> CostMap -> CostMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int
eid (Int, Fix SRTree)
new CostMap
m)
runEqSat :: Monad m => CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
runEqSat :: forall (m :: * -> *).
Monad m =>
CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
runEqSat CostFun
costFun [Rule]
rules Int
maxIter = Int -> IntMap Int -> StateT EGraph m (Bool, Int)
forall {m :: * -> *}.
Monad m =>
Int -> IntMap Int -> StateT EGraph m (Bool, Int)
go Int
maxIter IntMap Int
forall a. IntMap a
IntMap.empty
where
rules' :: [Rule]
rules' = (Rule -> [Rule]) -> [Rule] -> [Rule]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Rule -> [Rule]
replaceEqRules [Rule]
rules
replaceEqRules :: Rule -> [Rule]
replaceEqRules :: Rule -> [Rule]
replaceEqRules (Pattern
p1 :=> Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2]
replaceEqRules (Pattern
p1 :==: Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2, Pattern
p2 Pattern -> Pattern -> Rule
:=> Pattern
p1]
replaceEqRules (Rule
r :| Condition
cond) = (Rule -> Rule) -> [Rule] -> [Rule]
forall a b. (a -> b) -> [a] -> [b]
map (Rule -> Condition -> Rule
:| Condition
cond) ([Rule] -> [Rule]) -> [Rule] -> [Rule]
forall a b. (a -> b) -> a -> b
$ Rule -> [Rule]
replaceEqRules Rule
r
go :: Int -> IntMap Int -> StateT EGraph m (Bool, Int)
go Int
it IntMap Int
sch = do Map ENode Int
eNodes <- (EGraph -> Map ENode Int) -> StateT EGraph m (Map ENode Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> Map ENode Int
_eNodeToEClass
ClassIdMap EClass
eClasses <- (EGraph -> ClassIdMap EClass)
-> StateT EGraph m (ClassIdMap EClass)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap EClass
_eClass
let matchSch :: Int -> Rule -> Scheduler [Rule]
matchSch = Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
it
matchAll :: [Rule] -> StateT (IntMap Int) Identity [[Rule]]
matchAll = (Int -> Rule -> Scheduler [Rule])
-> [Int] -> [Rule] -> StateT (IntMap Int) Identity [[Rule]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> Rule -> Scheduler [Rule]
matchSch [Int
0..]
([[Rule]]
rules, IntMap Int
sch') = StateT (IntMap Int) Identity [[Rule]]
-> IntMap Int -> ([[Rule]], IntMap Int)
forall s a. State s a -> s -> (a, s)
runState ([Rule] -> StateT (IntMap Int) Identity [[Rule]]
matchAll [Rule]
rules') IntMap Int
sch
[[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
matches <- (Rule
-> StateT
EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> [Rule]
-> StateT
EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Rule
rule -> ((Map ClassOrVar ClassOrVar, ClassOrVar)
-> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)))
-> [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall a b. (a -> b) -> [a] -> [b]
map (Rule
rule,) ([(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> StateT
EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
forall (m :: * -> *).
Monad m =>
Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match (Rule -> Pattern
source Rule
rule)) ([Rule]
-> StateT
EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]])
-> [Rule]
-> StateT
EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
forall a b. (a -> b) -> a -> b
$ [[Rule]] -> [Rule]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Rule]]
rules
((Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))
-> StateT EGraph m ())
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
-> StateT EGraph m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Rule
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> StateT EGraph m ())
-> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))
-> StateT EGraph m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (CostFun
-> Rule
-> (Map ClassOrVar ClassOrVar, ClassOrVar)
-> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMatch CostFun
costFun)) ([(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
-> StateT EGraph m ())
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
-> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
matches
CostFun -> StateT EGraph m ()
forall (m :: * -> *). Monad m => CostFun -> EGraphST m ()
rebuild CostFun
costFun
Map ENode Int
eNodes' <- (EGraph -> Map ENode Int) -> StateT EGraph m (Map ENode Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> Map ENode Int
_eNodeToEClass
ClassIdMap EClass
eClasses' <- (EGraph -> ClassIdMap EClass)
-> StateT EGraph m (ClassIdMap EClass)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap EClass
_eClass
if Int
it Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
|| (Map ENode Int
eNodes' Map ENode Int -> Map ENode Int -> Bool
forall a. Eq a => a -> a -> Bool
== Map ENode Int
eNodes Bool -> Bool -> Bool
&& ClassIdMap EClass
eClasses' ClassIdMap EClass -> ClassIdMap EClass -> Bool
forall a. Eq a => a -> a -> Bool
== ClassIdMap EClass
eClasses)
then (Bool, Int) -> StateT EGraph m (Bool, Int)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
True, Int
it)
else if ClassIdMap EClass -> Int
forall a. IntMap a -> Int
IntMap.size ClassIdMap EClass
eClasses' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
500
then (Bool, Int) -> StateT EGraph m (Bool, Int)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
False, Int
it)
else Int -> IntMap Int -> StateT EGraph m (Bool, Int)
go (Int
itInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) IntMap Int
sch'
applySingleMergeOnlyEqSat :: Monad m => CostFun -> [Rule] -> EGraphST m ()
applySingleMergeOnlyEqSat :: forall (m :: * -> *). Monad m => CostFun -> [Rule] -> EGraphST m ()
applySingleMergeOnlyEqSat CostFun
costFun [Rule]
rules =
do DB
db <- (EGraph -> DB) -> StateT EGraph m DB
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
let matchSch :: Int -> Rule -> Scheduler [Rule]
matchSch = Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
10
matchAll :: [Rule] -> StateT (IntMap Int) Identity [[Rule]]
matchAll = (Int -> Rule -> Scheduler [Rule])
-> [Int] -> [Rule] -> StateT (IntMap Int) Identity [[Rule]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> Rule -> Scheduler [Rule]
matchSch [Int
0..]
([[Rule]]
rules, IntMap Int
sch') = StateT (IntMap Int) Identity [[Rule]]
-> IntMap Int -> ([[Rule]], IntMap Int)
forall s a. State s a -> s -> (a, s)
runState ([Rule] -> StateT (IntMap Int) Identity [[Rule]]
matchAll [Rule]
rules') IntMap Int
forall a. IntMap a
IntMap.empty
[[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
matches <- (Rule
-> StateT
EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> [Rule]
-> StateT
EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Rule
rule -> ((Map ClassOrVar ClassOrVar, ClassOrVar)
-> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)))
-> [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall a b. (a -> b) -> [a] -> [b]
map (Rule
rule,) ([(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> StateT
EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
forall (m :: * -> *).
Monad m =>
Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match (Rule -> Pattern
source Rule
rule)) ([Rule]
-> StateT
EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]])
-> [Rule]
-> StateT
EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
forall a b. (a -> b) -> a -> b
$ [[Rule]] -> [Rule]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Rule]]
rules
((Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)) -> EGraphST m ())
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
-> EGraphST m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ())
-> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)) -> EGraphST m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMergeOnlyMatch CostFun
costFun)) ([(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
-> EGraphST m ())
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
-> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
matches
CostFun -> EGraphST m ()
forall (m :: * -> *). Monad m => CostFun -> EGraphST m ()
rebuild CostFun
costFun
where
rules' :: [Rule]
rules' = (Rule -> [Rule]) -> [Rule] -> [Rule]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Rule -> [Rule]
replaceEqRules [Rule]
rules
replaceEqRules :: Rule -> [Rule]
replaceEqRules :: Rule -> [Rule]
replaceEqRules (Pattern
p1 :=> Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2]
replaceEqRules (Pattern
p1 :==: Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2, Pattern
p2 Pattern -> Pattern -> Rule
:=> Pattern
p1]
replaceEqRules (Rule
r :| Condition
cond) = (Rule -> Rule) -> [Rule] -> [Rule]
forall a b. (a -> b) -> [a] -> [b]
map (Rule -> Condition -> Rule
:| Condition
cond) ([Rule] -> [Rule]) -> [Rule] -> [Rule]
forall a b. (a -> b) -> a -> b
$ Rule -> [Rule]
replaceEqRules Rule
r
matchWithScheduler :: Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler :: Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
it Int
ruleNumber Rule
rule =
do Maybe Int
mbBan <- (IntMap Int -> Maybe Int)
-> StateT (IntMap Int) Identity (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntMap Int -> Int -> Maybe Int
forall a. IntMap a -> Int -> Maybe a
IntMap.!? Int
ruleNumber)
if Maybe Int
mbBan Maybe Int -> Maybe Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Maybe Int
forall a. Maybe a
Nothing Bool -> Bool -> Bool
&& Maybe Int -> Int
forall a. Maybe a -> a
fromJust Maybe Int
mbBan Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
it
then [Rule] -> Scheduler [Rule]
forall a. a -> StateT (IntMap Int) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
else do
(IntMap Int -> IntMap Int) -> StateT (IntMap Int) Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
ruleNumber (Int
itInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
5))
[Rule] -> Scheduler [Rule]
forall a. a -> StateT (IntMap Int) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Rule
rule]