-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.EqSat.Info
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :
--
-- Functions related to info/data calculation in Equality Graph data structure
-- Heavily based on hegg (https://github.com/alt-romes/hegg by alt-romes)
--
-----------------------------------------------------------------------------

module Algorithm.EqSat.Info where

import Control.Lens ( over )
import Control.Monad --(forM, forM_, when, foldM, void)
import Control.Monad.State
import Data.AEq (AEq ((~==)))
import Data.IntMap (IntMap) -- , delete, empty, insert, toList)
import qualified Data.IntMap as IntMap
import Data.Map (Map)
import qualified Data.Map as Map
import Data.SRTree
import Data.SRTree.Eval (evalFun, evalOp, PVector)
import Data.HashSet (HashSet)
import qualified Data.HashSet as Set
import qualified Data.IntSet as IntSet
import Algorithm.EqSat.Egraph
import Data.AEq (AEq ((~==)))
import Algorithm.EqSat.Queries
import Data.Maybe
import qualified Data.Set as TrueSet
import Data.Sequence (Seq(..), (><))

import Debug.Trace

-- * Data related functions 

-- | join data from two e-classes
-- TODO: instead of folding, just do not apply rules
-- list of values instead of single value
joinData :: EClassData -> EClassData -> EClassData
joinData :: EClassData -> EClassData -> EClassData
joinData (EData Cost
c1 ENode
b1 Consts
cn1 Maybe Double
fit1 Maybe PVector
p1 Cost
sz1) (EData Cost
c2 ENode
b2 Consts
cn2 Maybe Double
fit2 Maybe PVector
p2 Cost
sz2) =
  Cost
-> ENode
-> Consts
-> Maybe Double
-> Maybe PVector
-> Cost
-> EClassData
EData (Cost -> Cost -> Cost
forall a. Ord a => a -> a -> a
min Cost
c1 Cost
c2) ENode
b (Consts -> Consts -> Consts
combineConsts Consts
cn1 Consts
cn2) (Maybe Double -> Maybe Double -> Maybe Double
forall {a}. Ord a => Maybe a -> Maybe a -> Maybe a
minMaybe Maybe Double
fit1 Maybe Double
fit2) (Maybe PVector
-> Maybe PVector -> Maybe Double -> Maybe Double -> Maybe PVector
forall {a} {a}.
Ord a =>
Maybe a -> Maybe a -> Maybe a -> Maybe a -> Maybe a
bestParam Maybe PVector
p1 Maybe PVector
p2 Maybe Double
fit1 Maybe Double
fit2) (Cost -> Cost -> Cost
forall a. Ord a => a -> a -> a
min Cost
sz1 Cost
sz2)
  where
    minMaybe :: Maybe a -> Maybe a -> Maybe a
minMaybe Maybe a
Nothing Maybe a
x = Maybe a
x
    minMaybe Maybe a
x Maybe a
Nothing = Maybe a
x
    minMaybe Maybe a
x Maybe a
y       = Maybe a -> Maybe a -> Maybe a
forall a. Ord a => a -> a -> a
min Maybe a
x Maybe a
y

    bestParam :: Maybe a -> Maybe a -> Maybe a -> Maybe a -> Maybe a
bestParam Maybe a
Nothing Maybe a
x Maybe a
_ Maybe a
_ = Maybe a
x
    bestParam Maybe a
x Maybe a
Nothing Maybe a
_ Maybe a
_ = Maybe a
x
    bestParam Maybe a
x Maybe a
y (Just a
f1) (Just a
f2) = if a
f1 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
f2 then Maybe a
x else Maybe a
y

    b :: ENode
b = if Cost
c1 Cost -> Cost -> Bool
forall a. Ord a => a -> a -> Bool
<= Cost
c2 then ENode
b1 else ENode
b2
    combineConsts :: Consts -> Consts -> Consts
combineConsts (ConstVal Double
x) (ConstVal Double
y)
      | Double -> Double
forall a. Num a => a -> a
abs (Double
xDouble -> Double -> Double
forall a. Num a => a -> a -> a
-Double
y) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
1e-7   = Double -> Consts
ConstVal (Double -> Consts) -> Double -> Consts
forall a b. (a -> b) -> a -> b
$ (Double
xDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
y)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
2
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x Bool -> Bool -> Bool
|| Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
x = Double -> Consts
ConstVal Double
y 
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
y Bool -> Bool -> Bool
|| Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
y = Double -> Consts
ConstVal Double
x
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x Bool -> Bool -> Bool
&& Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
y = Double -> Consts
ConstVal Double
x
      | Double
