-- | Equations.
{-# LANGUAGE TypeFamilies #-}
module Twee.Equation where

import Twee.Base
import Control.Monad

--------------------------------------------------------------------------------
-- * Equations.
--------------------------------------------------------------------------------

data Equation f =
  (:=:) {
    forall f. Equation f -> Term f
eqn_lhs :: {-# UNPACK #-} !(Term f),
    forall f. Equation f -> Term f
eqn_rhs :: {-# UNPACK #-} !(Term f) }
  deriving (Equation f -> Equation f -> Bool
forall f. Equation f -> Equation f -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Equation f -> Equation f -> Bool
$c/= :: forall f. Equation f -> Equation f -> Bool
== :: Equation f -> Equation f -> Bool
$c== :: forall f. Equation f -> Equation f -> Bool
Eq, Equation f -> Equation f -> Bool
Equation f -> Equation f -> Ordering
forall f. Eq (Equation f)
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall f. Equation f -> Equation f -> Bool
forall f. Equation f -> Equation f -> Ordering
forall f. Equation f -> Equation f -> Equation f
min :: Equation f -> Equation f -> Equation f
$cmin :: forall f. Equation f -> Equation f -> Equation f
max :: Equation f -> Equation f -> Equation f
$cmax :: forall f. Equation f -> Equation f -> Equation f
>= :: Equation f -> Equation f -> Bool
$c>= :: forall f. Equation f -> Equation f -> Bool
> :: Equation f -> Equation f -> Bool
$c> :: forall f. Equation f -> Equation f -> Bool
<= :: Equation f -> Equation f -> Bool
$c<= :: forall f. Equation f -> Equation f -> Bool
< :: Equation f -> Equation f -> Bool
$c< :: forall f. Equation f -> Equation f -> Bool
compare :: Equation f -> Equation f -> Ordering
$ccompare :: forall f. Equation f -> Equation f -> Ordering
Ord, Int -> Equation f -> ShowS
forall f. (Labelled f, Show f) => Int -> Equation f -> ShowS
forall f. (Labelled f, Show f) => [Equation f] -> ShowS
forall f. (Labelled f, Show f) => Equation f -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Equation f] -> ShowS
$cshowList :: forall f. (Labelled f, Show f) => [Equation f] -> ShowS
show :: Equation f -> String
$cshow :: forall f. (Labelled f, Show f) => Equation f -> String
showsPrec :: Int -> Equation f -> ShowS
$cshowsPrec :: forall f. (Labelled f, Show f) => Int -> Equation f -> ShowS
Show)
type EquationOf a = Equation (ConstantOf a)

instance Symbolic (Equation f) where
  type ConstantOf (Equation f) = f
  termsDL :: Equation f -> DList (TermListOf (Equation f))
termsDL (Term f
t :=: Term f
u) = forall a. Symbolic a => a -> DList (TermListOf a)
termsDL Term f
t forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` forall a. Symbolic a => a -> DList (TermListOf a)
termsDL Term f
u
  subst_ :: (Var -> BuilderOf (Equation f)) -> Equation f -> Equation f
subst_ Var -> BuilderOf (Equation f)
sub (Term f
t :=: Term f
u) = forall a. Symbolic a => (Var -> BuilderOf a) -> a -> a
subst_ Var -> BuilderOf (Equation f)
sub Term f
t forall f. Term f -> Term f -> Equation f
:=: forall a. Symbolic a => (Var -> BuilderOf a) -> a -> a
subst_ Var -> BuilderOf (Equation f)
sub Term f
u

instance (Labelled f, PrettyTerm f) => Pretty (Equation f) where
  pPrint :: Equation f -> Doc
pPrint (Term f
x :=: Term f
y) = forall a. Pretty a => a -> Doc
pPrint Term f
x Doc -> Doc -> Doc
<+> String -> Doc
text String
"=" Doc -> Doc -> Doc
<+> forall a. Pretty a => a -> Doc
pPrint Term f
y

-- | Order an equation roughly left-to-right, and
-- canonicalise its variables.
-- There is no guarantee that the result is oriented.
order :: Function f => Equation f -> Equation f
order :: forall f. Function f => Equation f -> Equation f
order (Term f
l :=: Term f
r)
  -- If the two terms have the same skeleton,
  -- then take whichever orientation gives a simpler equation
  | Term f
gl forall a. Eq a => a -> a -> Bool
== Term f
gr =
    let eq1 :: Equation f
eq1 = forall a. Symbolic a => a -> a
canonicalise (Term f
l forall f. Term f -> Term f -> Equation f
:=: Term f
r)
        eq2 :: Equation f
eq2 = forall a. Symbolic a => a -> a
canonicalise (Term f
r forall f. Term f -> Term f -> Equation f
:=: Term f
l) in
    if Equation f
eq1 forall a. Eq a => a -> a -> Bool
== Equation f
eq2 Bool -> Bool -> Bool
|| forall f. Function f => Equation f -> Equation f -> Bool
orderedSimplerThan Equation f
eq1 Equation f
eq2 then Equation f
eq1 else Equation f
eq2
  -- Otherwise, the LHS should be the term with the greater skeleton
  | Term f
gl forall f. Ordered f => Term f -> Term f -> Bool
`lessEq` Term f
gr = Term f
r forall f. Term f -> Term f -> Equation f
:=: Term f
l
  | Bool
otherwise = Term f
l forall f. Term f -> Term f -> Equation f
:=: Term f
r
  where
    gl :: Term f
gl = forall a f. (Symbolic a, ConstantOf a ~ f, Minimal f) => a -> a
ground Term f
l
    gr :: Term f
gr = forall a f. (Symbolic a, ConstantOf a ~ f, Minimal f) => a -> a
ground Term f
r

-- Helper for 'order' and 'simplerThan'
orderedSimplerThan :: Function f => Equation f -> Equation f -> Bool
orderedSimplerThan :: forall f. Function f => Equation f -> Equation f -> Bool
orderedSimplerThan (Term f
t1 :=: Term f
u1) (Term f
t2 :=: Term f
u2) =
  Term f
t1 forall f. Ordered f => Term f -> Term f -> Bool
`lessEqSkolem` Term f
t2 Bool -> Bool -> Bool
&& (Term f
t1 forall a. Eq a => a -> a -> Bool
/= Term f
t2 Bool -> Bool -> Bool
|| ((Term f
u1 forall f. Ordered f => Term f -> Term f -> Bool
`lessEqSkolem` Term f
u2 Bool -> Bool -> Bool
&& Term f
u1 forall a. Eq a => a -> a -> Bool
/= Term f
u2)))

-- | Apply a function to both sides of an equation.
bothSides :: (Term f -> Term f') -> Equation f -> Equation f'
bothSides :: forall f f'. (Term f -> Term f') -> Equation f -> Equation f'
bothSides Term f -> Term f'
f (Term f
t :=: Term f
u) = Term f -> Term f'
f Term f
t forall f. Term f -> Term f -> Equation f
:=: Term f -> Term f'
f Term f
u

-- | Is an equation of the form t = t?
trivial :: Eq f => Equation f -> Bool
trivial :: forall f. Eq f => Equation f -> Bool
trivial (Term f
t :=: Term f
u) = Term f
t forall a. Eq a => a -> a -> Bool
== Term f
u

-- | A total order on equations. Equations with lesser terms are smaller.
simplerThan :: Function f => Equation f -> Equation f -> Bool
Equation f
eq1 simplerThan :: forall f. Function f => Equation f -> Equation f -> Bool
`simplerThan` Equation f
eq2 =
  forall f. Function f => Equation f -> Equation f
order Equation f
eq1 forall f. Function f => Equation f -> Equation f -> Bool
`orderedSimplerThan` forall f. Function f => Equation f -> Equation f
order Equation f
eq2

-- | Match one equation against another.
matchEquation :: Equation f -> Equation f -> Maybe (Subst f)
matchEquation :: forall f. Equation f -> Equation f -> Maybe (Subst f)
matchEquation (Term f
pat1 :=: Term f
pat2) (Term f
t1 :=: Term f
t2) = do
  Subst f
sub <- forall f. Term f -> Term f -> Maybe (Subst f)
match Term f
pat1 Term f
t1
  forall f. Subst f -> Term f -> Term f -> Maybe (Subst f)
matchIn Subst f
sub Term f
pat2 Term f
t2