{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}
module Algorithm.EqSat.DB where
import Algorithm.EqSat.Egraph
import Control.Lens ( over )
import Control.Monad (when, foldM, forM)
import Control.Monad.State
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.List (intercalate, nub, sortBy)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Ord (comparing)
import Data.SRTree
import Data.HashSet (HashSet)
import qualified Data.HashSet as Set
import Data.String (IsString (..))
import Debug.Trace
data Pattern = Fixed (SRTree Pattern) | VarPat Char deriving Int -> Pattern -> ShowS
[Pattern] -> ShowS
Pattern -> String
(Int -> Pattern -> ShowS)
-> (Pattern -> String) -> ([Pattern] -> ShowS) -> Show Pattern
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Pattern -> ShowS
showsPrec :: Int -> Pattern -> ShowS
$cshow :: Pattern -> String
show :: Pattern -> String
$cshowList :: [Pattern] -> ShowS
showList :: [Pattern] -> ShowS
Show
instance IsString Pattern where
fromString :: String -> Pattern
fromString [] = String -> Pattern
forall a. HasCallStack => String -> a
error String
"empty string in VarPat"
fromString [Char
c] | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
65 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
122 = Char -> Pattern
VarPat Char
c where n :: Int
n = Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
c
fromString String
s = String -> Pattern
forall a. HasCallStack => String -> a
error (String -> Pattern) -> String -> Pattern
forall a b. (a -> b) -> a -> b
$ String
"invalid string in VarPat: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
s
data Rule = Pattern :=> Pattern | Pattern :==: Pattern | Rule :| Condition
infix 3 :=>
infix 3 :==:
infixl 2 :|
instance Show Rule where
show :: Rule -> String
show (Pattern
a :=> Pattern
b) = Pattern -> String
forall a. Show a => a -> String
show Pattern
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" => " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Pattern -> String
forall a. Show a => a -> String
show Pattern
b
show (Pattern
a :==: Pattern
b) = Pattern -> String
forall a. Show a => a -> String
show Pattern
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" == " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Pattern -> String
forall a. Show a => a -> String
show Pattern
b
show (Rule
a :| Condition
b) = Rule -> String
forall a. Show a => a -> String
show Rule
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" | <cond>"
type Query = [Atom]
type Condition = Map ClassOrVar ClassOrVar -> EGraph -> Bool
type ClassOrVar = Either EClassId Int
data Atom = Atom ClassOrVar (SRTree ClassOrVar) deriving Int -> Atom -> ShowS
[Atom] -> ShowS
Atom -> String
(Int -> Atom -> ShowS)
-> (Atom -> String) -> ([Atom] -> ShowS) -> Show Atom
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Atom -> ShowS
showsPrec :: Int -> Atom -> ShowS
$cshow :: Atom -> String
show :: Atom -> String
$cshowList :: [Atom] -> ShowS
showList :: [Atom] -> ShowS
Show
unFixPat :: Pattern -> SRTree Pattern
unFixPat :: Pattern -> SRTree Pattern
unFixPat (Fixed SRTree Pattern
p) = SRTree Pattern
p
{-# INLINE unFixPat #-}
instance Num Pattern where
Pattern
l + :: Pattern -> Pattern -> Pattern
+ Pattern
r = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Op -> Pattern -> Pattern -> SRTree Pattern
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Pattern
l Pattern
r
{-# INLINE (+) #-}
Pattern
l - :: Pattern -> Pattern -> Pattern
- Pattern
r = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Op -> Pattern -> Pattern -> SRTree Pattern
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Pattern
l Pattern
r
{-# INLINE (-) #-}
Pattern
l * :: Pattern -> Pattern -> Pattern
* Pattern
r = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Op -> Pattern -> Pattern -> SRTree Pattern
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Pattern
l Pattern
r
{-# INLINE (*) #-}
abs :: Pattern -> Pattern
abs = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Abs
{-# INLINE abs #-}
negate :: Pattern -> Pattern
negate Pattern
t = SRTree Pattern -> Pattern
Fixed (Double -> SRTree Pattern
forall val. Double -> SRTree val
Const (-Double
1)) Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
t
{-# INLINE negate #-}
signum :: Pattern -> Pattern
signum Pattern
t = case Pattern
t of
Fixed (Const Double
x) -> SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Double -> SRTree Pattern) -> Double -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree Pattern
forall val. Double -> SRTree val
Const (Double -> Pattern) -> Double -> Pattern
forall a b. (a -> b) -> a -> b
$ Double -> Double
forall a. Num a => a -> a
signum Double
x
Pattern
_ -> SRTree Pattern -> Pattern
Fixed (Double -> SRTree Pattern
forall val. Double -> SRTree val
Const Double
0)
fromInteger :: Integer -> Pattern
fromInteger Integer
x = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Double -> SRTree Pattern
forall val. Double -> SRTree val
Const (Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
x)
{-# INLINE fromInteger #-}
instance Fractional Pattern where
Pattern
l / :: Pattern -> Pattern -> Pattern
/ Pattern
r = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Op -> Pattern -> Pattern -> SRTree Pattern
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Pattern
l Pattern
r
{-# INLINE (/) #-}
fromRational :: Rational -> Pattern
fromRational = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Rational -> SRTree Pattern) -> Rational -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree Pattern
forall val. Double -> SRTree val
Const (Double -> SRTree Pattern)
-> (Rational -> Double) -> Rational -> SRTree Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> Double
forall a. Fractional a => Rational -> a
fromRational
{-# INLINE fromRational #-}
instance Floating Pattern where
pi :: Pattern
pi = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Double -> SRTree Pattern
forall val. Double -> SRTree val
Const Double
forall a. Floating a => a
pi
{-# INLINE pi #-}
exp :: Pattern -> Pattern
exp = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Exp
{-# INLINE exp #-}
log :: Pattern -> Pattern
log = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Log
{-# INLINE log #-}
sqrt :: Pattern -> Pattern
sqrt = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Sqrt
{-# INLINE sqrt #-}
sin :: Pattern -> Pattern
sin = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Sin
{-# INLINE sin #-}
cos :: Pattern -> Pattern
cos = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Cos
{-# INLINE cos #-}
tan :: Pattern -> Pattern
tan = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Tan
{-# INLINE tan #-}
asin :: Pattern -> Pattern
asin = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ASin
{-# INLINE asin #-}
acos :: Pattern -> Pattern
acos = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ACos
{-# INLINE acos #-}
atan :: Pattern -> Pattern
atan = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ATan
{-# INLINE atan #-}
sinh :: Pattern -> Pattern
sinh = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Sinh
{-# INLINE sinh #-}
cosh :: Pattern -> Pattern
cosh = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Cosh
{-# INLINE cosh #-}
tanh :: Pattern -> Pattern
tanh = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Tanh
{-# INLINE tanh #-}
asinh :: Pattern -> Pattern
asinh = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ASinh
{-# INLINE asinh #-}
acosh :: Pattern -> Pattern
acosh = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ACosh
{-# INLINE acosh #-}
atanh :: Pattern -> Pattern
atanh = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ATanh
{-# INLINE atanh #-}
Pattern
l ** :: Pattern -> Pattern -> Pattern
** Pattern
r = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Op -> Pattern -> Pattern -> SRTree Pattern
forall val. Op -> val -> val -> SRTree val
Bin Op
Power Pattern
l Pattern
r
{-# INLINE (**) #-}
logBase :: Pattern -> Pattern -> Pattern
logBase Pattern
l Pattern
r = Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
l Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
r
{-# INLINE logBase #-}
target :: Rule -> Pattern
target :: Rule -> Pattern
target (Rule
r :| Condition
_) = Rule -> Pattern
target Rule
r
target (Pattern
_ :=> Pattern
t) = Pattern
t
target (Pattern
_ :==: Pattern
t) = Pattern
t
source :: Rule -> Pattern
source :: Rule -> Pattern
source (Rule
r :| Condition
_) = Rule -> Pattern
source Rule
r
source (Pattern
s :=> Pattern
_) = Pattern
s
source (Pattern
s :==: Pattern
_) = Pattern
s
getConditions :: Rule -> [Condition]
getConditions :: Rule -> [Condition]
getConditions (Rule
r :| Condition
c) = Condition
c Condition -> [Condition] -> [Condition]
forall a. a -> [a] -> [a]
: Rule -> [Condition]
getConditions Rule
r
getConditions Rule
_ = []
cleanDB :: Monad m => EGraphST m ()
cleanDB :: forall (m :: * -> *). Monad m => EGraphST m ()
cleanDB = (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph DB DB -> (DB -> DB) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((DB -> Identity DB) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph DB DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DB -> Identity DB) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB DB
patDB) (DB -> DB -> DB
forall a b. a -> b -> a
const DB
forall k a. Map k a
Map.empty)
match :: Monad m => Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match :: forall (m :: * -> *).
Monad m =>
Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match Pattern
src = do
let ([Atom]
q, ClassOrVar
root) = Pattern -> ([Atom], ClassOrVar)
compileToQuery Pattern
src
substs <- [Atom] -> ClassOrVar -> EGraphST m [Map ClassOrVar ClassOrVar]
forall (m :: * -> *).
Monad m =>
[Atom] -> ClassOrVar -> EGraphST m [Map ClassOrVar ClassOrVar]
genericJoin [Atom]
q ClassOrVar
root
pure [(s, s Map.! root) | s <- substs, Map.size s > 0]
compileToQuery :: Pattern -> (Query, ClassOrVar)
compileToQuery :: Pattern -> ([Atom], ClassOrVar)
compileToQuery Pattern
pat = State Int ([Atom], ClassOrVar) -> Int -> ([Atom], ClassOrVar)
forall s a. State s a -> s -> a
evalState (Pattern -> State Int ([Atom], ClassOrVar)
processPat Pattern
pat) Int
256
where
processPat :: Pattern -> State Int (Query, ClassOrVar)
processPat :: Pattern -> State Int ([Atom], ClassOrVar)
processPat (VarPat Char
x) = ([Atom], ClassOrVar) -> State Int ([Atom], ClassOrVar)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], Int -> ClassOrVar
forall a b. b -> Either a b
Right (Int -> ClassOrVar) -> Int -> ClassOrVar
forall a b. (a -> b) -> a -> b
$ Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
x)
processPat (Fixed SRTree Pattern
pat) = do
v <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get
let root = Int -> Either a Int
forall a b. b -> Either a b
Right Int
v
modify (+1)
patChilds <- mapM processPat (getElems pat)
let atoms = (([Atom], ClassOrVar) -> [Atom])
-> [([Atom], ClassOrVar)] -> [Atom]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ([Atom], ClassOrVar) -> [Atom]
forall a b. (a, b) -> a
fst [([Atom], ClassOrVar)]
patChilds
roots = (([Atom], ClassOrVar) -> ClassOrVar)
-> [([Atom], ClassOrVar)] -> [ClassOrVar]
forall a b. (a -> b) -> [a] -> [b]
map ([Atom], ClassOrVar) -> ClassOrVar
forall a b. (a, b) -> b
snd [([Atom], ClassOrVar)]
patChilds
atom = ClassOrVar -> SRTree ClassOrVar -> Atom
Atom ClassOrVar
forall {a}. Either a Int
root ([ClassOrVar] -> SRTree Pattern -> SRTree ClassOrVar
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [ClassOrVar]
roots SRTree Pattern
pat)
atoms' = Atom
atomAtom -> [Atom] -> [Atom]
forall a. a -> [a] -> [a]
:[Atom]
atoms
pure (atoms', root)
getInt :: ClassOrVar -> Int
getInt :: ClassOrVar -> Int
getInt (Left Int
a) = Int
a
getInt (Right Int
a) = Int
a
getElems :: SRTree a -> [a]
getElems :: forall a. SRTree a -> [a]
getElems (Bin Op
_ a
l a
r) = [a
l,a
r]
getElems (Uni Function
_ a
t) = [a
t]
getElems SRTree a
_ = []
genericJoin :: Monad m => Query -> ClassOrVar -> EGraphST m [Map ClassOrVar ClassOrVar]
genericJoin :: forall (m :: * -> *).
Monad m =>
[Atom] -> ClassOrVar -> EGraphST m [Map ClassOrVar ClassOrVar]
genericJoin [Atom]
atoms ClassOrVar
root = do
let vars :: [ClassOrVar]
vars = [Atom] -> [ClassOrVar]
orderedVars [Atom]
atoms
[Atom] -> [ClassOrVar] -> EGraphST m [Map ClassOrVar ClassOrVar]
forall (m :: * -> *).
Monad m =>
[Atom] -> [ClassOrVar] -> EGraphST m [Map ClassOrVar ClassOrVar]
go [Atom]
atoms [ClassOrVar]
vars
where
go :: Monad m => Query -> [ClassOrVar] -> EGraphST m [Map ClassOrVar ClassOrVar]
go :: forall (m :: * -> *).
Monad m =>
[Atom] -> [ClassOrVar] -> EGraphST m [Map ClassOrVar ClassOrVar]
go [Atom]
atoms [] = [Map ClassOrVar ClassOrVar]
-> StateT EGraph m [Map ClassOrVar ClassOrVar]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Map ClassOrVar ClassOrVar
forall k a. Map k a
Map.empty]
go [Atom]
atoms (ClassOrVar
x:[ClassOrVar]
vars) = do cIds1 <- ClassOrVar -> [Atom] -> ClassOrVar -> EGraphST m [ClassOrVar]
forall (m :: * -> *).
Monad m =>
ClassOrVar -> [Atom] -> ClassOrVar -> EGraphST m [ClassOrVar]
domainX ClassOrVar
x [Atom]
atoms ClassOrVar
root
maps <- forM cIds1 $ \ClassOrVar
classId -> do
(Map ClassOrVar ClassOrVar -> Map ClassOrVar ClassOrVar)
-> [Map ClassOrVar ClassOrVar] -> [Map ClassOrVar ClassOrVar]
forall a b. (a -> b) -> [a] -> [b]
map (ClassOrVar
-> ClassOrVar
-> Map ClassOrVar ClassOrVar
-> Map ClassOrVar ClassOrVar
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ClassOrVar
x ClassOrVar
classId) ([Map ClassOrVar ClassOrVar] -> [Map ClassOrVar ClassOrVar])
-> StateT EGraph m [Map ClassOrVar ClassOrVar]
-> StateT EGraph m [Map ClassOrVar ClassOrVar]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Atom]
-> [ClassOrVar] -> StateT EGraph m [Map ClassOrVar ClassOrVar]
forall (m :: * -> *).
Monad m =>
[Atom] -> [ClassOrVar] -> EGraphST m [Map ClassOrVar ClassOrVar]
go (ClassOrVar -> ClassOrVar -> [Atom] -> [Atom]
updateVar ClassOrVar
x ClassOrVar
classId [Atom]
atoms) [ClassOrVar]
vars
pure (concat maps)
domainX :: Monad m => ClassOrVar -> Query -> ClassOrVar -> EGraphST m [ClassOrVar]
domainX :: forall (m :: * -> *).
Monad m =>
ClassOrVar -> [Atom] -> ClassOrVar -> EGraphST m [ClassOrVar]
domainX ClassOrVar
var [Atom]
atoms ClassOrVar
root = do
let atoms' :: [Atom]
atoms' = (Atom -> Bool) -> [Atom] -> [Atom]
forall a. (a -> Bool) -> [a] -> [a]
filter (ClassOrVar -> Atom -> Bool
elemOfAtom ClassOrVar
var) [Atom]
atoms
(Int -> ClassOrVar) -> [Int] -> [ClassOrVar]
forall a b. (a -> b) -> [a] -> [b]
map Int -> ClassOrVar
forall a b. a -> Either a b
Left ([Int] -> [ClassOrVar])
-> StateT EGraph m [Int] -> StateT EGraph m [ClassOrVar]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ClassOrVar -> [Atom] -> ClassOrVar -> StateT EGraph m [Int]
forall (m :: * -> *).
Monad m =>
ClassOrVar -> [Atom] -> ClassOrVar -> EGraphST m [Int]
intersectAtoms ClassOrVar
var [Atom]
atoms' ClassOrVar
root
intersectAtoms :: Monad m => ClassOrVar -> Query -> ClassOrVar -> EGraphST m [EClassId]
intersectAtoms :: forall (m :: * -> *).
Monad m =>
ClassOrVar -> [Atom] -> ClassOrVar -> EGraphST m [Int]
intersectAtoms ClassOrVar
_ [] ClassOrVar
root = [Int] -> StateT EGraph m [Int]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
intersectAtoms ClassOrVar
var (Atom
a:[Atom]
atoms) ClassOrVar
root = do
a0 <- Atom -> StateT EGraph m (HashSet Int)
forall {m :: * -> *}.
MonadState EGraph m =>
Atom -> m (HashSet Int)
go Atom
a
Set.toList <$> (foldM (\HashSet Int
acc Atom
atom -> HashSet Int -> HashSet Int -> HashSet Int
forall a. Eq a => HashSet a -> HashSet a -> HashSet a
Set.intersection HashSet Int
acc (HashSet Int -> HashSet Int)
-> StateT EGraph m (HashSet Int) -> StateT EGraph m (HashSet Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Atom -> StateT EGraph m (HashSet Int)
forall {m :: * -> *}.
MonadState EGraph m =>
Atom -> m (HashSet Int)
go Atom
atom) a0 atoms)
where
toCanon :: HashSet Int -> StateT EGraph m (HashSet Int)
toCanon HashSet Int
x = if ClassOrVar
varClassOrVar -> ClassOrVar -> Bool
forall a. Eq a => a -> a -> Bool
==ClassOrVar
root
then HashSet Int -> StateT EGraph m (HashSet Int)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure HashSet Int
x
else [Int] -> HashSet Int
forall a. (Eq a, Hashable a) => [a] -> HashSet a
Set.fromList ([Int] -> HashSet Int)
-> StateT EGraph m [Int] -> StateT EGraph m (HashSet Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Int -> StateT EGraph m Int) -> [Int] -> StateT EGraph m [Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical ([Int] -> StateT EGraph m [Int]) -> [Int] -> StateT EGraph m [Int]
forall a b. (a -> b) -> a -> b
$ HashSet Int -> [Int]
forall a. HashSet a -> [a]
Set.toList HashSet Int
x)
go :: Atom -> m (HashSet Int)
go (Atom ClassOrVar
r SRTree ClassOrVar
t) = do
let op :: SRTree ()
op = SRTree ClassOrVar -> SRTree ()
forall a. SRTree a -> SRTree ()
getOperator SRTree ClassOrVar
t
mTrie <- (EGraph -> Maybe IntTrie) -> m (Maybe IntTrie)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((DB -> SRTree () -> Maybe IntTrie
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? SRTree ()
op) (DB -> Maybe IntTrie) -> (EGraph -> DB) -> EGraph -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
case mTrie of
Just IntTrie
trie -> HashSet Int -> m (HashSet Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HashSet Int -> Maybe (HashSet Int) -> HashSet Int
forall a. a -> Maybe a -> a
fromMaybe HashSet Int
forall a. HashSet a
Set.empty (Maybe (HashSet Int) -> HashSet Int)
-> Maybe (HashSet Int) -> HashSet Int
forall a b. (a -> b) -> a -> b
$ ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var Map ClassOrVar Int
forall k a. Map k a
Map.empty IntTrie
trie (ClassOrVar
rClassOrVar -> [ClassOrVar] -> [ClassOrVar]
forall a. a -> [a] -> [a]
:SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
t))
Maybe IntTrie
Nothing -> HashSet Int -> m (HashSet Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure HashSet Int
forall a. HashSet a
Set.empty
intersectTries :: ClassOrVar -> Map ClassOrVar EClassId -> IntTrie -> [ClassOrVar] -> Maybe (HashSet EClassId)
intersectTries :: ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var Map ClassOrVar Int
xs IntTrie
trie [] = HashSet Int -> Maybe (HashSet Int)
forall a. a -> Maybe a
Just HashSet Int
forall a. HashSet a
Set.empty
intersectTries ClassOrVar
var Map ClassOrVar Int
xs IntTrie
trie (ClassOrVar
i:[ClassOrVar]
ids) =
case ClassOrVar
i of
Left Int
x -> if Int
x Int -> HashSet Int -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`Set.member` IntTrie -> HashSet Int
_keys IntTrie
trie
then ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var Map ClassOrVar Int
xs (IntTrie -> IntMap IntTrie
_trie IntTrie
trie IntMap IntTrie -> Int -> IntTrie
forall a. IntMap a -> Int -> a
IntMap.! Int
x) [ClassOrVar]
ids
else Maybe (HashSet Int)
forall a. Maybe a
Nothing
Right Int
x -> if ClassOrVar
i ClassOrVar -> Map ClassOrVar Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map ClassOrVar Int
xs
then if Map ClassOrVar Int
xs Map ClassOrVar Int -> ClassOrVar -> Int
forall k a. Ord k => Map k a -> k -> a
Map.! ClassOrVar
i Int -> HashSet Int -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`Set.member` IntTrie -> HashSet Int
_keys IntTrie
trie
then ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var Map ClassOrVar Int
xs (IntTrie -> IntMap IntTrie
_trie IntTrie
trie IntMap IntTrie -> Int -> IntTrie
forall a. IntMap a -> Int -> a
IntMap.! (Map ClassOrVar Int
xs Map ClassOrVar Int -> ClassOrVar -> Int
forall k a. Ord k => Map k a -> k -> a
Map.! ClassOrVar
i)) [ClassOrVar]
ids
else Maybe (HashSet Int)
forall a. Maybe a
Nothing
else if Int -> ClassOrVar
forall a b. b -> Either a b
Right Int
x ClassOrVar -> ClassOrVar -> Bool
forall a. Eq a => a -> a -> Bool
== ClassOrVar
var
then if (ClassOrVar -> Bool) -> [ClassOrVar] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> ClassOrVar -> Bool
isDiffFrom Int
x) [ClassOrVar]
ids
then HashSet Int -> Maybe (HashSet Int)
forall a. a -> Maybe a
Just (HashSet Int -> Maybe (HashSet Int))
-> HashSet Int -> Maybe (HashSet Int)
forall a b. (a -> b) -> a -> b
$ IntTrie -> HashSet Int
_keys IntTrie
trie
else HashSet Int -> Maybe (HashSet Int)
forall a. a -> Maybe a
Just (HashSet Int -> Maybe (HashSet Int))
-> HashSet Int -> Maybe (HashSet Int)
forall a b. (a -> b) -> a -> b
$ (Int -> IntTrie -> HashSet Int -> HashSet Int)
-> HashSet Int -> IntMap IntTrie -> HashSet Int
forall a b. (Int -> a -> b -> b) -> b -> IntMap a -> b
IntMap.foldrWithKey (\Int
k IntTrie
v HashSet Int
acc ->
case ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var (ClassOrVar -> Int -> Map ClassOrVar Int -> Map ClassOrVar Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ClassOrVar
i Int
k Map ClassOrVar Int
xs) IntTrie
v [ClassOrVar]
ids of
Maybe (HashSet Int)
Nothing -> HashSet Int
acc
Maybe (HashSet Int)
_ -> Int -> HashSet Int -> HashSet Int
forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
Set.insert Int
k HashSet Int
acc) HashSet Int
forall a. HashSet a
Set.empty (IntTrie -> IntMap IntTrie
_trie IntTrie
trie)
else HashSet Int -> Maybe (HashSet Int)
forall a. a -> Maybe a
Just (HashSet Int -> Maybe (HashSet Int))
-> HashSet Int -> Maybe (HashSet Int)
forall a b. (a -> b) -> a -> b
$ (Int -> IntTrie -> HashSet Int -> HashSet Int)
-> HashSet Int -> IntMap IntTrie -> HashSet Int
forall a b. (Int -> a -> b -> b) -> b -> IntMap a -> b
IntMap.foldrWithKey (\Int
k IntTrie
v HashSet Int
acc ->
case ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var (ClassOrVar -> Int -> Map ClassOrVar Int -> Map ClassOrVar Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ClassOrVar
i Int
k Map ClassOrVar Int
xs) IntTrie
v [ClassOrVar]
ids of
Maybe (HashSet Int)
Nothing -> HashSet Int
acc
Just HashSet Int
s -> HashSet Int -> HashSet Int -> HashSet Int
forall a. Eq a => HashSet a -> HashSet a -> HashSet a
Set.union HashSet Int
acc HashSet Int
s
) HashSet Int
forall a. HashSet a
Set.empty (IntTrie -> IntMap IntTrie
_trie IntTrie
trie)
updateVar :: ClassOrVar -> ClassOrVar -> Query -> Query
updateVar :: ClassOrVar -> ClassOrVar -> [Atom] -> [Atom]
updateVar ClassOrVar
var ClassOrVar
x = (Atom -> Atom) -> [Atom] -> [Atom]
forall a b. (a -> b) -> [a] -> [b]
map Atom -> Atom
replace
where
replace :: Atom -> Atom
replace (Atom ClassOrVar
r SRTree ClassOrVar
t) = let children :: [ClassOrVar]
children = [if ClassOrVar
c ClassOrVar -> ClassOrVar -> Bool
forall a. Eq a => a -> a -> Bool
== ClassOrVar
var then ClassOrVar
x else ClassOrVar
c | ClassOrVar
c <- SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
t]
t' :: SRTree ClassOrVar
t' = [ClassOrVar] -> SRTree ClassOrVar -> SRTree ClassOrVar
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [ClassOrVar]
children SRTree ClassOrVar
t
in ClassOrVar -> SRTree ClassOrVar -> Atom
Atom (if ClassOrVar
r ClassOrVar -> ClassOrVar -> Bool
forall a. Eq a => a -> a -> Bool
== ClassOrVar
var then ClassOrVar
x else ClassOrVar
r) SRTree ClassOrVar
t'
isDiffFrom :: Int -> ClassOrVar -> Bool
isDiffFrom :: Int -> ClassOrVar -> Bool
isDiffFrom Int
x ClassOrVar
y = case ClassOrVar
y of
Left Int
_ -> Bool
False
Right Int
z -> Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
z
elemOfAtom :: ClassOrVar -> Atom -> Bool
elemOfAtom :: ClassOrVar -> Atom -> Bool
elemOfAtom ClassOrVar
v (Atom ClassOrVar
root SRTree ClassOrVar
tree) =
case ClassOrVar
root of
Left Int
_ -> ClassOrVar
v ClassOrVar -> [ClassOrVar] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
tree
Right Int
x -> Int -> ClassOrVar
forall a b. b -> Either a b
Right Int
x ClassOrVar -> ClassOrVar -> Bool
forall a. Eq a => a -> a -> Bool
== ClassOrVar
v Bool -> Bool -> Bool
|| ClassOrVar
v ClassOrVar -> [ClassOrVar] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
tree
orderedVars :: Query -> [ClassOrVar]
orderedVars :: [Atom] -> [ClassOrVar]
orderedVars [Atom]
atoms = (ClassOrVar -> ClassOrVar -> Ordering)
-> [ClassOrVar] -> [ClassOrVar]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ((ClassOrVar -> Int) -> ClassOrVar -> ClassOrVar -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing ClassOrVar -> Int
varCost) ([ClassOrVar] -> [ClassOrVar]) -> [ClassOrVar] -> [ClassOrVar]
forall a b. (a -> b) -> a -> b
$ [ClassOrVar] -> [ClassOrVar]
forall a. Eq a => [a] -> [a]
nub [ClassOrVar
a | Atom
atom <- [Atom]
atoms, ClassOrVar
a <- Atom -> [ClassOrVar]
getIdsFrom Atom
atom, ClassOrVar -> Bool
forall {a} {b}. Either a b -> Bool
isRight ClassOrVar
a]
where
getIdsFrom :: Atom -> [ClassOrVar]
getIdsFrom (Atom ClassOrVar
r SRTree ClassOrVar
t) = ClassOrVar
r ClassOrVar -> [ClassOrVar] -> [ClassOrVar]
forall a. a -> [a] -> [a]
: SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
t
isRight :: Either a b -> Bool
isRight (Right b
_) = Bool
True
isRight Either a b
_ = Bool
False
varCost :: ClassOrVar -> Int
varCost :: ClassOrVar -> Int
varCost ClassOrVar
var = (Atom -> Int -> Int) -> Int -> [Atom] -> Int
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Atom
a Int
acc -> if ClassOrVar -> Atom -> Bool
elemOfAtom ClassOrVar
var Atom
a then Int
acc Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
100 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Atom -> Int
atomLen Atom
a else Int
acc) Int
0 [Atom]
atoms
atomLen :: Atom -> Int
atomLen (Atom ClassOrVar
_ SRTree ClassOrVar
t) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [ClassOrVar] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
t)