x Double -> Double -> Bool
forall a. AEq a => a -> a -> Bool
~== Double
y = Double -> Consts
ConstVal (Double -> Consts) -> Double -> Consts
forall a b. (a -> b) -> a -> b
$ (Double
xDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
y)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
2
      | Double -> Double
forall a. Num a => a -> a
abs (Double
x Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
y) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1e-6 Bool -> Bool -> Bool
|| Double -> Double
forall a. Num a => a -> a
abs (Double
y Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
x) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1e-6 = Double -> Consts
ConstVal (Double -> Consts) -> Double -> Consts
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
x Double
y
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
x Bool -> Bool -> Bool
&& Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
y = Double -> Consts
ConstVal Double
x
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
x Bool -> Bool -> Bool
&& Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
y = Double -> Consts
ConstVal Double
y
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x Bool -> Bool -> Bool
&& Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
y = Double -> Consts
ConstVal Double
x
      | Bool
otherwise          = [Char] -> Consts
forall a. HasCallStack => [Char] -> a
error ([Char] -> Consts) -> [Char] -> Consts
forall a b. (a -> b) -> a -> b
$ [Char]
"Combining different values: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Double -> [Char]
forall a. Show a => a -> [Char]
show Double
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Double -> [Char]
forall a. Show a => a -> [Char]
show Double
y [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Double -> [Char]
forall a. Show a => a -> [Char]
show (Double
xDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
y)
    combineConsts (ParamIx Cost
ix) (ParamIx Cost
iy) = Cost -> Consts
ParamIx (Cost -> Cost -> Cost
forall a. Ord a => a -> a -> a
min Cost
ix Cost
iy)
    combineConsts Consts
NotConst Consts
x = Consts
x
    combineConsts Consts
x Consts
NotConst = Consts
x
    combineConsts Consts
x Consts
y = [Char] -> Consts
forall a. HasCallStack => [Char] -> a
error (Consts -> [Char]
forall a. Show a => a -> [Char]
show Consts
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Consts -> [Char]
forall a. Show a => a -> [Char]
show Consts
y)

-- | Calculate e-node data (constant values and cost)
makeAnalysis :: Monad m => CostFun -> ENode -> EGraphST m EClassData
makeAnalysis :: forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassData
makeAnalysis CostFun
costFun ENode
enode =
  do Consts
consts <- ENode -> EGraphST m Consts
forall (m :: * -> *). Monad m => ENode -> EGraphST m Consts
calculateConsts ENode
enode
     ENode
enode' <- ENode -> EGraphST m ENode
forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize ENode
enode
     Cost
cost   <- CostFun -> ENode -> EGraphST m Cost
forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m Cost
calculateCost CostFun
costFun ENode
enode'
     Cost
sz <- [Cost] -> Cost
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Cost] -> Cost) -> StateT EGraph m [Cost] -> EGraphST m Cost
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Cost -> EGraphST m Cost) -> [Cost] -> StateT EGraph m [Cost]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Cost
ecId -> (EGraph -> Cost) -> EGraphST m Cost
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> Cost
_size (EClassData -> Cost) -> (EGraph -> EClassData) -> EGraph -> Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Cost -> EClass
forall a. IntMap a -> Cost -> a
IntMap.! Cost
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)) (ENode -> [Cost]
forall a. SRTree a -> [a]
childrenOf ENode
enode')
     EClassData -> EGraphST m EClassData
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (EClassData -> EGraphST m EClassData)
-> EClassData -> EGraphST m EClassData
forall a b. (a -> b) -> a -> b
$ Cost
-> ENode
-> Consts
-> Maybe Double
-> Maybe PVector
-> Cost
-> EClassData
EData Cost
cost ENode
enode' Consts
consts Maybe Double
forall a. Maybe a
Nothing Maybe PVector
forall a. Maybe a
Nothing (Cost
szCost -> Cost -> Cost
forall a. Num a => a -> a -> a
+Cost
1)

getChildrenMinHeight :: Monad m => ENode -> EGraphST m Int
getChildrenMinHeight :: forall (m :: * -> *). Monad m => ENode -> EGraphST m Cost
getChildrenMinHeight ENode
enode = do
  let children :: [Cost]
children = ENode -> [Cost]
forall a. SRTree a -> [a]
childrenOf ENode
enode
      minimum' :: [a] -> a
minimum' [] = a
0
      minimum' [a]
xs = [a] -> a
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [a]
xs
  [Cost] -> Cost
forall {a}. (Num a, Ord a) => [a] -> a
minimum' ([Cost] -> Cost) -> StateT EGraph m [Cost] -> EGraphST m Cost
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Cost -> EGraphST m Cost) -> [Cost] -> StateT EGraph m [Cost]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Cost
ec -> (EGraph -> Cost) -> EGraphST m Cost
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClass -> Cost
_height (EClass -> Cost) -> (EGraph -> EClass) -> EGraph -> Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Cost -> EClass
forall a. IntMap a -> Cost -> a
IntMap.! Cost
ec) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)) [Cost]
children

