{-# LANGUAGE TupleSections #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.EqSat
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :
--
-- Equality Saturation for SRTree
-- Heavily based on hegg (https://github.com/alt-romes/hegg by alt-romes)
--
-----------------------------------------------------------------------------

module Algorithm.EqSat where

import Algorithm.EqSat.Egraph
import Algorithm.EqSat.DB
import Algorithm.EqSat.Info
import Algorithm.EqSat.Build
import Control.Lens (element, makeLenses, over, (&), (+~), (-~), (.~), (^.))
import Control.Monad.State
import Data.Function (on)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.List (intercalate, minimumBy)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (mapMaybe)
import Data.SRTree
import Data.HashSet (HashSet)
import qualified Data.HashSet as Set
import Control.Monad ( zipWithM )

import Debug.Trace

-- | The `Scheduler` stores a map with the banned iterations of a certain rule . 
-- TODO: make it more customizable.
type Scheduler a = State (IntMap Int) a

-- to avoid importing
fromJust :: Maybe a -> a
fromJust :: forall a. Maybe a -> a
fromJust (Just a
x) = a
x
fromJust Maybe a
_        = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"fromJust called with Nothing"
{-# INLINE fromJust #-}

-- | runs equality saturation from an expression tree,
-- a given set of rules, and a cost function.
-- Returns the tree with the smallest cost.
eqSat :: Monad m => Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
eqSat :: forall (m :: * -> *).
Monad m =>
Fix SRTree -> [Rule] -> CostFun -> Int -> EGraphST m (Fix SRTree)
eqSat Fix SRTree
expr [Rule]
rules CostFun
costFun Int
maxIt =
    do root <- CostFun -> Fix SRTree -> EGraphST m Int
forall (m :: * -> *).
Monad m =>
CostFun -> Fix SRTree -> EGraphST m Int
fromTree CostFun
costFun Fix SRTree
expr
       (end, it) <- runEqSat costFun rules maxIt
       best      <- getBest root
       --info      <- gets ((IntMap.! root) . _eClass)
       --info2     <- gets ((IntMap.! 9) . _eClass)
       --traceShow (info, info2) $
       if not end -- if had an early stop
         then do eqSat best rules costFun it -- reapplies eqsat on the best so far 
         else pure best

type CostMap = Map EClassId (Int, Fix SRTree)

-- | recalculates the costs with a new cost function
recalculateBest :: Monad m => CostFun -> EClassId -> EGraphST m (Fix SRTree)
recalculateBest :: forall (m :: * -> *).
Monad m =>
CostFun -> Int -> EGraphST m (Fix SRTree)
recalculateBest CostFun
costFun Int
eid =
    do classes <- (EGraph -> ClassIdMap EClass)
-> StateT EGraph m (ClassIdMap EClass)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap EClass
_eClass
       let costs = ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
forall k a. Map k a
Map.empty
       eid' <- canonical eid
       pure $ snd $ costs Map.! eid'
    where
        nodeCost :: CostMap -> ENode -> Maybe (Int, Fix SRTree)
        nodeCost :: CostMap -> ENode -> Maybe (Int, Fix SRTree)
nodeCost CostMap
costMap ENode
enode =
          do optChildren <- (Int -> Maybe (Int, Fix SRTree))
-> [Int] -> Maybe [(Int, Fix SRTree)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (CostMap
costMap CostMap -> Int -> Maybe (Int, Fix SRTree)
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!?) (ENode -> [Int]
forall a. SRTree a -> [a]
childrenOf ENode
enode) -- | gets the cost of the children, if one is missing, returns Nothing
             let cc = ((Int, Fix SRTree) -> Int) -> [(Int, Fix SRTree)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst [(Int, Fix SRTree)]
optChildren
                 nc = ((Int, Fix SRTree) -> Fix SRTree)
-> [(Int, Fix SRTree)] -> [Fix SRTree]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Fix SRTree) -> Fix SRTree
forall a b. (a, b) -> b
snd [(Int, Fix SRTree)]
optChildren
                 n  = [Int] -> ENode -> ENode
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Int]
cc ENode
enode
                 c  = CostFun
costFun ENode
n
             pure (c + sum cc, Fix $ replaceChildren nc enode) -- | otherwise, returns the cost of the node + children and the expression so far

        minimumBy' :: (a -> a -> Ordering) -> [a] -> Maybe a
minimumBy' a -> a -> Ordering
f [] = Maybe a
forall a. Maybe a
Nothing
        minimumBy' a -> a -> Ordering
f [a]
xs = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ (a -> a -> Ordering) -> [a] -> a
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy a -> a -> Ordering
f [a]
xs

        fillUpCosts :: IntMap EClass -> CostMap -> CostMap
        fillUpCosts :: ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
m =
            case (Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap))
-> (Bool, CostMap) -> ClassIdMap EClass -> (Bool, CostMap)
forall a b. (Int -> a -> b -> b) -> b -> IntMap a -> b
IntMap.foldrWithKey Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
costOfClass (Bool
False, CostMap
m) ClassIdMap EClass
classes of -- applies costOfClass to each class
              (Bool
False, CostMap
_) -> CostMap
m
              (Bool
True, CostMap
m') -> ClassIdMap EClass -> CostMap -> CostMap
fillUpCosts ClassIdMap EClass
classes CostMap
m' -- | if something changed, recurse

        costOfClass :: EClassId -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
        costOfClass :: Int -> EClass -> (Bool, CostMap) -> (Bool, CostMap)
costOfClass Int
eid EClass
ecl (Bool
b, CostMap
m) =
            let currentCost :: Maybe (Int, Fix SRTree)
currentCost = CostMap
m CostMap -> Int -> Maybe (Int, Fix SRTree)
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Int
eid
                minCost :: Maybe (Int, Fix SRTree)
minCost     = ((Int, Fix SRTree) -> (Int, Fix SRTree) -> Ordering)
-> [(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree)
forall {a}. (a -> a -> Ordering) -> [a] -> Maybe a
minimumBy' (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, Fix SRTree) -> Int)
-> (Int, Fix SRTree)
-> (Int, Fix SRTree)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst)  -- get the minimum available cost of the nodes of this class
                            ([(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree))
-> [(Int, Fix SRTree)] -> Maybe (Int, Fix SRTree)
forall a b. (a -> b) -> a -> b
$ (ENode -> Maybe (Int, Fix SRTree))
-> [ENode] -> [(Int, Fix SRTree)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (CostMap -> ENode -> Maybe (Int, Fix SRTree)
nodeCost CostMap
m)
                            ([ENode] -> [(Int, Fix SRTree)]) -> [ENode] -> [(Int, Fix SRTree)]
forall a b. (a -> b) -> a -> b
$ (ENodeEnc -> ENode) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> [a] -> [b]
map ENodeEnc -> ENode
decodeEnode
                            ([ENodeEnc] -> [ENode]) -> [ENodeEnc] -> [ENode]
forall a b. (a -> b) -> a -> b
$ HashSet ENodeEnc -> [ENodeEnc]
forall a. HashSet a -> [a]
Set.toList (EClass -> HashSet ENodeEnc
_eNodes EClass
ecl)
            in case (Maybe (Int, Fix SRTree)
currentCost, Maybe (Int, Fix SRTree)
minCost) of -- replace the costs accordingly
                  (Maybe (Int, Fix SRTree)
_, Maybe (Int, Fix SRTree)
Nothing)         -> (Bool
b, CostMap
m)
                  (Maybe (Int, Fix SRTree)
Nothing, Just (Int, Fix SRTree)
new)  -> (Bool
True, Int -> (Int, Fix SRTree) -> CostMap -> CostMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int
eid (Int, Fix SRTree)
new CostMap
m)
                  (Just (Int, Fix SRTree)
old, Just (Int, Fix SRTree)
new) -> if (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst (Int, Fix SRTree)
old Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= (Int, Fix SRTree) -> Int
forall a b. (a, b) -> a
fst (Int, Fix SRTree)
new
                                            then (Bool
b, CostMap
m)
                                            else (Bool
True, Int -> (Int, Fix SRTree) -> CostMap -> CostMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int
eid (Int, Fix SRTree)
new CostMap
m)

-- | run equality saturation for a number of iterations
runEqSat :: Monad m => CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
runEqSat :: forall (m :: * -> *).
Monad m =>
CostFun -> [Rule] -> Int -> EGraphST m (Bool, Int)
runEqSat CostFun
costFun [Rule]
rules Int
maxIter = Int -> IntMap Int -> StateT EGraph m (Bool, Int)
forall {m :: * -> *}.
Monad m =>
Int -> IntMap Int -> StateT EGraph m (Bool, Int)
go Int
maxIter IntMap Int
forall a. IntMap a
IntMap.empty
    where
        rules' :: [Rule]
rules' = (Rule -> [Rule]) -> [Rule] -> [Rule]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Rule -> [Rule]
replaceEqRules [Rule]
rules

        -- replaces the equality rules with two one-way rules
        replaceEqRules :: Rule -> [Rule]
        replaceEqRules :: Rule -> [Rule]
replaceEqRules (Pattern
p1 :=> Pattern
p2)  = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2]
        replaceEqRules (Pattern
p1 :==: Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2, Pattern
p2 Pattern -> Pattern -> Rule
:=> Pattern
p1]
        replaceEqRules (Rule
r :| Condition
cond)  = (Rule -> Rule) -> [Rule] -> [Rule]
forall a b. (a -> b) -> [a] -> [b]
map (Rule -> Condition -> Rule
:| Condition
cond) ([Rule] -> [Rule]) -> [Rule] -> [Rule]
forall a b. (a -> b) -> a -> b
$ Rule -> [Rule]
replaceEqRules Rule
r

        go :: Int -> IntMap Int -> StateT EGraph m (Bool, Int)
go Int
it IntMap Int
sch = do eNodes   <- (EGraph -> Map ENode Int) -> StateT EGraph m (Map ENode Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> Map ENode Int
_eNodeToEClass
                       eClasses <- gets _eClass
                       --createDB
                       --db       <- gets (_patDB . _eDB) -- createDB -- creates the DB

                       -- step 1: match the rules
                       let matchSch        = Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
it
                           matchAll        = (Int -> Rule -> Scheduler [Rule])
-> [Int] -> [Rule] -> StateT (IntMap Int) Identity [[Rule]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> Rule -> Scheduler [Rule]
matchSch [Int
0..]
                           (rules, sch') = runState (matchAll rules') sch

                       -- step 2: apply matches and rebuild
                       matches <- mapM (\Rule
rule -> ((Map ClassOrVar ClassOrVar, ClassOrVar)
 -> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)))
-> [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall a b. (a -> b) -> [a] -> [b]
map (Rule
rule,) ([(Map ClassOrVar ClassOrVar, ClassOrVar)]
 -> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> StateT
     EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
forall (m :: * -> *).
Monad m =>
Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match (Rule -> Pattern
source Rule
rule)) $ concat rules
                       mapM_ (uncurry (applyMatch costFun)) $ concat matches
                       rebuild costFun

                       -- recalculate heights
                       --calculateHeights
                       eNodes'   <- gets _eNodeToEClass
                       eClasses' <- gets _eClass

                       -- if nothing changed, return
                       if it == 1 || (eNodes' == eNodes && eClasses' == eClasses)
                          then pure (True, it)
                          else if IntMap.size eClasses' > 500 -- maximum allowed number of e-classes. TODO: customize
                                 then pure (False, it)
                                 else go (it-1) sch'

-- | apply a single step of merge-only equality saturation
applySingleMergeOnlyEqSat :: Monad m => CostFun -> [Rule] -> EGraphST m ()
applySingleMergeOnlyEqSat :: forall (m :: * -> *). Monad m => CostFun -> [Rule] -> EGraphST m ()
applySingleMergeOnlyEqSat CostFun
costFun [Rule]
rules =
  do db <- (EGraph -> DB) -> StateT EGraph m DB
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB) -- createDB
     let matchSch        = Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
10
         matchAll        = (Int -> Rule -> Scheduler [Rule])
-> [Int] -> [Rule] -> StateT (IntMap Int) Identity [[Rule]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> Rule -> Scheduler [Rule]
matchSch [Int
0..]
         (rules, sch') = runState (matchAll rules') IntMap.empty
     matches <- mapM (\Rule
rule -> ((Map ClassOrVar ClassOrVar, ClassOrVar)
 -> (Rule, (Map ClassOrVar ClassOrVar, ClassOrVar)))
-> [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall a b. (a -> b) -> [a] -> [b]
map (Rule
rule,) ([(Map ClassOrVar ClassOrVar, ClassOrVar)]
 -> [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))])
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
-> StateT
     EGraph m [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern
-> StateT EGraph m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
forall (m :: * -> *).
Monad m =>
Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match (Rule -> Pattern
source Rule
rule)) $ concat rules
     mapM_ (uncurry (applyMergeOnlyMatch costFun)) $ concat matches
     rebuild costFun
     -- recalculate heights
     --calculateHeights
      where
        rules' :: [Rule]
rules' = (Rule -> [Rule]) -> [Rule] -> [Rule]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Rule -> [Rule]
replaceEqRules [Rule]
rules

        -- replaces the equality rules with two one-way rules
        replaceEqRules :: Rule -> [Rule]
        replaceEqRules :: Rule -> [Rule]
replaceEqRules (Pattern
p1 :=> Pattern
p2)  = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2]
        replaceEqRules (Pattern
p1 :==: Pattern
p2) = [Pattern
p1 Pattern -> Pattern -> Rule
:=> Pattern
p2, Pattern
p2 Pattern -> Pattern -> Rule
:=> Pattern
p1]
        replaceEqRules (Rule
r :| Condition
cond)  = (Rule -> Rule) -> [Rule] -> [Rule]
forall a b. (a -> b) -> [a] -> [b]
map (Rule -> Condition -> Rule
:| Condition
cond) ([Rule] -> [Rule]) -> [Rule] -> [Rule]
forall a b. (a -> b) -> a -> b
$ Rule -> [Rule]
replaceEqRules Rule
r

-- | matches the rules given a scheduler
matchWithScheduler :: Int -> Int -> Rule -> Scheduler [Rule] -- [(Rule, (Map ClassOrVar ClassOrVar, ClassOrVar))]
matchWithScheduler :: Int -> Int -> Rule -> Scheduler [Rule]
matchWithScheduler Int
it Int
ruleNumber Rule
rule =
  do mbBan <- (IntMap Int -> Maybe Int)
-> StateT (IntMap Int) Identity (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (IntMap Int -> Int -> Maybe Int
forall a. IntMap a -> Int -> Maybe a
IntMap.!? Int
ruleNumber)
     if mbBan /= Nothing && fromJust mbBan <= it -- check if the rule is banned
        then pure []
        else do -- let matches = match db (source rule)
                modify (IntMap.insert ruleNumber (it+5))
                pure [rule] -- $ map (rule,) matches