{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.EqSat.Egraph
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :
--
-- Equality Graph data structure 
-- Heavily based on hegg (https://github.com/alt-romes/hegg by alt-romes)
--
-----------------------------------------------------------------------------

module Algorithm.EqSat.Egraph where

import Control.Lens (element, makeLenses, view, over, (&), (+~), (-~), (.~), (^.))
--import Control.Monad (forM, forM_, when, foldM, void)
import Data.List ( intercalate )
import Control.Monad.State.Strict
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IntMap
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.HashSet (HashSet)
import qualified Data.HashSet as Set
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Sequence ( Seq(..), (><) )
import qualified Data.Sequence as FingerTree
import Data.Foldable ( toList )
import Data.SRTree
import Data.SRTree.Eval
import Data.Hashable

import Debug.Trace

type EClassId     = Int -- NOTE: DO NOT CHANGE THIS, this will break the use of IntMap and IntSet
type ClassIdMap   = IntMap
type ENode        = SRTree EClassId
type ENodeEnc     = (Int, Int, Int, Double)
type EGraphST m a = StateT EGraph m a
type Cost         = Int
type CostFun      = SRTree Cost -> Cost

instance Hashable ENode where
  hashWithSalt :: Int -> ENode -> Int
hashWithSalt Int
n ENode
enode = Int -> ENodeEnc -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
n (ENode -> ENodeEnc
encodeEnode ENode
enode)

type RangeTree a = Seq (a, EClassId)

-- | this assumes up to 999 variables and params
encodeEnode :: ENode -> ENodeEnc
--encodeEnode = id
{--}
encodeEnode :: ENode -> ENodeEnc
encodeEnode (Var Int
ix)         = (Int
0, Int
ix, -Int
1, Double
0)
encodeEnode (Param Int
ix)       = (Int
1, Int
ix, -Int
1, Double
0)
encodeEnode (Const Double
x)        = (Int
2, -Int
1, -Int
1, Double
x)
encodeEnode (Uni Function
f Int
ed)       = (Int
300 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Function -> Int
forall a. Enum a => a -> Int
fromEnum Function
f, Int
ed, -Int
1, Double
0)
encodeEnode (Bin Op
op Int
ed1 Int
ed2) = (Int
400 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Op -> Int
forall a. Enum a => a -> Int
fromEnum Op
op, Int
ed1, Int
ed2, Double
0)
{--}
{-# INLINE encodeEnode #-}

decodeEnode :: ENodeEnc -> ENode
--decodeEnode = id
{--}
decodeEnode :: ENodeEnc -> ENode
decodeEnode (Int
0, Int
ix, Int
_, Double
_) = Int -> ENode
forall val. Int -> SRTree val
Var Int
ix
decodeEnode (Int
1, Int
ix, Int
_, Double
_) = Int -> ENode
forall val. Int -> SRTree val
Param Int
ix
decodeEnode (Int
2, Int
_, Int
_, Double
x)  = Double -> ENode
forall val. Double -> SRTree val
Const Double
x
decodeEnode (Int
opCode, Int
arg1, Int
arg2, Double
arg3)
  | Int
opCode Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
400 = Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni (Int -> Function
forall a. Enum a => Int -> a
toEnum (Int -> Function) -> Int -> Function
forall a b. (a -> b) -> a -> b
$ Int
opCodeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
300) Int
arg1
  | Bool
otherwise    = Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin (Int -> Op
forall a. Enum a => Int -> a
toEnum (Int -> Op) -> Int -> Op
forall a b. (a -> b) -> a -> b
$ Int
opCodeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
400) Int
arg1 Int
arg2
  {--}
{-# INLINE decodeEnode #-}

insertRange :: (Ord a, Show a) => EClassId -> a -> RangeTree a -> RangeTree a
insertRange :: forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
insertRange Int
eid a
x Seq (a, Int)
Empty                      = (a, Int) -> Seq (a, Int)
forall a. a -> Seq a
FingerTree.singleton (a
x, Int
eid)
insertRange Int
eid a
x ((a, Int)
y :<| Seq (a, Int)
_xs) | (a
x, Int
eid) (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
y = (a
x, Int
eid) (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| (a, Int)
y (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
_xs
insertRange Int
eid a
x (Seq (a, Int)
_xs :|> (a, Int)
y) | (a
x, Int
eid) (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y = Seq (a, Int)
_xs Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a
x, Int
eid)
insertRange Int
eid a
x Seq (a, Int)
rt = Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
rt
  where
    entry :: (a, Int)
entry   = (a
x, Int
eid)
    go :: Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
root = case Int -> Seq (a, Int) -> (Seq (a, Int), Seq (a, Int))
forall a. Int -> Seq a -> (Seq a, Seq a)
FingerTree.splitAt (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Seq (a, Int)
root of
                (Seq (a, Int)
Empty, Seq (a, Int)
Empty)    -> (a, Int) -> Seq (a, Int)
forall a. a -> Seq a
FingerTree.singleton (a, Int)
entry
                (Seq (a, Int)
Empty, (a, Int)
z :<| Seq (a, Int)
zs) | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
z -> (a, Int)
entry (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| (a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs
                                  | Bool
otherwise -> (a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| (Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
zs)
                (Seq (a, Int)
ys :|> (a, Int)
y, Seq (a, Int)
Empty) | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y -> Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
entry
                                  | Bool
otherwise -> (Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
ys) Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y
                (Seq (a, Int)
ys :|> (a, Int)
y, (a, Int)
z :<| Seq (a, Int)
zs)
                     | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y Bool -> Bool -> Bool
&& (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
z -> (Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
entry) Seq (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. Seq a -> Seq a -> Seq a
>< ((a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs)
                     | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
z              -> (Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y) Seq (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. Seq a -> Seq a -> Seq a
>< Seq (a, Int) -> Seq (a, Int)
go ((a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs)
                     | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
y              -> Seq (a, Int) -> Seq (a, Int)
go (Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y) Seq (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. Seq a -> Seq a -> Seq a
>< ((a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs)
                     | Bool
otherwise              -> Seq (a, Int)
root
      where
        n :: Int
n = Seq (a, Int) -> Int
forall a. Seq a -> Int
FingerTree.length Seq (a, Int)
root

removeRange :: (Ord a, Show a) => EClassId -> a -> RangeTree a -> RangeTree a
removeRange :: forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
removeRange Int
eid a
x Seq (a, Int)
Empty                  = Seq (a, Int)
forall a. Seq a
Empty
removeRange Int
eid a
x ((a, Int)
y :<| Seq (a, Int)
_xs) | (a
x, Int
eid) (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
y = ((a, Int)
y (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
_xs)
removeRange Int
eid a
x (Seq (a, Int)
_xs :|> (a, Int)
y) | (a
x, Int
eid) (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y = (Seq (a, Int)
_xs Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y)
removeRange Int
eid a
x Seq (a, Int)
rt = Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
rt
  where
    entry :: (a, Int)
entry   = (a
x, Int
eid)
    go :: Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
root = case Int -> Seq (a, Int) -> (Seq (a, Int), Seq (a, Int))
forall a. Int -> Seq a -> (Seq a, Seq a)
FingerTree.splitAt (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Seq (a, Int)
root of
                (Seq (a, Int)
Empty, Seq (a, Int)
Empty)    -> Seq (a, Int)
root
                (Seq (a, Int)
Empty, (a, Int)
z :<| Seq (a, Int)
zs)
                            | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
z  -> (a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs
                            | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== (a, Int)
z -> Seq (a, Int)
zs
                            | Bool
otherwise  -> (a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| (Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
zs)
                (Seq (a, Int)
ys :|> (a, Int)
y, Seq (a, Int)
Empty)
                            | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y  -> Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y
                            | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== (a, Int)
y -> Seq (a, Int)
ys
                            | Bool
otherwise  -> (Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
ys) Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y
                (Seq (a, Int)
ys :|> (a, Int)
y, (a, Int)
z :<| Seq (a, Int)
zs)
                     | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y Bool -> Bool -> Bool
&& (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
z -> Seq (a, Int)
root
                     | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
z              -> (Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y) Seq (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. Seq a -> Seq a -> Seq a
>< Seq (a, Int) -> Seq (a, Int)
go ((a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs)
                     | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
y              -> Seq (a, Int) -> Seq (a, Int)
go (Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y) Seq (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. Seq a -> Seq a -> Seq a
>< ((a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs)
                     | Bool
otherwise              -> Seq (a, Int)
root

      where
        n :: Int
n = Seq (a, Int) -> Int
forall a. Seq a -> Int
FingerTree.length Seq (a, Int)
root


-- TODO: check this \/
getWithinRange :: Ord a => a -> a -> RangeTree a -> [EClassId]
getWithinRange :: forall a. Ord a => a -> a -> RangeTree a -> [Int]
getWithinRange a
lb a
ub RangeTree a
rt = ((Any, Int) -> Int) -> [(Any, Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Any, Int) -> Int
forall a b. (a, b) -> b
snd ([(Any, Int)] -> [Int])
-> (Seq (Any, Int) -> [(Any, Int)]) -> Seq (Any, Int) -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (Any, Int) -> [(Any, Int)]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq (Any, Int) -> [Int]) -> Seq (Any, Int) -> [Int]
forall a b. (a -> b) -> a -> b
$ RangeTree a -> Seq (Any, Int)
forall {b} {a}. Seq (a, b) -> Seq a
go RangeTree a
rt
  where
    go :: Seq (a, b) -> Seq a
go Seq (a, b)
Empty = Seq a
forall a. Seq a
Empty
    go Seq (a, b)
root = case Int -> Seq (a, b) -> (Seq (a, b), Seq (a, b))
forall a. Int -> Seq a -> (Seq a, Seq a)
FingerTree.splitAt (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Seq (a, b)
root of
                (Seq (a, b)
Empty, Seq (a, b)
Empty)    -> Seq a
forall a. Seq a
Empty
                (Seq (a, b)
ys :|> (a, b)
y, Seq (a, b)
Empty)
                     | (a, b) -> a
forall a b. (a, b) -> a
fst (a, b)
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
lb    -> Seq a
forall a. Seq a
Empty
                     | Bool
otherwise -> Seq (a, b) -> Seq a
go (Seq (a, b)
ys Seq (a, b) -> (a, b) -> Seq (a, b)
forall a. Seq a -> a -> Seq a
:|> (a, b)
y)
                (Seq (a, b)
Empty, (a, b)
z :<| Seq (a, b)
zs)
                            | (a, b) -> a
forall a b. (a, b) -> a
fst (a, b)
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
ub    -> Seq a
forall a. Seq a
Empty
                            | Bool
otherwise -> Seq (a, b) -> Seq a
go ((a, b)
z (a, b) -> Seq (a, b) -> Seq (a, b)
forall a. a -> Seq a -> Seq a
:<| Seq (a, b)
zs)
                (Seq (a, b)
ys :|> (a, b)
y, (a, b)
z :<| Seq (a, b)
zs)
                     | (a, b) -> a
forall a b. (a, b) -> a
fst (a, b)
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
lb -> Seq (a, b) -> Seq a
go ((a, b)
z (a, b) -> Seq (a, b) -> Seq (a, b)
forall a. a -> Seq a -> Seq a
:<| Seq (a, b)
zs)
                     | (a, b) -> a
forall a b. (a, b) -> a
fst (a, b)
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
ub -> Seq (a, b) -> Seq a
go (Seq (a, b)
ys Seq (a, b) -> (a, b) -> Seq (a, b)
forall a. Seq a -> a -> Seq a
:|> (a, b)
y)
                     | Bool
otherwise -> Seq (a, b) -> Seq a
go (Seq (a, b)
ys Seq (a, b) -> (a, b) -> Seq (a, b)
forall a. Seq a -> a -> Seq a
:|> (a, b)
y) Seq a -> Seq a -> Seq a
forall a. Seq a -> Seq a -> Seq a
>< Seq (a, b) -> Seq a
go ((a, b)
z (a, b) -> Seq (a, b) -> Seq (a, b)
forall a. a -> Seq a -> Seq a
:<| Seq (a, b)
zs)
      where
        n :: Int
n = Seq (a, b) -> Int
forall a. Seq a -> Int
FingerTree.length Seq (a, b)
root


getSmallest :: Ord a => RangeTree a -> (a, EClassId)
getSmallest :: forall a. Ord a => RangeTree a -> (a, Int)
getSmallest RangeTree a
rt = case RangeTree a
rt of
                     RangeTree a
Empty -> [Char] -> (a, Int)
forall a. HasCallStack => [Char] -> a
error [Char]
"empty finger"
                     (a, Int)
x :<| RangeTree a
t -> (a, Int)
x
getGreatest :: Ord a => RangeTree a -> (a, EClassId)
getGreatest :: forall a. Ord a => RangeTree a -> (a, Int)
getGreatest RangeTree a
rt = case RangeTree a
rt of
                     RangeTree a
Empty -> [Char] -> (a, Int)
forall a. HasCallStack => [Char] -> a
error [Char]
"empty finger"
                     RangeTree a
t :|> (a, Int)
x -> (a, Int)
x


data EGraph = EGraph { EGraph -> ClassIdMap Int
_canonicalMap  :: ClassIdMap EClassId   -- maps an e-class id to its canonical form
                     , EGraph -> Map ENode Int
_eNodeToEClass :: Map ENode EClassId    -- maps an e-node to its e-class id
                     , EGraph -> ClassIdMap EClass
_eClass        :: ClassIdMap EClass     -- maps an e-class id to its e-class data
                     , EGraph -> EGraphDB
_eDB           :: EGraphDB
                     } deriving Int -> EGraph -> ShowS
[EGraph] -> ShowS
EGraph -> [Char]
(Int -> EGraph -> ShowS)
-> (EGraph -> [Char]) -> ([EGraph] -> ShowS) -> Show EGraph
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EGraph -> ShowS
showsPrec :: Int -> EGraph -> ShowS
$cshow :: EGraph -> [Char]
show :: EGraph -> [Char]
$cshowList :: [EGraph] -> ShowS
showList :: [EGraph] -> ShowS
Show

data EGraphDB = EDB { EGraphDB -> HashSet (Int, ENode)
_worklist      :: HashSet (EClassId, ENode)      -- e-nodes and e-class schedule for analysis
                    , EGraphDB -> HashSet (Int, ENode)
_analysis      :: HashSet (EClassId, ENode)      -- e-nodes and e-class that changed data
                    , EGraphDB -> DB
_patDB         :: DB                         -- database of patterns
                    , EGraphDB -> RangeTree Double
_fitRangeDB    :: RangeTree Double           -- database of valid fitness
                    , EGraphDB -> IntMap IntSet
_sizeDB        :: IntMap IntSet              -- database of model sizes
                    , EGraphDB -> IntMap (RangeTree Double)
_sizeFitDB     :: IntMap (RangeTree Double)  -- hacky! Size x Fitness DB
                    , EGraphDB -> IntSet
_unevaluated   :: IntSet                     -- set of not-evaluated e-classes
                    , EGraphDB -> Int
_nextId        :: Int                        -- next available id
                    } deriving Int -> EGraphDB -> ShowS
[EGraphDB] -> ShowS
EGraphDB -> [Char]
(Int -> EGraphDB -> ShowS)
-> (EGraphDB -> [Char]) -> ([EGraphDB] -> ShowS) -> Show EGraphDB
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EGraphDB -> ShowS
showsPrec :: Int -> EGraphDB -> ShowS
$cshow :: EGraphDB -> [Char]
show :: EGraphDB -> [Char]
$cshowList :: [EGraphDB] -> ShowS
showList :: [EGraphDB] -> ShowS
Show

data EClass = EClass { EClass -> Int
_eClassId :: Int                   -- e-class id (maybe we don't need that here)
                     , EClass -> HashSet ENodeEnc
_eNodes   :: HashSet ENodeEnc          -- set of e-nodes inside this e-class
                     , EClass -> HashSet (Int, ENode)
_parents  :: HashSet (EClassId, ENode) -- parents (e-class, e-node)'s
                     , EClass -> Int
_height   :: Int                   -- height
                     , EClass -> EClassData
_info     :: EClassData            -- data
                     } deriving (Int -> EClass -> ShowS
[EClass] -> ShowS
EClass -> [Char]
(Int -> EClass -> ShowS)
-> (EClass -> [Char]) -> ([EClass] -> ShowS) -> Show EClass
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EClass -> ShowS
showsPrec :: Int -> EClass -> ShowS
$cshow :: EClass -> [Char]
show :: EClass -> [Char]
$cshowList :: [EClass] -> ShowS
showList :: [EClass] -> ShowS
Show, EClass -> EClass -> Bool
(EClass -> EClass -> Bool)
-> (EClass -> EClass -> Bool) -> Eq EClass
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: EClass -> EClass -> Bool
== :: EClass -> EClass -> Bool
$c/= :: EClass -> EClass -> Bool
/= :: EClass -> EClass -> Bool
Eq)

data Consts   = NotConst | ParamIx Int | ConstVal Double deriving (Int -> Consts -> ShowS
[Consts] -> ShowS
Consts -> [Char]
(Int -> Consts -> ShowS)
-> (Consts -> [Char]) -> ([Consts] -> ShowS) -> Show Consts
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Consts -> ShowS
showsPrec :: Int -> Consts -> ShowS
$cshow :: Consts -> [Char]
show :: Consts -> [Char]
$cshowList :: [Consts] -> ShowS
showList :: [Consts] -> ShowS
Show, Consts -> Consts -> Bool
(Consts -> Consts -> Bool)
-> (Consts -> Consts -> Bool) -> Eq Consts
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Consts -> Consts -> Bool
== :: Consts -> Consts -> Bool
$c/= :: Consts -> Consts -> Bool
/= :: Consts -> Consts -> Bool
Eq)
data Property = Positive | Negative | NonZero | Real deriving (Int -> Property -> ShowS
[Property] -> ShowS
Property -> [Char]
(Int -> Property -> ShowS)
-> (Property -> [Char]) -> ([Property] -> ShowS) -> Show Property
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Property -> ShowS
showsPrec :: Int -> Property -> ShowS
$cshow :: Property -> [Char]
show :: Property -> [Char]
$cshowList :: [Property] -> ShowS
showList :: [Property] -> ShowS
Show, Property -> Property -> Bool
(Property -> Property -> Bool)
-> (Property -> Property -> Bool) -> Eq Property
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Property -> Property -> Bool
== :: Property -> Property -> Bool
$c/= :: Property -> Property -> Bool
/= :: Property -> Property -> Bool
Eq) -- TODO: incorporate properties

data EClassData = EData { EClassData -> Int
_cost    :: Cost
                        , EClassData -> ENode
_best    :: ENode
                        , EClassData -> Consts
_consts  :: Consts
                        , EClassData -> Maybe Double
_fitness :: Maybe Double    -- NOTE: this cannot be NaN
                        , EClassData -> Maybe PVector
_theta   :: Maybe PVector
                        , EClassData -> Int
_size    :: Int
                        -- , _properties :: Property
                        -- TODO: include evaluation of expression from this e-class
                        } deriving (Int -> EClassData -> ShowS
[EClassData] -> ShowS
EClassData -> [Char]
(Int -> EClassData -> ShowS)
-> (EClassData -> [Char])
-> ([EClassData] -> ShowS)
-> Show EClassData
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EClassData -> ShowS
showsPrec :: Int -> EClassData -> ShowS
$cshow :: EClassData -> [Char]
show :: EClassData -> [Char]
$cshowList :: [EClassData] -> ShowS
showList :: [EClassData] -> ShowS
Show)

instance Eq EClassData where
  EData Int
c1 ENode
b1 Consts
cs1 Maybe Double
ft1 Maybe PVector
_ Int
s1 == :: EClassData -> EClassData -> Bool
== EData Int
c2 ENode
b2 Consts
cs2 Maybe Double
ft2 Maybe PVector
_ Int
s2 = Int
c1Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
c2 Bool -> Bool -> Bool
&& ENode
b1ENode -> ENode -> Bool
forall a. Eq a => a -> a -> Bool
==ENode
b2 Bool -> Bool -> Bool
&& Consts
cs1Consts -> Consts -> Bool
forall a. Eq a => a -> a -> Bool
==Consts
cs2 Bool -> Bool -> Bool
&& Maybe Double
ft1Maybe Double -> Maybe Double -> Bool
forall a. Eq a => a -> a -> Bool
==Maybe Double
ft2 Bool -> Bool -> Bool
&& Int
s1Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
s2

-- The database maps a symbol to an IntTrie
-- The IntTrie stores the possible paths from a certain e-class
-- that matches a pattern
type DB = Map (SRTree ()) IntTrie
-- The IntTrie is composed of the set of available keys (for convenience)
-- and an IntMap that maps one e-class id to the first child IntTrie,
-- the first child IntTrie will point to the next child and so on
data IntTrie = IntTrie { IntTrie -> HashSet Int
_keys :: HashSet EClassId, IntTrie -> IntMap IntTrie
_trie :: IntMap IntTrie } -- deriving Show

-- Shows the IntTrie as {keys} -> {show IntTries}
instance Show IntTrie where
  show :: IntTrie -> [Char]
show (IntTrie HashSet Int
k IntMap IntTrie
t) = let keys :: [Char]
keys  = [Char] -> [[Char]] -> [Char]
forall a. [a] -> [[a]] -> [a]
intercalate [Char]
"," ((Int -> [Char]) -> [Int] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map Int -> [Char]
forall a. Show a => a -> [Char]
show ([Int] -> [[Char]]) -> [Int] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ HashSet Int -> [Int]
forall a. HashSet a -> [a]
Set.toList HashSet Int
k)
                           tries :: [Char]
tries = [Char] -> [[Char]] -> [Char]
forall a. [a] -> [[a]] -> [a]
intercalate [Char]
"," (((Int, IntTrie) -> [Char]) -> [(Int, IntTrie)] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map (\(Int
k,IntTrie
v) -> Int -> [Char]
forall a. Show a => a -> [Char]
show Int
k [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
" -> " [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> IntTrie -> [Char]
forall a. Show a => a -> [Char]
show IntTrie
v) ([(Int, IntTrie)] -> [[Char]]) -> [(Int, IntTrie)] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ IntMap IntTrie -> [(Int, IntTrie)]
forall a. IntMap a -> [(Int, a)]
IntMap.toList IntMap IntTrie
t)
                       in [Char]
"{" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
keys [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"} - {" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
tries [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"}"

makeLenses ''EGraph
makeLenses ''EClass
makeLenses ''EClassData
makeLenses ''EGraphDB

-- * E-Graph basic supporting functions

-- | returns an empty e-graph
emptyGraph :: EGraph
emptyGraph :: EGraph
emptyGraph = ClassIdMap Int
-> Map ENode Int -> ClassIdMap EClass -> EGraphDB -> EGraph
EGraph ClassIdMap Int
forall a. IntMap a
IntMap.empty Map ENode Int
forall k a. Map k a
Map.empty ClassIdMap EClass
forall a. IntMap a
IntMap.empty EGraphDB
emptyDB

-- | returns an empty e-graph DB
emptyDB :: EGraphDB
emptyDB :: EGraphDB
emptyDB = HashSet (Int, ENode)
-> HashSet (Int, ENode)
-> DB
-> RangeTree Double
-> IntMap IntSet
-> IntMap (RangeTree Double)
-> IntSet
-> Int
-> EGraphDB
EDB HashSet (Int, ENode)
forall a. HashSet a
Set.empty HashSet (Int, ENode)
forall a. HashSet a
Set.empty DB
forall k a. Map k a
Map.empty RangeTree Double
forall a. Seq a
FingerTree.empty IntMap IntSet
forall a. IntMap a
IntMap.empty IntMap (RangeTree Double)
forall a. IntMap a
IntMap.empty IntSet
IntSet.empty Int
0

-- | Creates a new e-class from an e-class id, a new e-node,
-- and the info of this e-class 
createEClass :: EClassId -> ENode -> EClassData -> Int -> EClass
createEClass :: Int -> ENode -> EClassData -> Int -> EClass
createEClass Int
cId ENode
enode' EClassData
info Int
h = Int
-> HashSet ENodeEnc
-> HashSet (Int, ENode)
-> Int
-> EClassData
-> EClass
EClass Int
cId (ENodeEnc -> HashSet ENodeEnc
forall a. Hashable a => a -> HashSet a
Set.singleton (ENodeEnc -> HashSet ENodeEnc) -> ENodeEnc -> HashSet ENodeEnc
forall a b. (a -> b) -> a -> b
$ ENode -> ENodeEnc
encodeEnode ENode
enode') HashSet (Int, ENode)
forall a. HashSet a
Set.empty Int
h EClassData
info
{-# INLINE createEClass #-}

-- | gets the canonical id of an e-class
canonical :: Monad m => EClassId -> EGraphST m EClassId
canonical :: forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eclassId =
  do m <- (EGraph -> ClassIdMap Int) -> StateT EGraph m (ClassIdMap Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap Int
_canonicalMap
     let oneStep = ClassIdMap Int
m ClassIdMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IntMap.! Int
eclassId
     if oneStep == eclassId
        then pure eclassId
        else go m oneStep
    where
      go :: Monad m => IntMap EClassId -> EClassId -> EGraphST m EClassId
      go :: forall (m :: * -> *).
Monad m =>
ClassIdMap Int -> Int -> EGraphST m Int
go ClassIdMap Int
m Int
ecId
        | ClassIdMap Int
m ClassIdMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IntMap.! Int
ecId Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
ecId = do (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap Int) (ClassIdMap Int)
-> (ClassIdMap Int -> ClassIdMap Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap Int) (ClassIdMap Int)
Lens' EGraph (ClassIdMap Int)
canonicalMap (Int -> Int -> ClassIdMap Int -> ClassIdMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
eclassId Int
ecId) -- creates a shortcut for next time
                                       Int -> StateT EGraph m Int
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
ecId        -- if the e-class id is mapped to itself, it's canonical
        | Bool
otherwise        = ClassIdMap Int -> Int -> StateT EGraph m Int
forall (m :: * -> *).
Monad m =>
ClassIdMap Int -> Int -> EGraphST m Int
go ClassIdMap Int
m (ClassIdMap Int
m ClassIdMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IntMap.! Int
ecId)  -- otherwise, check the next id in the sequence
{-# INLINE canonical #-}

-- | canonize the e-node children
canonize :: Monad m => ENode -> EGraphST m ENode
canonize :: forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize = (Int -> StateT EGraph m Int) -> ENode -> StateT EGraph m ENode
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> SRTree a -> m (SRTree b)
mapM Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical  -- applies canonical to the children
{-# INLINE canonize #-}

-- | gets an e-class with id `c`
getEClass :: Monad m => EClassId -> EGraphST m EClass
getEClass :: forall (m :: * -> *). Monad m => Int -> EGraphST m EClass
getEClass Int
c = (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ClassIdMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
c) (ClassIdMap EClass -> EClass)
-> (EGraph -> ClassIdMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> ClassIdMap EClass
_eClass)
{-# INLINE getEClass #-}

-- | Creates a singleton trie from an e-class id
trie :: EClassId -> IntMap IntTrie -> IntTrie
trie :: Int -> IntMap IntTrie -> IntTrie
trie Int
eid = HashSet Int -> IntMap IntTrie -> IntTrie
IntTrie (Int -> HashSet Int
forall a. Hashable a => a -> HashSet a
Set.singleton Int
eid)

-- | Check whether an e-class is a constant value
isConst :: Monad m => EClassId -> EGraphST m Bool
isConst :: forall (m :: * -> *). Monad m => Int -> EGraphST m Bool
isConst Int
eid = do ec <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ClassIdMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
eid) (ClassIdMap EClass -> EClass)
-> (EGraph -> ClassIdMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> ClassIdMap EClass
_eClass)
                 case (_consts . _info) ec of
                   ConstVal Double
_ -> Bool -> StateT EGraph m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
                   Consts
_          -> Bool -> StateT EGraph m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
{-# INLINE isConst #-}