-- | update the heights of each e-class
-- won't work if there's no root
calculateHeights :: Monad m => EGraphST m ()
calculateHeights :: forall (m :: * -> *). Monad m => EGraphST m ()
calculateHeights =
  do [Cost]
queue   <- EGraphST m [Cost]
forall (m :: * -> *). Monad m => EGraphST m [Cost]
findRootClasses
     [Cost]
classes <- (EGraph -> [Cost]) -> EGraphST m [Cost]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((Cost, EClass) -> Cost) -> [(Cost, EClass)] -> [Cost]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Cost, EClass) -> Cost
forall a b. (a, b) -> a
fst ([(Cost, EClass)] -> [Cost])
-> (EGraph -> [(Cost, EClass)]) -> EGraph -> [Cost]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap EClass -> [(Cost, EClass)]
forall a. IntMap a -> [(Cost, a)]
IntMap.toList (IntMap EClass -> [(Cost, EClass)])
-> (EGraph -> IntMap EClass) -> EGraph -> [(Cost, EClass)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
     let nClasses :: Cost
nClasses = [Cost] -> Cost
forall a. [a] -> Cost
forall (t :: * -> *) a. Foldable t => t a -> Cost
length [Cost]
classes
     [Cost] -> (Cost -> EGraphST m ()) -> EGraphST m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Cost]
classes (Cost -> Cost -> EGraphST m ()
forall {m :: * -> *}. Monad m => Cost -> Cost -> StateT EGraph m ()
setHeight Cost
nClasses) -- set all heights to max possible height (number of e-classes)
     [Cost] -> (Cost -> EGraphST m ()) -> EGraphST m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Cost]
queue (Cost -> Cost -> EGraphST m ()
forall {m :: * -> *}. Monad m => Cost -> Cost -> StateT EGraph m ()
setHeight Cost
0)          -- set root e-classes height to zero
     [Cost] -> Set Cost -> Cost -> EGraphST m ()
forall {m :: * -> *}.
Monad m =>
[Cost] -> Set Cost -> Cost -> StateT EGraph m ()
go [Cost]
queue ([Cost] -> Set Cost
forall a. Ord a => [a] -> Set a
TrueSet.fromList [Cost]
queue) Cost
1    -- next height is 1
  where
    setHeight :: Cost -> Cost -> StateT EGraph m ()
setHeight Cost
x Cost
eId' =
      do Cost
eId <- Cost -> EGraphST m Cost
forall (m :: * -> *). Monad m => Cost -> EGraphST m Cost
canonical Cost
eId'
         EClass
ec <- Cost -> EGraphST m EClass
forall (m :: * -> *). Monad m => Cost -> EGraphST m EClass
getEClass Cost
eId
         let ec' :: EClass
