{-# LANGUAGE TupleSections #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.EqSat
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :
--
-- Equality Saturation for SRTree
-- Heavily based on hegg (https://github.com/alt-romes/hegg by alt-romes)
--
-----------------------------------------------------------------------------

module Algorithm.EqSat where

import Algorithm.EqSat.Egraph
import Algorithm.EqSat.DB
import Algorithm.EqSat.Info
import Algorithm.EqSat.Build
import Control.Lens (element, makeLenses, over, (&), (+~), (-~), (.~), (^.))
import Control.Monad.State
import Data.Function (on)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.List (intercalate, minimumBy)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (mapMaybe)
import Data.SRTree
import Data.HashSet (HashSet)
import qualified Data.HashSet as Set
import Control.Monad ( zipWithM )

import Debug.Trace

-- | The `Scheduler` stores a map with the banned iterations of a certain rule . 
-- TODO: make it more customizable.
type Scheduler a = State (IntMap Int) a

-- to avoid importing
fromJust :: Maybe a -> a
fromJust :: forall a. Maybe a -> a
fromJust (Just a
x) = a
x
fromJust Maybe a
_        = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"fromJust called with Nothing"
{-# INLINE fromJust #-}

-- | runs equality saturation from an expression tree,
-- a given set of rules, and a cost function.
-- Returns the tree with the smallest cost.
eqSat :: Monad m => Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
eqSat :: forall (m :: * -> *).
Monad m =>
Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
eqSat Fix SRTree
expr [Rule]
rules CostFun
costFun Int
maxIt =
    do Int
root <- CostFun -> Fix SRTree -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Fix SRTree -> EGraphST m Int
fromTree CostFun
costFun Fix SRTree
expr
       (Bool
end, Int
it) <- CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
forall (m :: * -> *).
Monad m =>
CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
runEqSat CostFun
costFun [Rule]
rules Int
maxIt
       Fix SRTree
best      <- Int -> EGraphST m (Fix SRTree)
forall (m :: * -> *). Monad m => Int -> EGraphST m (Fix SRTree)
getBest Int
root
       --info      <- gets ((IntMap.! root) . _eClass)
       --info2     <- gets ((IntMap.! 9) . _eClass)
       --traceShow (info, info2) $
       if Bool -> Bool
not Bool
end -- if had an early stop
         then do Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
eqSat Fix SRTree
best [Rule]
rules CostFun
costFun Int
it -- reapplies eqsat on the best so far 
         else Fix SRTree -> EGraphST m (Fix SRTree)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Fix SRTree
best

type CostMap = Map EClassId (Int, Fix SRTree)

-- | recalculates the costs with a new cost function
recalculateBest :: Monad m => CostFun -> EClassId -> EGraphST m (Fix SRTree)
recalculateBest :: forall (m :: * -> *).
Monad m =>
CostFun -> Int -> EGraphST m (Fix SRTree)
recalculateBest CostFun
costFun Int
eid =
    do ClassIdMap EClass
classes <- (EGraph -> ClassIdMap EClass)
-> StateT EGraph m (ClassIdMap EClass)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap EClass
_eClass
       let costs :: CostMap
costs = ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
forall k a. Map k a
Map.empty
       Int
eid' <- Int -> EGraphST m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eid
       Fix SRTree -> EGraphST m (Fix SRTree)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> EGraphST m (Fix SRTree))
-> Fix SRTree -> EGraphST m (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ (Int, Fix SRTree) -> Fix SRTree
forall a b. (a, b) -> b
snd ((Int, Fix SRTree) -> Fix SRTree)
-> (Int, Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ CostMap
costs CostMap -> Int -> (Int, Fix SRTree)
forall k a. Ord k => Map k a -> k -> a
Map.! Int
eid'
    where
        nodeCost :: CostMap -> ENode -> Maybe (Int, Fix SRTree)
        nodeCost :: CostMap -> ENode -> Maybe (Int, Fix SRTree)
nodeCost CostMap
costMap ENode
enode =
          do [(Int, Fix SRTree)]
optChildren <- (Int -> Maybe (Int, Fix SRTree))
-> [Int] -> Maybe [(Int, Fix SRTree)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (CostMap
costMap CostMap -> Int -> Maybe (Int, Fix SRTree)
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!?) (ENode -> [Int]
forall a. SRTree a -> [a]
childrenOf ENode
enode) -- | gets the cost of the children, if one is missing, returns Nothing
             let cc :: [Int]
cc = ((Int, Fix SRTree) -> Int) -> [(Int, Fix SRTree)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst [(Int, Fix SRTree)]
optChildren
                 nc :: [Fix SRTree]
nc = ((Int, Fix SRTree) -> Fix SRTree)
-> [(Int, Fix SRTree)] -> [Fix SRTree]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Fix SRTree) -> Fix SRTree
forall a b. (a, b) -> b
snd [(Int, Fix SRTree)]
optChildren
                 n :: ENode
n  = [Int] -> ENode -> ENode
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Int]
cc ENode
enode
                 c :: Int
c  = CostFun
costFun ENode
n
             (Int, Fix SRTree) -> Maybe (Int, Fix SRTree)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
cc, SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ [Fix SRTree] -> ENode -> SRTree (Fix SRTree)
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Fix SRTree]
nc ENode
enode) -- | otherwise, returns the cost of the node + children and the expression so far

        minimumBy' :: (a -> a -> Ordering) -> [a] -> Maybe a
minimumBy' a -> a -> Ordering
f [] = Maybe a
forall a. Maybe a
Nothing
        minimumBy' a -> a -> Ordering
f [a]
xs = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ (a -> a -> Ordering) -> [a] -> a
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy a -> a -> Ordering
f [a]
xs

        fillUpCosts :: IntMap EClass -> CostMap -> CostMap
        fillUpCosts :: ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
m =
            case (Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap))
-> (Bool, CostMap) -> ClassIdMap EClass -> (Bool, CostMap)
forall a b. (Int -> a -> b -> b) -> b -> IntMap a -> b
IntMap.foldrWithKey Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
costOfClass (Bool
False, CostMap
m) ClassIdMap EClass
classes of -- applies costOfClass to each class
              (Bool
False, CostMap
_) -> CostMap
m
              (Bool
True, CostMap
m') -> ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
m' -- | if something changed, recurse

        costOfClass :: EClassId -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
        costOfClass :: Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
costOfClass Int
eid EClass
ecl (Bool
b, CostMap
m) =
            let currentCost :: Maybe (Int, Fix SRTree)
currentCost = CostMap
m CostMap -> Int -> Maybe (Int, Fix SRTree)
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Int
eid
                minCost :: Maybe (Int, Fix SRTree)
minCost     = ((Int, Fix SRTree) -> (Int, Fix SRTree) -> Ordering)
-> [(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree)
forall {a}. (a -> a -> Ordering) -> [a] -> Maybe a
minimumBy' (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, Fix SRTree) -> Int)
-> (Int, Fix SRTree)
-> (Int, Fix SRTree)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst)  -- get the minimum available cost of the nodes of this class
                            ([(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree))
-> [(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree)
forall a b. (a -> b) -> a -> b
$ (ENode -> Maybe (Int, Fix SRTree))
-> [ENode] -> [(Int, Fix SRTree)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (CostMap -> ENode -> Maybe (Int, Fix SRTree)
nodeCost CostMap
m)
                            ([ENode] -> [(Int, Fix SRTree)]) -> [ENode] -> [(Int, Fix SRTree)]
forall a b. (a -> b) -> a -> b
$ (ENodeEnc -> ENode) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> [a] -> [b]
map ENodeEnc -> ENode
decodeEnode
                            ([ENodeEnc] -> [ENode]) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> a -> b
$ HashSet ENodeEnc -> [ENodeEnc]
forall a. HashSet a -> [a]
Set.toList (EClass -> HashSet ENodeEnc
_eNodes EClass
ecl)
            in case (Maybe (Int, Fix SRTree)
currentCost, Maybe (Int, Fix SRTree)
minCost) of -- replace the costs accordingly
                  (Maybe (Int, Fix SRTree)
_, Maybe (Int, Fix SRTree)
Nothing)         -> (Bool
b, CostMap
m)
                  (Maybe (Int, Fix SRTree)
Nothing, Just (Int, Fix SRTree)
new)  -> (Bool
True, Int -> (Int, Fix SRTree) -> CostMap -> CostMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int
eid (Int, Fix SRTree)
new CostMap
m)
                  (Just (Int, Fix SRTree)
old, Just (Int, Fix SRTree)
new) -> if (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst (Int, Fix SRTree)
old Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst (Int, Fix SRTree)
new
                                            then (Bool
b, CostMap
m)
                                            else (Bool
True, Int -> (Int, Fix SRTree) -> CostMap -> CostMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int
eid (Int, Fix SRTree)
new CostMap
m)

-- | run equality saturation for a number of iterations
runEqSat :: Monad m => CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
runEqSat :: forall (m :: * -> *).
Monad m =>
CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
runEqSat CostFun
costFun [Rule]
rules Int
maxIter = Int -> IntMap Int -> StateT EGraph m (Bool, Int)
forall {m :: * -> *}.
Monad m =>
Int -> IntMap Int -> StateT EGraph m (Bool, Int)
go Int
maxIter IntMap Int
forall a. IntMap a
IntMap.empty
    where
        rules' :: [Rule]
rules' = (Rule -> [Rule]) -> [Rule] -> [Rule]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Rule -> [Rule]
replaceEqRules [Rule]
rules

        -- replaces the equality rules with two one-way rules
        replaceEqRules :: Rule -> [Rule]
        replaceEqRules :: Rule -> [Rule]
replaceEqRules (Pattern
p1 :=> Pattern
p2)  = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2]
        replaceEqRules (Pattern
p1 :==: Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2, Pattern
p2 Pattern -> Pattern -> Rule
:=> Pattern
p1]
        replaceEqRules (Rule
r :| Condition
cond)  = (Rule -> Rule) -> [Rule] -> [Rule]
forall a b. (a -> b) -> [a] -> [b]
map (Rule -> Condition -> Rule
:| Condition
cond) ([Rule] -> [Rule]) -> [Rule] -> [Rule]
forall a b. (a -> b) -> a -> b
$ Rule -> [Rule]
replaceEqRules Rule
r

        go :: Int -> IntMap Int -> StateT EGraph m (Bool, Int)
go Int
it IntMap Int
sch = do Map ENode Int
eNodes   <- (EGraph -> Map ENode Int) -> StateT EGraph m (Map ENode Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> Map ENode Int
_eNodeToEClass
                       ClassIdMap EClass
eClasses <- (EGraph -> ClassIdMap EClass)
-> StateT EGraph m (ClassIdMap EClass)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap EClass
_eClass
                       --createDB
                       --db       <- gets (_patDB . _eDB) -- createDB -- creates the DB

                       -- step 1: match the rules
                       let matchSch :: Int -> Rule -> Scheduler [Rule]
matchSch        = Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
it
                           matchAll :: [Rule] -> StateT (IntMap Int) Identity [[Rule]]
matchAll        = (Int -> Rule -> Scheduler [Rule])
-> [Int] -> [Rule] -> StateT (IntMap Int) Identity [[Rule]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> Rule -> Scheduler [Rule]
matchSch [Int
0..]
                           ([[Rule]]
rules, IntMap Int
sch') = StateT (IntMap Int) Identity [[Rule]]
-> IntMap Int -> ([[Rule]], IntMap Int)
forall s a. State s a -> s -> (a, s)
runState ([Rule] -> StateT (IntMap Int) Identity [[Rule]]
matchAll [Rule]
rules') IntMap Int
sch

                       -- step 2: apply matches and rebuild
                       [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
matches <- (Rule
 -> StateT
      EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> [Rule]
-> StateT
     EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Rule
rule -> ((Map ClassOrVar ClassOrVar, ClassOrVar)
 -> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)))
-> [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall a b. (a -> b) -> [a] -> [b]
map (Rule
rule,) ([(Map ClassOrVar ClassOrVar, ClassOrVar)]
 -> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> StateT
     EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
forall (m :: * -> *).
Monad m =>
Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match (Rule -> Pattern
source Rule
rule)) ([Rule]
 -> StateT
      EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]])
-> [Rule]
-> StateT
     EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
forall a b. (a -> b) -> a -> b
$ [[Rule]] -> [Rule]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Rule]]
rules
                       ((Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))
 -> StateT EGraph m ())
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
-> StateT EGraph m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Rule
 -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> StateT EGraph m ())
-> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))
-> StateT EGraph m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (CostFun
-> Rule
-> (Map ClassOrVar ClassOrVar, ClassOrVar)
-> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMatch CostFun
costFun)) ([(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
 -> StateT EGraph m ())
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
-> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
matches
                       CostFun -> StateT EGraph m ()
forall (m :: * -> *). Monad m => CostFun -> EGraphST m ()
rebuild CostFun
costFun

                       -- recalculate heights
                       --calculateHeights
                       Map ENode Int
eNodes'   <- (EGraph -> Map ENode Int) -> StateT EGraph m (Map ENode Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> Map ENode Int
_eNodeToEClass
                       ClassIdMap EClass
eClasses' <- (EGraph -> ClassIdMap EClass)
-> StateT EGraph m (ClassIdMap EClass)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap EClass
_eClass

                       -- if nothing changed, return
                       if Int
it Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
|| (Map ENode Int
eNodes' Map ENode Int -> Map ENode Int -> Bool
forall a. Eq a => a -> a -> Bool
== Map ENode Int
eNodes Bool -> Bool -> Bool
&& ClassIdMap EClass
eClasses' ClassIdMap EClass -> ClassIdMap EClass -> Bool
forall a. Eq a => a -> a -> Bool
== ClassIdMap EClass
eClasses)
                          then (Bool, Int) -> StateT EGraph m (Bool, Int)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
True, Int
it)
                          else if ClassIdMap EClass -> Int
forall a. IntMap a -> Int
IntMap.size ClassIdMap EClass
eClasses' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
500 -- maximum allowed number of e-classes. TODO: customize
                                 then (Bool, Int) -> StateT EGraph m (Bool, Int)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
False, Int
it)
                                 else Int -> IntMap Int -> StateT EGraph m (Bool, Int)
go (Int
itInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) IntMap Int
sch'

-- | apply a single step of merge-only equality saturation
applySingleMergeOnlyEqSat :: Monad m => CostFun -> [Rule] -> EGraphST m ()
applySingleMergeOnlyEqSat :: forall (m :: * -> *). Monad m => CostFun -> [Rule] -> EGraphST m ()
applySingleMergeOnlyEqSat CostFun
costFun [Rule]
rules =
  do DB
db <- (EGraph -> DB) -> StateT EGraph m DB
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB) -- createDB
     let matchSch :: Int -> Rule -> Scheduler [Rule]
