{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.EqSat.EqSatDB
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :
--
-- Pattern matching and rule application functions
-- Heavily based on hegg (https://github.com/alt-romes/hegg by alt-romes)
--
-----------------------------------------------------------------------------
module Algorithm.EqSat.DB where

import Algorithm.EqSat.Egraph
import Control.Lens ( over )
import Control.Monad (when, foldM, forM)
import Control.Monad.State
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.List (intercalate, nub, sortBy)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Ord (comparing)
import Data.SRTree
--import Data.Set (Set)
import Data.HashSet (HashSet)
import qualified Data.HashSet as Set
import Data.String (IsString (..))

import Debug.Trace

-- A Pattern is either a fixed-point of a tree or an
-- index to a pattern variable. The pattern variable matches anything. 
data Pattern = Fixed (SRTree Pattern) | VarPat Char deriving Int -> Pattern -> ShowS
[Pattern] -> ShowS
Pattern -> String
(Int -> Pattern -> ShowS)
-> (Pattern -> String) -> ([Pattern] -> ShowS) -> Show Pattern
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Pattern -> ShowS
showsPrec :: Int -> Pattern -> ShowS
$cshow :: Pattern -> String
show :: Pattern -> String
$cshowList :: [Pattern] -> ShowS
showList :: [Pattern] -> ShowS
Show -- Fixed structure of a pattern or a variable that matches anything

-- The instance for `IsString` for a `Pattern` is 
-- valid only for a single letter char from a-zA-Z. 
-- The patterns can be written as "x" + "y", for example,
-- and it will translate to `Fixed (Bin Add (VarPat 120) (VarPat 121)`.
instance IsString Pattern where
  fromString :: String -> Pattern
fromString []     = String -> Pattern
forall a. HasCallStack => String -> a
error String
"empty string in VarPat"
  fromString [Char
c] | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
65 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
122 = Char -> Pattern
VarPat Char
c where n :: Int
n = Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
c
  fromString String
s      = String -> Pattern
forall a. HasCallStack => String -> a
error (String -> Pattern) -> String -> Pattern
forall a b. (a -> b) -> a -> b
$ String
"invalid string in VarPat: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
s

-- A rule is either a directional rule where pat1 can be replaced by pat2, a bidirectional rule 
-- where pat1 can be replaced or replace pat2, or a pattern with a conditional function 
-- describing when to apply the rule 
data Rule = Pattern :=> Pattern | Pattern :==: Pattern | Rule :| Condition

infix  3 :=>
infix  3 :==:
infixl 2 :|

instance Show Rule where
  show :: Rule -> String
show (Pattern
a :=> Pattern
b) = Pattern -> String
forall a. Show a => a -> String
show Pattern
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" => " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Pattern -> String
forall a. Show a => a -> String
show Pattern
b
  show (Pattern
a :==: Pattern
b) = Pattern -> String
forall a. Show a => a -> String
show Pattern
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" == " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Pattern -> String
forall a. Show a => a -> String
show Pattern
b
  show (Rule
a :| Condition
b) = Rule -> String
forall a. Show a => a -> String
show Rule
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" | <cond>"

-- A Query is a list of Atoms 
type Query = [Atom]

-- A `Condition` is a function that takes a substution map,
-- an e-graph and returns whether the pattern attends the condition.
type Condition = Map ClassOrVar ClassOrVar -> EGraph -> Bool

-- An Atom is composed of either an e-class id or pattern variable id
-- and the tree that generated that pattern. Left is e-class id and Right is a VarPat.
type ClassOrVar = Either EClassId Int
data Atom = Atom ClassOrVar (SRTree ClassOrVar) deriving Int -> Atom -> ShowS
[Atom] -> ShowS
Atom -> String
(Int -> Atom -> ShowS)
-> (Atom -> String) -> ([Atom] -> ShowS) -> Show Atom
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Atom -> ShowS
showsPrec :: Int -> Atom -> ShowS
$cshow :: Atom -> String
show :: Atom -> String
$cshowList :: [Atom] -> ShowS
showList :: [Atom] -> ShowS
Show

unFixPat :: Pattern -> SRTree Pattern
unFixPat :: Pattern -> SRTree Pattern
unFixPat (Fixed SRTree Pattern
p) = SRTree Pattern
p
{-# INLINE unFixPat #-}


instance Num Pattern where
  Pattern
l + :: Pattern -> Pattern -> Pattern
+ Pattern
r = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Op -> Pattern -> Pattern -> SRTree Pattern
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Pattern
l Pattern
r
  {-# INLINE (+) #-}
  Pattern
l - :: Pattern -> Pattern -> Pattern
- Pattern
r = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Op -> Pattern -> Pattern -> SRTree Pattern
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Pattern
l Pattern
r
  {-# INLINE (-) #-}
  Pattern
l * :: Pattern -> Pattern -> Pattern
* Pattern
r = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Op -> Pattern -> Pattern -> SRTree Pattern
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Pattern
l Pattern
r
  {-# INLINE (*) #-}

  abs :: Pattern -> Pattern
abs = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Abs
  {-# INLINE abs #-}

  negate :: Pattern -> Pattern
negate Pattern
t = SRTree Pattern -> Pattern
Fixed (Double -> SRTree Pattern
forall val. Double -> SRTree val
Const (-Double
1)) Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
t
  {-# INLINE negate #-}

  signum :: Pattern -> Pattern
signum Pattern
t = case Pattern
t of
               Fixed (Const Double
x) -> SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Double -> SRTree Pattern) -> Double -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree Pattern
forall val. Double -> SRTree val
Const (Double -> Pattern) -> Double -> Pattern
forall a b. (a -> b) -> a -> b
$ Double -> Double
forall a. Num a => a -> a
signum Double
x
               Pattern
_               -> SRTree Pattern -> Pattern
Fixed (Double -> SRTree Pattern
forall val. Double -> SRTree val
Const Double
0)
  fromInteger :: Integer -> Pattern
fromInteger Integer
x = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Double -> SRTree Pattern
forall val. Double -> SRTree val
Const (Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
x)
  {-# INLINE fromInteger #-}

instance Fractional Pattern where
  Pattern
l / :: Pattern -> Pattern -> Pattern
/ Pattern
r = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Op -> Pattern -> Pattern -> SRTree Pattern
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Pattern
l Pattern
r
  {-# INLINE (/) #-}

  fromRational :: Rational -> Pattern
fromRational = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Rational -> SRTree Pattern) -> Rational -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree Pattern
forall val. Double -> SRTree val
Const (Double -> SRTree Pattern)
-> (Rational -> Double) -> Rational -> SRTree Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> Double
forall a. Fractional a => Rational -> a
fromRational
  {-# INLINE fromRational #-}

instance Floating Pattern where
  pi :: Pattern
pi      = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Double -> SRTree Pattern
forall val. Double -> SRTree val
Const  Double
forall a. Floating a => a
pi
  {-# INLINE pi #-}
  exp :: Pattern -> Pattern
exp     = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Exp
  {-# INLINE exp #-}
  log :: Pattern -> Pattern
log     = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Log
  {-# INLINE log #-}
  sqrt :: Pattern -> Pattern
sqrt    = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Sqrt
  {-# INLINE sqrt #-}
  sin :: Pattern -> Pattern
sin     = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Sin
  {-# INLINE sin #-}
  cos :: Pattern -> Pattern
cos     = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Cos
  {-# INLINE cos #-}
  tan :: Pattern -> Pattern
tan     = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Tan
  {-# INLINE tan #-}
  asin :: Pattern -> Pattern
asin    = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ASin
  {-# INLINE asin #-}
  acos :: Pattern -> Pattern
acos    = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ACos
  {-# INLINE acos #-}
  atan :: Pattern -> Pattern
atan    = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ATan
  {-# INLINE atan #-}
  sinh :: Pattern -> Pattern
sinh    = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Sinh
  {-# INLINE sinh #-}
  cosh :: Pattern -> Pattern
cosh    = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Cosh
  {-# INLINE cosh #-}
  tanh :: Pattern -> Pattern
tanh    = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Tanh
  {-# INLINE tanh #-}
  asinh :: Pattern -> Pattern
asinh   = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ASinh
  {-# INLINE asinh #-}
  acosh :: Pattern -> Pattern
acosh   = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ACosh
  {-# INLINE acosh #-}
  atanh :: Pattern -> Pattern
atanh   = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern)
-> (Pattern -> SRTree Pattern) -> Pattern -> Pattern
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
ATanh
  {-# INLINE atanh #-}

  Pattern
l ** :: Pattern -> Pattern -> Pattern
** Pattern
r  = SRTree Pattern -> Pattern
Fixed (SRTree Pattern -> Pattern) -> SRTree Pattern -> Pattern
forall a b. (a -> b) -> a -> b
$ Op -> Pattern -> Pattern -> SRTree Pattern
forall val. Op -> val -> val -> SRTree val
Bin Op
Power Pattern
l Pattern
r
  {-# INLINE (**) #-}

  logBase :: Pattern -> Pattern -> Pattern
logBase Pattern
l Pattern
r = Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
l Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
r
  {-# INLINE logBase #-}

target :: Rule -> Pattern
target :: Rule -> Pattern
target (Rule
r :| Condition
_)   = Rule -> Pattern
target Rule
r
target (Pattern
_ :=> Pattern
t)  = Pattern
t
target (Pattern
_ :==: Pattern
t) = Pattern
t

source :: Rule -> Pattern
source :: Rule -> Pattern
source (Rule
r :| Condition
_) = Rule -> Pattern
source Rule
r
source (Pattern
s :=> Pattern
_)  = Pattern
s
source (Pattern
s :==: Pattern
_) = Pattern
s

getConditions :: Rule -> [Condition]
getConditions :: Rule -> [Condition]
getConditions (Rule
r :| Condition
c) = Condition
c Condition -> [Condition] -> [Condition]
forall a. a -> [a] -> [a]
: Rule -> [Condition]
getConditions Rule
r
getConditions Rule
_ = []


cleanDB :: Monad m => EGraphST m ()
cleanDB :: forall (m :: * -> *). Monad m => EGraphST m ()
cleanDB = (EGraph -> EGraph) -> StateT EGraph m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' ((EGraph -> EGraph) -> StateT EGraph m ())
-> (EGraph -> EGraph) -> StateT EGraph m ()
forall a b. (a -> b) -> a -> b
$ ASetter EGraph EGraph DB DB -> (DB -> DB) -> EGraph -> EGraph
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph
Lens' EGraph EGraphDB
eDB((EGraphDB -> Identity EGraphDB) -> EGraph -> Identity EGraph)
-> ((DB -> Identity DB) -> EGraphDB -> Identity EGraphDB)
-> ASetter EGraph EGraph DB DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DB -> Identity DB) -> EGraphDB -> Identity EGraphDB
Lens' EGraphDB DB
patDB) (DB -> DB -> DB
forall a b. a -> b -> a
const DB
forall k a. Map k a
Map.empty)

-- | Returns the substitution rules
-- for every match of the pattern `source` inside the e-graph.
match :: Monad m => Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match :: forall (m :: * -> *).
Monad m =>
Pattern -> EGraphST m [(Map ClassOrVar ClassOrVar, ClassOrVar)]
match Pattern
src = do
  let ([Atom]
q, ClassOrVar
root) = Pattern -> ([Atom], ClassOrVar)
compileToQuery Pattern
src     -- compile the source of the pattern into a query
  substs <- [Atom] -> ClassOrVar -> EGraphST m [Map ClassOrVar ClassOrVar]
forall (m :: * -> *).
Monad m =>
[Atom] -> ClassOrVar -> EGraphST m [Map ClassOrVar ClassOrVar]
genericJoin [Atom]
q ClassOrVar
root               -- find the substituion rules for this pattern
  pure [(s, s Map.! root) | s <- substs, Map.size s > 0]

-- | Returns a Query (list of atoms) of a pattern
compileToQuery :: Pattern -> (Query, ClassOrVar)
compileToQuery :: Pattern -> ([Atom], ClassOrVar)
compileToQuery Pattern
pat = State Int ([Atom], ClassOrVar) -> Int -> ([Atom], ClassOrVar)
forall s a. State s a -> s -> a
evalState (Pattern -> State Int ([Atom], ClassOrVar)
processPat Pattern
pat) Int
256 -- returns (atoms, root)
  where
      -- creates the atoms of a pattern
      processPat :: Pattern -> State Int (Query, ClassOrVar)
      processPat :: Pattern -> State Int ([Atom], ClassOrVar)
processPat (VarPat Char
x)  = ([Atom], ClassOrVar) -> State Int ([Atom], ClassOrVar)
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], Int -> ClassOrVar
forall a b. b -> Either a b
Right (Int -> ClassOrVar) -> Int -> ClassOrVar
forall a b. (a -> b) -> a -> b
$ Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
x)
      processPat (Fixed SRTree Pattern
pat) = do
          -- get the next available var id and add as root
          v <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get
          let root = Int -> Either a Int
forall a b. b -> Either a b
Right Int
v
          -- updates the next available id
          modify (+1)
          -- recursivelly process the children of the pattern
          patChilds <- mapM processPat (getElems pat)
          -- create an atom composed of the
          -- root and the tree with the children
          -- replaced by the childs roots
          -- add the child atoms to the list
          let atoms = (([Atom], ClassOrVar) -> [Atom])
-> [([Atom], ClassOrVar)] -> [Atom]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ([Atom], ClassOrVar) -> [Atom]
forall a b. (a, b) -> a
fst [([Atom], ClassOrVar)]
patChilds
              roots = (([Atom], ClassOrVar) -> ClassOrVar)
-> [([Atom], ClassOrVar)] -> [ClassOrVar]
forall a b. (a -> b) -> [a] -> [b]
map ([Atom], ClassOrVar) -> ClassOrVar
forall a b. (a, b) -> b
snd [([Atom], ClassOrVar)]
patChilds
              atom  = ClassOrVar -> SRTree ClassOrVar -> Atom
Atom ClassOrVar
forall {a}. Either a Int
root ([ClassOrVar] -> SRTree Pattern -> SRTree ClassOrVar
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [ClassOrVar]
roots SRTree Pattern
pat)
              atoms' = Atom
atomAtom -> [Atom] -> [Atom]
forall a. a -> [a] -> [a]
:[Atom]
atoms
          pure (atoms', root)

-- get the value from the Either Int Int
getInt :: ClassOrVar -> Int
getInt :: ClassOrVar -> Int
getInt (Left Int
a)  = Int
a
getInt (Right Int
a) = Int
a

-- | returns the list of the children values
getElems :: SRTree a -> [a]
getElems :: forall a. SRTree a -> [a]
getElems (Bin Op
_ a
l a
r) = [a
l,a
r]
getElems (Uni Function
_ a
t)   = [a
t]
getElems SRTree a
_           = []

-- | Creates the substituion map for
-- the pattern variables for each one of the
-- matched subgraph
genericJoin :: Monad m => Query -> ClassOrVar -> EGraphST m [Map ClassOrVar ClassOrVar]
genericJoin :: forall (m :: * -> *).
Monad m =>
[Atom] -> ClassOrVar -> EGraphST m [Map ClassOrVar ClassOrVar]
genericJoin [Atom]
atoms ClassOrVar
root = do
  let vars :: [ClassOrVar]
vars = [Atom] -> [ClassOrVar]
orderedVars [Atom]
atoms -- order the vars, starting with the most frequently occuring
  [Atom] -> [ClassOrVar] -> EGraphST m [Map ClassOrVar ClassOrVar]
forall (m :: * -> *).
Monad m =>
[Atom] -> [ClassOrVar] -> EGraphST m [Map ClassOrVar ClassOrVar]
go [Atom]
atoms [ClassOrVar]
vars -- TODO: investigate why we need nub
  where
    -- for each variable
    --   for each possible e-class id for that variable
    --      replace the var id with this e-class id, and
    --      recurse to find the possible matches for the next atom
    go :: Monad m => Query -> [ClassOrVar] -> EGraphST m [Map ClassOrVar ClassOrVar]
    go :: forall (m :: * -> *).
Monad m =>
[Atom] -> [ClassOrVar] -> EGraphST m [Map ClassOrVar ClassOrVar]
go [Atom]
atoms [] = [Map ClassOrVar ClassOrVar]
-> StateT EGraph m [Map ClassOrVar ClassOrVar]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Map ClassOrVar ClassOrVar
forall k a. Map k a
Map.empty] -- | _ <- atoms]
    go [Atom]
atoms (ClassOrVar
x:[ClassOrVar]
vars) = do cIds1 <- ClassOrVar -> [Atom] -> ClassOrVar -> EGraphST m [ClassOrVar]
forall (m :: * -> *).
Monad m =>
ClassOrVar -> [Atom] -> ClassOrVar -> EGraphST m [ClassOrVar]
domainX ClassOrVar
x [Atom]
atoms ClassOrVar
root
                           maps <- forM cIds1 $ \ClassOrVar
classId -> do
                             (Map ClassOrVar ClassOrVar -> Map ClassOrVar ClassOrVar)
-> [Map ClassOrVar ClassOrVar] -> [Map ClassOrVar ClassOrVar]
forall a b. (a -> b) -> [a] -> [b]
map (ClassOrVar
-> ClassOrVar
-> Map ClassOrVar ClassOrVar
-> Map ClassOrVar ClassOrVar
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ClassOrVar
x ClassOrVar
classId) ([Map ClassOrVar ClassOrVar] -> [Map ClassOrVar ClassOrVar])
-> StateT EGraph m [Map ClassOrVar ClassOrVar]
-> StateT EGraph m [Map ClassOrVar ClassOrVar]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Atom]
-> [ClassOrVar] -> StateT EGraph m [Map ClassOrVar ClassOrVar]
forall (m :: * -> *).
Monad m =>
[Atom] -> [ClassOrVar] -> EGraphST m [Map ClassOrVar ClassOrVar]
go (ClassOrVar -> ClassOrVar -> [Atom] -> [Atom]
updateVar ClassOrVar
x ClassOrVar
classId [Atom]
atoms) [ClassOrVar]
vars
                           pure (concat maps)


     -- [Map.insert x classId y | classId <- domainX db x atoms
     --                                           , y <- go (updateVar x classId atoms) vars]


-- | returns the e-class id for a certain variable that
-- matches the pattern described by the atoms
domainX :: Monad m => ClassOrVar -> Query -> ClassOrVar -> EGraphST m [ClassOrVar]
domainX :: forall (m :: * -> *).
Monad m =>
ClassOrVar -> [Atom] -> ClassOrVar -> EGraphST m [ClassOrVar]
domainX ClassOrVar
var [Atom]
atoms ClassOrVar
root = do
  let atoms' :: [Atom]
atoms' = (Atom -> Bool) -> [Atom] -> [Atom]
forall a. (a -> Bool) -> [a] -> [a]
filter (ClassOrVar -> Atom -> Bool
elemOfAtom ClassOrVar
var) [Atom]
atoms -- :: [ClassOrVar]  -- look only in the atoms with this var
  (Int -> ClassOrVar) -> [Int] -> [ClassOrVar]
forall a b. (a -> b) -> [a] -> [b]
map Int -> ClassOrVar
forall a b. a -> Either a b
Left ([Int] -> [ClassOrVar])
-> StateT EGraph m [Int] -> StateT EGraph m [ClassOrVar]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ClassOrVar -> [Atom] -> ClassOrVar -> StateT EGraph m [Int]
forall (m :: * -> *).
Monad m =>
ClassOrVar -> [Atom] -> ClassOrVar -> EGraphST m [Int]
intersectAtoms ClassOrVar
var [Atom]
atoms' ClassOrVar
root -- find the intersection of possible keys by each atom

  --let ss = (map Left
  --                                $ intersectAtoms var db
  --                                $
  --                     in ss

-- | returns all e-class id that can matches this sequence of atoms
intersectAtoms :: Monad m => ClassOrVar -> Query -> ClassOrVar -> EGraphST m [EClassId]
intersectAtoms :: forall (m :: * -> *).
Monad m =>
ClassOrVar -> [Atom] -> ClassOrVar -> EGraphST m [Int]
intersectAtoms ClassOrVar
_ [] ClassOrVar
root = [Int] -> StateT EGraph m [Int]
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
intersectAtoms ClassOrVar
var (Atom
a:[Atom]
atoms) ClassOrVar
root = do
  a0 <- Atom -> StateT EGraph m (HashSet Int)
forall {m :: * -> *}.
MonadState EGraph m =>
Atom -> m (HashSet Int)
go Atom
a
  Set.toList <$> (foldM (\HashSet Int
acc Atom
atom -> HashSet Int -> HashSet Int -> HashSet Int
forall a. Eq a => HashSet a -> HashSet a -> HashSet a
Set.intersection HashSet Int
acc (HashSet Int -> HashSet Int)
-> StateT EGraph m (HashSet Int) -> StateT EGraph m (HashSet Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Atom -> StateT EGraph m (HashSet Int)
forall {m :: * -> *}.
MonadState EGraph m =>
Atom -> m (HashSet Int)
go Atom
atom) a0 atoms)
  where
      -- canonize everything except the root for consistency
      -- doing this here prevents traversing the map again
      toCanon :: HashSet Int -> StateT EGraph m (HashSet Int)
toCanon HashSet Int
x = if ClassOrVar
varClassOrVar -> ClassOrVar -> Bool
forall a. Eq a => a -> a -> Bool
==ClassOrVar
root
                     then HashSet Int -> StateT EGraph m (HashSet Int)
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure HashSet Int
x
                     else [Int] -> HashSet Int
forall a. (Eq a, Hashable a) => [a] -> HashSet a
Set.fromList ([Int] -> HashSet Int)
-> StateT EGraph m [Int] -> StateT EGraph m (HashSet Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Int -> StateT EGraph m Int) -> [Int] -> StateT EGraph m [Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Int -> StateT EGraph m Int
forall (m :: * -> *). Monad m => Int -> EGraphST m Int
canonical ([Int] -> StateT EGraph m [Int]) -> [Int] -> StateT EGraph m [Int]
forall a b. (a -> b) -> a -> b
$ HashSet Int -> [Int]
forall a. HashSet a -> [a]
Set.toList HashSet Int
x)

      go :: Atom -> m (HashSet Int)
go (Atom ClassOrVar
r SRTree ClassOrVar
t) = do
        let op :: SRTree ()
op = SRTree ClassOrVar -> SRTree ()
forall a. SRTree a -> SRTree ()
getOperator SRTree ClassOrVar
t
        mTrie <- (EGraph -> Maybe IntTrie) -> m (Maybe IntTrie)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((DB -> SRTree () -> Maybe IntTrie
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? SRTree ()
op) (DB -> Maybe IntTrie) -> (EGraph -> DB) -> EGraph -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraphDB -> DB
_patDB (EGraphDB -> DB) -> (EGraph -> EGraphDB) -> EGraph -> DB
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> EGraphDB
_eDB)
        case mTrie of
          Just IntTrie
trie -> HashSet Int -> m (HashSet Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HashSet Int -> Maybe (HashSet Int) -> HashSet Int
forall a. a -> Maybe a -> a
fromMaybe HashSet Int
forall a. HashSet a
Set.empty (Maybe (HashSet Int) -> HashSet Int)
-> Maybe (HashSet Int) -> HashSet Int
forall a b. (a -> b) -> a -> b
$ ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var Map ClassOrVar Int
forall k a. Map k a
Map.empty IntTrie
trie (ClassOrVar
rClassOrVar -> [ClassOrVar] -> [ClassOrVar]
forall a. a -> [a] -> [a]
:SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
t))
          Maybe IntTrie
Nothing   -> HashSet Int -> m (HashSet Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure HashSet Int
forall a. HashSet a
Set.empty
          -- TODO: remove FlexibleContexts
        --if op `Map.member` db -- if the e-graph contains the operator
                               -- try to find an intersection of the tries that matches each atom of the pattern
        --  then
        --  else pure Set.empty

-- | searches for the intersection of e-class ids that
-- matches each part of the query.
-- Returns Nothing if the intersection is empty.
--
-- var is the current variable being investigated
-- xs is the map of ids being investigated and their corresponding e-class id
-- trie is the current trie of the pattern
-- (i:ids) sequence of root : children of the atom to investigate
-- NOTE: it must be Maybe Set to differentiate between empty set and no answer
intersectTries :: ClassOrVar -> Map ClassOrVar EClassId -> IntTrie -> [ClassOrVar] -> Maybe (HashSet EClassId)
intersectTries :: ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var Map ClassOrVar Int
xs IntTrie
trie [] = HashSet Int -> Maybe (HashSet Int)
forall a. a -> Maybe a
Just HashSet Int
forall a. HashSet a
Set.empty
intersectTries ClassOrVar
var Map ClassOrVar Int
xs IntTrie
trie (ClassOrVar
i:[ClassOrVar]
ids) =
    case ClassOrVar
i of
      Left Int
x  -> if Int
x Int -> HashSet Int -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`Set.member` IntTrie -> HashSet Int
_keys IntTrie
trie
                    -- if the current investigated id is an e-class id and
                    -- it is one of the keys of the trie...
                    -- ..try to match the next id with the next trie
                    then ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var Map ClassOrVar Int
xs (IntTrie -> IntMap IntTrie
_trie IntTrie
trie IntMap IntTrie -> Int -> IntTrie
forall a. IntMap a -> Int -> a
IntMap.! Int
x) [ClassOrVar]
ids
                    else Maybe (HashSet Int)
forall a. Maybe a
Nothing
      Right Int
x -> if ClassOrVar
i ClassOrVar -> Map ClassOrVar Int -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member` Map ClassOrVar Int
xs
                    -- if it is a pattern variable under investigation
                    -- and the e-class id is part of the trie
                    then if Map ClassOrVar Int
xs Map ClassOrVar Int -> ClassOrVar -> Int
forall k a. Ord k => Map k a -> k -> a
Map.! ClassOrVar
i Int -> HashSet Int -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`Set.member` IntTrie -> HashSet Int
_keys IntTrie
trie
                            -- match the next id with the next trie
                            then ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var Map ClassOrVar Int
xs (IntTrie -> IntMap IntTrie
_trie IntTrie
trie IntMap IntTrie -> Int -> IntTrie
forall a. IntMap a -> Int -> a
IntMap.! (Map ClassOrVar Int
xs Map ClassOrVar Int -> ClassOrVar -> Int
forall k a. Ord k => Map k a -> k -> a
Map.! ClassOrVar
i)) [ClassOrVar]
ids
                            else Maybe (HashSet Int)
forall a. Maybe a
Nothing
                    else if Int -> ClassOrVar
forall a b. b -> Either a b
Right Int
x ClassOrVar -> ClassOrVar -> Bool
forall a. Eq a => a -> a -> Bool
== ClassOrVar
var
                            -- not under investigation and is the var of interest
                            then if (ClassOrVar -> Bool) -> [ClassOrVar] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> ClassOrVar -> Bool
isDiffFrom Int
x) [ClassOrVar]
ids
                                    -- if there are no other occurrence of x in the next vars,
                                    -- the keys of the trie are all possible candidates
                                    then HashSet Int -> Maybe (HashSet Int)
forall a. a -> Maybe a
Just (HashSet Int -> Maybe (HashSet Int))
-> HashSet Int -> Maybe (HashSet Int)
forall a b. (a -> b) -> a -> b
$ IntTrie -> HashSet Int
_keys IntTrie
trie
                                    -- oterwise, put i under investigation and check the next occurrences
                                    -- returning the intersection
                                    else HashSet Int -> Maybe (HashSet Int)
forall a. a -> Maybe a
Just (HashSet Int -> Maybe (HashSet Int))
-> HashSet Int -> Maybe (HashSet Int)
forall a b. (a -> b) -> a -> b
$ (Int -> IntTrie -> HashSet Int -> HashSet Int)
-> HashSet Int -> IntMap IntTrie -> HashSet Int
forall a b. (Int -> a -> b -> b) -> b -> IntMap a -> b
IntMap.foldrWithKey (\Int
k IntTrie
v HashSet Int
acc ->
                                                    case ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var (ClassOrVar -> Int -> Map ClassOrVar Int -> Map ClassOrVar Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ClassOrVar
i Int
k Map ClassOrVar Int
xs) IntTrie
v [ClassOrVar]
ids of
                                                      Maybe (HashSet Int)
Nothing -> HashSet Int
acc
                                                      Maybe (HashSet Int)
_       -> Int -> HashSet Int -> HashSet Int
forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
Set.insert Int
k HashSet Int
acc) HashSet Int
forall a. HashSet a
Set.empty (IntTrie -> IntMap IntTrie
_trie IntTrie
trie)
                            -- if it is not the var of interest
                            -- assign and test all possible e-class ids to it
                            -- and move forward
                            else HashSet Int -> Maybe (HashSet Int)
forall a. a -> Maybe a
Just (HashSet Int -> Maybe (HashSet Int))
-> HashSet Int -> Maybe (HashSet Int)
forall a b. (a -> b) -> a -> b
$ (Int -> IntTrie -> HashSet Int -> HashSet Int)
-> HashSet Int -> IntMap IntTrie -> HashSet Int
forall a b. (Int -> a -> b -> b) -> b -> IntMap a -> b
IntMap.foldrWithKey (\Int
k IntTrie
v HashSet Int
acc ->
                                                case ClassOrVar
-> Map ClassOrVar Int
-> IntTrie
-> [ClassOrVar]
-> Maybe (HashSet Int)
intersectTries ClassOrVar
var (ClassOrVar -> Int -> Map ClassOrVar Int -> Map ClassOrVar Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ClassOrVar
i Int
k Map ClassOrVar Int
xs) IntTrie
v [ClassOrVar]
ids of
                                                  Maybe (HashSet Int)
Nothing -> HashSet Int
acc
                                                  Just HashSet Int
s  -> HashSet Int -> HashSet Int -> HashSet Int
forall a. Eq a => HashSet a -> HashSet a -> HashSet a
Set.union HashSet Int
acc HashSet Int
s
                                                     ) HashSet Int
forall a. HashSet a
Set.empty (IntTrie -> IntMap IntTrie
_trie IntTrie
trie)

-- | updates all occurrence of var with the new id x
updateVar :: ClassOrVar -> ClassOrVar -> Query -> Query
updateVar :: ClassOrVar -> ClassOrVar -> [Atom] -> [Atom]
updateVar ClassOrVar
var ClassOrVar
x = (Atom -> Atom) -> [Atom] -> [Atom]
forall a b. (a -> b) -> [a] -> [b]
map Atom -> Atom
replace
  where
      replace :: Atom -> Atom
replace (Atom ClassOrVar
r SRTree ClassOrVar
t) = let children :: [ClassOrVar]
children = [if ClassOrVar
c ClassOrVar -> ClassOrVar -> Bool
forall a. Eq a => a -> a -> Bool
== ClassOrVar
var then ClassOrVar
x else ClassOrVar
c | ClassOrVar
c <- SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
t]
                               t' :: SRTree ClassOrVar
t'       =  [ClassOrVar] -> SRTree ClassOrVar -> SRTree ClassOrVar
forall a b. [a] -> SRTree b -> SRTree a
replaceChildren [ClassOrVar]
children SRTree ClassOrVar
t
                            in ClassOrVar -> SRTree ClassOrVar -> Atom
Atom (if ClassOrVar
r ClassOrVar -> ClassOrVar -> Bool
forall a. Eq a => a -> a -> Bool
== ClassOrVar
var then ClassOrVar
x else ClassOrVar
r) SRTree ClassOrVar
t'

-- | checks whether two ClassOrVar are different
-- only check if it is a pattern variable, else returns true
isDiffFrom :: Int -> ClassOrVar -> Bool
isDiffFrom :: Int -> ClassOrVar -> Bool
isDiffFrom Int
x ClassOrVar
y = case ClassOrVar
y of
                   Left Int
_ -> Bool
False
                   Right Int
z -> Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
z

-- | checks if v is an element of an atom
elemOfAtom :: ClassOrVar -> Atom -> Bool
elemOfAtom :: ClassOrVar -> Atom -> Bool
elemOfAtom ClassOrVar
v (Atom ClassOrVar
root SRTree ClassOrVar
tree) =
    case ClassOrVar
root of
      Left Int
_  -> ClassOrVar
v ClassOrVar -> [ClassOrVar] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
tree
      Right Int
x -> Int -> ClassOrVar
forall a b. b -> Either a b
Right Int
x ClassOrVar -> ClassOrVar -> Bool
forall a. Eq a => a -> a -> Bool
== ClassOrVar
v Bool -> Bool -> Bool
|| ClassOrVar
v ClassOrVar -> [ClassOrVar] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
tree

-- | sorts the variables in a query by the most frequently occurring
orderedVars :: Query -> [ClassOrVar]
orderedVars :: [Atom] -> [ClassOrVar]
orderedVars [Atom]
atoms = (ClassOrVar -> ClassOrVar -> Ordering)
-> [ClassOrVar] -> [ClassOrVar]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ((ClassOrVar -> Int) -> ClassOrVar -> ClassOrVar -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing ClassOrVar -> Int
varCost) ([ClassOrVar] -> [ClassOrVar]) -> [ClassOrVar] -> [ClassOrVar]
forall a b. (a -> b) -> a -> b
$ [ClassOrVar] -> [ClassOrVar]
forall a. Eq a => [a] -> [a]
nub [ClassOrVar
a | Atom
atom <- [Atom]
atoms, ClassOrVar
a <- Atom -> [ClassOrVar]
getIdsFrom Atom
atom, ClassOrVar -> Bool
forall {a} {b}. Either a b -> Bool
isRight ClassOrVar
a]
  where
    getIdsFrom :: Atom -> [ClassOrVar]
getIdsFrom (Atom ClassOrVar
r SRTree ClassOrVar
t) = ClassOrVar
r ClassOrVar -> [ClassOrVar] -> [ClassOrVar]
forall a. a -> [a] -> [a]
: SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
t
    isRight :: Either a b -> Bool
isRight (Right b
_) = Bool
True
    isRight Either a b
_ = Bool
False

    varCost :: ClassOrVar -> Int
    varCost :: ClassOrVar -> Int
varCost ClassOrVar
var = (Atom -> Int -> Int) -> Int -> [Atom] -> Int
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Atom
a Int
acc -> if ClassOrVar -> Atom -> Bool
elemOfAtom ClassOrVar
var Atom
a then Int
acc Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
100 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Atom -> Int
atomLen Atom
a else Int
acc) Int
0 [Atom]
atoms

    atomLen :: Atom -> Int
atomLen (Atom ClassOrVar
_ SRTree ClassOrVar
t) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [ClassOrVar] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (SRTree ClassOrVar -> [ClassOrVar]
forall a. SRTree a -> [a]
getElems SRTree ClassOrVar
t)