{-# LANGUAGE TupleSections #-}
{-# LANGUAGE BangPatterns #-}
module Algorithm.EqSat.Build where
import System.Random (Random (randomR), StdGen)
import Control.Lens ( over )
import Control.Monad ( forM_, when, foldM, forM )
import Data.Maybe ( fromMaybe, catMaybes )
import Data.SRTree
import Algorithm.EqSat.Egraph
import Algorithm.EqSat.DB
import qualified Data.IntMap.Strict as IntMap
import Data.Map.Strict ( Map )
import qualified Data.Map.Strict as Map
import qualified Data.HashSet as Set
import Control.Monad.State.Strict
import Data.SRTree.Recursion (cataM)
import Algorithm.EqSat.Info
import qualified Data.IntSet as IntSet
import Data.Maybe
import Data.Sequence (Seq(..), (><))
import Debug.Trace (trace, traceShow)
add :: Monad m => CostFun -> ENode -> EGraphST m EClassId
add :: forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassId
add CostFun
costFun ENode
enode =
do enode'' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
enode
constEnode <- calculateConsts enode''
let enode' = case Consts
constEnode of
ConstVal Double
x -> Double -> ENode
forall val. Double -> SRTree val
Const Double
x
ParamIx EClassId
x -> EClassId -> ENode
forall val. EClassId -> SRTree val
Param EClassId
x
Consts
_ -> ENode
enode''
maybeEid <- gets ((Map.!? enode') . _eNodeToEClass)
case maybeEid of
Just EClassId
eid -> EClassId -> StateT EGraph m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
eid
Maybe EClassId
Nothing -> do
curId <- (EGraph -> EClassId) -> StateT EGraph m EClassId
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> EClassId
_nextId (EGraphDB -> EClassId)
-> (EGraph -> EGraphDB) -> EGraph -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
modify' $ over canonicalMap (IntMap.insert curId curId)
. over eNodeToEClass (Map.insert enode' curId)
. over (eDB . nextId) (+1)
. over (eDB . worklist) (Set.insert (curId, enode'))
forM_ (childrenOf enode') (addParents curId enode')
info <- makeAnalysis costFun enode'
h <- getChildrenMinHeight enode'
let newClass = EClassId -> ENode -> EClassData -> EClassId -> EClass
createEClass EClassId
curId ENode
enode' EClassData
info EClassId
h
modify' $ over eClass (IntMap.insert curId newClass)
addToDB enode' curId
modify' $ over (eDB . sizeDB)
$ IntMap.insertWith (IntSet.union) (_size info) (IntSet.singleton curId)
modify' $ over (eDB . unevaluated) (IntSet.insert curId)
pure curId
where
addParents :: Monad m => EClassId -> ENode -> EClassId -> EGraphST m ()
addParents :: forall (m :: * -> *).
Monad m =>
EClassId -> ENode -> EClassId -> EGraphST m ()
addParents EClassId
cId ENode
node EClassId
c =
do ec <- EClassId -> EGraphST m EClass
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClass
getEClass EClassId
c
let ec' = EClass
ec{ _parents = Set.insert (cId, node) (_parents ec) }
modify' $ over eClass (IntMap.insert c ec')
rebuild :: Monad m => CostFun -> EGraphST m ()
rebuild :: forall (m :: * -> *). Monad m => CostFun -> EGraphST m ()
rebuild CostFun
costFun =
do wl <- (EGraph -> HashSet (EClassId, ENode))
-> StateT EGraph m (HashSet (EClassId, ENode))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> HashSet (EClassId, ENode)
_worklist (EGraphDB -> HashSet (EClassId, ENode))
-> (EGraph -> EGraphDB) -> EGraph -> HashSet (EClassId, ENode)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
al <- gets (_analysis . _eDB)
modify' $ over (eDB . worklist) (const Set.empty)
. over (eDB . analysis) (const Set.empty)
forM_ wl (uncurry (repair costFun))
forM_ al (uncurry (repairAnalysis costFun))
repair :: Monad m => CostFun -> EClassId -> ENode -> EGraphST m ()
repair :: forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> ENode -> EGraphST m ()
repair CostFun
costFun EClassId
ecId ENode
enode =
do (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
-> (Map ENode EClassId -> Map ENode EClassId) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
Lens' EGraph (Map ENode EClassId)
eNodeToEClass (ENode -> Map ENode EClassId -> Map ENode EClassId
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete ENode
enode)
enode' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
enode
ecId' <- canonical ecId
doExist <- gets ((Map.!? enode') . _eNodeToEClass)
case doExist of
Just EClassId
ecIdCanon -> do mergedId <- CostFun -> EClassId -> EClassId -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge CostFun
costFun EClassId
ecIdCanon EClassId
ecId'
modify' $ over eNodeToEClass (Map.insert enode' mergedId)
Maybe EClassId
Nothing -> (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
-> (Map ENode EClassId -> Map ENode EClassId) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (Map ENode EClassId) (Map ENode EClassId)
Lens' EGraph (Map ENode EClassId)
eNodeToEClass (ENode -> EClassId -> Map ENode EClassId -> Map ENode EClassId
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ENode
enode' EClassId
ecId')
repairAnalysis :: Monad m => CostFun -> EClassId -> ENode -> EGraphST m ()
repairAnalysis :: forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> ENode -> EGraphST m ()
repairAnalysis CostFun
costFun EClassId
ecId ENode
enode =
do ecId' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
ecId
enode' <- canonize enode
eclass <- getEClass ecId'
info <- makeAnalysis costFun enode'
let newData = EClassData -> EClassData -> EClassData
joinData (EClass -> EClassData
_info EClass
eclass) EClassData
info
eclass' = EClass
eclass { _info = newData }
when (_info eclass /= newData) $
do modify' $ over (eDB . analysis) (_parents eclass <>)
. over eClass (IntMap.insert ecId' eclass')
_ <- modifyEClass costFun ecId'
pure ()
merge :: Monad m => CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge :: forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge CostFun
costFun EClassId
c1 EClassId
c2 =
do c1' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
c1
c2' <- canonical c2
if c1' == c2'
then pure c1'
else do (led, ledC, ledOrig, sub, subC, subOrig) <- getLeaderSub c1' c1 c2' c2
mergeClasses led ledC ledOrig sub subC subOrig
where
mergeClasses :: Monad m => EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m EClassId
mergeClasses :: forall (m :: * -> *).
Monad m =>
EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m EClassId
mergeClasses EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO =
do (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap EClassId) (ClassIdMap EClassId)
-> (ClassIdMap EClassId -> ClassIdMap EClassId) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap EClassId) (ClassIdMap EClassId)
Lens' EGraph (ClassIdMap EClassId)
canonicalMap (EClassId -> EClassId -> ClassIdMap EClassId -> ClassIdMap EClassId
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
sub EClassId
led)
let
newC :: EClass
newC = EClassId
-> HashSet ENodeEnc
-> HashSet (EClassId, ENode)
-> EClassId
-> EClassData
-> EClass
EClass EClassId
led
(EClass -> HashSet ENodeEnc
_eNodes EClass
ledC HashSet ENodeEnc -> HashSet ENodeEnc -> HashSet ENodeEnc
forall a. Eq a => HashSet a -> HashSet a -> HashSet a
`Set.union` EClass -> HashSet ENodeEnc
_eNodes EClass
subC)
(EClass -> HashSet (EClassId, ENode)
_parents EClass
ledC HashSet (EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a. Semigroup a => a -> a -> a
<> EClass -> HashSet (EClassId, ENode)
_parents EClass
subC)
(EClassId -> EClassId -> EClassId
forall a. Ord a => a -> a -> a
min (EClass -> EClassId
_height EClass
ledC) (EClass -> EClassId
_height EClass
subC))
(EClassData -> EClassData -> EClassData
joinData (EClass -> EClassData
_info EClass
ledC) (EClass -> EClassData
_info EClass
subC))
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
-> (ClassIdMap EClass -> ClassIdMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap EClass) (ClassIdMap EClass)
Lens' EGraph (ClassIdMap EClass)
eClass (EClassId -> EClass -> ClassIdMap EClass -> ClassIdMap EClass
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
led EClass
newC (ClassIdMap EClass -> ClassIdMap EClass)
-> (ClassIdMap EClass -> ClassIdMap EClass)
-> ClassIdMap EClass
-> ClassIdMap EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> ClassIdMap EClass -> ClassIdMap EClass
forall a. EClassId -> IntMap a -> IntMap a
IntMap.delete EClassId
sub)
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
-> (HashSet (EClassId, ENode) -> HashSet (EClassId, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (EClassId, ENode)
-> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (EClassId, ENode) -> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (EClassId, ENode))
worklist) (EClass -> HashSet (EClassId, ENode)
_parents EClass
subC HashSet (EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a. Semigroup a => a -> a -> a
<>)
Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EClass -> EClassData
_info EClass
newC EClassData -> EClassData -> Bool
forall a. Eq a => a -> a -> Bool
/= EClass -> EClassData
_info EClass
ledC)
(StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
-> (HashSet (EClassId, ENode) -> HashSet (EClassId, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (EClassId, ENode)
-> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (EClassId, ENode) -> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (EClassId, ENode))
analysis) (EClass -> HashSet (EClassId, ENode)
_parents EClass
ledC HashSet (EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a. Semigroup a => a -> a -> a
<>)
Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EClass -> EClassData
_info EClass
newC EClassData -> EClassData -> Bool
forall a. Eq a => a -> a -> Bool
/= EClass -> EClassData
_info EClass
subC)
(StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
-> (HashSet (EClassId, ENode) -> HashSet (EClassId, ENode))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((HashSet (EClassId, ENode)
-> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(HashSet (EClassId, ENode))
(HashSet (EClassId, ENode))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HashSet (EClassId, ENode) -> Identity (HashSet (EClassId, ENode)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (HashSet (EClassId, ENode))
analysis) (EClass -> HashSet (EClassId, ENode)
_parents EClass
subC HashSet (EClassId, ENode)
-> HashSet (EClassId, ENode) -> HashSet (EClassId, ENode)
forall a. Semigroup a => a -> a -> a
<>)
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> StateT EGraph m ()
forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateDBs EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO
CostFun -> EClassId -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EGraphST m EClassId
modifyEClass CostFun
costFun EClassId
led
EClassId -> EGraphST m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
led
getLeaderSub :: EClassId
-> c
-> EClassId
-> c
-> StateT EGraph m (EClassId, EClass, c, EClassId, EClass, c)
getLeaderSub EClassId
c1 c
c1O EClassId
c2 c
c2O =
do ec1 <- EClassId -> EGraphST m EClass
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClass
getEClass EClassId
c1
ec2 <- getEClass c2
let n1 = HashSet (EClassId, ENode) -> EClassId
forall a. HashSet a -> EClassId
forall (t :: * -> *) a. Foldable t => t a -> EClassId
length (EClass -> HashSet (EClassId, ENode)
_parents EClass
ec1)
n2 = HashSet (EClassId, ENode) -> EClassId
forall a. HashSet a -> EClassId
forall (t :: * -> *) a. Foldable t => t a -> EClassId
length (EClass -> HashSet (EClassId, ENode)
_parents EClass
ec2)
pure $ if n1 >= n2
then (c1, ec1, c1O, c2, ec2, c2O)
else (c2, ec2, c2O, c1, ec1, c1O)
updateDBs :: Monad m => EClass -> EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m ()
updateDBs :: forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateDBs EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO = do
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateFitnessDB EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateSizeDB EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO
updateSizeDB :: Monad m => EClass -> EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m ()
updateSizeDB :: forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateSizeDB EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO = do
let sz :: EClassId
sz = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
newC
szL :: EClassId
szL = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
ledC
szS :: EClassId
szS = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
subC
fun :: IntMap IntSet -> IntMap IntSet
fun = (IntSet -> IntSet) -> EClassId -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> IntSet -> IntSet
IntSet.insert EClassId
led) EClassId
sz (IntMap IntSet -> IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet)
-> IntMap IntSet
-> IntMap IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> IntSet) -> EClassId -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> IntSet -> IntSet
IntSet.delete EClassId
led (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
ledO) EClassId
szL (IntMap IntSet -> IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet)
-> IntMap IntSet
-> IntMap IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> IntSet) -> EClassId -> IntMap IntSet -> IntMap IntSet
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> IntSet -> IntSet
IntSet.delete EClassId
sub (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
subO) EClassId
szS
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap IntSet) (IntMap IntSet)
-> (IntMap IntSet -> IntMap IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap IntSet -> Identity (IntMap IntSet))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (IntMap IntSet) (IntMap IntSet)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap IntSet -> Identity (IntMap IntSet))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap IntSet)
sizeDB) IntMap IntSet -> IntMap IntSet
fun
updateFitnessDB :: Monad m => EClass -> EClassId -> EClass -> EClassId -> EClassId -> EClass -> EClassId -> EGraphST m ()
updateFitnessDB :: forall (m :: * -> *).
Monad m =>
EClass
-> EClassId
-> EClass
-> EClassId
-> EClassId
-> EClass
-> EClassId
-> EGraphST m ()
updateFitnessDB EClass
newC EClassId
led EClass
ledC EClassId
ledO EClassId
sub EClass
subC EClassId
subO =
if (Maybe Double -> Bool
forall a. Maybe a -> Bool
isJust Maybe Double
fitNew)
then do
Bool -> StateT EGraph m () -> StateT EGraph m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe Double
fitNew Maybe Double -> Maybe Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Maybe Double
fitLed) (StateT EGraph m () -> StateT EGraph m ())
-> StateT EGraph m () -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ do
if Maybe Double -> Bool
forall a. Maybe a -> Bool
isNothing Maybe Double
fitLed
then (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (EClassId -> IntSet -> IntSet
IntSet.delete EClassId
led (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
ledO)
else (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
ledO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed))
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
-> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeFitDB) ((RangeTree Double -> RangeTree Double)
-> EClassId
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
ledO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitLed)) EClassId
szLed)
(EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
insertRange EClassId
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitNew))
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
-> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeFitDB) ((RangeTree Double -> RangeTree Double)
-> EClassId
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
insertRange EClassId
led (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitNew)) EClassId
szNew (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> RangeTree Double -> RangeTree Double)
-> EClassId
-> RangeTree Double
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall a. (a -> a -> a) -> EClassId -> a -> IntMap a -> IntMap a
IntMap.insertWith RangeTree Double -> RangeTree Double -> RangeTree Double
forall a. Seq a -> Seq a -> Seq a
(><) EClassId
szNew RangeTree Double
forall a. Seq a
Empty)
if Maybe Double -> Bool
forall a. Maybe a -> Bool
isNothing Maybe Double
fitSub
then (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (EClassId -> IntSet -> IntSet
IntSet.delete EClassId
sub (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
subO)
else (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
sub (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
subO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub))
(EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
-> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB)
-> ASetter
EGraph
EGraph
(IntMap (RangeTree Double))
(IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeFitDB) ((RangeTree Double -> RangeTree Double)
-> EClassId
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall a. (a -> a) -> EClassId -> IntMap a -> IntMap a
IntMap.adjust (EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
subO (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub) (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
EClassId -> a -> RangeTree a -> RangeTree a
removeRange EClassId
sub (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
fitSub)) EClassId
szSub)
else (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (EClassId -> IntSet -> IntSet
IntSet.insert EClassId
led (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
ledO (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
sub (IntSet -> IntSet) -> (IntSet -> IntSet) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassId -> IntSet -> IntSet
IntSet.delete EClassId
subO)
where
fitNew :: Maybe Double
fitNew = (EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
newC
fitLed :: Maybe Double
fitLed = (EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
ledC
fitSub :: Maybe Double
fitSub = (EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
subC
szNew :: EClassId
szNew = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
newC
szLed :: EClassId
szLed = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
ledC
szSub :: EClassId
szSub = (EClassData -> EClassId
_size (EClassData -> EClassId)
-> (EClass -> EClassData) -> EClass -> EClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) EClass
subC
modifyEClass :: Monad m => CostFun -> EClassId -> EGraphST m EClassId
modifyEClass :: forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EGraphST m EClassId
modifyEClass CostFun
costFun EClassId
ecId =
do ec <- EClassId -> EGraphST m EClass
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClass
getEClass EClassId
ecId
case (_consts . _info) ec of
ConstVal Double
x -> do
let en :: SRTree val
en = Double -> SRTree val
forall val. Double -> SRTree val
Const Double
x
c <- CostFun -> ENode -> StateT EGraph m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassId
calculateCost CostFun
costFun ENode
forall {val}. SRTree val
en
let infoEc = (EClass -> EClassData
_info EClass
ec){ _cost = c, _best = en, _consts = toConst en }
maybeEid <- gets ((Map.!? en) . _eNodeToEClass)
modify' $ over eClass (IntMap.insert ecId ec{_eNodes = Set.singleton (encodeEnode en) , _info = infoEc})
case maybeEid of
Maybe EClassId
Nothing -> EClassId -> StateT EGraph m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
ecId
Just EClassId
eid' -> CostFun -> EClassId -> EClassId -> StateT EGraph m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge CostFun
costFun EClassId
eid' EClassId
ecId
Consts
_ -> EClassId -> StateT EGraph m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
ecId
where
isTerm :: SRTree val -> Bool
isTerm (Var EClassId
_) = Bool
True
isTerm (Const Double
_) = Bool
True
isTerm (Param EClassId
_) = Bool
True
isTerm SRTree val
_ = Bool
False
toConst :: SRTree val -> Consts
toConst (Param EClassId
ix) = EClassId -> Consts
ParamIx EClassId
ix
toConst (Const Double
x) = Double -> Consts
ConstVal Double
x
toConst SRTree val
_ = Consts
NotConst
createDB :: Monad m => EGraphST m DB
createDB :: forall (m :: * -> *). Monad m => EGraphST m DB
createDB = do (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph DB DB -> (DB -> DB) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((DB -> Identity DB) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph DB DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DB -> Identity DB) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB DB
patDB) (DB -> DB -> DB
forall a b. a -> b -> a
const DB
forall k a. Map k a
Map.empty)
ecls <- (EGraph -> [(ENode, EClassId)])
-> StateT EGraph m [(ENode, EClassId)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Map ENode EClassId -> [(ENode, EClassId)]
forall k a. Map k a -> [(k, a)]
Map.toList (Map ENode EClassId -> [(ENode, EClassId)])
-> (EGraph -> Map ENode EClassId) -> EGraph -> [(ENode, EClassId)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> Map ENode EClassId
_eNodeToEClass)
mapM_ (uncurry addToDB) ecls
gets (_patDB . _eDB)
addToDB :: Monad m => ENode -> EClassId -> EGraphST m ()
addToDB :: forall (m :: * -> *). Monad m => ENode -> EClassId -> EGraphST m ()
addToDB ENode
enode EClassId
eid = do
let ids :: [EClassId]
ids = EClassId
eid EClassId -> [EClassId] -> [EClassId]
forall a. a -> [a] -> [a]
: ENode -> [EClassId]
forall a. SRTree a -> [a]
childrenOf ENode
enode
op :: SRTree ()
op = ENode -> SRTree ()
forall a. SRTree a -> SRTree ()
getOperator ENode
enode
trie <- (EGraph -> Maybe IntTrie) -> StateT EGraph m (Maybe IntTrie)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((DB -> SRTree () -> Maybe IntTrie
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? SRTree ()
op) (DB -> Maybe IntTrie) -> (EGraph -> DB) -> EGraph -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
case populate trie ids of
Maybe IntTrie
Nothing -> () -> StateT EGraph m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Just IntTrie
t -> (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph DB DB -> (DB -> DB) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((DB -> Identity DB) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph DB DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DB -> Identity DB) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB DB
patDB) (SRTree () -> IntTrie -> DB -> DB
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert SRTree ()
op IntTrie
t)
populate :: Maybe IntTrie -> [EClassId] -> Maybe IntTrie
populate :: Maybe IntTrie -> [EClassId] -> Maybe IntTrie
populate Maybe IntTrie
_ [] = Maybe IntTrie
forall a. Maybe a
Nothing
populate Maybe IntTrie
Nothing [EClassId]
eids = (EClassId -> Maybe IntTrie -> Maybe IntTrie)
-> Maybe IntTrie -> [EClassId] -> Maybe IntTrie
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr EClassId -> Maybe IntTrie -> Maybe IntTrie
f Maybe IntTrie
forall a. Maybe a
Nothing [EClassId]
eids
where
f :: EClassId -> Maybe IntTrie -> Maybe IntTrie
f :: EClassId -> Maybe IntTrie -> Maybe IntTrie
f EClassId
eid (Just IntTrie
t) = IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie) -> IntTrie -> Maybe IntTrie
forall a b. (a -> b) -> a -> b
$ EClassId -> IntMap IntTrie -> IntTrie
trie EClassId
eid (EClassId -> IntTrie -> IntMap IntTrie
forall a. EClassId -> a -> IntMap a
IntMap.singleton EClassId
eid IntTrie
t)
f EClassId
eid Maybe IntTrie
Nothing = IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie) -> IntTrie -> Maybe IntTrie
forall a b. (a -> b) -> a -> b
$ EClassId -> IntMap IntTrie -> IntTrie
trie EClassId
eid IntMap IntTrie
forall a. IntMap a
IntMap.empty
populate (Just IntTrie
tId) (EClassId
eid:[EClassId]
eids) = let keys :: HashSet EClassId
keys = EClassId -> HashSet EClassId -> HashSet EClassId
forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
Set.insert EClassId
eid (IntTrie -> HashSet EClassId
_keys IntTrie
tId)
nextTrie :: Maybe IntTrie
nextTrie = IntTrie -> IntMap IntTrie
_trie IntTrie
tId IntMap IntTrie -> EClassId -> Maybe IntTrie
forall a. IntMap a -> EClassId -> Maybe a
IntMap.!? EClassId
eid
val :: IntTrie
val = IntTrie -> Maybe IntTrie -> IntTrie
forall a. a -> Maybe a -> a
fromMaybe (EClassId -> IntMap IntTrie -> IntTrie
trie EClassId
eid IntMap IntTrie
forall a. IntMap a
IntMap.empty) (Maybe IntTrie -> IntTrie) -> Maybe IntTrie -> IntTrie
forall a b. (a -> b) -> a -> b
$ Maybe IntTrie -> [EClassId] -> Maybe IntTrie
populate Maybe IntTrie
nextTrie [EClassId]
eids
in IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie) -> IntTrie -> Maybe IntTrie
forall a b. (a -> b) -> a -> b
$ HashSet EClassId -> IntMap IntTrie -> IntTrie
IntTrie HashSet EClassId
keys (EClassId -> IntTrie -> IntMap IntTrie -> IntMap IntTrie
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
eid IntTrie
val (IntTrie -> IntMap IntTrie
_trie IntTrie
tId))
canonizeMap :: Monad m => (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap :: forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap (Map ClassOrVar ClassOrVar
subst, ClassOrVar
cv) = (,ClassOrVar
cv) (Map ClassOrVar ClassOrVar
-> (Map ClassOrVar ClassOrVar, ClassOrVar))
-> StateT EGraph m (Map ClassOrVar ClassOrVar)
-> StateT EGraph m (Map ClassOrVar ClassOrVar, ClassOrVar)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ClassOrVar -> StateT EGraph m ClassOrVar)
-> Map ClassOrVar ClassOrVar
-> StateT EGraph m (Map ClassOrVar ClassOrVar)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Map ClassOrVar a -> f (Map ClassOrVar b)
traverse ClassOrVar -> StateT EGraph m ClassOrVar
forall (m :: * -> *).
Monad m =>
ClassOrVar -> EGraphST m ClassOrVar
g Map ClassOrVar ClassOrVar
subst
where
g :: Monad m => ClassOrVar -> EGraphST m ClassOrVar
g :: forall (m :: * -> *).
Monad m =>
ClassOrVar -> EGraphST m ClassOrVar
g (Left EClassId
e2) = EClassId -> ClassOrVar
forall a b. a -> Either a b
Left (EClassId -> ClassOrVar)
-> StateT EGraph m EClassId -> StateT EGraph m ClassOrVar
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> StateT EGraph m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
e2
g ClassOrVar
e2 = ClassOrVar -> StateT EGraph m ClassOrVar
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ClassOrVar
e2
f :: Monad m => (ClassOrVar, ClassOrVar) -> EGraphST m (ClassOrVar, ClassOrVar)
f :: forall (m :: * -> *).
Monad m =>
(ClassOrVar, ClassOrVar) -> EGraphST m (ClassOrVar, ClassOrVar)
f (ClassOrVar
e1, Left EClassId
e2) = do e2' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
e2
pure (e1, Left e2')
f (ClassOrVar
e1, ClassOrVar
e2) = (ClassOrVar, ClassOrVar)
-> StateT EGraph m (ClassOrVar, ClassOrVar)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ClassOrVar
e1, ClassOrVar
e2)
applyMatch :: Monad m => CostFun -> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMatch :: forall (m :: * -> *).
Monad m =>
CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMatch CostFun
costFun Rule
rule (Map ClassOrVar ClassOrVar, ClassOrVar)
match' =
do let conds :: [Condition]
conds = Rule -> [Condition]
getConditions Rule
rule
match <- (Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap (Map ClassOrVar ClassOrVar, ClassOrVar)
match'
validHeight <- isValidHeight match
validConds <- mapM (`isValidConditions` match) conds
when (validHeight && and validConds) $
do new_eclass <- reprPrat costFun (fst match) (target rule)
merge costFun (getInt (snd match)) new_eclass
pure ()
applyMergeOnlyMatch :: Monad m => CostFun -> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMergeOnlyMatch :: forall (m :: * -> *).
Monad m =>
CostFun
-> Rule -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m ()
applyMergeOnlyMatch CostFun
costFun Rule
rule (Map ClassOrVar ClassOrVar, ClassOrVar)
match' =
do let conds :: [Condition]
conds = Rule -> [Condition]
getConditions Rule
rule
match <- (Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar)
-> EGraphST m (Map ClassOrVar ClassOrVar, ClassOrVar)
canonizeMap (Map ClassOrVar ClassOrVar, ClassOrVar)
match'
validHeight <- isValidHeight match
validConds <- mapM (`isValidConditions` match) conds
when (validHeight && and validConds) $
do maybe_eid <- classOfENode costFun (fst match) (target rule)
case maybe_eid of
Maybe EClassId
Nothing -> () -> StateT EGraph m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Just EClassId
eid -> do CostFun -> EClassId -> EClassId -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> EClassId -> EClassId -> EGraphST m EClassId
merge CostFun
costFun (ClassOrVar -> EClassId
getInt ((Map ClassOrVar ClassOrVar, ClassOrVar) -> ClassOrVar
forall a b. (a, b) -> b
snd (Map ClassOrVar ClassOrVar, ClassOrVar)
match)) EClassId
eid
() -> StateT EGraph m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
classOfENode :: Monad m => CostFun -> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m (Maybe EClassId)
classOfENode :: forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar
-> Pattern
-> EGraphST m (Maybe EClassId)
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst (VarPat Char
c) = do let maybeEid :: Maybe EClassId
maybeEid = ClassOrVar -> EClassId
getInt (ClassOrVar -> EClassId) -> Maybe ClassOrVar -> Maybe EClassId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map ClassOrVar ClassOrVar
subst Map ClassOrVar ClassOrVar -> ClassOrVar -> Maybe ClassOrVar
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? EClassId -> ClassOrVar
forall a b. b -> Either a b
Right (Char -> EClassId
forall a. Enum a => a -> EClassId
fromEnum Char
c)
case Maybe EClassId
maybeEid of
Maybe EClassId
Nothing -> Maybe EClassId -> EGraphST m (Maybe EClassId)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe EClassId
forall a. Maybe a
Nothing
Just EClassId
eid -> EClassId -> Maybe EClassId
forall a. a -> Maybe a
Just (EClassId -> Maybe EClassId)
-> StateT EGraph m EClassId -> EGraphST m (Maybe EClassId)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> StateT EGraph m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
eid
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst (Fixed (Const Double
x)) = EClassId -> Maybe EClassId
forall a. a -> Maybe a
Just (EClassId -> Maybe EClassId)
-> StateT EGraph m EClassId -> EGraphST m (Maybe EClassId)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CostFun -> ENode -> StateT EGraph m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassId
add CostFun
costFun (Double -> ENode
forall val. Double -> SRTree val
Const Double
x)
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst (Fixed SRTree Pattern
target) = do newChildren <- (Pattern -> EGraphST m (Maybe EClassId))
-> [Pattern] -> StateT EGraph m [Maybe EClassId]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (CostFun
-> Map ClassOrVar ClassOrVar
-> Pattern
-> EGraphST m (Maybe EClassId)
forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar
-> Pattern
-> EGraphST m (Maybe EClassId)
classOfENode CostFun
costFun Map ClassOrVar ClassOrVar
subst) (SRTree Pattern -> [Pattern]
forall a. SRTree a -> [a]
getElems SRTree Pattern
target)
case sequence newChildren of
Maybe [EClassId]
Nothing -> Maybe EClassId -> EGraphST m (Maybe EClassId)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe EClassId
forall a. Maybe a
Nothing
Just [EClassId]
cs -> do let new_enode :: ENode
new_enode = [EClassId] -> SRTree Pattern -> ENode
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [EClassId]
cs SRTree Pattern
target
cs' <- (EClassId -> StateT EGraph m EClassId)
-> [EClassId] -> StateT EGraph m [EClassId]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM EClassId -> StateT EGraph m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical [EClassId]
cs
areConsts <- mapM isConst cs'
if and areConsts
then do eid <- add costFun new_enode
rebuild costFun
pure (Just eid)
else gets ((Map.!? new_enode) . _eNodeToEClass)
reprPrat :: Monad m => CostFun -> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m EClassId
reprPrat :: forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m EClassId
reprPrat CostFun
costFun Map ClassOrVar ClassOrVar
subst (VarPat Char
c) = EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical (EClassId -> EGraphST m EClassId)
-> EClassId -> EGraphST m EClassId
forall a b. (a -> b) -> a -> b
$ ClassOrVar -> EClassId
getInt (ClassOrVar -> EClassId) -> ClassOrVar -> EClassId
forall a b. (a -> b) -> a -> b
$ Map ClassOrVar ClassOrVar
subst Map ClassOrVar ClassOrVar -> ClassOrVar -> ClassOrVar
forall k a. Ord k => Map k a -> k -> a
Map.! EClassId -> ClassOrVar
forall a b. b -> Either a b
Right (Char -> EClassId
forall a. Enum a => a -> EClassId
fromEnum Char
c)
reprPrat CostFun
costFun Map ClassOrVar ClassOrVar
subst (Fixed SRTree Pattern
target) = do newChildren <- (Pattern -> EGraphST m EClassId)
-> [Pattern] -> StateT EGraph m [EClassId]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun
-> Map ClassOrVar ClassOrVar -> Pattern -> EGraphST m EClassId
reprPrat CostFun
costFun Map ClassOrVar ClassOrVar
subst) (SRTree Pattern -> [Pattern]
forall a. SRTree a -> [a]
getElems SRTree Pattern
target)
add costFun (replaceChildren newChildren target)
isValidHeight :: Monad m => (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidHeight :: forall (m :: * -> *).
Monad m =>
(Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidHeight (Map ClassOrVar ClassOrVar, ClassOrVar)
match = do
h <- case (Map ClassOrVar ClassOrVar, ClassOrVar) -> ClassOrVar
forall a b. (a, b) -> b
snd (Map ClassOrVar ClassOrVar, ClassOrVar)
match of
Left EClassId
ec -> do ec' <- EClassId -> StateT EGraph m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
ec
gets (_height . (IntMap.! ec') . _eClass)
Right EClassId
_ -> EClassId -> StateT EGraph m EClassId
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure EClassId
0
pure $ h < 15
isValidConditions :: Monad m => Condition -> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidConditions :: forall (m :: * -> *).
Monad m =>
Condition
-> (Map ClassOrVar ClassOrVar, ClassOrVar) -> EGraphST m Bool
isValidConditions Condition
cond (Map ClassOrVar ClassOrVar, ClassOrVar)
match = (EGraph -> Bool) -> StateT EGraph m Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((EGraph -> Bool) -> StateT EGraph m Bool)
-> (EGraph -> Bool) -> StateT EGraph m Bool
forall a b. (a -> b) -> a -> b
$ Condition
cond ((Map ClassOrVar ClassOrVar, ClassOrVar)
-> Map ClassOrVar ClassOrVar
forall a b. (a, b) -> a
fst (Map ClassOrVar ClassOrVar, ClassOrVar)
match)
fromTree :: Monad m => CostFun -> Fix SRTree -> EGraphST m EClassId
fromTree :: forall (m :: * -> *).
Monad m =>
CostFun -> Fix SRTree -> EGraphST m EClassId
fromTree CostFun
costFun = (forall x.
SRTree (StateT EGraph m x) -> StateT EGraph m (SRTree x))
-> (ENode -> StateT EGraph m EClassId)
-> Fix SRTree
-> StateT EGraph m EClassId
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(forall x. f (m x) -> m (f x)) -> (f a -> m a) -> Fix f -> m a
cataM SRTree (StateT EGraph m x) -> StateT EGraph m (SRTree x)
forall x. SRTree (StateT EGraph m x) -> StateT EGraph m (SRTree x)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => SRTree (m a) -> m (SRTree a)
sequence (CostFun -> ENode -> StateT EGraph m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassId
add CostFun
costFun)
fromTrees :: Monad m => CostFun -> [Fix SRTree] -> EGraphST m [EClassId]
fromTrees :: forall (m :: * -> *).
Monad m =>
CostFun -> [Fix SRTree] -> EGraphST m [EClassId]
fromTrees CostFun
costFun = ([EClassId] -> Fix SRTree -> StateT EGraph m [EClassId])
-> [EClassId] -> [Fix SRTree] -> StateT EGraph m [EClassId]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\[EClassId]
rs Fix SRTree
t -> do eid <- CostFun -> Fix SRTree -> EGraphST m EClassId
forall (m :: * -> *).
Monad m =>
CostFun -> Fix SRTree -> EGraphST m EClassId
fromTree CostFun
costFun Fix SRTree
t; pure (eid:rs)) []
getBest :: Monad m => EClassId -> EGraphST m (Fix SRTree)
getBest :: forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m (Fix SRTree)
getBest EClassId
eid = do eid' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
eid
best <- gets (_best . _info . (IntMap.! eid') . _eClass)
childs <- mapM getBest $ childrenOf best
pure . Fix $ replaceChildren childs best
getExpressionFrom :: Monad m => EClassId -> EGraphST m (Fix SRTree)
getExpressionFrom :: forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m (Fix SRTree)
getExpressionFrom EClassId
eId' = do
eId <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
eId'
nodes <- gets (Set.map decodeEnode . _eNodes . (IntMap.! eId) . _eClass)
let hasTerm = (ENode -> Bool) -> HashSet ENode -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ENode -> Bool
forall {val}. SRTree val -> Bool
isTerm HashSet ENode
nodes
cands = if Bool
hasTerm then (ENode -> Bool) -> [ENode] -> [ENode]
forall a. (a -> Bool) -> [a] -> [a]
filter ENode -> Bool
forall {val}. SRTree val -> Bool
isTerm (HashSet ENode -> [ENode]
forall a. HashSet a -> [a]
Set.toList HashSet ENode
nodes) else HashSet ENode -> [ENode]
forall a. HashSet a -> [a]
Set.toList HashSet ENode
nodes
Fix <$> case head $ Set.toList nodes of
Bin Op
op EClassId
l EClassId
r -> Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
op (Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree))
-> StateT EGraph m (Fix SRTree)
-> StateT EGraph m (Fix SRTree -> SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> StateT EGraph m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m (Fix SRTree)
getExpressionFrom EClassId
l StateT EGraph m (Fix SRTree -> SRTree (Fix SRTree))
-> StateT EGraph m (Fix SRTree)
-> StateT EGraph m (SRTree (Fix SRTree))
forall a b.
StateT EGraph m (a -> b) -> StateT EGraph m a -> StateT EGraph m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> EClassId -> StateT EGraph m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m (Fix SRTree)
getExpressionFrom EClassId
r
Uni Function
f EClassId
t -> Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f (Fix SRTree -> SRTree (Fix SRTree))
-> StateT EGraph m (Fix SRTree)
-> StateT EGraph m (SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> StateT EGraph m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m (Fix SRTree)
getExpressionFrom EClassId
t
Var EClassId
ix -> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Var EClassId
ix
Const Double
x -> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x
Param EClassId
ix -> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree) -> StateT EGraph m (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Param EClassId
ix
where
isTerm :: SRTree val -> Bool
isTerm (Var EClassId
_) = Bool
True
isTerm (Const Double
_) = Bool
True
isTerm (Param EClassId
_) = Bool
True
isTerm SRTree val
_ = Bool
False
getAllExpressionsFrom :: Monad m => EClassId -> EGraphST m [Fix SRTree]
getAllExpressionsFrom :: forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m [Fix SRTree]
getAllExpressionsFrom EClassId
eId' = do
eId <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
eId'
nodes <- gets (map decodeEnode . Set.toList . _eNodes . (IntMap.! eId) . _eClass)
let cands = (ENode -> Bool) -> [ENode] -> [ENode]
forall a. (a -> Bool) -> [a] -> [a]
filter ENode -> Bool
forall {val}. SRTree val -> Bool
isTerm [ENode]
nodes
concat <$> go nodes
where
isTerm :: SRTree val -> Bool
isTerm (Var EClassId
_) = Bool
True
isTerm (Const Double
_) = Bool
True
isTerm (Param EClassId
_) = Bool
True
isTerm SRTree val
_ = Bool
False
toTree :: SRTree val -> Fix SRTree
toTree (Var EClassId
ix) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Var EClassId
ix
toTree (Const Double
x) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x
toTree (Param EClassId
ix) = SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree) -> Fix SRTree
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Param EClassId
ix
toTree SRTree val
_ = Fix SRTree
forall a. HasCallStack => a
undefined
go :: [ENode] -> StateT EGraph m [[Fix SRTree]]
go [] = [[Fix SRTree]] -> StateT EGraph m [[Fix SRTree]]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
go (ENode
n:[ENode]
ns) = do
t <- (SRTree (Fix SRTree) -> Fix SRTree)
-> [SRTree (Fix SRTree)] -> [Fix SRTree]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix ([SRTree (Fix SRTree)] -> [Fix SRTree])
-> StateT EGraph m [SRTree (Fix SRTree)]
-> StateT EGraph m [Fix SRTree]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case ENode
n of
Bin Op
op EClassId
l EClassId
r -> do l' <- EClassId -> StateT EGraph m [Fix SRTree]
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m [Fix SRTree]
getAllExpressionsFrom EClassId
l
r' <- getAllExpressionsFrom r
pure $ [Bin op li ri | li <- l', ri <- r']
Uni Function
f EClassId
t -> (Fix SRTree -> SRTree (Fix SRTree))
-> [Fix SRTree] -> [SRTree (Fix SRTree)]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f) ([Fix SRTree] -> [SRTree (Fix SRTree)])
-> StateT EGraph m [Fix SRTree]
-> StateT EGraph m [SRTree (Fix SRTree)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> StateT EGraph m [Fix SRTree]
forall (m :: * -> *).
Monad m =>
EClassId -> EGraphST m [Fix SRTree]
getAllExpressionsFrom EClassId
t
Var EClassId
ix -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Var EClassId
ix]
Const Double
x -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x]
Param EClassId
ix -> [SRTree (Fix SRTree)] -> StateT EGraph m [SRTree (Fix SRTree)]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Param EClassId
ix]
ts <- go ns
pure (t:ts)
getRndExpressionFrom :: EClassId -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom :: EClassId -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom EClassId
eId' = do
eId <- EClassId -> EGraphST (State StdGen) EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
eId'
nodes <- gets (Set.toList . _eNodes . (IntMap.! eId) . _eClass)
n <- lift $ randomFrom nodes
Fix <$> case decodeEnode n of
Bin Op
op EClassId
l EClassId
r -> Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
op (Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST (State StdGen) (Fix SRTree)
-> StateT EGraph (State StdGen) (Fix SRTree -> SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom EClassId
l StateT EGraph (State StdGen) (Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST (State StdGen) (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b.
StateT EGraph (State StdGen) (a -> b)
-> StateT EGraph (State StdGen) a -> StateT EGraph (State StdGen) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> EClassId -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom EClassId
r
Uni Function
f EClassId
t -> Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f (Fix SRTree -> SRTree (Fix SRTree))
-> EGraphST (State StdGen) (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClassId -> EGraphST (State StdGen) (Fix SRTree)
getRndExpressionFrom EClassId
t
Var EClassId
ix -> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a. a -> StateT EGraph (State StdGen) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Var EClassId
ix
Const Double
x -> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a. a -> StateT EGraph (State StdGen) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const Double
x
Param EClassId
ix -> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a. a -> StateT EGraph (State StdGen) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree)))
-> SRTree (Fix SRTree)
-> StateT EGraph (State StdGen) (SRTree (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ EClassId -> SRTree (Fix SRTree)
forall val. EClassId -> SRTree val
Param EClassId
ix
where
randomRange :: (a, a) -> m a
randomRange (a, a)
rng = (s -> (a, s)) -> m a
forall a. (s -> (a, s)) -> m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state ((a, a) -> s -> (a, s)
forall g. RandomGen g => (a, a) -> g -> (a, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (a, a)
rng)
randomFrom :: [b] -> m b
randomFrom [b]
xs = do n <- (EClassId, EClassId) -> m EClassId
forall {s} {m :: * -> *} {a}.
(MonadState s m, Random a, RandomGen s) =>
(a, a) -> m a
randomRange (EClassId
0, [b] -> EClassId
forall a. [a] -> EClassId
forall (t :: * -> *) a. Foldable t => t a -> EClassId
length [b]
xs EClassId -> EClassId -> EClassId
forall a. Num a => a -> a -> a
- EClassId
1)
pure $ xs !! n
cleanMaps :: Monad m => EGraphST m ()
cleanMaps :: forall (m :: * -> *). Monad m => EGraphST m ()
cleanMaps = do
enode2eclass <- (EGraph -> Map ENode EClassId)
-> StateT EGraph m (Map ENode EClassId)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> Map ENode EClassId
_eNodeToEClass
entries <- forM (Map.toList enode2eclass) $ \(ENode
k,EClassId
v) -> do
k' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
k
v' <- canonical v
pure (k',v')
let enode2eclass' = [(ENode, EClassId)] -> Map ENode EClassId
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(ENode, EClassId)]
entries
eclassMap <- gets _eClass
entries' <- forM (IntMap.toList eclassMap) $ \(EClassId
k,EClass
v) -> do
k' <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
k
pure $ if k==k' then (Just (k,v)) else Nothing
let eclassMap' = [(EClassId, EClass)] -> ClassIdMap EClass
forall a. [(EClassId, a)] -> IntMap a
IntMap.fromList ([Maybe (EClassId, EClass)] -> [(EClassId, EClass)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (EClassId, EClass)]
entries')
canon <- gets _canonicalMap
entries'' <- forM (IntMap.toList canon) $ \(EClassId
k,EClassId
v) -> do
Maybe (EClassId, EClassId)
-> StateT EGraph m (Maybe (EClassId, EClassId))
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (EClassId, EClassId)
-> StateT EGraph m (Maybe (EClassId, EClassId)))
-> Maybe (EClassId, EClassId)
-> StateT EGraph m (Maybe (EClassId, EClassId))
forall a b. (a -> b) -> a -> b
$ if EClassId
kEClassId -> EClassId -> Bool
forall a. Eq a => a -> a -> Bool
==EClassId
v then (EClassId, EClassId) -> Maybe (EClassId, EClassId)
forall a. a -> Maybe a
Just (EClassId
k,EClassId
v) else Maybe (EClassId, EClassId)
forall a. Maybe a
Nothing
let canon' = [(EClassId, EClassId)] -> ClassIdMap EClassId
forall a. [(EClassId, a)] -> IntMap a
IntMap.fromList ([Maybe (EClassId, EClassId)] -> [(EClassId, EClassId)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (EClassId, EClassId)]
entries'')
eDB' <- gets _eDB
put $ EGraph canon enode2eclass' eclassMap' eDB'
forceState
forceState :: Monad m => StateT s m ()
forceState :: forall (m :: * -> *) s. Monad m => StateT s m ()
forceState = StateT s m s
forall s (m :: * -> *). MonadState s m => m s
get StateT s m s -> (s -> StateT s m ()) -> StateT s m ()
forall a b. StateT s m a -> (a -> StateT s m b) -> StateT s m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ !s
_ -> () -> StateT s m ()
forall a. a -> StateT s m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()