ec' = ASetter EClass EClass Cost Cost
-> (Cost -> Cost) -> EClass -> EClass
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EClass EClass Cost Cost
Lens' EClass Cost
height (Cost -> Cost -> Cost
forall a b. a -> b -> a
const Cost
x) EClass
ec
         (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
-> (IntMap EClass -> IntMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
Lens' EGraph (IntMap EClass)
eClass (Cost -> EClass -> IntMap EClass -> IntMap EClass
forall a. Cost -> a -> IntMap a -> IntMap a
IntMap.insert Cost
eId EClass
ec')

    setMinHeight :: Cost -> Cost -> StateT EGraph m ()
setMinHeight Cost
x Cost
eId' = -- set height to the minimum between current and x
      do Cost
eId <- Cost -> EGraphST m Cost
forall (m :: * -> *). Monad m => Cost -> EGraphST m Cost
canonical Cost
eId'
         Cost
h <- EClass -> Cost
_height (EClass -> Cost) -> StateT EGraph m EClass -> EGraphST m Cost
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Cost -> StateT EGraph m EClass
forall (m :: * -> *). Monad m => Cost -> EGraphST m EClass
getEClass Cost
eId
         Cost -> Cost -> StateT EGraph m ()
forall {m :: * -> *}. Monad m => Cost -> Cost -> StateT EGraph m ()
setHeight (Cost -> Cost -> Cost
forall a. Ord a => a -> a -> a
min Cost
h Cost
x) Cost
eId

    getChildrenEC :: Monad m => EClassId -> EGraphST m [EClassId]
    getChildrenEC :: forall (m :: * -> *). Monad m => Cost -> EGraphST m [Cost]
getChildrenEC Cost
ec' = do Cost
ec <- Cost -> EGraphST m Cost
forall (m :: * -> *). Monad m => Cost -> EGraphST m Cost
canonical Cost
ec'
                           (EGraph -> [Cost]) -> EGraphST m [Cost]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((Cost, Cost, Cost, Double) -> [Cost])
-> HashSet (Cost, Cost, Cost, Double) -> [Cost]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Cost, Cost, Cost, Double) -> [Cost]
forall {a} {a} {d}. (Eq a, Num a) => (a, a, a, d) -> [a]
childrenOf' (HashSet (Cost, Cost, Cost, Double) -> [Cost])
-> (EGraph -> HashSet (Cost, Cost, Cost, Double))
-> EGraph
-> [Cost]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> HashSet (Cost, Cost, Cost, Double)
_eNodes (EClass -> HashSet (Cost, Cost, Cost, Double))
-> (EGraph -> EClass)
-> EGraph
-> HashSet (Cost, Cost, Cost, Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Cost -> EClass
forall a. IntMap a -> Cost -> a
IntMap.! Cost
ec) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)

    childrenOf' :: (a, a, a, d) -> [a]
childrenOf' (a
_, -1, -1, d
_) = []
    childrenOf' (a
_, a
e1, -1, d
_) = [a
e1]
    childrenOf' (a
_, a
e1, a
e2, d
_) = [a
e1, a
e2]

    go :: [Cost] -> Set Cost -> Cost -> StateT EGraph m ()
go [] Set Cost
_    Cost
_ = () -> StateT EGraph m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    go [Cost]
qs Set Cost
tabu Cost
h =
      do Set Cost
childrenOf <- (Set Cost -> Set Cost -> Set Cost
forall a. Ord a => Set a -> Set a -> Set a
TrueSet.\\ Set Cost
tabu) (Set Cost -> Set Cost)
-> ([[Cost]] -> Set Cost) -> [[Cost]] -> Set Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Cost] -> Set Cost
forall a. Ord a => [a] -> Set a
TrueSet.fromList ([Cost] -> Set Cost)
-> ([[Cost]] -> [Cost]) -> [[Cost]] -> Set Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Cost]] -> [Cost]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Cost]] -> Set Cost)
-> StateT EGraph m [[Cost]] -> StateT EGraph m (Set Cost)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Cost]
-> (Cost -> StateT EGraph m [Cost]) -> StateT EGraph m [[Cost]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Cost]
qs Cost -> StateT EGraph m [Cost]
forall (m :: * -> *). Monad m => Cost -> EGraphST m [Cost]
getChildrenEC -- rerieve all unvisited children
         let childrenL :: [Cost]
childrenL = Set Cost -> [Cost]
forall a. Set a -> [a]
TrueSet.toList Set Cost
childrenOf
         [Cost] -> (Cost -> StateT EGraph m ()) -> StateT EGraph m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Cost]
childrenL (Cost -> Cost -> StateT EGraph m ()
forall {m :: * -> *}. Monad m => Cost -> Cost -> StateT EGraph m ()
setMinHeight Cost
h) -- set the height of the children as the minimum between current and h
         [Cost] -> Set Cost -> Cost -> StateT EGraph m ()
go [Cost]
childrenL (Set Cost -> Set Cost -> Set Cost
forall a. Ord a => Set a -> Set a -> Set a
TrueSet.union Set Cost
tabu Set Cost
childrenOf) (Cost
hCost -> Cost -> Cost
forall a. Num a => a -> a -> a
+Cost
1) -- move one breadth search style

