{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE MonoLocalBinds #-}
{-|
   Given an e-graph representing expressions of our language, we might want to
   extract, out of all expressions represented by some equivalence class, /the best/
   expression (according to a 'CostFunction') represented by that class

   The function 'extractBest' allows us to do exactly that: get the best
   expression represented in an e-class of an e-graph given a 'CostFunction'
 -}
module Data.Equality.Extraction
  (
  -- * Extraction
    extractBest

  -- * Cost
  , CostFunction
  , depthCost
  ) where

import qualified Data.Set as S
import qualified Data.IntMap.Strict as IM

import Data.Equality.Graph.Internal (EGraph(classes))
import Data.Equality.Utils
import Data.Equality.Graph

-- vvvv and necessarily all the best sub-expressions from children equilalence classes

-- | Extract the /best/ expression from an equivalence class according to a
-- 'CostFunction'
--
-- @
-- (i, egr) = ...
--    i <- represent expr
--            ...
--
-- bestExpr = extractBest egr 'depthCost' i
-- @
--
-- For a real example you might want to check out the source code of 'Data.Equality.Saturation.equalitySaturation''
extractBest :: forall anl lang cost
             . (Language lang, Ord cost)
            => EGraph anl lang        -- ^ The e-graph out of which we are extracting an expression
            -> CostFunction lang cost -- ^ The cost function to define /best/
            -> ClassId                -- ^ The e-class from which we'll extract the expression
            -> Fix lang               -- ^ The resulting /best/ expression, in its fixed point form.
extractBest :: forall anl (lang :: * -> *) cost.
(Language lang, Ord cost) =>
EGraph anl lang -> CostFunction lang cost -> ClassId -> Fix lang
extractBest EGraph anl lang
egr CostFunction lang cost
cost ((ClassId -> EGraph anl lang -> ClassId)
-> EGraph anl lang -> ClassId -> ClassId
forall a b c. (a -> b -> c) -> b -> a -> c
flip ClassId -> EGraph anl lang -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find EGraph anl lang
egr -> ClassId
i) = 

    -- Use `egg`s strategy of find costs for all possible classes and then just
    -- picking up the best from the target e-class.  In practice this shouldn't
    -- find the cost of unused nodes because the "topmost" e-class will be the
    -- target, and all sub-classes must be calculated?
    let allCosts :: ClassIdMap (CostWithExpr lang cost)
allCosts = ClassIdMap (EClass anl lang)
-> ClassIdMap (CostWithExpr lang cost)
-> ClassIdMap (CostWithExpr lang cost)
findCosts (EGraph anl lang -> ClassIdMap (EClass anl lang)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph anl lang
egr) ClassIdMap (CostWithExpr lang cost)
forall a. Monoid a => a
mempty

     in case ClassId
-> ClassIdMap (CostWithExpr lang cost)
-> Maybe (CostWithExpr lang cost)
forall (lang :: * -> *) a.
ClassId
-> ClassIdMap (CostWithExpr lang a) -> Maybe (CostWithExpr lang a)
findBest ClassId
i ClassIdMap (CostWithExpr lang cost)
allCosts of
        Just (CostWithExpr (cost
_,Fix lang
n)) -> Fix lang
n
        Maybe (CostWithExpr lang cost)
Nothing    -> [Char] -> Fix lang
forall a. HasCallStack => [Char] -> a
error ([Char] -> Fix lang) -> [Char] -> Fix lang
forall a b. (a -> b) -> a -> b
$ [Char]
"Couldn't find a best node for e-class " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> ClassId -> [Char]
forall a. Show a => a -> [Char]
show ClassId
i

  where

    -- | Find the lowest cost of all e-classes in an e-graph in an extraction
    findCosts :: ClassIdMap (EClass anl lang) -> ClassIdMap (CostWithExpr lang cost) -> ClassIdMap (CostWithExpr lang cost)
    findCosts :: ClassIdMap (EClass anl lang)
-> ClassIdMap (CostWithExpr lang cost)
-> ClassIdMap (CostWithExpr lang cost)
findCosts ClassIdMap (EClass anl lang)
eclasses ClassIdMap (CostWithExpr lang cost)
current =

      let (Bool
modified, ClassIdMap (CostWithExpr lang cost)
updated) = ((Bool, ClassIdMap (CostWithExpr lang cost))
 -> ClassId
 -> EClass anl lang
 -> (Bool, ClassIdMap (CostWithExpr lang cost)))
-> (Bool, ClassIdMap (CostWithExpr lang cost))
-> ClassIdMap (EClass anl lang)
-> (Bool, ClassIdMap (CostWithExpr lang cost))
forall a b. (a -> ClassId -> b -> a) -> a -> IntMap b -> a
IM.foldlWithKey (Bool, ClassIdMap (CostWithExpr lang cost))
-> ClassId
-> EClass anl lang
-> (Bool, ClassIdMap (CostWithExpr lang cost))
f (Bool
False, ClassIdMap (CostWithExpr lang cost)
current) ClassIdMap (EClass anl lang)
eclasses

          {-# INLINE f #-}
          f :: (Bool, ClassIdMap (CostWithExpr lang cost)) -> Int -> EClass anl lang -> (Bool, ClassIdMap (CostWithExpr lang cost))
          f :: (Bool, ClassIdMap (CostWithExpr lang cost))
-> ClassId
-> EClass anl lang
-> (Bool, ClassIdMap (CostWithExpr lang cost))
f acc :: (Bool, ClassIdMap (CostWithExpr lang cost))
acc@(Bool
_, ClassIdMap (CostWithExpr lang cost)
beingUpdated) ClassId
i' EClass{eClassNodes :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> Set (ENode language)
eClassNodes = Set (ENode lang)
nodes} =
                let
                    currentCost :: Maybe (CostWithExpr lang cost)
currentCost = ClassId
-> ClassIdMap (CostWithExpr lang cost)
-> Maybe (CostWithExpr lang cost)
forall a. ClassId -> IntMap a -> Maybe a
IM.lookup ClassId
i' ClassIdMap (CostWithExpr lang cost)
beingUpdated

                    newCost :: Maybe (CostWithExpr lang cost)
newCost = (Maybe (CostWithExpr lang cost)
 -> ENode lang -> Maybe (CostWithExpr lang cost))
-> Maybe (CostWithExpr lang cost)
-> Set (ENode lang)
-> Maybe (CostWithExpr lang cost)
forall a b. (a -> b -> a) -> a -> Set b -> a
S.foldl' (\Maybe (CostWithExpr lang cost)
c ENode lang
n -> case (Maybe (CostWithExpr lang cost)
c, Traversable lang =>
ClassIdMap (CostWithExpr lang cost)
-> ENode lang -> Maybe (CostWithExpr lang cost)
ClassIdMap (CostWithExpr lang cost)
-> ENode lang -> Maybe (CostWithExpr lang cost)
nodeTotalCost ClassIdMap (CostWithExpr lang cost)
beingUpdated ENode lang
n) of
                                                  (Maybe (CostWithExpr lang cost)
Nothing, Maybe (CostWithExpr lang cost)
Nothing) -> Maybe (CostWithExpr lang cost)
forall a. Maybe a
Nothing
                                                  (Maybe (CostWithExpr lang cost)
Nothing, Just CostWithExpr lang cost
nc) -> CostWithExpr lang cost -> Maybe (CostWithExpr lang cost)
forall a. a -> Maybe a
Just CostWithExpr lang cost
nc
                                                  (Just CostWithExpr lang cost
oc, Maybe (CostWithExpr lang cost)
Nothing) -> CostWithExpr lang cost -> Maybe (CostWithExpr lang cost)
forall a. a -> Maybe a
Just CostWithExpr lang cost
oc
                                                  (Just CostWithExpr lang cost
oc, Just CostWithExpr lang cost
nc) -> CostWithExpr lang cost -> Maybe (CostWithExpr lang cost)
forall a. a -> Maybe a
Just (CostWithExpr lang cost
oc CostWithExpr lang cost
-> CostWithExpr lang cost -> CostWithExpr lang cost
forall a. Ord a => a -> a -> a
`min` CostWithExpr lang cost
nc)
                                       ) Maybe (CostWithExpr lang cost)
