-- | Equations.
{-# LANGUAGE TypeFamilies #-}
module Twee.Equation where

import Twee.Base
import Control.Monad

--------------------------------------------------------------------------------
-- * Equations.
--------------------------------------------------------------------------------

data Equation f =
  (:=:) {
    Equation f -> Term f
eqn_lhs :: {-# UNPACK #-} !(Term f),
    Equation f -> Term f
eqn_rhs :: {-# UNPACK #-} !(Term f) }
  deriving (Equation f -> Equation f -> Bool
(Equation f -> Equation f -> Bool)
-> (Equation f -> Equation f -> Bool) -> Eq (Equation f)
forall f. Equation f -> Equation f -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Equation f -> Equation f -> Bool
$c/= :: forall f. Equation f -> Equation f -> Bool
== :: Equation f -> Equation f -> Bool
$c== :: forall f. Equation f -> Equation f -> Bool
Eq, Eq (Equation f)
Eq (Equation f)
-> (Equation f -> Equation f -> Ordering)
-> (Equation f -> Equation f -> Bool)
-> (Equation f -> Equation f -> Bool)
-> (Equation f -> Equation f -> Bool)
-> (Equation f -> Equation f -> Bool)
-> (Equation f -> Equation f -> Equation f)
-> (Equation f -> Equation f -> Equation f)
-> Ord (Equation f)
Equation f -> Equation f -> Bool
Equation f -> Equation f -> Ordering
Equation f -> Equation f -> Equation f
forall f. Eq (Equation f)
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall f. Equation f -> Equation f -> Bool
forall f. Equation f -> Equation f -> Ordering
forall f. Equation f -> Equation f -> Equation f
min :: Equation f -> Equation f -> Equation f
$cmin :: forall f. Equation f -> Equation f -> Equation f
max :: Equation f -> Equation f -> Equation f
$cmax :: forall f. Equation f -> Equation f -> Equation f
>= :: Equation f -> Equation f -> Bool
$c>= :: forall f. Equation f -> Equation f -> Bool
> :: Equation f -> Equation f -> Bool
$c> :: forall f. Equation f -> Equation f -> Bool
<= :: Equation f -> Equation f -> Bool
$c<= :: forall f. Equation f -> Equation f -> Bool
< :: Equation f -> Equation f -> Bool
$c< :: forall f. Equation f -> Equation f -> Bool
compare :: Equation f -> Equation f -> Ordering
$ccompare :: forall f. Equation f -> Equation f -> Ordering
$cp1Ord :: forall f. Eq (Equation f)
Ord, Int -> Equation f -> ShowS
[Equation f] -> ShowS
Equation f -> String
(Int -> Equation f -> ShowS)
-> (Equation f -> String)
-> ([Equation f] -> ShowS)
-> Show (Equation f)
forall f. (Labelled f, Show f) => Int -> Equation f -> ShowS
forall f. (Labelled f, Show f) => [Equation f] -> ShowS
forall f. (Labelled f, Show f) => Equation f -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Equation f] -> ShowS
$cshowList :: forall f. (Labelled f, Show f) => [Equation f] -> ShowS
show :: Equation f -> String
$cshow :: forall f. (Labelled f, Show f) => Equation f -> String
showsPrec :: Int -> Equation f -> ShowS
$cshowsPrec :: forall f. (Labelled f, Show f) => Int -> Equation f -> ShowS
Show)
type EquationOf a = Equation (ConstantOf a)

instance Symbolic (Equation f) where
  type ConstantOf (Equation f) = f
  termsDL :: Equation f -> DList (TermListOf (Equation f))
termsDL (Term f
t :=: Term f
u) = Term f -> DList (TermListOf (Term f))
forall a. Symbolic a => a -> DList (TermListOf a)
termsDL Term f
t DList (TermList f) -> DList (TermList f) -> DList (TermList f)
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Term f -> DList (TermListOf (Term f))
forall a. Symbolic a => a -> DList (TermListOf a)
termsDL Term f
u
  subst_ :: (Var -> BuilderOf (Equation f)) -> Equation f -> Equation f
subst_ Var -> BuilderOf (Equation f)
sub (Term f
t :=: Term f
u) = (Var -> BuilderOf (Term f)) -> Term f -> Term f
forall a. Symbolic a => (Var -> BuilderOf a) -> a -> a
subst_ Var -> BuilderOf (Term f)
Var -> BuilderOf (Equation f)
sub Term f
t Term f -> Term f -> Equation f
forall f. Term f -> Term f -> Equation f
:=: (Var -> BuilderOf (Term f)) -> Term f -> Term f
forall a. Symbolic a => (Var -> BuilderOf a) -> a -> a
subst_ Var -> BuilderOf (Term f)
Var -> BuilderOf (Equation f)
sub Term f
u

instance (Labelled f, PrettyTerm f) => Pretty (Equation f) where
  pPrint :: Equation f -> Doc
pPrint (Term f
x :=: Term f
y) = Term f -> Doc
forall a. Pretty a => a -> Doc
pPrint Term f
x Doc -> Doc -> Doc
<+> String -> Doc
text String
"=" Doc -> Doc -> Doc
<+> Term f -> Doc
forall a. Pretty a => a -> Doc
pPrint Term f
y

-- | Order an equation roughly left-to-right.
-- However, there is no guarantee that the result is oriented.
order :: Function f => Equation f -> Equation f
order :: Equation f -> Equation f
order (Term f
l :=: Term f
r)
  | Term f
l Term f -> Term f -> Bool
forall a. Eq a => a -> a -> Bool
== Term f
r = Term f
l Term f -> Term f -> Equation f
forall f. Term f -> Term f -> Equation f
:=: Term f
r
  | Term f -> Term f -> Bool
forall f. Ordered f => Term f -> Term f -> Bool
lessEqSkolem Term f
l Term f
r = Term f
r Term f -> Term f -> Equation f
forall f. Term f -> Term f -> Equation f
:=: Term f
l
  | Bool
otherwise = Term f
l Term f -> Term f -> Equation f
forall f. Term f -> Term f -> Equation f
:=: Term f
r

-- | Apply a function to both sides of an equation.
bothSides :: (Term f -> Term f') -> Equation f -> Equation f'
bothSides :: (Term f -> Term f') -> Equation f -> Equation f'
bothSides Term f -> Term f'
f (Term f
t :=: Term f
u) = Term f -> Term f'
f Term f
t Term f' -> Term f' -> Equation f'
forall f. Term f -> Term f -> Equation f
:=: Term f -> Term f'
f Term f
u

-- | Is an equation of the form t = t?
trivial :: Eq f => Equation f -> Bool
trivial :: Equation f -> Bool
trivial (Term f
t :=: Term f
u) = Term f
t Term f -> Term f -> Bool
forall a. Eq a => a -> a -> Bool
== Term f
u

-- | A total order on equations. Equations with lesser terms are smaller.
simplerThan :: Function f => Equation f -> Equation f -> Bool
Equation f
eq1 simplerThan :: Equation f -> Equation f -> Bool
`simplerThan` Equation f
eq2 =
  --traceShow (hang (pPrint eq1) 2 (text "`simplerThan`" <+> pPrint eq2 <+> text "=" <+> pPrint res)) res
  Term f
t1 Term f -> Term f -> Bool
forall f. Ordered f => Term f -> Term f -> Bool
`lessEqSkolem` Term f
t2 Bool -> Bool -> Bool
&& (Term f
t1 Term f -> Term f -> Bool
forall a. Eq a => a -> a -> Bool
/= Term f
t2 Bool -> Bool -> Bool
|| ((Term f
u1 Term f -> Term f -> Bool
forall f. Ordered f => Term f -> Term f -> Bool
`lessEqSkolem` Term f
u2 Bool -> Bool -> Bool
&& Term f
u1 Term f -> Term f -> Bool
forall a. Eq a => a -> a -> Bool
/= Term f
u2)))
  where
    Term f
t1 :=: Term f
u1 = Equation f -> Equation f
forall a. Symbolic a => a -> a
canonicalise (Equation f -> Equation f
forall f. Function f => Equation f -> Equation f
order Equation f
eq1)
    Term f
t2 :=: Term f
u2 = Equation f -> Equation f
forall a. Symbolic a => a -> a
canonicalise (Equation f -> Equation f
forall f. Function f => Equation f -> Equation f
order Equation f
eq2)

-- | Match one equation against another.
matchEquation :: Equation f -> Equation f -> Maybe (Subst f)
matchEquation :: Equation f -> Equation f -> Maybe (Subst f)
matchEquation (Term f
pat1 :=: Term f
pat2) (Term f
t1 :=: Term f
t2) = do
  Subst f
sub <- Term f -> Term f -> Maybe (Subst f)
forall f. Term f -> Term f -> Maybe (Subst f)
match Term f
pat1 Term f
t1
  Subst f -> Term f -> Term f -> Maybe (Subst f)
forall f. Subst f -> Term f -> Term f -> Maybe (Subst f)
matchIn Subst f
sub Term f
pat2 Term f
t2