-- | calculates the cost of a node
calculateCost :: Monad m => CostFun -> SRTree EClassId -> EGraphST m Cost
calculateCost :: forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m Cost
calculateCost CostFun
f ENode
t =
  do let cs :: [Cost]
cs = ENode -> [Cost]
forall a. SRTree a -> [a]
childrenOf ENode
t
     [Cost]
costs <- (Cost -> EGraphST m Cost) -> [Cost] -> StateT EGraph m [Cost]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((EClass -> Cost) -> StateT EGraph m EClass -> EGraphST m Cost
forall a b. (a -> b) -> StateT EGraph m a -> StateT EGraph m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (EClassData -> Cost
_cost (EClassData -> Cost) -> (EClass -> EClassData) -> EClass -> Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) (StateT EGraph m EClass -> EGraphST m Cost)
-> (Cost -> StateT EGraph m EClass) -> Cost -> EGraphST m Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cost -> StateT EGraph m EClass
forall (m :: * -> *). Monad m => Cost -> EGraphST m EClass
getEClass) [Cost]
cs
     Cost -> EGraphST m Cost
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Cost -> EGraphST m Cost) -> CostFun -> ENode -> EGraphST m Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CostFun
f (ENode -> EGraphST m Cost) -> ENode -> EGraphST m Cost
forall a b. (a -> b) -> a -> b
$ [Cost] -> ENode -> ENode
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Cost]
costs ENode
t

-- | check whether an e-node evaluates to a const
calculateConsts :: Monad m => SRTree EClassId -> EGraphST m Consts
calculateConsts :: forall (m :: * -> *). Monad m => ENode -> EGraphST m Consts
calculateConsts ENode
t =
  do let cs :: [Cost]
cs = ENode -> [Cost]
forall a. SRTree a -> [a]
childrenOf ENode
t
     EGraph
eg <- StateT EGraph m EGraph
forall s (m :: * -> *). MonadState s m => m s
get
     [Consts]
consts <- (Cost -> EGraphST m Consts) -> [Cost] -> StateT EGraph m [Consts]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((EClass -> Consts) -> StateT EGraph m EClass -> EGraphST m Consts
forall a b. (a -> b) -> StateT EGraph m a -> StateT EGraph m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (EClassData -> Consts
_consts (EClassData -> Consts)
-> (EClass -> EClassData) -> EClass -> Consts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) (StateT EGraph m EClass -> EGraphST m Consts)
-> (Cost -> StateT EGraph m EClass) -> Cost -> EGraphST m Consts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cost -> StateT EGraph m EClass
forall (m :: * -> *). Monad m => Cost -> EGraphST m EClass
getEClass) [Cost]
cs
     case SRTree Consts -> Consts
combineConsts (SRTree Consts -> Consts) -> SRTree Consts -> Consts
forall a b. (a -> b) -> a -> b
$ [Consts] -> ENode -> SRTree Consts
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [Consts]
consts ENode
t of
          ConstVal Double
x | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x -> Consts -> EGraphST m Consts
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> Consts
ConstVal Double
x)
          Consts
a -> Consts -> EGraphST m Consts
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Consts
a

combineConsts :: SRTree Consts -> Consts
combineConsts :: SRTree Consts -> Consts
combineConsts (Const Double
x)    = Double -> Consts
ConstVal Double
x
combineConsts (Param Cost
ix)   = Cost -> Consts
ParamIx Cost
ix
combineConsts (Var Cost
_)      = Consts
NotConst
combineConsts (Uni Function
f Consts
t)    = case Consts
t of
                              ConstVal Double
x -> Double -> Consts
ConstVal (Double -> Consts) -> Double -> Consts
forall a b. (a -> b) -> a -> b
$ Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun Function
f Double
x
                              Consts
_          -> Consts
t
combineConsts (Bin Op
op Consts
l Consts
r) = Consts -> Consts -> Consts
evalOp' Consts
l Consts
r
  where
    evalOp' :: Consts -> Consts -> Consts
evalOp' (ParamIx Cost
ix) (ParamIx Cost
iy) = Cost -> Consts
ParamIx (Cost -> Cost -> Cost
forall a. Ord a => a -> a -> a
min Cost
ix Cost
iy)
    evalOp' (ConstVal Double
x) (ConstVal Double
y) = Double -> Consts
ConstVal (Double -> Consts) -> Double -> Consts
forall a b. (a -> b) -> a -> b
$ Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op Double
x Double
y
    evalOp' Consts