matchSch        = Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
10
         matchAll :: [Rule] -> StateT (IntMap Int) Identity [[Rule]]
matchAll        = (Int -> Rule -> Scheduler [Rule])
-> [Int] -> [Rule] -> StateT (IntMap Int) Identity [[Rule]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> Rule -> Scheduler [Rule]
matchSch [Int
0..]
         ([[Rule]]
rules, IntMap Int
sch') = StateT (IntMap Int) Identity [[Rule]]
-> IntMap Int -> ([[Rule]], IntMap Int)
forall s a. State s a -> s -> (a, s)
runState ([Rule] -> StateT (IntMap Int) Identity [[Rule]]
matchAll [Rule]
rules') IntMap Int
forall a. IntMap a
IntMap.empty
     [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
matches <- (Rule
 -> StateT
      EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> [Rule]
-> StateT
     EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Rule
rule -> ((Map ClassOrVar ClassOrVar, ClassOrVar)
 -> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)))
-> [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall a b. (a -> b) -> [a] -> [b]
map (Rule
rule,) ([(Map ClassOrVar ClassOrVar, ClassOrVar)]
 -> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> StateT
     EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
forall (m :: * -> *).
Monad m =>
Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match (Rule -> Pattern
source Rule
rule)) ([Rule]
 -> StateT
      EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]])
-> [Rule]
-> StateT
     EGraph m [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
forall a b. (a -> b) -> a -> b
$ [[Rule]] -> [Rule]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Rule]]
rules
     ((Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)) -> EGraphST m ())
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
-> EGraphST m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ())
-> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)) -> EGraphST m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMergeOnlyMatch CostFun
costFun)) ([(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
 -> EGraphST m ())
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
-> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]]
matches
     CostFun -> EGraphST m ()
