{-# LANGUAGE TupleSections #-}
module Algorithm.EqSat where
import Algorithm.EqSat.Egraph
import Algorithm.EqSat.DB
import Algorithm.EqSat.Info
import Algorithm.EqSat.Build
import Control.Lens (element, makeLenses, over, (&), (+~), (-~), (.~), (^.))
import Control.Monad.State
import Data.Function (on)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.List (intercalate, minimumBy)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (mapMaybe)
import Data.SRTree
import Data.HashSet (HashSet)
import qualified Data.HashSet as Set
import Control.Monad ( zipWithM )
import Debug.Trace
type Scheduler a = State (IntMap Int) a
fromJust :: Maybe a -> a
fromJust :: forall a. Maybe a -> a
fromJust (Just a
x) = a
x
fromJust Maybe a
_ = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"fromJust called with Nothing"
{-# INLINE fromJust #-}
eqSat :: Monad m => Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
eqSat :: forall (m :: * -> *).
Monad m =>
Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
eqSat Fix SRTree
expr [Rule]
rules CostFun
costFun Int
maxIt =
do root <- CostFun -> Fix SRTree -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Fix SRTree -> EGraphST m Int
fromTree CostFun
costFun Fix SRTree
expr
(end, it) <- runEqSat costFun rules maxIt
best <- getBest root
if not end
then do eqSat best rules costFun it
else pure best
type CostMap = Map EClassId (Int, Fix SRTree)
recalculateBest :: Monad m => CostFun -> EClassId -> EGraphST m (Fix SRTree)
recalculateBest :: forall (m :: * -> *).
Monad m =>
CostFun -> Int -> EGraphST m (Fix SRTree)
recalculateBest CostFun
costFun Int
eid =
do classes <- (EGraph -> ClassIdMap EClass)
-> StateT EGraph m (ClassIdMap EClass)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap EClass
_eClass
let costs = ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
forall k a. Map k a
Map.empty
eid' <- canonical eid
pure $ snd $ costs Map.! eid'
where
nodeCost :: CostMap -> ENode -> Maybe (Int, Fix SRTree)
nodeCost :: CostMap -> ENode -> Maybe (Int, Fix SRTree)
nodeCost CostMap
costMap ENode
enode =
do optChildren <- (Int -> Maybe (Int, Fix SRTree))
-> [Int] -> Maybe [(Int, Fix SRTree)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (CostMap
costMap CostMap -> Int -> Maybe (Int, Fix SRTree)
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!?) (ENode -> [Int]
forall a. SRTree a -> [a]
childrenOf ENode
enode)
let cc = ((Int, Fix SRTree) -> Int) -> [(Int, Fix SRTree)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst [(Int, Fix SRTree)]
optChildren
nc = ((Int, Fix SRTree) -> Fix SRTree)
-> [(Int, Fix SRTree)] -> [Fix SRTree]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Fix SRTree) -> Fix SRTree
forall a b. (a, b) -> b
snd [(Int, Fix SRTree)]
optChildren
n = [Int] -> ENode -> ENode
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Int]
cc ENode
enode
c = CostFun
costFun ENode
n
pure (c + sum cc, Fix $ replaceChildren nc enode)
minimumBy' :: (a -> a -> Ordering) -> [a] -> Maybe a
minimumBy' a -> a -> Ordering
f [] = Maybe a
forall a. Maybe a
Nothing
minimumBy' a -> a -> Ordering
f [a]
xs = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ (a -> a -> Ordering) -> [a] -> a
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy a -> a -> Ordering
f [a]
xs
fillUpCosts :: IntMap EClass -> CostMap -> CostMap
fillUpCosts :: ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
m =
case (Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap))
-> (Bool, CostMap) -> ClassIdMap EClass -> (Bool, CostMap)
forall a b. (Int -> a -> b -> b) -> b -> IntMap a -> b
IntMap.foldrWithKey Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
costOfClass (Bool
False, CostMap
m) ClassIdMap EClass
classes of
(Bool
False, CostMap
_) -> CostMap
m
(Bool
True, CostMap
m') -> ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
m'
costOfClass :: EClassId -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
costOfClass :: Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
costOfClass Int
eid EClass
ecl (Bool
b, CostMap
m) =
let currentCost :: Maybe (Int, Fix SRTree)
currentCost = CostMap
m CostMap -> Int -> Maybe (Int, Fix SRTree)
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Int
eid
minCost :: Maybe (Int, Fix SRTree)
minCost = ((Int, Fix SRTree) -> (Int, Fix SRTree) -> Ordering)
-> [(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree)
forall {a}. (a -> a -> Ordering) -> [a] -> Maybe a
minimumBy' (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, Fix SRTree) -> Int)
-> (Int, Fix SRTree)
-> (Int, Fix SRTree)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst)
([(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree))
-> [(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree)
forall a b. (a -> b) -> a -> b
$ (ENode -> Maybe (Int, Fix SRTree))
-> [ENode] -> [(Int, Fix SRTree)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (CostMap -> ENode -> Maybe (Int, Fix SRTree)
nodeCost CostMap
m)
([ENode] -> [(Int, Fix SRTree)]) -> [ENode] -> [(Int, Fix SRTree)]
forall a b. (a -> b) -> a -> b
$ (ENodeEnc -> ENode) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> [a] -> [b]
map ENodeEnc -> ENode
decodeEnode
([ENodeEnc] -> [ENode]) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> a -> b
$ HashSet ENodeEnc -> [ENodeEnc]
forall a. HashSet a -> [a]
Set.toList (EClass -> HashSet ENodeEnc
_eNodes EClass
ecl)
in case (Maybe (Int, Fix SRTree)
currentCost, Maybe (Int, Fix SRTree)
minCost) of
(Maybe (Int, Fix SRTree)
_, Maybe (Int, Fix SRTree)
Nothing) -> (Bool
b, CostMap
m)
(Maybe (Int, Fix SRTree)
Nothing, Just (Int, Fix SRTree)
new) -> (Bool
True, Int -> (Int, Fix SRTree) -> CostMap -> CostMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int
eid (Int, Fix SRTree)
new CostMap
m)
(Just (Int, Fix SRTree)
old, Just (Int, Fix SRTree)
new) -> if (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst (Int, Fix SRTree)
old Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst (Int, Fix SRTree)
new
then (Bool
b, CostMap
m)
else (Bool
True, Int -> (Int, Fix SRTree) -> CostMap -> CostMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int
eid (Int, Fix SRTree)
new CostMap
m)
runEqSat :: Monad m => CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
runEqSat :: forall (m :: * -> *).
Monad m =>
CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
runEqSat CostFun
costFun [Rule]
rules Int
maxIter = Int -> IntMap Int -> StateT EGraph m (Bool, Int)
forall {m :: * -> *}.
Monad m =>
Int -> IntMap Int -> StateT EGraph m (Bool, Int)
go Int
maxIter IntMap Int
forall a. IntMap a
IntMap.empty
where
rules' :: [Rule]
rules' = (Rule -> [Rule]) -> [Rule] -> [Rule]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Rule -> [Rule]
replaceEqRules [Rule]
rules
replaceEqRules :: Rule -> [Rule]
replaceEqRules :: Rule -> [Rule]
replaceEqRules (Pattern
p1 :=> Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2]
replaceEqRules (Pattern
p1 :==: Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2, Pattern
p2 Pattern -> Pattern -> Rule
:=> Pattern
p1]
replaceEqRules (Rule
r :| Condition
cond) = (Rule -> Rule) -> [Rule] -> [Rule]
forall a b. (a -> b) -> [a] -> [b]
map (Rule -> Condition -> Rule
:| Condition
cond) ([Rule] -> [Rule]) -> [Rule] -> [Rule]
forall a b. (a -> b) -> a -> b
$ Rule -> [Rule]
replaceEqRules Rule
r
go :: Int -> IntMap Int -> StateT EGraph m (Bool, Int)
go Int
it IntMap Int
sch = do eNodes <- (EGraph -> Map ENode Int) -> StateT EGraph m (Map ENode Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> Map ENode Int
_eNodeToEClass
eClasses <- gets _eClass
let matchSch = Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
it
matchAll = (Int -> Rule -> Scheduler [Rule])
-> [Int] -> [Rule] -> StateT (IntMap Int) Identity [[Rule]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> Rule -> Scheduler [Rule]
matchSch [Int
0..]
(rules, sch') = runState (matchAll rules') sch
matches <- mapM (\Rule
rule -> ((Map ClassOrVar ClassOrVar, ClassOrVar)
-> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)))
-> [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall a b. (a -> b) -> [a] -> [b]
map (Rule
rule,) ([(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> StateT
EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
forall (m :: * -> *).
Monad m =>
Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match (Rule -> Pattern
source Rule
rule)) $ concat rules
mapM_ (uncurry (applyMatch costFun)) $ concat matches
rebuild costFun
eNodes' <- gets _eNodeToEClass
eClasses' <- gets _eClass
if it == 1 || (eNodes' == eNodes && eClasses' == eClasses)
then pure (True, it)
else if IntMap.size eClasses' > 500
then pure (False, it)
else go (it-1) sch'
applySingleMergeOnlyEqSat :: Monad m => CostFun -> [Rule] -> EGraphST m ()
applySingleMergeOnlyEqSat :: forall (m :: * -> *). Monad m => CostFun -> [Rule] -> EGraphST m ()
applySingleMergeOnlyEqSat CostFun
costFun [Rule]
rules =
do db <- (EGraph -> DB) -> StateT EGraph m DB
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
let matchSch = Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
10
matchAll = (Int -> Rule -> Scheduler [Rule])
-> [Int] -> [Rule] -> StateT (IntMap Int) Identity [[Rule]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> Rule -> Scheduler [Rule]
matchSch [Int
0..]
(rules, sch') = runState (matchAll rules') IntMap.empty
matches <- mapM (\Rule
rule -> ((Map ClassOrVar ClassOrVar, ClassOrVar)
-> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)))
-> [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall a b. (a -> b) -> [a] -> [b]
map (Rule
rule,) ([(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> StateT
EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
forall (m :: * -> *).
Monad m =>
Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match (Rule -> Pattern
source Rule
rule)) $ concat rules
mapM_ (uncurry (applyMergeOnlyMatch costFun)) $ concat matches
rebuild costFun
where
rules' :: [Rule]
rules' = (Rule -> [Rule]) -> [Rule] -> [Rule]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Rule -> [Rule]
replaceEqRules [Rule]
rules
replaceEqRules :: Rule -> [Rule]
replaceEqRules :: Rule -> [Rule]
replaceEqRules (Pattern
p1 :=> Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2]
replaceEqRules (Pattern
p1 :==: Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2, Pattern
p2 Pattern -> Pattern -> Rule
:=> Pattern
p1]
replaceEqRules (Rule
r :| Condition
cond) = (Rule -> Rule) -> [Rule] -> [Rule]
forall a b. (a -> b) -> [a] -> [b]
map (Rule -> Condition -> Rule
:| Condition
cond) ([Rule] -> [Rule]) -> [Rule] -> [Rule]
forall a b. (a -> b) -> a -> b
$ Rule -> [Rule]
replaceEqRules Rule
r
matchWithScheduler :: Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler :: Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
it Int
ruleNumber Rule
rule =
do mbBan <- (IntMap Int -> Maybe Int)
-> StateT (IntMap Int) Identity (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntMap Int -> Int -> Maybe Int
forall a. IntMap a -> Int -> Maybe a
IntMap.!? Int
ruleNumber)
if mbBan /= Nothing && fromJust mbBan <= it
then pure []
else do
modify (IntMap.insert ruleNumber (it+5))
pure [rule]