{-# LANGUAGE TupleSections #-}
{-# LANGUAGE BangPatterns #-}
module Algorithm.EqSat.Build where
import System.Random (Random (randomR), StdGen)
import Control.Lens ( over )
import Control.Monad ( forM_, when, foldM, forM )
import Data.Maybe ( fromMaybe, catMaybes )
import Data.SRTree
import Algorithm.EqSat.Egraph
import Algorithm.EqSat.DB
import qualified Data.IntMap.Strict as IntMap
import Data.Map.Strict ( Map )
import qualified Data.Map.Strict as Map
import qualified Data.HashSet as Set
import Control.Monad.State.Strict
import Data.SRTree.Recursion (cataM)
import Algorithm.EqSat.Info
import qualified Data.IntSet as IntSet
import Data.Maybe
import Data.Sequence (Seq(..), (><))
import Debug.Trace (trace, traceShow)
add :: Monad m => CostFun -> ENode -> EGraphST m EClassId
add :: forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassId
add CostFun
costFun ENode
enode =
do ENode
enode'' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
enode
Consts
constEnode <- ENode -> EGraphST m Consts
forall (m :: * -> *). Monad m => ENode -> EGraphST m Consts
calculateConsts ENode
enode''
let enode' :: ENode
enode' = case Consts
constEnode of
ConstVal Double
x -> Double -> ENode
forall val. Double -> SRTree val
Const Double
x
ParamIx EClassId
x -> EClassId -> ENode
forall val. EClassId -> SRTree val
Param EClassId
x
Consts
_ -> ENode
enode''
Maybe EClassId
maybeEid <- (EGraph -> Maybe EClassId) -> StateT EGraph m (Maybe EClassId)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode EClassId -> ENode -> Maybe EClassId
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
enode') (Map ENode EClassId -> Maybe EClassId)
-> (EGraph -> Map ENode EClassId) -> EGraph -> Maybe EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode EClassId
_eNodeToEClass)
case Maybe EClassId
maybeEid of
Just EClassId
eid -> EClassId -> EGraphST m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
eid
Maybe EClassId
Nothing -> do
EClassId
curId <- (EGraph -> EClassId) -> EGraphST m EClassId
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> EClassId
_nextId (EGraphDB -> EClassId)
-> (EGraph -> EGraphDB) -> EGraph -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap EClassId) (ClassIdMap EClassId)
-> (ClassIdMap EClassId -> ClassIdMap EClassId) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap EClassId) (ClassIdMap EClassId)
Lens' EGraph (ClassIdMap EClassId)
canonicalMap (EClassId -> EClassId -> ClassIdMap EClassId -> ClassIdMap EClassId
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
curId EClassId
curId)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
-> (Map ENode EClassId -> Map ENode EClassId) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
Lens' EGraph (Map ENode EClassId)
eNodeToEClass (ENode -> EClassId -> Map ENode EClassId -> Map ENode EClassId
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ENode
enode' EClassId
curId)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph EClassId EClassId
-> (EClassId -> EClassId) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((EClassId -> Identity EClassId)
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph EClassId EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EClassId -> Identity EClassId) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB EClassId
nextId) (EClassId -> EClassId -> EClassId
forall a. Num a => a -> a -> a
+EClassId
1)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
-> (HashSet (EClassId, ENode) -> HashSet (EClassId, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (EClassId, ENode)
-> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (EClassId, ENode) -> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (EClassId, ENode))
worklist) ((EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
Set.insert (EClassId
curId, ENode
enode'))
[EClassId]
-> (EClassId -> StateT EGraph m ()) -> StateT EGraph m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ENode -> [EClassId]
forall a. SRTree a -> [a]
childrenOf ENode
enode') (EClassId -> ENode -> EClassId -> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
EClassId -> ENode -> EClassId -> EGraphST m ()
addParents EClassId
curId ENode
enode')
EClassData
info <- CostFun -> ENode -> EGraphST m EClassData
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassData
makeAnalysis CostFun
costFun ENode
enode'
EClassId
h <- ENode -> EGraphST m EClassId
forall (m :: * -> *). Monad m => ENode -> EGraphST m EClassId
getChildrenMinHeight ENode
enode'
let newClass :: EClass
newClass = EClassId -> ENode -> EClassData -> EClassId -> EClass
createEClass EClassId
curId ENode
enode' EClassData
info EClassId
h
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
-> (ClassIdMap EClass -> ClassIdMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
Lens' EGraph (ClassIdMap EClass)
eClass (EClassId -> EClass -> ClassIdMap EClass -> ClassIdMap EClass
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
curId EClass
newClass)
ENode -> EClassId -> StateT EGraph m ()
forall (m :: * -> *). Monad m => ENode -> EClassId -> EGraphST m ()
addToDB ENode
enode' EClassId
curId
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap IntSet) (IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap IntSet -> Identity (IntMap IntSet))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (IntMap IntSet) (IntMap IntSet)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap IntSet -> Identity (IntMap IntSet))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap IntSet)
sizeDB)
((IntMap IntSet -> IntMap IntSet) -> EGraph -> EGraph)
-> (IntMap IntSet -> IntMap IntSet) -> EGraph -> EGraph
forall a b. (a -> b) -> a -> b
$ (IntSet -> IntSet -> IntSet)
-> EClassId -> IntSet -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a -> a) -> EClassId -> a -> IntMap a -> IntMap a
IntMap.insertWith (IntSet -> IntSet -> IntSet
IntSet.union) (EClassData -> EClassId
_size EClassData
info) (EClassId -> IntSet
IntSet.singleton EClassId
curId)
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (EClassId -> IntSet -> IntSet
IntSet.insert EClassId
curId)
EClassId -> EGraphST m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
curId
where
addParents :: Monad m => EClassId -> ENode -> EClassId -> EGraphST m ()
addParents :: forall (m :: * -> *).
Monad m =>
EClassId -> ENode -> EClassId -> EGraphST m ()
addParents EClassId
cId ENode
node EClassId
c =
do EClass
ec <- EClassId -> EGraphST m EClass
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClass
getEClass EClassId
c
let ec' :: EClass
ec' = EClass
ec{ _parents = Set.insert (cId, node) (_parents ec) }
(EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
-> (ClassIdMap EClass -> ClassIdMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
Lens' EGraph (ClassIdMap EClass)
eClass (EClassId -> EClass -> ClassIdMap EClass -> ClassIdMap EClass
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
c EClass
ec')
rebuild :: Monad m => CostFun -> EGraphST m ()
rebuild :: forall (m :: * -> *). Monad m => CostFun -> EGraphST m ()
rebuild CostFun
costFun =
do HashSet (EClassId, ENode)
wl <- (EGraph -> HashSet (EClassId, ENode))
-> StateT EGraph m (HashSet (EClassId, ENode))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> HashSet (EClassId, ENode)
_worklist (EGraphDB -> HashSet (EClassId, ENode))
-> (EGraph -> EGraphDB) -> EGraph -> HashSet (EClassId, ENode)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
HashSet (EClassId, ENode)
al <- (EGraph -> HashSet (EClassId, ENode))
-> StateT EGraph m (HashSet (EClassId, ENode))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> HashSet (EClassId, ENode)
_analysis (EGraphDB -> HashSet (EClassId, ENode))
-> (EGraph -> EGraphDB) -> EGraph -> HashSet (EClassId, ENode)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
(EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
-> (HashSet (EClassId, ENode) -> HashSet (EClassId, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (EClassId, ENode)
-> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (EClassId, ENode) -> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (EClassId, ENode))
worklist) (HashSet (EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a b. a -> b -> a
const HashSet (EClassId, ENode)
forall a. HashSet a
Set.empty)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
-> (HashSet (EClassId, ENode) -> HashSet (EClassId, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (EClassId, ENode)
-> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (EClassId, ENode) -> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (EClassId, ENode))
analysis) (HashSet (EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a b. a -> b -> a
const HashSet (EClassId, ENode)
forall a. HashSet a
Set.empty)
HashSet (EClassId, ENode)
-> ((EClassId, ENode) -> EGraphST m ()) -> EGraphST m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ HashSet (EClassId, ENode)
wl ((EClassId -> ENode -> EGraphST m ())
-> (EClassId, ENode) -> EGraphST m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (CostFun -> EClassId -> ENode -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> ENode -> EGraphST m ()
repair CostFun
costFun))
HashSet (EClassId, ENode)
-> ((EClassId, ENode) -> EGraphST m ()) -> EGraphST m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ HashSet (EClassId, ENode)
al ((EClassId -> ENode -> EGraphST m ())
-> (EClassId, ENode) -> EGraphST m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (CostFun -> EClassId -> ENode -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> ENode -> EGraphST m ()
repairAnalysis CostFun
costFun))
repair :: Monad m => CostFun -> EClassId -> ENode -> EGraphST m ()
repair :: forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> ENode -> EGraphST m ()
repair CostFun
costFun EClassId
ecId ENode
enode =
do (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
-> (Map ENode EClassId -> Map ENode EClassId) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
Lens' EGraph (Map ENode EClassId)
eNodeToEClass (ENode -> Map ENode EClassId -> Map ENode EClassId
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete ENode
enode)
ENode
enode' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
enode
EClassId
ecId' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
ecId
Maybe EClassId
doExist <- (EGraph -> Maybe EClassId) -> StateT EGraph m (Maybe EClassId)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode EClassId -> ENode -> Maybe EClassId
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
enode') (Map ENode EClassId -> Maybe EClassId)
-> (EGraph -> Map ENode EClassId) -> EGraph -> Maybe EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode EClassId
_eNodeToEClass)
case Maybe EClassId
doExist of
Just EClassId
ecIdCanon -> do EClassId
mergedId <- CostFun -> EClassId -> EClassId -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge CostFun
costFun EClassId
ecIdCanon EClassId
ecId'
(EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
-> (Map ENode EClassId -> Map ENode EClassId) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
Lens' EGraph (Map ENode EClassId)
eNodeToEClass (ENode -> EClassId -> Map ENode EClassId -> Map ENode EClassId
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ENode
enode' EClassId
mergedId)
Maybe EClassId
Nothing -> (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
-> (Map ENode EClassId -> Map ENode EClassId) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
Lens' EGraph (Map ENode EClassId)
eNodeToEClass (ENode -> EClassId -> Map ENode EClassId -> Map ENode EClassId
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ENode
enode' EClassId
ecId')
repairAnalysis :: Monad m => CostFun -> EClassId -> ENode -> EGraphST m ()
repairAnalysis :: forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> ENode -> EGraphST m ()
repairAnalysis CostFun
costFun EClassId
ecId ENode
enode =
do EClassId
ecId' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
ecId
ENode
enode' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
enode
EClass
eclass <- EClassId -> EGraphST m EClass
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClass
getEClass EClassId
ecId'
EClassData
info <- CostFun -> ENode -> EGraphST m EClassData
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassData
makeAnalysis CostFun
costFun ENode
enode'
let newData :: EClassData
newData = EClassData -> EClassData -> EClassData
joinData (EClass -> EClassData
_info EClass
eclass) EClassData
info
eclass' :: EClass
eclass' = EClass
eclass { _info = newData }
Bool -> EGraphST m () -> EGraphST m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EClass -> EClassData
_info EClass
eclass EClassData -> EClassData -> Bool
forall a. Eq a => a -> a -> Bool
/= EClassData
newData) (EGraphST m () -> EGraphST m ()) -> EGraphST m () -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$
do (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
-> (HashSet (EClassId, ENode) -> HashSet (EClassId, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (EClassId, ENode)
-> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (EClassId, ENode) -> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (EClassId, ENode))
analysis) (EClass -> HashSet (EClassId, ENode)
_parents EClass
eclass HashSet (EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a. Semigroup a => a -> a -> a
<>)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
-> (ClassIdMap EClass -> ClassIdMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
Lens' EGraph (ClassIdMap EClass)
eClass (EClassId -> EClass -> ClassIdMap EClass -> ClassIdMap EClass
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
ecId' EClass
eclass')
EClassId
_ <- CostFun -> EClassId -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EGraphST m EClassId
modifyEClass CostFun
costFun EClassId
ecId'
() -> EGraphST m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
merge :: Monad m => CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge :: forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge CostFun
costFun EClassId
c1 EClassId
c2 =
do EClassId
c1' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
c1
EClassId
c2' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
c2
if EClassId
c1' EClassId -> EClassId -> Bool
forall a. Eq a => a -> a -> Bool
== EClassId
c2'
then EClassId -> EGraphST m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
c1'
else do (EClassId
led, EClass
ledC, EClassId
ledOrig, EClassId
sub, EClass
subC, EClassId
subOrig) <- EClassId
-> EClassId
-> EClassId
-> EClassId
-> StateT
EGraph m (EClassId, EClass, EClassId, EClassId, EClass, EClassId)
forall {m :: * -> *} {c}.
Monad m =>
EClassId
-> c
-> EClassId
-> c
-> StateT EGraph m (EClassId, EClass, c, EClassId, EClass, c)
getLeaderSub EClassId
c1' EClassId
c1 EClassId
c2' EClassId
c2
EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m EClassId
mergeClasses EClassId
led EClass
ledC EClassId
ledOrig EClassId
sub EClass
subC EClassId
subOrig
where
mergeClasses :: Monad m => EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m EClassId
mergeClasses :: forall (m :: * -> *).
Monad m =>
EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m EClassId
mergeClasses EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO =
do (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap EClassId) (ClassIdMap EClassId)
-> (ClassIdMap EClassId -> ClassIdMap EClassId) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap EClassId) (ClassIdMap EClassId)
Lens' EGraph (ClassIdMap EClassId)
canonicalMap (EClassId -> EClassId -> ClassIdMap EClassId -> ClassIdMap EClassId
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
sub EClassId
led)
let
newC :: EClass
newC = EClassId
-> HashSet ENodeEnc
-> HashSet (EClassId, ENode)
-> EClassId
-> EClassData
-> EClass
EClass EClassId
led
(EClass -> HashSet ENodeEnc
_eNodes EClass
ledC HashSet ENodeEnc -> HashSet ENodeEnc -> HashSet ENodeEnc
forall a. Eq a => HashSet a -> HashSet a -> HashSet a
`Set.union` EClass -> HashSet ENodeEnc
_eNodes EClass
subC)
(EClass -> HashSet (EClassId, ENode)
_parents EClass
ledC HashSet (EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a. Semigroup a => a -> a -> a
<> EClass -> HashSet (EClassId, ENode)
_parents EClass
subC)
(EClassId -> EClassId -> EClassId
forall a. Ord a => a -> a -> a
min (EClass -> EClassId
_height EClass
ledC) (EClass -> EClassId
_height EClass
subC))
(EClassData -> EClassData -> EClassData
joinData (EClass -> EClassData
_info EClass
ledC) (EClass -> EClassData
_info EClass
subC))
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
-> (ClassIdMap EClass -> ClassIdMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
Lens' EGraph (ClassIdMap EClass)
eClass (EClassId -> EClass -> ClassIdMap EClass -> ClassIdMap EClass
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
led EClass
newC (ClassIdMap EClass -> ClassIdMap EClass)
-> (ClassIdMap EClass -> ClassIdMap EClass)
-> ClassIdMap EClass
-> ClassIdMap EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> ClassIdMap EClass -> ClassIdMap EClass
forall a. EClassId -> IntMap a -> IntMap a
IntMap.delete EClassId
sub)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
-> (HashSet (EClassId, ENode) -> HashSet (EClassId, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (EClassId, ENode)
-> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (EClassId, ENode) -> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (EClassId, ENode))
worklist) (EClass -> HashSet (EClassId, ENode)
_parents EClass
subC HashSet (EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a. Semigroup a => a -> a -> a
<>)
Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EClass -> EClassData
_info EClass
newC EClassData -> EClassData -> Bool
forall a. Eq a => a -> a -> Bool
/= EClass -> EClassData
_info EClass
ledC)
(StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
-> (HashSet (EClassId, ENode) -> HashSet (EClassId, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (EClassId, ENode)
-> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (EClassId, ENode) -> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (EClassId, ENode))
analysis) (EClass -> HashSet (EClassId, ENode)
_parents EClass
ledC HashSet (EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a. Semigroup a => a -> a -> a
<>)
Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EClass -> EClassData
_info EClass
newC EClassData -> EClassData -> Bool
forall a. Eq a => a -> a -> Bool
/= EClass -> EClassData
_info EClass
subC)
(StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
-> (HashSet (EClassId, ENode) -> HashSet (EClassId, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (EClassId, ENode)
-> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (EClassId, ENode) -> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (EClassId, ENode))
analysis) (EClass -> HashSet (EClassId, ENode)
_parents EClass
subC HashSet (EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a. Semigroup a => a -> a -> a
<>)
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateDBs EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO
CostFun -> EClassId -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EGraphST m EClassId
modifyEClass CostFun
costFun EClassId
led
EClassId -> EGraphST m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
led
getLeaderSub :: EClassId
-> c
-> EClassId
-> c
-> StateT EGraph m (EClassId, EClass, c, EClassId, EClass, c)
getLeaderSub EClassId
c1 c
c1O EClassId
c2 c
c2O =
do EClass
ec1 <- EClassId -> EGraphST m EClass
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClass
getEClass EClassId
c1
EClass
ec2 <- EClassId -> EGraphST m EClass
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClass
getEClass EClassId
c2
let n1 :: EClassId
n1 = HashSet (EClassId, ENode) -> EClassId
forall a. HashSet a -> EClassId
forall (t :: * -> *) a. Foldable t => t a -> EClassId
length (EClass -> HashSet (EClassId, ENode)
_parents EClass
ec1)
n2 :: EClassId
n2 = HashSet (EClassId, ENode) -> EClassId
forall a. HashSet a -> EClassId
forall (t :: * -> *) a. Foldable t => t a -> EClassId
length (EClass -> HashSet (EClassId, ENode)
_parents EClass
ec2)
(EClassId, EClass, c, EClassId, EClass, c)
-> StateT EGraph m (EClassId, EClass, c, EClassId, EClass, c)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((EClassId, EClass, c, EClassId, EClass, c)
-> StateT EGraph m (EClassId, EClass, c, EClassId, EClass, c))
-> (EClassId, EClass, c, EClassId, EClass, c)
-> StateT EGraph m (EClassId, EClass, c, EClassId, EClass, c)
forall a b. (a -> b) -> a -> b
$ if EClassId
n1 EClassId -> EClassId -> Bool
forall a. Ord a => a -> a -> Bool
>= EClassId
n2
then (EClassId
c1, EClass
ec1, c
c1O, EClassId
c2, EClass
ec2, c
c2O)
else (EClassId
c2, EClass
ec2, c
c2O, EClassId
c1, EClass
ec1, c
c1O)
updateDBs :: Monad m => EClass -> EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m ()
updateDBs :: forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateDBs EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO = do
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateFitnessDB EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateSizeDB EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO
updateSizeDB :: Monad m => EClass -> EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m ()
updateSizeDB :: forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateSizeDB EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO = do
let sz :: EClassId
sz = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
newC
szL :: EClassId
szL = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
ledC
szS :: EClassId
szS = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
subC
fun :: IntMap IntSet -> IntMap IntSet
fun = (IntSet -> IntSet) -> EClassId -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> IntSet -> IntSet
IntSet.insert EClassId
led) EClassId
sz (IntMap IntSet -> IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet)
-> IntMap IntSet
-> IntMap IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> IntSet) -> EClassId -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> IntSet -> IntSet
IntSet.delete EClassId
led (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
ledO) EClassId
szL (IntMap IntSet -> IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet)
-> IntMap IntSet
-> IntMap IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> IntSet) -> EClassId -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> IntSet -> IntSet
IntSet.delete EClassId
sub (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
subO) EClassId
szS
(EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap IntSet) (IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap IntSet -> Identity (IntMap IntSet))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (IntMap IntSet) (IntMap IntSet)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap IntSet -> Identity (IntMap IntSet))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap IntSet)
sizeDB) IntMap IntSet -> IntMap IntSet
fun
updateFitnessDB :: Monad m => EClass -> EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m ()
updateFitnessDB :: forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateFitnessDB EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO =
if (Maybe Double -> Bool
forall a. Maybe a -> Bool
isJust Maybe Double
fitNew)
then do
Bool -> EGraphST m () -> EGraphST m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe Double
fitNew Maybe Double -> Maybe Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Maybe Double
fitLed) (EGraphST m () -> EGraphST m ()) -> EGraphST m () -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ do
if Maybe Double -> Bool
forall a. Maybe a -> Bool
isNothing Maybe Double
fitLed
then (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (EClassId -> IntSet -> IntSet
IntSet.delete EClassId
led (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
ledO)
else (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
ledO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed))
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
-> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeFitDB) ((RangeTree Double -> RangeTree Double)
-> EClassId
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
ledO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed)) EClassId
szLed)
(EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
insertRange EClassId
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitNew))
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
-> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeFitDB) ((RangeTree Double -> RangeTree Double)
-> EClassId
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
insertRange EClassId
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitNew)) EClassId
szNew (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> RangeTree Double -> RangeTree Double)
-> EClassId
-> RangeTree Double
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall a. (a -> a -> a) -> EClassId -> a -> IntMap a -> IntMap a
IntMap.insertWith RangeTree Double -> RangeTree Double -> RangeTree Double
forall a. Seq a -> Seq a -> Seq a
(><) EClassId
szNew RangeTree Double
forall a. Seq a
Empty)
if Maybe Double -> Bool
forall a. Maybe a -> Bool
isNothing Maybe Double
fitSub
then (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (EClassId -> IntSet -> IntSet
IntSet.delete EClassId
sub (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
subO)
else (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
sub (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
subO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub))
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
-> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeFitDB) ((RangeTree Double -> RangeTree Double)
-> EClassId
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
subO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
sub (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub)) EClassId
szSub)
else (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (EClassId -> IntSet -> IntSet
IntSet.insert EClassId
led (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
ledO (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
sub (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
subO)
where
fitNew :: Maybe Double
fitNew = (EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
newC
fitLed :: Maybe Double
fitLed = (EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
ledC
fitSub :: Maybe Double
fitSub = (EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
subC
szNew :: EClassId
szNew = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
newC
szLed :: EClassId
szLed = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
ledC
szSub :: EClassId
szSub = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
subC
modifyEClass :: Monad m => CostFun -> EClassId -> EGraphST m EClassId
modifyEClass :: forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EGraphST m EClassId
modifyEClass CostFun
costFun EClassId
ecId =
do EClass
ec <- EClassId -> EGraphST m EClass
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClass
getEClass EClassId
ecId
case (EClassData -> Consts
_consts (EClassData -> Consts)
-> (EClass -> EClassData) -> EClass -> Consts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
ec of
ConstVal Double
x -> do
let en :: SRTree val
en = Double -> SRTree val
forall val. Double -> SRTree val
Const Double
x
EClassId
c <- CostFun -> ENode -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassId
calculateCost CostFun
costFun ENode
forall {val}. SRTree val
en
let infoEc :: EClassData
infoEc = (EClass -> EClassData
_info EClass
ec){ _cost = c, _best = en, _consts = toConst en }
Maybe EClassId
maybeEid <- (EGraph -> Maybe EClassId) -> StateT EGraph m (Maybe EClassId)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode EClassId -> ENode -> Maybe EClassId
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
forall {val}. SRTree val
en) (Map ENode EClassId -> Maybe EClassId)
-> (EGraph -> Map ENode EClassId) -> EGraph -> Maybe EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode EClassId
_eNodeToEClass)
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
-> (ClassIdMap EClass -> ClassIdMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
Lens' EGraph (ClassIdMap EClass)
eClass (EClassId -> EClass -> ClassIdMap EClass -> ClassIdMap EClass
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
ecId EClass
ec{_eNodes = Set.singleton (encodeEnode en) , _info = infoEc})
case Maybe EClassId
maybeEid of
Maybe EClassId
Nothing -> EClassId -> EGraphST m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
ecId
Just EClassId
eid' -> CostFun -> EClassId -> EClassId -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge CostFun
costFun EClassId
eid' EClassId
ecId
Consts
_ -> EClassId -> EGraphST m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
ecId
where
isTerm :: SRTree val -> Bool
isTerm (Var EClassId
_) = Bool
True
isTerm (Const Double
_) = Bool
True
isTerm (Param EClassId
_) = Bool
True
isTerm SRTree val
_ = Bool
False
toConst :: SRTree val -> Consts
toConst (Param EClassId
ix) = EClassId -> Consts
ParamIx EClassId
ix
toConst (Const Double
x) = Double -> Consts
ConstVal Double
x
toConst SRTree val
_ = Consts
NotConst
createDB :: Monad m => EGraphST m DB
createDB :: forall (m :: * -> *). Monad m => EGraphST m DB
createDB = do (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph DB DB -> (DB -> DB) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((DB -> Identity DB) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph DB DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DB -> Identity DB) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB DB
patDB) (DB -> DB -> DB
forall a b. a -> b -> a
const DB
forall k a. Map k a
Map.empty)
[(ENode, EClassId)]
ecls <- (EGraph -> [(ENode, EClassId)])
-> StateT EGraph m [(ENode, EClassId)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Map ENode EClassId -> [(ENode, EClassId)]
forall k a. Map k a -> [(k, a)]
Map.toList (Map ENode EClassId -> [(ENode, EClassId)])
-> (EGraph -> Map ENode EClassId) -> EGraph -> [(ENode, EClassId)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode EClassId
_eNodeToEClass)
((ENode, EClassId) -> StateT EGraph m ())
-> [(ENode, EClassId)] -> StateT EGraph m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((ENode -> EClassId -> StateT EGraph m ())
-> (ENode, EClassId) -> StateT EGraph m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ENode -> EClassId -> StateT EGraph m ()
forall (m :: * -> *). Monad m => ENode -> EClassId -> EGraphST m ()
addToDB) [(ENode, EClassId)]
ecls
(EGraph -> DB) -> EGraphST m DB
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
addToDB :: Monad m => ENode -> EClassId -> EGraphST m ()
addToDB :: forall (m :: * -> *). Monad m => ENode -> EClassId -> EGraphST m ()
addToDB ENode
enode EClassId
eid = do
let ids :: [EClassId]
ids = EClassId
eid EClassId -> [EClassId] -> [EClassId]
forall a. a -> [a] -> [a]
: ENode -> [EClassId]
forall a. SRTree a -> [a]
childrenOf ENode
enode
op :: SRTree ()
op = ENode -> SRTree ()
forall a. SRTree a -> SRTree ()
getOperator ENode
enode
Maybe IntTrie
trie <- (EGraph -> Maybe IntTrie) -> StateT EGraph m (Maybe IntTrie)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((DB -> SRTree () -> Maybe IntTrie
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? SRTree ()
op) (DB -> Maybe IntTrie) -> (EGraph -> DB) -> EGraph -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
case Maybe IntTrie -> [EClassId] -> Maybe IntTrie
populate Maybe IntTrie
trie [EClassId]
ids of
Maybe IntTrie
Nothing -> () -> EGraphST m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Just IntTrie
t -> (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph DB DB -> (DB -> DB) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((DB -> Identity DB) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph DB DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DB -> Identity DB) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB DB
patDB) (SRTree () -> IntTrie -> DB -> DB
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert SRTree ()
op IntTrie
t)
populate :: Maybe IntTrie -> [EClassId] -> Maybe IntTrie
populate :: Maybe IntTrie -> [EClassId] -> Maybe IntTrie
populate Maybe IntTrie
_ [] = Maybe IntTrie
forall a. Maybe a
Nothing
populate Maybe IntTrie
Nothing [EClassId]
eids = (EClassId -> Maybe IntTrie -> Maybe IntTrie)
-> Maybe IntTrie -> [EClassId] -> Maybe IntTrie
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr EClassId -> Maybe IntTrie -> Maybe IntTrie
f Maybe IntTrie
forall a. Maybe a
Nothing [EClassId]
eids
where
f :: EClassId -> Maybe IntTrie -> Maybe IntTrie
f :: EClassId -> Maybe IntTrie -> Maybe IntTrie
f EClassId
eid (Just IntTrie
t) = IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie) -> IntTrie -> Maybe IntTrie
forall a b. (a -> b) -> a -> b
$ EClassId -> IntMap IntTrie -> IntTrie
trie EClassId
eid (EClassId -> IntTrie -> IntMap IntTrie
forall a. EClassId -> a -> IntMap a
IntMap.singleton EClassId
eid IntTrie
t)
f EClassId
eid Maybe IntTrie
Nothing = IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie) -> IntTrie -> Maybe IntTrie
forall a b. (a -> b) -> a -> b
$ EClassId -> IntMap IntTrie -> IntTrie
trie EClassId
eid IntMap IntTrie
forall a. IntMap a
IntMap.empty
populate (Just IntTrie
tId) (EClassId
eid:[EClassId]
eids) = let keys :: HashSet EClassId
keys = EClassId -> HashSet EClassId -> HashSet EClassId
forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
Set.insert EClassId
eid (IntTrie -> HashSet EClassId
_keys IntTrie
tId)
nextTrie :: Maybe IntTrie
nextTrie = IntTrie -> IntMap IntTrie
_trie IntTrie
tId IntMap IntTrie -> EClassId -> Maybe IntTrie
forall a. IntMap a -> EClassId -> Maybe a
IntMap.!? EClassId
eid
val :: IntTrie
val = IntTrie -> Maybe IntTrie -> IntTrie
forall a. a -> Maybe a -> a
fromMaybe (EClassId -> IntMap IntTrie -> IntTrie
trie EClassId
eid IntMap IntTrie
forall a. IntMap a
IntMap.empty) (Maybe IntTrie -> IntTrie) -> Maybe IntTrie -> IntTrie
forall a b. (a -> b) -> a -> b
$ Maybe IntTrie -> [EClassId] -> Maybe IntTrie
populate Maybe IntTrie
nextTrie [EClassId]
eids
in IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie) -> IntTrie -> Maybe IntTrie
forall a b. (a -> b) -> a -> b
$ HashSet EClassId -> IntMap IntTrie -> IntTrie
IntTrie HashSet EClassId
keys (EClassId -> IntTrie -> IntMap IntTrie -> IntMap IntTrie
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
eid IntTrie
val (IntTrie -> IntMap IntTrie
_trie IntTrie
tId))
canonizeMap :: Monad m => (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap :: forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap (Map ClassOrVar ClassOrVar
subst, ClassOrVar
cv) = (,ClassOrVar
cv) (Map ClassOrVar ClassOrVar
-> (Map ClassOrVar ClassOrVar, ClassOrVar))
-> StateT EGraph m (Map ClassOrVar ClassOrVar)
-> StateT EGraph m (Map ClassOrVar ClassOrVar, ClassOrVar)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ClassOrVar -> StateT EGraph m ClassOrVar)
-> Map ClassOrVar ClassOrVar
-> StateT EGraph m (Map ClassOrVar ClassOrVar)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Map ClassOrVar a -> f (Map ClassOrVar b)
traverse ClassOrVar -> StateT EGraph m ClassOrVar
forall (m :: * -> *).
Monad m =>
ClassOrVar -> EGraphST m ClassOrVar
g Map ClassOrVar ClassOrVar
subst
where
g :: Monad m => ClassOrVar -> EGraphST m ClassOrVar
g :: forall (m :: * -> *).
Monad m =>
ClassOrVar -> EGraphST m ClassOrVar
g (Left EClassId
e2) = EClassId -> ClassOrVar
forall a b. a -> Either a b
Left (EClassId -> ClassOrVar)
-> StateT EGraph m EClassId -> StateT EGraph m ClassOrVar
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> StateT EGraph m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
e2
g ClassOrVar
e2 = ClassOrVar -> StateT EGraph m ClassOrVar
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ClassOrVar
e2
f :: Monad m => (ClassOrVar, ClassOrVar) -> EGraphST m (ClassOrVar, ClassOrVar)
f :: forall (m :: * -> *).
Monad m =>
(ClassOrVar, ClassOrVar) -> EGraphST m (ClassOrVar, ClassOrVar)
f (ClassOrVar
e1, Left EClassId
e2) = do EClassId
e2' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
e2
(ClassOrVar, ClassOrVar) -> EGraphST m (ClassOrVar, ClassOrVar)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ClassOrVar
e1, EClassId -> ClassOrVar
forall a b. a -> Either a b
Left EClassId
e2')
f (ClassOrVar
e1, ClassOrVar
e2) = (ClassOrVar, ClassOrVar) -> EGraphST m (ClassOrVar, ClassOrVar)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ClassOrVar
e1, ClassOrVar
e2)
applyMatch :: Monad m => CostFun -> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMatch :: forall (m :: * -> *).
Monad m =>
CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMatch CostFun
costFun Rule
rule (Map ClassOrVar ClassOrVar, ClassOrVar)
match' =
do let conds :: [Condition]
conds = Rule -> [Condition]
getConditions Rule
rule
(Map ClassOrVar ClassOrVar, ClassOrVar)
match <- (Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap (Map ClassOrVar ClassOrVar, ClassOrVar)
match'
Bool
validHeight <- (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidHeight (Map ClassOrVar ClassOrVar, ClassOrVar)
match
[Bool]
validConds <- (Condition -> EGraphST m Bool)
-> [Condition] -> StateT EGraph m [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Condition
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
forall (m :: * -> *).
Monad m =>
Condition
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
`isValidConditions` (Map ClassOrVar ClassOrVar, ClassOrVar)
match) [Condition]
conds
Bool -> EGraphST m () -> EGraphST m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
validHeight Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
validConds) (EGraphST m () -> EGraphST m ()) -> EGraphST m () -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$
do EClassId
new_eclass <- CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m EClassId
reprPrat CostFun
costFun ((Map ClassOrVar ClassOrVar, ClassOrVar)
-> Map ClassOrVar ClassOrVar
forall a b. (a, b) -> a
fst (Map ClassOrVar ClassOrVar, ClassOrVar)
match) (Rule -> Pattern
target Rule
rule)
CostFun -> EClassId -> EClassId -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge CostFun
costFun (ClassOrVar -> EClassId
getInt ((Map ClassOrVar ClassOrVar, ClassOrVar) -> ClassOrVar
forall a b. (a, b) -> b
snd (Map ClassOrVar ClassOrVar, ClassOrVar)
match)) EClassId
new_eclass
() -> EGraphST m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
applyMergeOnlyMatch :: Monad m => CostFun -> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMergeOnlyMatch :: forall (m :: * -> *).
Monad m =>
CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMergeOnlyMatch CostFun
costFun Rule
rule (Map ClassOrVar ClassOrVar, ClassOrVar)
match' =
do let conds :: [Condition]
conds = Rule -> [Condition]
getConditions Rule
rule
(Map ClassOrVar ClassOrVar, ClassOrVar)
match <- (Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap (Map ClassOrVar ClassOrVar, ClassOrVar)
match'
Bool
validHeight <- (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidHeight (Map ClassOrVar ClassOrVar, ClassOrVar)
match
[Bool]
validConds <- (Condition -> EGraphST m Bool)
-> [Condition] -> StateT EGraph m [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Condition
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
forall (m :: * -> *).
Monad m =>
Condition
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
`isValidConditions` (Map ClassOrVar ClassOrVar, ClassOrVar)
match) [Condition]
conds
Bool -> EGraphST m () -> EGraphST m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
validHeight Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
validConds) (EGraphST m () -> EGraphST m ()) -> EGraphST m () -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$
do Maybe EClassId
maybe_eid <- CostFun
-> Map ClassOrVar ClassOrVar
-> Pattern
-> EGraphST m (Maybe EClassId)
forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar
-> Pattern
-> EGraphST m (Maybe EClassId)
classOfENode CostFun
costFun ((Map ClassOrVar ClassOrVar, ClassOrVar)
-> Map ClassOrVar ClassOrVar
forall a b. (a, b) -> a
fst (Map ClassOrVar ClassOrVar, ClassOrVar)
match) (Rule -> Pattern
target Rule
rule)
case Maybe EClassId
maybe_eid of
Maybe EClassId
Nothing -> () -> EGraphST m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Just EClassId
eid -> do CostFun -> EClassId -> EClassId -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge CostFun
costFun (ClassOrVar -> EClassId
getInt ((Map ClassOrVar ClassOrVar, ClassOrVar) -> ClassOrVar
forall a b. (a, b) -> b
snd (Map ClassOrVar ClassOrVar, ClassOrVar)
match)) EClassId
eid
() -> EGraphST m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
classOfENode :: Monad m => CostFun -> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m (Maybe EClassId)
classOfENode :: forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar
-> Pattern
-> EGraphST m (Maybe EClassId)
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst (VarPat Char
c) = do let maybeEid :: Maybe EClassId
maybeEid = ClassOrVar -> EClassId
getInt (ClassOrVar -> EClassId) -> Maybe ClassOrVar -> Maybe EClassId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map ClassOrVar ClassOrVar
subst Map ClassOrVar ClassOrVar -> ClassOrVar -> Maybe ClassOrVar
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? EClassId -> ClassOrVar
forall a b. b -> Either a b
Right (Char -> EClassId
forall a. Enum a => a -> EClassId
fromEnum Char
c)
case Maybe EClassId
maybeEid of
Maybe EClassId
Nothing -> Maybe EClassId -> EGraphST m (Maybe EClassId)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe EClassId
forall a. Maybe a
Nothing
Just EClassId
eid -> EClassId -> Maybe EClassId
forall a. a -> Maybe a
Just (EClassId -> Maybe EClassId)
-> StateT EGraph m EClassId -> EGraphST m (Maybe EClassId)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> StateT EGraph m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
eid
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst (Fixed (Const Double
x)) = EClassId -> Maybe EClassId
forall a. a -> Maybe a
Just (EClassId -> Maybe EClassId)
-> StateT EGraph m EClassId -> EGraphST m (Maybe EClassId)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CostFun -> ENode -> StateT EGraph m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassId
add CostFun
costFun (Double -> ENode
forall val. Double -> SRTree val
Const Double
x)
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst (Fixed SRTree Pattern
target) = do [Maybe EClassId]
newChildren <- (Pattern -> EGraphST m (Maybe EClassId))
-> [Pattern] -> StateT EGraph m [Maybe EClassId]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (CostFun
-> Map ClassOrVar ClassOrVar
-> Pattern
-> EGraphST m (Maybe EClassId)
forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar
-> Pattern
-> EGraphST m (Maybe EClassId)
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst) (SRTree Pattern -> [Pattern]
forall a. SRTree a -> [a]
getElems SRTree Pattern
target)
case [Maybe EClassId] -> Maybe [EClassId]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Maybe EClassId]
newChildren of
Maybe [EClassId]
Nothing -> Maybe EClassId -> EGraphST m (Maybe EClassId)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe EClassId
forall a. Maybe a
Nothing
Just [EClassId]
cs -> do let new_enode :: ENode
new_enode = [EClassId] -> SRTree Pattern -> ENode
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [EClassId]
cs SRTree Pattern
target
[EClassId]
cs' <- (EClassId -> StateT EGraph m EClassId)
-> [EClassId] -> StateT EGraph m [EClassId]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM EClassId -> StateT EGraph m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical [EClassId]
cs
[Bool]
areConsts <- (EClassId -> StateT EGraph m Bool)
-> [EClassId] -> StateT EGraph m [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM EClassId -> StateT EGraph m Bool
forall (m :: * -> *). Monad m => EClassId -> EGraphST m Bool
isConst [EClassId]
cs'
if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
areConsts
then do EClassId
eid <- CostFun -> ENode -> StateT EGraph m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassId
add CostFun
costFun ENode
new_enode
CostFun -> EGraphST m ()
forall (m :: * -> *). Monad m => CostFun -> EGraphST m ()
rebuild CostFun
costFun
Maybe EClassId -> EGraphST m (Maybe EClassId)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (EClassId -> Maybe EClassId
forall a. a -> Maybe a
Just EClassId
eid)
else (EGraph -> Maybe EClassId) -> EGraphST m (Maybe EClassId)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map ENode EClassId -> ENode -> Maybe EClassId
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? ENode
new_enode) (Map ENode EClassId -> Maybe EClassId)
-> (EGraph -> Map ENode EClassId) -> EGraph -> Maybe EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode EClassId
_eNodeToEClass)
reprPrat :: Monad m => CostFun -> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m EClassId
reprPrat :: forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m EClassId
reprPrat CostFun
costFun Map ClassOrVar ClassOrVar
subst (VarPat Char
c) = EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical (EClassId -> EGraphST m EClassId)
-> EClassId -> EGraphST m EClassId
forall a b. (a -> b) -> a -> b
$ ClassOrVar -> EClassId
getInt (ClassOrVar -> EClassId) -> ClassOrVar -> EClassId
forall a b. (a -> b) -> a -> b
$ Map ClassOrVar ClassOrVar
subst Map ClassOrVar ClassOrVar -> ClassOrVar -> ClassOrVar
forall k a. Ord k => Map k a -> k -> a
Map.! EClassId -> ClassOrVar
forall a b. b -> Either a b
Right (Char -> EClassId
forall a. Enum a => a -> EClassId
fromEnum Char
c)
reprPrat CostFun
costFun Map ClassOrVar ClassOrVar
subst (Fixed SRTree Pattern
target) = do [EClassId]
newChildren <- (Pattern -> EGraphST m EClassId)
-> [Pattern] -> StateT EGraph m [EClassId]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m EClassId
reprPrat CostFun
costFun Map ClassOrVar ClassOrVar
subst) (SRTree Pattern -> [Pattern]
forall a. SRTree a -> [a]
getElems SRTree Pattern
target)
CostFun -> ENode -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassId
add CostFun
costFun ([EClassId] -> SRTree Pattern -> ENode
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [EClassId]
newChildren SRTree Pattern
target)
isValidHeight :: Monad m => (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidHeight :: forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidHeight (Map ClassOrVar ClassOrVar, ClassOrVar)
match = do
EClassId
h <- case (Map ClassOrVar ClassOrVar, ClassOrVar) -> ClassOrVar
forall a b. (a, b) -> b
snd (Map ClassOrVar ClassOrVar, ClassOrVar)
match of
Left EClassId
ec -> do EClassId
ec' <- EClassId -> StateT EGraph m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
ec
(EGraph -> EClassId) -> StateT EGraph m EClassId
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClass -> EClassId
_height (EClass -> EClassId) -> (EGraph -> EClass) -> EGraph -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClassIdMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
ec') (ClassIdMap EClass -> EClass)
-> (EGraph -> ClassIdMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> ClassIdMap EClass
_eClass)
Right EClassId
_ -> EClassId -> StateT EGraph m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
0
Bool -> EGraphST m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> EGraphST m Bool) -> Bool -> EGraphST m Bool
forall a b. (a -> b) -> a -> b
$ EClassId
h EClassId -> EClassId -> Bool
forall a. Ord a => a -> a -> Bool
< EClassId
15
isValidConditions :: Monad m => Condition -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidConditions :: forall (m :: * -> *).
Monad m =>
Condition
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidConditions Condition
cond (Map ClassOrVar ClassOrVar, ClassOrVar)
match = (EGraph -> Bool) -> StateT EGraph m Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((EGraph -> Bool) -> StateT EGraph m Bool)
-> (EGraph -> Bool) -> StateT EGraph m Bool
forall a b. (a -> b) -> a -> b
$ Condition
cond ((Map ClassOrVar ClassOrVar, ClassOrVar)
-> Map ClassOrVar ClassOrVar
forall a b. (a, b) -> a
fst (Map ClassOrVar ClassOrVar, ClassOrVar)
match)
fromTree :: Monad m => CostFun -> Fix SRTree -> EGraphST m EClassId
fromTree :: forall (m :: * -> *).
Monad m =>
CostFun -> Fix SRTree -> EGraphST m EClassId
fromTree CostFun
costFun = (forall x.
SRTree (StateT EGraph m x) -> StateT EGraph m (SRTree x))
-> (ENode -> StateT EGraph m EClassId)
-> Fix SRTree
-> StateT EGraph m EClassId
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(forall x. f (m x) -> m (f x)) -> (f a -> m a) -> Fix f -> m a
cataM SRTree (StateT EGraph m x) -> StateT EGraph m (SRTree x)
forall x. SRTree (StateT EGraph m x) -> StateT EGraph m (SRTree x)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => SRTree (m a) -> m (SRTree a)
sequence (CostFun -> ENode -> StateT EGraph m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassId
add CostFun
costFun)
fromTrees :: Monad m => CostFun -> [Fix SRTree] -> EGraphST m [EClassId]
fromTrees :: forall (m :: * -> *).
Monad m =>
CostFun -> [Fix SRTree] -> EGraphST m [EClassId]
fromTrees CostFun
costFun = ([EClassId] -> Fix SRTree -> StateT EGraph m [EClassId])
-> [EClassId] -> [Fix SRTree] -> StateT EGraph m [EClassId]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\[EClassId]
rs Fix SRTree
t -> do EClassId
eid <- CostFun -> Fix SRTree -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> Fix SRTree -> EGraphST m EClassId
fromTree CostFun
costFun Fix SRTree
t; [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (EClassId
eidEClassId -> [EClassId] -> [EClassId]
forall a. a -> [a] -> [a]
:[EClassId]
rs)) []
getBest :: Monad m => EClassId -> EGraphST m (Fix SRTree)
getBest :: forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m (Fix SRTree)
getBest EClassId
eid = do EClassId
eid' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
eid
ENode
best <- (EGraph -> ENode) -> StateT EGraph m ENode
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> ENode
_best (EClassData -> ENode) -> (EGraph -> EClassData) -> EGraph -> ENode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClassIdMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
eid') (ClassIdMap EClass -> EClass)
-> (EGraph -> ClassIdMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> ClassIdMap EClass
_eClass)
[Fix SRTree]
childs <- (EClassId -> EGraphST m (Fix SRTree))
-> [EClassId] -> StateT EGraph m [Fix SRTree]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM EClassId -> EGraphST m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m (Fix SRTree)
getBest ([EClassId] -> StateT EGraph m [Fix SRTree])
-> [EClassId] -> StateT EGraph m [Fix SRTree]
forall a b. (a -> b) -> a -> b
$ ENode -> [EClassId]
forall a. SRTree a -> [a]
childrenOf ENode
best
Fix SRTree -> EGraphST m (Fix SRTree)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> EGraphST m (Fix SRTree))
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> EGraphST m (Fix SRTree)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> EGraphST m (Fix SRTree))
-> SRTree (Fix SRTree) -> EGraphST m (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ [Fix SRTree] -> ENode -> SRTree (Fix SRTree)
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Fix SRTree]
childs ENode
best
getExpressionFrom :: Monad m => EClassId -> EGraphST m (Fix SRTree)
getExpressionFrom :: forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m (Fix SRTree)
getExpressionFrom EClassId
eId' = do
EClassId
eId <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
eId'
HashSet ENode
nodes <- (EGraph -> HashSet ENode) -> StateT EGraph m (HashSet ENode)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENodeEnc -> ENode) -> HashSet ENodeEnc -> HashSet ENode
forall b a.
(Hashable b, Eq b) =>
(a -> b) -> HashSet a -> HashSet b
Set.map ENodeEnc -> ENode
decodeEnode (HashSet ENodeEnc -> HashSet ENode)
-> (EGraph -> HashSet ENodeEnc) -> EGraph -> HashSet ENode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> HashSet ENodeEnc
_eNodes (EClass -> HashSet ENodeEnc)
-> (EGraph -> EClass) -> EGraph -> HashSet ENodeEnc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClassIdMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
eId) (ClassIdMap EClass -> EClass)
-> (EGraph -> ClassIdMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> ClassIdMap EClass
_eClass)
let hasTerm :: Bool
hasTerm = (ENode -> Bool) -> HashSet ENode -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ENode -> Bool
forall {val}. SRTree val -> Bool
isTerm HashSet ENode
nodes
cands :: [ENode]
cands = if Bool
hasTerm then (ENode -> Bool) -> [ENode] -> [ENode]
forall a. (a -> Bool) -> [a] -> [a]
filter ENode -> Bool
forall {val}. SRTree val -> Bool
isTerm (HashSet ENode -> [ENode]
forall a. HashSet a -> [a]
Set.toList HashSet ENode
nodes) else HashSet ENode -> [ENode]
forall a. HashSet a -> [a]
Set.toList HashSet ENode
nodes
SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> StateT EGraph m (SRTree (Fix SRTree)) -> EGraphST m (Fix SRTree)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case [ENode] -> ENode
forall a. HasCallStack => [a] -> a
head ([ENode] -> ENode) -> [ENode] -> ENode
forall a b. (a -> b) -> a -> b
$ HashSet ENode -> [ENode]
forall a. HashSet a -> [a]
Set.toList HashSet ENode
nodes of
Bin Op
op EClassId
l EClassId
r -> Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
op (Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST m (Fix SRTree)
-> StateT EGraph m (Fix SRTree -> SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> EGraphST m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m (Fix SRTree)
getExpressionFrom EClassId
l StateT EGraph m (Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST m (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a b.
StateT EGraph m (a -> b) -> StateT EGraph m a -> StateT EGraph m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> EClassId -> EGraphST m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m (Fix SRTree)
getExpressionFrom EClassId
r
Uni Function
f EClassId
t -> Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f (Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST m (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> EGraphST m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m (Fix SRTree)
getExpressionFrom EClassId
t
Var EClassId
ix -> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Var EClassId
ix
Const Double
x -> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x
Param EClassId
ix -> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Param EClassId
ix
where
isTerm :: SRTree val -> Bool
isTerm (Var EClassId
_) = Bool
True
isTerm (Const Double
_) = Bool
True
isTerm (Param EClassId
_) = Bool
True
isTerm SRTree val
_ = Bool
False
getAllExpressionsFrom :: Monad m => EClassId -> EGraphST m [Fix SRTree]
getAllExpressionsFrom :: forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m [Fix SRTree]
getAllExpressionsFrom EClassId
eId' = do
EClassId
eId <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
eId'
[ENode]
nodes <- (EGraph -> [ENode]) -> StateT EGraph m [ENode]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ENodeEnc -> ENode) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> [a] -> [b]
map ENodeEnc -> ENode
decodeEnode ([ENodeEnc] -> [ENode])
-> (EGraph -> [ENodeEnc]) -> EGraph -> [ENode]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashSet ENodeEnc -> [ENodeEnc]
forall a. HashSet a -> [a]
Set.toList (HashSet ENodeEnc -> [ENodeEnc])
-> (EGraph -> HashSet ENodeEnc) -> EGraph -> [ENodeEnc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> HashSet ENodeEnc
_eNodes (EClass -> HashSet ENodeEnc)
-> (EGraph -> EClass) -> EGraph -> HashSet ENodeEnc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClassIdMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
eId) (ClassIdMap EClass -> EClass)
-> (EGraph -> ClassIdMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> ClassIdMap EClass
_eClass)
let cands :: [ENode]
cands = (ENode -> Bool) -> [ENode] -> [ENode]
forall a. (a -> Bool) -> [a] -> [a]
filter ENode -> Bool
forall {val}. SRTree val -> Bool
isTerm [ENode]
nodes
[[Fix SRTree]] -> [Fix SRTree]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Fix SRTree]] -> [Fix SRTree])
-> StateT EGraph m [[Fix SRTree]] -> EGraphST m [Fix SRTree]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ENode] -> StateT EGraph m [[Fix SRTree]]
forall {m :: * -> *}.
Monad m =>
[ENode] -> StateT EGraph m [[Fix SRTree]]
go [ENode]
nodes
where
isTerm :: SRTree val -> Bool
isTerm (Var EClassId
_) = Bool
True
isTerm (Const Double
_) = Bool
True
isTerm (Param EClassId
_) = Bool
True
isTerm SRTree val
_ = Bool
False
toTree :: SRTree val -> Fix SRTree
toTree (Var EClassId
ix) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Var EClassId
ix
toTree (Const Double
x) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x
toTree (Param EClassId
ix) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Param EClassId
ix
toTree SRTree val
_ = Fix SRTree
forall a. HasCallStack => a
undefined
go :: [ENode] -> StateT EGraph m [[Fix SRTree]]
go [] = [[Fix SRTree]] -> StateT EGraph m [[Fix SRTree]]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
go (ENode
n:[ENode]
ns) = do
[Fix SRTree]
t <- (SRTree (Fix SRTree) -> Fix SRTree)
-> [SRTree (Fix SRTree)] -> [Fix SRTree]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix ([SRTree (Fix SRTree)] -> [Fix SRTree])
-> StateT EGraph m [SRTree (Fix SRTree)]
-> StateT EGraph m [Fix SRTree]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case ENode
n of
Bin Op
op EClassId
l EClassId
r -> do [Fix SRTree]
l' <- EClassId -> StateT EGraph m [Fix SRTree]
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m [Fix SRTree]
getAllExpressionsFrom EClassId
l
[Fix SRTree]
r' <- EClassId -> StateT EGraph m [Fix SRTree]
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m [Fix SRTree]
getAllExpressionsFrom EClassId
r
[SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)])
-> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a b. (a -> b) -> a -> b
$ [Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
op Fix SRTree
li Fix SRTree
ri | Fix SRTree
li <- [Fix SRTree]
l', Fix SRTree
ri <- [Fix SRTree]
r']
Uni Function
f EClassId
t -> (Fix SRTree -> SRTree (Fix SRTree))
-> [Fix SRTree] -> [SRTree (Fix SRTree)]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f) ([Fix SRTree] -> [SRTree (Fix SRTree)])
-> StateT EGraph m [Fix SRTree]
-> StateT EGraph m [SRTree (Fix SRTree)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> StateT EGraph m [Fix SRTree]
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m [Fix SRTree]
getAllExpressionsFrom EClassId
t
Var EClassId
ix -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Var EClassId
ix]
Const Double
x -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x]
Param EClassId
ix -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Param EClassId
ix]
[[Fix SRTree]]
ts <- [ENode] -> StateT EGraph m [[Fix SRTree]]
go [ENode]
ns
[[Fix SRTree]] -> StateT EGraph m [[Fix SRTree]]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Fix SRTree]
t[Fix SRTree] -> [[Fix SRTree]] -> [[Fix SRTree]]
forall a. a -> [a] -> [a]
:[[Fix SRTree]]
ts)
getRndExpressionFrom :: EClassId -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom :: EClassId -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom EClassId
eId' = do
EClassId
eId <- EClassId -> EGraphST (State StdGen) EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
eId'
[ENodeEnc]
nodes <- (EGraph -> [ENodeEnc]) -> StateT EGraph (State StdGen) [ENodeEnc]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (HashSet ENodeEnc -> [ENodeEnc]
forall a. HashSet a -> [a]
Set.toList (HashSet ENodeEnc -> [ENodeEnc])
-> (EGraph -> HashSet ENodeEnc) -> EGraph -> [ENodeEnc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> HashSet ENodeEnc
_eNodes (EClass -> HashSet ENodeEnc)
-> (EGraph -> EClass) -> EGraph -> HashSet ENodeEnc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClassIdMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
eId) (ClassIdMap EClass -> EClass)
-> (EGraph -> ClassIdMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> ClassIdMap EClass
_eClass)
ENodeEnc
n <- State StdGen ENodeEnc -> StateT EGraph (State StdGen) ENodeEnc
forall (m :: * -> *) a. Monad m => m a -> StateT EGraph m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (State StdGen ENodeEnc -> StateT EGraph (State StdGen) ENodeEnc)
-> State StdGen ENodeEnc -> StateT EGraph (State StdGen) ENodeEnc
forall a b. (a -> b) -> a -> b
$ [ENodeEnc] -> State StdGen ENodeEnc
forall {m :: * -> *} {s} {b}.
(MonadState s m, RandomGen s) =>
[b] -> m b
randomFrom [ENodeEnc]
nodes
SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
-> EGraphST (State StdGen) (Fix SRTree)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case ENodeEnc -> ENode
decodeEnode ENodeEnc
n of
Bin Op
op EClassId
l EClassId
r -> Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
op (Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST (State StdGen) (Fix SRTree)
-> StateT EGraph (State StdGen) (Fix SRTree -> SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom EClassId
l StateT EGraph (State StdGen) (Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST (State StdGen) (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b.
StateT EGraph (State StdGen) (a -> b)
-> StateT EGraph (State StdGen) a -> StateT EGraph (State StdGen) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> EClassId -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom EClassId
r
Uni Function
f EClassId
t -> Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f (Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST (State StdGen) (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom EClassId
t
Var EClassId
ix -> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a. a -> StateT EGraph (State StdGen) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Var EClassId
ix
Const Double
x -> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a. a -> StateT EGraph (State StdGen) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x
Param EClassId
ix -> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a. a -> StateT EGraph (State StdGen) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Param EClassId
ix
where
randomRange :: (a, a) -> m a
randomRange (a, a)
rng = (s -> (a, s)) -> m a
forall a. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state ((a, a) -> s -> (a, s)
forall g. RandomGen g => (a, a) -> g -> (a, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (a, a)
rng)
randomFrom :: [b] -> m b
randomFrom [b]
xs = do EClassId
n <- (EClassId, EClassId) -> m EClassId
forall {s} {m :: * -> *} {a}.
(MonadState s m, Random a, RandomGen s) =>
(a, a) -> m a
randomRange (EClassId
0, [b] -> EClassId
forall a. [a] -> EClassId
forall (t :: * -> *) a. Foldable t => t a -> EClassId
length [b]
xs EClassId -> EClassId -> EClassId
forall a. Num a => a -> a -> a
- EClassId
1)
b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b -> m b) -> b -> m b
forall a b. (a -> b) -> a -> b
$ [b]
xs [b] -> EClassId -> b
forall a. HasCallStack => [a] -> EClassId -> a
!! EClassId
n
cleanMaps :: Monad m => EGraphST m ()
cleanMaps :: forall (m :: * -> *). Monad m => EGraphST m ()
cleanMaps = do
Map ENode EClassId
enode2eclass <- (EGraph -> Map ENode EClassId)
-> StateT EGraph m (Map ENode EClassId)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> Map ENode EClassId
_eNodeToEClass
[(ENode, EClassId)]
entries <- [(ENode, EClassId)]
-> ((ENode, EClassId) -> StateT EGraph m (ENode, EClassId))
-> StateT EGraph m [(ENode, EClassId)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Map ENode EClassId -> [(ENode, EClassId)]
forall k a. Map k a -> [(k, a)]
Map.toList Map ENode EClassId
enode2eclass) (((ENode, EClassId) -> StateT EGraph m (ENode, EClassId))
-> StateT EGraph m [(ENode, EClassId)])
-> ((ENode, EClassId) -> StateT EGraph m (ENode, EClassId))
-> StateT EGraph m [(ENode, EClassId)]
forall a b. (a -> b) -> a -> b
$ \(ENode
k,EClassId
v) -> do
ENode
k' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
k
EClassId
v' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
v
(ENode, EClassId) -> StateT EGraph m (ENode, EClassId)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ENode
k',EClassId
v')
let enode2eclass' :: Map ENode EClassId
enode2eclass' = [(ENode, EClassId)] -> Map ENode EClassId
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(ENode, EClassId)]
entries
ClassIdMap EClass
eclassMap <- (EGraph -> ClassIdMap EClass)
-> StateT EGraph m (ClassIdMap EClass)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap EClass
_eClass
[Maybe (EClassId, EClass)]
entries' <- [(EClassId, EClass)]
-> ((EClassId, EClass)
-> StateT EGraph m (Maybe (EClassId, EClass)))
-> StateT EGraph m [Maybe (EClassId, EClass)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (ClassIdMap EClass -> [(EClassId, EClass)]
forall a. IntMap a -> [(EClassId, a)]
IntMap.toList ClassIdMap EClass
eclassMap) (((EClassId, EClass) -> StateT EGraph m (Maybe (EClassId, EClass)))
-> StateT EGraph m [Maybe (EClassId, EClass)])
-> ((EClassId, EClass)
-> StateT EGraph m (Maybe (EClassId, EClass)))
-> StateT EGraph m [Maybe (EClassId, EClass)]
forall a b. (a -> b) -> a -> b
$ \(EClassId
k,EClass
v) -> do
EClassId
k' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
k
Maybe (EClassId, EClass)
-> StateT EGraph m (Maybe (EClassId, EClass))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (EClassId, EClass)
-> StateT EGraph m (Maybe (EClassId, EClass)))
-> Maybe (EClassId, EClass)
-> StateT EGraph m (Maybe (EClassId, EClass))
forall a b. (a -> b) -> a -> b
$ if EClassId
kEClassId -> EClassId -> Bool
forall a. Eq a => a -> a -> Bool
==EClassId
k' then ((EClassId, EClass) -> Maybe (EClassId, EClass)
forall a. a -> Maybe a
Just (EClassId
k,EClass
v)) else Maybe (EClassId, EClass)
forall a. Maybe a
Nothing
let eclassMap' :: ClassIdMap EClass
eclassMap' = [(EClassId, EClass)] -> ClassIdMap EClass
forall a. [(EClassId, a)] -> IntMap a
IntMap.fromList ([Maybe (EClassId, EClass)] -> [(EClassId, EClass)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (EClassId, EClass)]
entries')
ClassIdMap EClassId
canon <- (EGraph -> ClassIdMap EClassId)
-> StateT EGraph m (ClassIdMap EClassId)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap EClassId
_canonicalMap
[Maybe (EClassId, EClassId)]
entries'' <- [(EClassId, EClassId)]
-> ((EClassId, EClassId)
-> StateT EGraph m (Maybe (EClassId, EClassId)))
-> StateT EGraph m [Maybe (EClassId, EClassId)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (ClassIdMap EClassId -> [(EClassId, EClassId)]
forall a. IntMap a -> [(EClassId, a)]
IntMap.toList ClassIdMap EClassId
canon) (((EClassId, EClassId)
-> StateT EGraph m (Maybe (EClassId, EClassId)))
-> StateT EGraph m [Maybe (EClassId, EClassId)])
-> ((EClassId, EClassId)
-> StateT EGraph m (Maybe (EClassId, EClassId)))
-> StateT EGraph m [Maybe (EClassId, EClassId)]
forall a b. (a -> b) -> a -> b
$ \(EClassId
k,EClassId
v) -> do
Maybe (EClassId, EClassId)
-> StateT EGraph m (Maybe (EClassId, EClassId))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (EClassId, EClassId)
-> StateT EGraph m (Maybe (EClassId, EClassId)))
-> Maybe (EClassId, EClassId)
-> StateT EGraph m (Maybe (EClassId, EClassId))
forall a b. (a -> b) -> a -> b
$ if EClassId
kEClassId -> EClassId -> Bool
forall a. Eq a => a -> a -> Bool
==EClassId
v then (EClassId, EClassId) -> Maybe (EClassId, EClassId)
forall a. a -> Maybe a
Just (EClassId
k,EClassId
v) else Maybe (EClassId, EClassId)
forall a. Maybe a
Nothing
let canon' :: ClassIdMap EClassId
canon' = [(EClassId, EClassId)] -> ClassIdMap EClassId
forall a. [(EClassId, a)] -> IntMap a
IntMap.fromList ([Maybe (EClassId, EClassId)] -> [(EClassId, EClassId)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (EClassId, EClassId)]
entries'')
EGraphDB
eDB' <- (EGraph -> EGraphDB) -> StateT EGraph m EGraphDB
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> EGraphDB
_eDB
EGraph -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (EGraph -> EGraphST m ()) -> EGraph -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ClassIdMap EClassId
-> Map ENode EClassId -> ClassIdMap EClass -> EGraphDB -> EGraph
EGraph ClassIdMap EClassId
canon Map ENode EClassId
enode2eclass' ClassIdMap EClass
eclassMap' EGraphDB
eDB'
EGraphST m ()
forall (m :: * -> *) s. Monad m => StateT s m ()
forceState
forceState :: Monad m => StateT s m ()
forceState :: forall (m :: * -> *) s. Monad m => StateT s m ()
forceState = StateT s m s
forall s (m :: * -> *). MonadState s m => m s
get StateT s m s -> (s -> StateT s m ()) -> StateT s m ()
forall a b. StateT s m a -> (a -> StateT s m b) -> StateT s m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ !s
_ -> () -> StateT s m ()
forall a. a -> StateT s m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()