_            Consts
_            = Consts
NotConst

insertFitness :: Monad m => EClassId -> Double -> PVector -> EGraphST m ()
insertFitness :: forall (m :: * -> *).
Monad m =>
Cost -> Double -> PVector -> EGraphST m ()
insertFitness Cost
eId Double
fit PVector
params = do
  EClass
ec <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap EClass -> Cost -> EClass
forall a. IntMap a -> Cost -> a
IntMap.! Cost
eId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
  let oldFit :: Maybe Double
oldFit  = EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> Maybe Double) -> EClass -> Maybe Double
forall a b. (a -> b) -> a -> b
$ EClass
ec
      newInfo :: EClassData
newInfo = (EClass -> EClassData
_info EClass
ec){_fitness = Just fit, _theta = Just params}
      newEc :: EClass
newEc   = EClass
ec{_info = newInfo}
      sz :: Cost
sz = EClassData -> Cost
_size EClassData
newInfo
  (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
-> (IntMap EClass -> IntMap EClass) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (IntMap EClass) (IntMap EClass)
Lens' EGraph (IntMap EClass)
eClass (Cost -> EClass -> IntMap EClass -> IntMap EClass
forall a. Cost -> a -> IntMap a -> IntMap a
IntMap.insert Cost
eId EClass
newEc)
  if (Maybe Double -> Bool
forall a. Maybe a -> Bool
isNothing Maybe Double
oldFit)
    then (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph IntSet IntSet
-> (IntSet -> IntSet) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph IntSet IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntSet -> Identity IntSet) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB IntSet
unevaluated) (Cost -> IntSet -> IntSet
IntSet.delete Cost
eId)
                 (EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
    -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (Cost -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
Cost -> a -> RangeTree a -> RangeTree a
insertRange Cost
eId Double
fit)
                 (EGraph -> EGraph) -> (EGraph -> EGraph) -> EGraph -> EGraph
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter
  EGraph
  EGraph
  (IntMap (RangeTree Double))
  (IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> EGraph
-> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((IntMap (RangeTree Double)
     -> Identity (IntMap (RangeTree Double)))
    -> EGraphDB -> Identity EGraphDB)
-> ASetter
     EGraph
     EGraph
     (IntMap (RangeTree Double))
     (IntMap (RangeTree Double))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap (RangeTree Double) -> Identity (IntMap (RangeTree Double)))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (IntMap (RangeTree Double))
sizeFitDB) ((RangeTree Double -> RangeTree Double)
-> Cost -> IntMap (RangeTree Double) -> IntMap (RangeTree Double)
forall a. (a -> a) -> Cost -> IntMap a -> IntMap a
IntMap.adjust (Cost -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
Cost -> a -> RangeTree a -> RangeTree a
insertRange Cost
eId Double
fit) Cost
sz (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> (IntMap (RangeTree Double) -> IntMap (RangeTree Double))
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> RangeTree Double -> RangeTree Double)
-> Cost
-> RangeTree Double
-> IntMap (RangeTree Double)
-> IntMap (RangeTree Double)
forall a. (a -> a -> a) -> Cost -> a -> IntMap a -> IntMap a
IntMap.insertWith RangeTree Double -> RangeTree Double -> RangeTree Double
forall a. Seq a -> Seq a -> Seq a
(><) Cost
sz RangeTree Double
forall a. Seq a
Empty)
    else (EGraph -> EGraph) -> EGraphST m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> EGraphST m ())
-> (EGraph -> EGraph) -> EGraphST m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
-> (RangeTree Double -> RangeTree Double) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((RangeTree Double -> Identity (RangeTree Double))
    -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph (RangeTree Double) (RangeTree Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RangeTree Double -> Identity (RangeTree Double))
-> EGraphDB -> Identity EGraphDB
Lens' EGraphDB (RangeTree Double)
fitRangeDB) (Cost -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
Cost -> a -> RangeTree a -> RangeTree a
insertRange Cost
eId Double
fit (RangeTree Double -> RangeTree Double)
-> (RangeTree Double -> RangeTree Double)
-> RangeTree Double
-> RangeTree Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cost -> Double -> RangeTree Double -> RangeTree Double
forall a.
(Ord a, Show a) =>
Cost -> a -> RangeTree a -> RangeTree a
removeRange Cost
eId (Maybe Double -> Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Double
oldFit))