forall a. Maybe a
Nothing Set (ENode lang)
nodes
                    -- Current cost + get lowest cost and corresponding node of an e-class if possible
                 in case (Maybe (CostWithExpr lang cost)
currentCost, Maybe (CostWithExpr lang cost)
newCost) of

                    (Maybe (CostWithExpr lang cost)
Nothing, Just CostWithExpr lang cost
new) -> (Bool
True, ClassId
-> CostWithExpr lang cost
-> ClassIdMap (CostWithExpr lang cost)
-> ClassIdMap (CostWithExpr lang cost)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
i' CostWithExpr lang cost
new ClassIdMap (CostWithExpr lang cost)
beingUpdated)

                    (Just (CostWithExpr (cost, Fix lang)
old), Just (CostWithExpr (cost, Fix lang)
new))
                      | (cost, Fix lang) -> cost
forall a b. (a, b) -> a
fst (cost, Fix lang)
new cost -> cost -> Bool
forall a. Ord a => a -> a -> Bool
< (cost, Fix lang) -> cost
forall a b. (a, b) -> a
fst (cost, Fix lang)
old -> (Bool
True, ClassId
-> CostWithExpr lang cost
-> ClassIdMap (CostWithExpr lang cost)
-> ClassIdMap (CostWithExpr lang cost)
forall a. ClassId -> a -> IntMap a -> IntMap a
IM.insert ClassId
i' ((cost, Fix lang) -> CostWithExpr lang cost
forall (lang :: * -> *) a. (a, Fix lang) -> CostWithExpr lang a
CostWithExpr (cost, Fix lang)
new) ClassIdMap (CostWithExpr lang cost)
beingUpdated)

                    (Maybe (CostWithExpr lang cost), Maybe (CostWithExpr lang cost))
