{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE BangPatterns #-}
module Algorithm.EqSat.Queries where
import Algorithm.EqSat.Egraph
import qualified Data.IntMap as IntMap
import qualified Data.Map as Map
import qualified Data.HashSet as Set
import Control.Monad.State ( gets, modify' )
import Control.Monad ( filterM )
import Control.Lens ( over )
import Data.Maybe
import Data.Sequence ( Seq(..) )
import Debug.Trace
getEClassesThat :: Monad m => (EClass -> Bool) -> EGraphST m [EClassId]
getEClassesThat :: forall (m :: * -> *).
Monad m =>
(EClass -> Bool) -> EGraphST m [EClassId]
getEClassesThat EClass -> Bool
p = do
(EGraph -> [EClassId]) -> EGraphST m [EClassId]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((EClassId, EClass) -> EClassId)
-> [(EClassId, EClass)] -> [EClassId]
forall a b. (a -> b) -> [a] -> [b]
map (EClassId, EClass) -> EClassId
forall a b. (a, b) -> a
fst ([(EClassId, EClass)] -> [EClassId])
-> (EGraph -> [(EClassId, EClass)]) -> EGraph -> [EClassId]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((EClassId, EClass) -> Bool)
-> [(EClassId, EClass)] -> [(EClassId, EClass)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(EClassId
ecId, EClass
ec) -> EClass -> Bool
p EClass
ec) ([(EClassId, EClass)] -> [(EClassId, EClass)])
-> (EGraph -> [(EClassId, EClass)])
-> EGraph
-> [(EClassId, EClass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap EClass -> [(EClassId, EClass)]
forall a. IntMap a -> [(EClassId, a)]
IntMap.toList (IntMap EClass -> [(EClassId, EClass)])
-> (EGraph -> IntMap EClass) -> EGraph -> [(EClassId, EClass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
where
go :: Monad m => [EClassId] -> EGraphST m [EClassId]
go :: forall (m :: * -> *).
Monad m =>
[EClassId] -> EGraphST m [EClassId]
go [] = [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
go (EClassId
ecId:[EClassId]
ecs) = do Bool
ec <- (EGraph -> Bool) -> StateT EGraph m Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClass -> Bool
p (EClass -> Bool) -> (EGraph -> EClass) -> EGraph -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
[EClassId]
ecs' <- [EClassId] -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
[EClassId] -> EGraphST m [EClassId]
go [EClassId]
ecs
if Bool
ec
then [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (EClassId
ecIdEClassId -> [EClassId] -> [EClassId]
forall a. a -> [a] -> [a]
:[EClassId]
ecs')
else [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
ecs'
updateFitness :: Monad m => Double -> EClassId -> EGraphST m ()
updateFitness :: forall (m :: * -> *).
Monad m =>
Double -> EClassId -> EGraphST m ()
updateFitness Double
f EClassId
ecId = do
EClass
ec <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
let info :: EClassData
info = EClass -> EClassData
_info EClass
ec
(EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
-> (IntMap EClass -> IntMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
Lens' EGraph (IntMap EClass)
eClass (EClassId -> EClass -> IntMap EClass -> IntMap EClass
forall a. EClassId -> a -> IntMap a -> IntMap a
IntMap.insert EClassId
ecId EClass
ec{_info=info{_fitness = Just f}})
findRootClasses :: Monad m => EGraphST m [EClassId]
findRootClasses :: forall (m :: * -> *). Monad m => EGraphST m [EClassId]
findRootClasses = (EGraph -> [EClassId]) -> StateT EGraph m [EClassId]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((EClassId, EClass) -> EClassId)
-> [(EClassId, EClass)] -> [EClassId]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (EClassId, EClass) -> EClassId
forall a b. (a, b) -> a
fst ([(EClassId, EClass)] -> [EClassId])
-> (EGraph -> [(EClassId, EClass)]) -> EGraph -> [EClassId]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((EClassId, EClass) -> Bool)
-> [(EClassId, EClass)] -> [(EClassId, EClass)]
forall a. (a -> Bool) -> [a] -> [a]
Prelude.filter (EClassId, EClass) -> Bool
isParent ([(EClassId, EClass)] -> [(EClassId, EClass)])
-> (EGraph -> [(EClassId, EClass)])
-> EGraph
-> [(EClassId, EClass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap EClass -> [(EClassId, EClass)]
forall a. IntMap a -> [(EClassId, a)]
IntMap.toList (IntMap EClass -> [(EClassId, EClass)])
-> (EGraph -> IntMap EClass) -> EGraph -> [(EClassId, EClass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
where
isParent :: (EClassId, EClass) -> Bool
isParent (EClassId
k, EClass
v) = HashSet (EClassId, ENode) -> Bool
forall a. HashSet a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
Prelude.null (EClass -> HashSet (EClassId, ENode)
_parents EClass
v) Bool -> Bool -> Bool
|| (EClassId
k EClassId -> HashSet EClassId -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`Set.member` (((EClassId, ENode) -> EClassId)
-> HashSet (EClassId, ENode) -> HashSet EClassId
forall b a.
(Hashable b, Eq b) =>
(a -> b) -> HashSet a -> HashSet b
Set.map (EClassId, ENode) -> EClassId
forall a b. (a, b) -> a
fst (EClass -> HashSet (EClassId, ENode)
_parents EClass
v)))
getTopECLassThat :: Monad m => Int -> (EClass -> Bool) -> EGraphST m [EClassId]
getTopECLassThat :: forall (m :: * -> *).
Monad m =>
EClassId -> (EClass -> Bool) -> EGraphST m [EClassId]
getTopECLassThat EClassId
n EClass -> Bool
p = do
(EGraph -> RangeTree Double) -> StateT EGraph m (RangeTree Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> RangeTree Double
_fitRangeDB (EGraphDB -> RangeTree Double)
-> (EGraph -> EGraphDB) -> EGraph -> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
StateT EGraph m (RangeTree Double)
-> (RangeTree Double -> EGraphST m [EClassId])
-> EGraphST m [EClassId]
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
n []
where
go :: Monad m => Int -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go :: forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
0 [EClassId]
bests RangeTree Double
rt = [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
go EClassId
m [EClassId]
bests RangeTree Double
rt = case RangeTree Double
rt of
RangeTree Double
Empty -> [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
RangeTree Double
t :|> (Double, EClassId)
y -> do let x :: EClassId
x = (Double, EClassId) -> EClassId
forall a b. (a, b) -> b
snd (Double, EClassId)
y
EClassId
ecId <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
x
EClass
ec <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
if (Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite (Double -> Bool) -> (EClass -> Double) -> EClass -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Double -> Double)
-> (EClass -> Maybe Double) -> EClass -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> Bool) -> EClass -> Bool
forall a b. (a -> b) -> a -> b
$ EClass
ec)
then [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
else if EClass -> Bool
p EClass
ec
then EClassId
-> [EClassId] -> RangeTree Double -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go (EClassId
mEClassId -> EClassId -> EClassId
forall a. Num a => a -> a -> a
-EClassId
1) (EClassId
xEClassId -> [EClassId] -> [EClassId]
forall a. a -> [a] -> [a]
:[EClassId]
bests) RangeTree Double
t
else EClassId
-> [EClassId] -> RangeTree Double -> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId -> [EClassId] -> RangeTree Double -> EGraphST m [EClassId]
go EClassId
m [EClassId]
bests RangeTree Double
t
getTopECLassWithSize :: Monad m => Int -> Int -> EGraphST m [EClassId]
getTopECLassWithSize :: forall (m :: * -> *).
Monad m =>
EClassId -> EClassId -> EGraphST m [EClassId]
getTopECLassWithSize EClassId
sz EClassId
n = do
(EGraph -> Maybe (RangeTree Double))
-> StateT EGraph m (Maybe (RangeTree Double))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap (RangeTree Double) -> EClassId -> Maybe (RangeTree Double)
forall a. IntMap a -> EClassId -> Maybe a
IntMap.!? EClassId
sz) (IntMap (RangeTree Double) -> Maybe (RangeTree Double))
-> (EGraph -> IntMap (RangeTree Double))
-> EGraph
-> Maybe (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> IntMap (RangeTree Double)
_sizeFitDB (EGraphDB -> IntMap (RangeTree Double))
-> (EGraph -> EGraphDB) -> EGraph -> IntMap (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
StateT EGraph m (Maybe (RangeTree Double))
-> (Maybe (RangeTree Double) -> EGraphST m [EClassId])
-> EGraphST m [EClassId]
forall a b.
StateT EGraph m a -> (a -> StateT EGraph m b) -> StateT EGraph m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= EClassId
-> [EClassId] -> Maybe (RangeTree Double) -> EGraphST m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId
-> [EClassId] -> Maybe (RangeTree Double) -> EGraphST m [EClassId]
go EClassId
n []
where
go :: Monad m => Int -> [EClassId] -> Maybe (RangeTree Double) -> EGraphST m [EClassId]
go :: forall (m :: * -> *).
Monad m =>
EClassId
-> [EClassId] -> Maybe (RangeTree Double) -> EGraphST m [EClassId]
go EClassId
_ [EClassId]
bests Maybe (RangeTree Double)
Nothing = [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
go EClassId
0 [EClassId]
bests (Just RangeTree Double
rt) = [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
go EClassId
m [EClassId]
bests (Just RangeTree Double
rt) = case RangeTree Double
rt of
RangeTree Double
Empty -> [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
RangeTree Double
t :|> (Double, EClassId)
y -> do let x :: EClassId
x = (Double, EClassId) -> EClassId
forall a b. (a, b) -> b
snd (Double, EClassId)
y
EClassId
ecId <- EClassId -> EGraphST m EClassId
forall (m :: * -> *). Monad m => EClassId -> EGraphST m EClassId
canonical EClassId
x
EClass
ec <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap EClass -> EClassId -> EClass
forall a. IntMap a -> EClassId -> a
IntMap.! EClassId
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
if (Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite (Double -> Bool) -> (EClass -> Double) -> EClass -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Double -> Double)
-> (EClass -> Maybe Double) -> EClass -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> Bool) -> EClass -> Bool
forall a b. (a -> b) -> a -> b
$ EClass
ec)
then [EClassId] -> StateT EGraph m [EClassId]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [EClassId]
bests
else EClassId
-> [EClassId]
-> Maybe (RangeTree Double)
-> StateT EGraph m [EClassId]
forall (m :: * -> *).
Monad m =>
EClassId
-> [EClassId] -> Maybe (RangeTree Double) -> EGraphST m [EClassId]
go (EClassId
mEClassId -> EClassId -> EClassId
forall a. Num a => a -> a -> a
-EClassId
1) (EClassId
xEClassId -> [EClassId] -> [EClassId]
forall a. a -> [a] -> [a]
:[EClassId]
bests) (RangeTree Double -> Maybe (RangeTree Double)
forall a. a -> Maybe a
Just RangeTree Double
t)