forall (m :: * -> *). Monad m => CostFun -> EGraphST m ()
rebuild CostFun
costFun
     -- recalculate heights
     --calculateHeights
      where
        rules' :: [Rule]
rules' = (Rule -> [Rule]) -> [Rule] -> [Rule]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Rule -> [Rule]
replaceEqRules [Rule]
rules

        -- replaces the equality rules with two one-way rules
        replaceEqRules :: Rule -> [Rule]
        replaceEqRules :: Rule -> [Rule]
replaceEqRules (Pattern
p1 :=> Pattern
p2)  = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2]
        replaceEqRules (Pattern
p1 :==: Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2, Pattern
p2 Pattern -> Pattern -> Rule
:=> Pattern
p1]
        replaceEqRules (Rule
r :| Condition
cond)  = (Rule -> Rule) -> [Rule] -> [Rule]
forall a b. (a -> b) -> [a] -> [b]
map (Rule -> Condition -> Rule
:| Condition
cond) ([Rule] -> [Rule]) -> [Rule] -> [Rule]
forall a b. (a -> b) -> a -> b
$ Rule -> [Rule]
replaceEqRules Rule
r

-- | matches the rules given a scheduler
matchWithScheduler :: Int -> Int -> Rule -> Scheduler [Rule] -- [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
matchWithScheduler :: Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
it Int
ruleNumber Rule
rule =
  do Maybe Int
mbBan <- (IntMap Int -> Maybe Int)
-> StateT (IntMap Int) Identity (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntMap Int -> Int -> Maybe Int
forall a. IntMap a -> Int -> Maybe a
IntMap.!? Int
ruleNumber)
     if Maybe Int
mbBan Maybe Int -> Maybe Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Maybe Int
forall a. Maybe a
Nothing Bool -> Bool -> Bool
&& Maybe Int -> Int
forall a. Maybe a -> a
fromJust Maybe Int
mbBan Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
it -- check if the rule is banned
        then [Rule] -> Scheduler [Rule]
forall a. a -> StateT (IntMap Int) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        else do -- let matches = match db (source rule)
                (IntMap Int -> IntMap Int) -> StateT (IntMap Int) Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
ruleNumber (Int
itInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
5))
                [Rule] -> Scheduler [Rule]
forall a. a -> StateT (IntMap Int) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Rule
rule] -- $ map (rule,) matches