{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-}
module Algorithm.EqSat.Egraph where
import Control.Lens (element, makeLenses, view, over, (&), (+~), (-~), (.~), (^.))
import Data.List ( intercalate )
import Control.Monad.State.Strict
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IntMap
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.HashSet (HashSet)
import qualified Data.HashSet as Set
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Sequence ( Seq(..), (><) )
import qualified Data.Sequence as FingerTree
import Data.Foldable ( toList )
import Data.SRTree
import Data.SRTree.Eval
import Data.Hashable
import Debug.Trace
type EClassId = Int
type ClassIdMap = IntMap
type ENode = SRTree EClassId
type ENodeEnc = (Int, Int, Int, Double)
type EGraphST m a = StateT EGraph m a
type Cost = Int
type CostFun = SRTree Cost -> Cost
instance Hashable ENode where
hashWithSalt :: Int -> ENode -> Int
hashWithSalt Int
n ENode
enode = Int -> ENodeEnc -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
n (ENode -> ENodeEnc
encodeEnode ENode
enode)
type RangeTree a = Seq (a, EClassId)
encodeEnode :: ENode -> ENodeEnc
encodeEnode :: ENode -> ENodeEnc
encodeEnode (Var Int
ix) = (Int
0, Int
ix, -Int
1, Double
0)
encodeEnode (Param Int
ix) = (Int
1, Int
ix, -Int
1, Double
0)
encodeEnode (Const Double
x) = (Int
2, -Int
1, -Int
1, Double
x)
encodeEnode (Uni Function
f Int
ed) = (Int
300 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Function -> Int
forall a. Enum a => a -> Int
fromEnum Function
f, Int
ed, -Int
1, Double
0)
encodeEnode (Bin Op
op Int
ed1 Int
ed2) = (Int
400 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Op -> Int
forall a. Enum a => a -> Int
fromEnum Op
op, Int
ed1, Int
ed2, Double
0)
{-# INLINE encodeEnode #-}
decodeEnode :: ENodeEnc -> ENode
decodeEnode :: ENodeEnc -> ENode
decodeEnode (Int
0, Int
ix, Int
_, Double
_) = Int -> ENode
forall val. Int -> SRTree val
Var Int
ix
decodeEnode (Int
1, Int
ix, Int
_, Double
_) = Int -> ENode
forall val. Int -> SRTree val
Param Int
ix
decodeEnode (Int
2, Int
_, Int
_, Double
x) = Double -> ENode
forall val. Double -> SRTree val
Const Double
x
decodeEnode (Int
opCode, Int
arg1, Int
arg2, Double
arg3)
| Int
opCode Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
400 = Function -> Int -> ENode
forall val. Function -> val -> SRTree val
Uni (Int -> Function
forall a. Enum a => Int -> a
toEnum (Int -> Function) -> Int -> Function
forall a b. (a -> b) -> a -> b
$ Int
opCodeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
300) Int
arg1
| Bool
otherwise = Op -> Int -> Int -> ENode
forall val. Op -> val -> val -> SRTree val
Bin (Int -> Op
forall a. Enum a => Int -> a
toEnum (Int -> Op) -> Int -> Op
forall a b. (a -> b) -> a -> b
$ Int
opCodeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
400) Int
arg1 Int
arg2
{-# INLINE decodeEnode #-}
insertRange :: (Ord a, Show a) => EClassId -> a -> RangeTree a -> RangeTree a
insertRange :: forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
insertRange Int
eid a
x Seq (a, Int)
Empty = (a, Int) -> Seq (a, Int)
forall a. a -> Seq a
FingerTree.singleton (a
x, Int
eid)
insertRange Int
eid a
x ((a, Int)
y :<| Seq (a, Int)
_xs) | (a
x, Int
eid) (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
y = (a
x, Int
eid) (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| (a, Int)
y (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
_xs
insertRange Int
eid a
x (Seq (a, Int)
_xs :|> (a, Int)
y) | (a
x, Int
eid) (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y = Seq (a, Int)
_xs Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a
x, Int
eid)
insertRange Int
eid a
x Seq (a, Int)
rt = Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
rt
where
entry :: (a, Int)
entry = (a
x, Int
eid)
go :: Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
root = case Int -> Seq (a, Int) -> (Seq (a, Int), Seq (a, Int))
forall a. Int -> Seq a -> (Seq a, Seq a)
FingerTree.splitAt (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Seq (a, Int)
root of
(Seq (a, Int)
Empty, Seq (a, Int)
Empty) -> (a, Int) -> Seq (a, Int)
forall a. a -> Seq a
FingerTree.singleton (a, Int)
entry
(Seq (a, Int)
Empty, (a, Int)
z :<| Seq (a, Int)
zs) | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
z -> (a, Int)
entry (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| (a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs
| Bool
otherwise -> (a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| (Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
zs)
(Seq (a, Int)
ys :|> (a, Int)
y, Seq (a, Int)
Empty) | (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y -> Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
entry
| Bool
otherwise -> (Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
ys) Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y
(Seq (a, Int)
ys :|> (a, Int)
y, (a, Int)
z :<| Seq (a, Int)
zs)
| (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y Bool -> Bool -> Bool
&& (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
z -> (Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
entry) Seq (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. Seq a -> Seq a -> Seq a
>< ((a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs)
| (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
z -> (Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y) Seq (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. Seq a -> Seq a -> Seq a
>< Seq (a, Int) -> Seq (a, Int)
go ((a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs)
| (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
y -> Seq (a, Int) -> Seq (a, Int)
go (Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y) Seq (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. Seq a -> Seq a -> Seq a
>< ((a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs)
| Bool
otherwise -> Seq (a, Int)
root
where
n :: Int
n = Seq (a, Int) -> Int
forall a. Seq a -> Int
FingerTree.length Seq (a, Int)
root
removeRange :: (Ord a, Show a) => EClassId -> a -> RangeTree a -> RangeTree a
removeRange :: forall a. (Ord a, Show a) => Int -> a -> RangeTree a -> RangeTree a
removeRange Int
eid a
x Seq (a, Int)
Empty = Seq (a, Int)
forall a. Seq a
Empty
removeRange Int
eid a
x ((a, Int)
y :<| Seq (a, Int)
_xs) | (a
x, Int
eid) (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
y = ((a, Int)
y (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
_xs)
removeRange Int
eid a
x (Seq (a, Int)
_xs :|> (a, Int)
y) | (a
x, Int
eid) (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y = (Seq (a, Int)
_xs Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y)
removeRange Int
eid a
x Seq (a, Int)
rt = Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
rt
where
entry :: (a, Int)
entry = (a
x, Int
eid)
go :: Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
root = case Int -> Seq (a, Int) -> (Seq (a, Int), Seq (a, Int))
forall a. Int -> Seq a -> (Seq a, Seq a)
FingerTree.splitAt (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Seq (a, Int)
root of
(Seq (a, Int)
Empty, Seq (a, Int)
Empty) -> Seq (a, Int)
root
(Seq (a, Int)
Empty, (a, Int)
z :<| Seq (a, Int)
zs)
| (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
z -> (a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs
| (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== (a, Int)
z -> Seq (a, Int)
zs
| Bool
otherwise -> (a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| (Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
zs)
(Seq (a, Int)
ys :|> (a, Int)
y, Seq (a, Int)
Empty)
| (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y -> Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y
| (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== (a, Int)
y -> Seq (a, Int)
ys
| Bool
otherwise -> (Seq (a, Int) -> Seq (a, Int)
go Seq (a, Int)
ys) Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y
(Seq (a, Int)
ys :|> (a, Int)
y, (a, Int)
z :<| Seq (a, Int)
zs)
| (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
y Bool -> Bool -> Bool
&& (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
z -> Seq (a, Int)
root
| (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> (a, Int)
z -> (Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y) Seq (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. Seq a -> Seq a -> Seq a
>< Seq (a, Int) -> Seq (a, Int)
go ((a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs)
| (a, Int)
entry (a, Int) -> (a, Int) -> Bool
forall a. Ord a => a -> a -> Bool
< (a, Int)
y -> Seq (a, Int) -> Seq (a, Int)
go (Seq (a, Int)
ys Seq (a, Int) -> (a, Int) -> Seq (a, Int)
forall a. Seq a -> a -> Seq a
:|> (a, Int)
y) Seq (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. Seq a -> Seq a -> Seq a
>< ((a, Int)
z (a, Int) -> Seq (a, Int) -> Seq (a, Int)
forall a. a -> Seq a -> Seq a
:<| Seq (a, Int)
zs)
| Bool
otherwise -> Seq (a, Int)
root
where
n :: Int
n = Seq (a, Int) -> Int
forall a. Seq a -> Int
FingerTree.length Seq (a, Int)
root
getWithinRange :: Ord a => a -> a -> RangeTree a -> [EClassId]
getWithinRange :: forall a. Ord a => a -> a -> RangeTree a -> [Int]
getWithinRange a
lb a
ub RangeTree a
rt = ((Any, Int) -> Int) -> [(Any, Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Any, Int) -> Int
forall a b. (a, b) -> b
snd ([(Any, Int)] -> [Int])
-> (Seq (Any, Int) -> [(Any, Int)]) -> Seq (Any, Int) -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq (Any, Int) -> [(Any, Int)]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq (Any, Int) -> [Int]) -> Seq (Any, Int) -> [Int]
forall a b. (a -> b) -> a -> b
$ RangeTree a -> Seq (Any, Int)
forall {b} {a}. Seq (a, b) -> Seq a
go RangeTree a
rt
where
go :: Seq (a, b) -> Seq a
go Seq (a, b)
Empty = Seq a
forall a. Seq a
Empty
go Seq (a, b)
root = case Int -> Seq (a, b) -> (Seq (a, b), Seq (a, b))
forall a. Int -> Seq a -> (Seq a, Seq a)
FingerTree.splitAt (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Seq (a, b)
root of
(Seq (a, b)
Empty, Seq (a, b)
Empty) -> Seq a
forall a. Seq a
Empty
(Seq (a, b)
ys :|> (a, b)
y, Seq (a, b)
Empty)
| (a, b) -> a
forall a b. (a, b) -> a
fst (a, b)
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
lb -> Seq a
forall a. Seq a
Empty
| Bool
otherwise -> Seq (a, b) -> Seq a
go (Seq (a, b)
ys Seq (a, b) -> (a, b) -> Seq (a, b)
forall a. Seq a -> a -> Seq a
:|> (a, b)
y)
(Seq (a, b)
Empty, (a, b)
z :<| Seq (a, b)
zs)
| (a, b) -> a
forall a b. (a, b) -> a
fst (a, b)
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
ub -> Seq a
forall a. Seq a
Empty
| Bool
otherwise -> Seq (a, b) -> Seq a
go ((a, b)
z (a, b) -> Seq (a, b) -> Seq (a, b)
forall a. a -> Seq a -> Seq a
:<| Seq (a, b)
zs)
(Seq (a, b)
ys :|> (a, b)
y, (a, b)
z :<| Seq (a, b)
zs)
| (a, b) -> a
forall a b. (a, b) -> a
fst (a, b)
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
lb -> Seq (a, b) -> Seq a
go ((a, b)
z (a, b) -> Seq (a, b) -> Seq (a, b)
forall a. a -> Seq a -> Seq a
:<| Seq (a, b)
zs)
| (a, b) -> a
forall a b. (a, b) -> a
fst (a, b)
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
ub -> Seq (a, b) -> Seq a
go (Seq (a, b)
ys Seq (a, b) -> (a, b) -> Seq (a, b)
forall a. Seq a -> a -> Seq a
:|> (a, b)
y)
| Bool
otherwise -> Seq (a, b) -> Seq a
go (Seq (a, b)
ys Seq (a, b) -> (a, b) -> Seq (a, b)
forall a. Seq a -> a -> Seq a
:|> (a, b)
y) Seq a -> Seq a -> Seq a
forall a. Seq a -> Seq a -> Seq a
>< Seq (a, b) -> Seq a
go ((a, b)
z (a, b) -> Seq (a, b) -> Seq (a, b)
forall a. a -> Seq a -> Seq a
:<| Seq (a, b)
zs)
where
n :: Int
n = Seq (a, b) -> Int
forall a. Seq a -> Int
FingerTree.length Seq (a, b)
root
getSmallest :: Ord a => RangeTree a -> (a, EClassId)
getSmallest :: forall a. Ord a => RangeTree a -> (a, Int)
getSmallest RangeTree a
rt = case RangeTree a
rt of
RangeTree a
Empty -> [Char] -> (a, Int)
forall a. HasCallStack => [Char] -> a
error [Char]
"empty finger"
(a, Int)
x :<| RangeTree a
t -> (a, Int)
x
getGreatest :: Ord a => RangeTree a -> (a, EClassId)
getGreatest :: forall a. Ord a => RangeTree a -> (a, Int)
getGreatest RangeTree a
rt = case RangeTree a
rt of
RangeTree a
Empty -> [Char] -> (a, Int)
forall a. HasCallStack => [Char] -> a
error [Char]
"empty finger"
RangeTree a
t :|> (a, Int)
x -> (a, Int)
x
data EGraph = EGraph { EGraph -> ClassIdMap Int
_canonicalMap :: ClassIdMap EClassId
, EGraph -> Map ENode Int
_eNodeToEClass :: Map ENode EClassId
, EGraph -> ClassIdMap EClass
_eClass :: ClassIdMap EClass
, EGraph -> EGraphDB
_eDB :: EGraphDB
} deriving Int -> EGraph -> ShowS
[EGraph] -> ShowS
EGraph -> [Char]
(Int -> EGraph -> ShowS)
-> (EGraph -> [Char]) -> ([EGraph] -> ShowS) -> Show EGraph
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EGraph -> ShowS
showsPrec :: Int -> EGraph -> ShowS
$cshow :: EGraph -> [Char]
show :: EGraph -> [Char]
$cshowList :: [EGraph] -> ShowS
showList :: [EGraph] -> ShowS
Show
data EGraphDB = EDB { EGraphDB -> HashSet (Int, ENode)
_worklist :: HashSet (EClassId, ENode)
, EGraphDB -> HashSet (Int, ENode)
_analysis :: HashSet (EClassId, ENode)
, EGraphDB -> DB
_patDB :: DB
, EGraphDB -> RangeTree Double
_fitRangeDB :: RangeTree Double
, EGraphDB -> IntMap IntSet
_sizeDB :: IntMap IntSet
, EGraphDB -> IntMap (RangeTree Double)
_sizeFitDB :: IntMap (RangeTree Double)
, EGraphDB -> IntSet
_unevaluated :: IntSet
, EGraphDB -> Int
_nextId :: Int
} deriving Int -> EGraphDB -> ShowS
[EGraphDB] -> ShowS
EGraphDB -> [Char]
(Int -> EGraphDB -> ShowS)
-> (EGraphDB -> [Char]) -> ([EGraphDB] -> ShowS) -> Show EGraphDB
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EGraphDB -> ShowS
showsPrec :: Int -> EGraphDB -> ShowS
$cshow :: EGraphDB -> [Char]
show :: EGraphDB -> [Char]
$cshowList :: [EGraphDB] -> ShowS
showList :: [EGraphDB] -> ShowS
Show
data EClass = EClass { EClass -> Int
_eClassId :: Int
, EClass -> HashSet ENodeEnc
_eNodes :: HashSet ENodeEnc
, EClass -> HashSet (Int, ENode)
_parents :: HashSet (EClassId, ENode)
, EClass -> Int
_height :: Int
, EClass -> EClassData
_info :: EClassData
} deriving (Int -> EClass -> ShowS
[EClass] -> ShowS
EClass -> [Char]
(Int -> EClass -> ShowS)
-> (EClass -> [Char]) -> ([EClass] -> ShowS) -> Show EClass
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EClass -> ShowS
showsPrec :: Int -> EClass -> ShowS
$cshow :: EClass -> [Char]
show :: EClass -> [Char]
$cshowList :: [EClass] -> ShowS
showList :: [EClass] -> ShowS
Show, EClass -> EClass -> Bool
(EClass -> EClass -> Bool)
-> (EClass -> EClass -> Bool) -> Eq EClass
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: EClass -> EClass -> Bool
== :: EClass -> EClass -> Bool
$c/= :: EClass -> EClass -> Bool
/= :: EClass -> EClass -> Bool
Eq)
data Consts = NotConst | ParamIx Int | ConstVal Double deriving (Int -> Consts -> ShowS
[Consts] -> ShowS
Consts -> [Char]
(Int -> Consts -> ShowS)
-> (Consts -> [Char]) -> ([Consts] -> ShowS) -> Show Consts
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Consts -> ShowS
showsPrec :: Int -> Consts -> ShowS
$cshow :: Consts -> [Char]
show :: Consts -> [Char]
$cshowList :: [Consts] -> ShowS
showList :: [Consts] -> ShowS
Show, Consts -> Consts -> Bool
(Consts -> Consts -> Bool)
-> (Consts -> Consts -> Bool) -> Eq Consts
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Consts -> Consts -> Bool
== :: Consts -> Consts -> Bool
$c/= :: Consts -> Consts -> Bool
/= :: Consts -> Consts -> Bool
Eq)
data Property = Positive | Negative | NonZero | Real deriving (Int -> Property -> ShowS
[Property] -> ShowS
Property -> [Char]
(Int -> Property -> ShowS)
-> (Property -> [Char]) -> ([Property] -> ShowS) -> Show Property
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Property -> ShowS
showsPrec :: Int -> Property -> ShowS
$cshow :: Property -> [Char]
show :: Property -> [Char]
$cshowList :: [Property] -> ShowS
showList :: [Property] -> ShowS
Show, Property -> Property -> Bool
(Property -> Property -> Bool)
-> (Property -> Property -> Bool) -> Eq Property
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Property -> Property -> Bool
== :: Property -> Property -> Bool
$c/= :: Property -> Property -> Bool
/= :: Property -> Property -> Bool
Eq)
data EClassData = EData { EClassData -> Int
_cost :: Cost
, EClassData -> ENode
_best :: ENode
, EClassData -> Consts
_consts :: Consts
, EClassData -> Maybe Double
_fitness :: Maybe Double
, EClassData -> Maybe PVector
_theta :: Maybe PVector
, EClassData -> Int
_size :: Int
} deriving (Int -> EClassData -> ShowS
[EClassData] -> ShowS
EClassData -> [Char]
(Int -> EClassData -> ShowS)
-> (EClassData -> [Char])
-> ([EClassData] -> ShowS)
-> Show EClassData
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EClassData -> ShowS
showsPrec :: Int -> EClassData -> ShowS
$cshow :: EClassData -> [Char]
show :: EClassData -> [Char]
$cshowList :: [EClassData] -> ShowS
showList :: [EClassData] -> ShowS
Show)
instance Eq EClassData where
EData Int
c1 ENode
b1 Consts
cs1 Maybe Double
ft1 Maybe PVector
_ Int
s1 == :: EClassData -> EClassData -> Bool
== EData Int
c2 ENode
b2 Consts
cs2 Maybe Double
ft2 Maybe PVector
_ Int
s2 = Int
c1Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
c2 Bool -> Bool -> Bool
&& ENode
b1ENode -> ENode -> Bool
forall a. Eq a => a -> a -> Bool
==ENode
b2 Bool -> Bool -> Bool
&& Consts
cs1Consts -> Consts -> Bool
forall a. Eq a => a -> a -> Bool
==Consts
cs2 Bool -> Bool -> Bool
&& Maybe Double
ft1Maybe Double -> Maybe Double -> Bool
forall a. Eq a => a -> a -> Bool
==Maybe Double
ft2 Bool -> Bool -> Bool
&& Int
s1Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
s2
type DB = Map (SRTree ()) IntTrie
data IntTrie = IntTrie { IntTrie -> HashSet Int
_keys :: HashSet EClassId, IntTrie -> IntMap IntTrie
_trie :: IntMap IntTrie }
instance Show IntTrie where
show :: IntTrie -> [Char]
show (IntTrie HashSet Int
k IntMap IntTrie
t) = let keys :: [Char]
keys = [Char] -> [[Char]] -> [Char]
forall a. [a] -> [[a]] -> [a]
intercalate [Char]
"," ((Int -> [Char]) -> [Int] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map Int -> [Char]
forall a. Show a => a -> [Char]
show ([Int] -> [[Char]]) -> [Int] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ HashSet Int -> [Int]
forall a. HashSet a -> [a]
Set.toList HashSet Int
k)
tries :: [Char]
tries = [Char] -> [[Char]] -> [Char]
forall a. [a] -> [[a]] -> [a]
intercalate [Char]
"," (((Int, IntTrie) -> [Char]) -> [(Int, IntTrie)] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map (\(Int
k,IntTrie
v) -> Int -> [Char]
forall a. Show a => a -> [Char]
show Int
k [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
" -> " [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> IntTrie -> [Char]
forall a. Show a => a -> [Char]
show IntTrie
v) ([(Int, IntTrie)] -> [[Char]]) -> [(Int, IntTrie)] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ IntMap IntTrie -> [(Int, IntTrie)]
forall a. IntMap a -> [(Int, a)]
IntMap.toList IntMap IntTrie
t)
in [Char]
"{" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
keys [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"} - {" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
tries [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"}"
makeLenses ''EGraph
makeLenses ''EClass
makeLenses ''EClassData
makeLenses ''EGraphDB
emptyGraph :: EGraph
emptyGraph :: EGraph
emptyGraph = ClassIdMap Int
-> Map ENode Int -> ClassIdMap EClass -> EGraphDB -> EGraph
EGraph ClassIdMap Int
forall a. IntMap a
IntMap.empty Map ENode Int
forall k a. Map k a
Map.empty ClassIdMap EClass
forall a. IntMap a
IntMap.empty EGraphDB
emptyDB
emptyDB :: EGraphDB
emptyDB :: EGraphDB
emptyDB = HashSet (Int, ENode)
-> HashSet (Int, ENode)
-> DB
-> RangeTree Double
-> IntMap IntSet
-> IntMap (RangeTree Double)
-> IntSet
-> Int
-> EGraphDB
EDB HashSet (Int, ENode)
forall a. HashSet a
Set.empty HashSet (Int, ENode)
forall a. HashSet a
Set.empty DB
forall k a. Map k a
Map.empty RangeTree Double
forall a. Seq a
FingerTree.empty IntMap IntSet
forall a. IntMap a
IntMap.empty IntMap (RangeTree Double)
forall a. IntMap a
IntMap.empty IntSet
IntSet.empty Int
0
createEClass :: EClassId -> ENode -> EClassData -> Int -> EClass
createEClass :: Int -> ENode -> EClassData -> Int -> EClass
createEClass Int
cId ENode
enode' EClassData
info Int
h = Int
-> HashSet ENodeEnc
-> HashSet (Int, ENode)
-> Int
-> EClassData
-> EClass
EClass Int
cId (ENodeEnc -> HashSet ENodeEnc
forall a. Hashable a => a -> HashSet a
Set.singleton (ENodeEnc -> HashSet ENodeEnc) -> ENodeEnc -> HashSet ENodeEnc
forall a b. (a -> b) -> a -> b
$ ENode -> ENodeEnc
encodeEnode ENode
enode') HashSet (Int, ENode)
forall a. HashSet a
Set.empty Int
h EClassData
info
{-# INLINE createEClass #-}
canonical :: Monad m => EClassId -> EGraphST m EClassId
canonical :: forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical Int
eclassId =
do m <- (EGraph -> ClassIdMap Int) -> StateT EGraph m (ClassIdMap Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets EGraph -> ClassIdMap Int
_canonicalMap
let oneStep = ClassIdMap Int
m ClassIdMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IntMap.! Int
eclassId
if oneStep == eclassId
then pure eclassId
else go m oneStep
where
go :: Monad m => IntMap EClassId -> EClassId -> EGraphST m EClassId
go :: forall (m :: * -> *).
Monad m =>
ClassIdMap Int -> Int -> EGraphST m Int
go ClassIdMap Int
m Int
ecId
| ClassIdMap Int
m ClassIdMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IntMap.! Int
ecId Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
ecId = do (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph (ClassIdMap Int) (ClassIdMap Int)
-> (ClassIdMap Int -> ClassIdMap Int) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EGraph EGraph (ClassIdMap Int) (ClassIdMap Int)
Lens' EGraph (ClassIdMap Int)
canonicalMap (Int -> Int -> ClassIdMap Int -> ClassIdMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
eclassId Int
ecId)
Int -> StateT EGraph m Int
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
ecId
| Bool
otherwise = ClassIdMap Int -> Int -> StateT EGraph m Int
forall (m :: * -> *).
Monad m =>
ClassIdMap Int -> Int -> EGraphST m Int
go ClassIdMap Int
m (ClassIdMap Int
m ClassIdMap Int -> Int -> Int
forall a. IntMap a -> Int -> a
IntMap.! Int
ecId)
{-# INLINE canonical #-}
canonize :: Monad m => ENode -> EGraphST m ENode
canonize :: forall (m :: * -> *). Monad m => ENode -> EGraphST m ENode
canonize = (Int -> StateT EGraph m Int) -> ENode -> StateT EGraph m ENode
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> SRTree a -> m (SRTree b)
mapM Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical
{-# INLINE canonize #-}
getEClass :: Monad m => EClassId -> EGraphST m EClass
getEClass :: forall (m :: * -> *). Monad m => Int -> EGraphST m EClass
getEClass Int
c = (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ClassIdMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
c) (ClassIdMap EClass -> EClass)
-> (EGraph -> ClassIdMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> ClassIdMap EClass
_eClass)
{-# INLINE getEClass #-}
trie :: EClassId -> IntMap IntTrie -> IntTrie
trie :: Int -> IntMap IntTrie -> IntTrie
trie Int
eid = HashSet Int -> IntMap IntTrie -> IntTrie
IntTrie (Int -> HashSet Int
forall a. Hashable a => a -> HashSet a
Set.singleton Int
eid)
isConst :: Monad m => EClassId -> EGraphST m Bool
isConst :: forall (m :: * -> *). Monad m => Int -> EGraphST m Bool
isConst Int
eid = do ec <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ClassIdMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IntMap.! Int
eid) (ClassIdMap EClass -> EClass)
-> (EGraph -> ClassIdMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> ClassIdMap EClass
_eClass)
case (_consts . _info) ec of
ConstVal Double
_ -> Bool -> StateT EGraph m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
Consts
_ -> Bool -> StateT EGraph m Bool
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
{-# INLINE isConst #-}