_ -> (Bool, ClassIdMap (CostWithExpr lang cost))
acc

        -- If any class was modified, loop
       in if Bool
modified
            then ClassIdMap (EClass anl lang)
-> ClassIdMap (CostWithExpr lang cost)
-> ClassIdMap (CostWithExpr lang cost)
findCosts ClassIdMap (EClass anl lang)
eclasses ClassIdMap (CostWithExpr lang cost)
updated
            else ClassIdMap (CostWithExpr lang cost)
updated

    -- | Get the total cost of a node in an e-graph if possible at this stage of
    -- the extraction
    --
    -- For a node to have a cost, all its (canonical) sub-classes have a cost and
    -- an associated better expression. We return the constructed best expression
    -- with its cost
    nodeTotalCost :: Traversable lang => ClassIdMap (CostWithExpr lang cost) -> ENode lang -> Maybe (CostWithExpr lang cost)
    nodeTotalCost :: Traversable lang =>
ClassIdMap (CostWithExpr lang cost)
-> ENode lang -> Maybe (CostWithExpr lang cost)
nodeTotalCost ClassIdMap (CostWithExpr lang cost)
m (Node lang ClassId
n) = do
        lang (CostWithExpr lang cost)
expr <- (ClassId -> Maybe (CostWithExpr lang cost))
-> lang ClassId -> Maybe (lang (CostWithExpr lang cost))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> lang a -> f (lang b)
traverse ((ClassId
-> ClassIdMap (CostWithExpr lang cost)
-> Maybe (CostWithExpr lang cost)
forall a. ClassId -> IntMap a -> Maybe a
`IM.lookup` ClassIdMap (CostWithExpr lang cost)
m) (ClassId -> Maybe (CostWithExpr lang cost))
-> (ClassId -> ClassId)
-> ClassId
-> Maybe (CostWithExpr lang cost)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ClassId -> EGraph anl lang -> ClassId)
-> EGraph anl lang -> ClassId -> ClassId
forall a b c. (a -> b -> c) -> b -> a -> c
flip ClassId -> EGraph anl lang -> ClassId
forall a (l :: * -> *). ClassId -> EGraph a l -> ClassId
find EGraph anl lang
egr) lang ClassId
n
        CostWithExpr lang cost -> Maybe (CostWithExpr lang cost)
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (CostWithExpr lang cost -> Maybe (CostWithExpr lang cost))
-> CostWithExpr lang cost -> Maybe (CostWithExpr lang cost)
forall a b. (a -> b) -> a -> b
$ (cost, Fix lang) -> CostWithExpr lang cost
forall (lang :: * -> *) a. (a, Fix lang) -> CostWithExpr lang a
CostWithExpr (CostFunction lang cost
cost ((cost, Fix lang) -> cost
forall a b. (a, b) -> a
fst ((cost, Fix lang) -> cost)
-> (CostWithExpr lang cost -> (cost, Fix lang))
-> CostWithExpr lang cost
-> cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CostWithExpr lang cost -> (cost, Fix lang)
forall (lang :: * -> *) a. CostWithExpr lang a -> (a, Fix lang)
unCWE (CostWithExpr lang cost -> cost)
-> lang (CostWithExpr lang cost) -> lang cost
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> lang (CostWithExpr lang cost)
expr), lang (Fix lang) -> Fix lang
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (lang (Fix lang) -> Fix lang) -> lang (Fix lang) -> Fix lang
forall a b. (a -> b) -> a -> b
$ (cost, Fix lang) -> Fix lang
forall a b. (a, b) -> b
snd ((cost, Fix lang) -> Fix lang)
-> (CostWithExpr lang cost -> (cost, Fix lang))
-> CostWithExpr lang cost
-> Fix lang
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CostWithExpr lang cost -> (cost, Fix lang)
forall (lang :: * -> *) a. CostWithExpr lang a -> (a, Fix lang)
unCWE (CostWithExpr lang cost -> Fix lang)
-> lang (CostWithExpr lang cost) -> lang (Fix lang)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> lang (CostWithExpr lang cost)
expr)
    {-# INLINE nodeTotalCost #-}
{-# INLINABLE extractBest #-}

-- | A cost function is used to attribute a cost to representations in the
-- e-graph and to extract the best one.
--
-- The cost function is polymorphic over the type used for the cost, however
-- @cost@ must instance 'Ord' in order for the defined 'CostFunction' to
-- fulfill its purpose. That's why we have an @Ord cost@ constraint in
-- 'Data.Equality.Saturation.equalitySaturation' and 'extractBest'
--
-- === Example
-- @
-- symCost :: Expr Int -> Int
-- symCost = \case
--     BinOp Integral e1 e2 -> e1 + e2 + 20000
--     BinOp Diff e1 e2 -> e1 + e2 + 500
--     BinOp x e1 e2 -> e1 + e2 + 3
--     UnOp x e1 -> e1 + 30
--     Sym _ -> 1
--     Const _ -> 1
-- @
type CostFunction l cost = l cost -> cost

-- | Simple cost function: the deeper the expression, the bigger the cost
depthCost :: Language l => CostFunction l Int
depthCost :: forall (l :: * -> *). Language l => CostFunction l ClassId
depthCost = (ClassId -> ClassId -> ClassId
forall a. Num a => a -> a -> a
+ClassId
1) (ClassId -> ClassId)
-> (l ClassId -> ClassId) -> l ClassId -> ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. l ClassId -> ClassId
forall a. Num a => l a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
{-# INLINE depthCost #-}

-- | Find the current best node and its cost in an equivalence class given only the class and the current extraction
-- This is not necessarily the best node in the e-graph, only the best in the current extraction state
findBest :: ClassId -> ClassIdMap (CostWithExpr lang a) -> Maybe (CostWithExpr lang a)
findBest :: forall (lang :: * -> *) a.
ClassId
-> ClassIdMap (CostWithExpr lang a) -> Maybe (CostWithExpr lang a)
findBest = ClassId
-> IntMap (CostWithExpr lang a) -> Maybe (CostWithExpr lang a)
forall a. ClassId -> IntMap a -> Maybe a
IM.lookup
{-# INLINE findBest #-}

newtype CostWithExpr lang a = CostWithExpr { forall (lang :: * -> *) a. CostWithExpr lang a -> (a, Fix lang)
unCWE :: (a, Fix lang) }

instance Eq a => Eq (CostWithExpr lang a) where
  == :: CostWithExpr lang a -> CostWithExpr lang a -> Bool
(==) (CostWithExpr (a
a,Fix lang
_)) (CostWithExpr (a
b,Fix lang
_)) = a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
b
  {-# INLINE (==) #-}

instance Ord a => Ord (CostWithExpr lang a) where
  compare :: CostWithExpr lang a -> CostWithExpr lang a -> Ordering
compare (CostWithExpr (a
a,Fix lang
_)) (CostWithExpr (a
b,Fix lang
_)) = a
a a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` a
b
  {-# INLINE compare #-}