-- | Unification for first order terms.
--
-- Copyright (c) 2003-2007, John Harrison. (See "LICENSE.txt" for details.)

{-# OPTIONS -Wall #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Data.Logic.ATP.Unif
    ( Unify(unify', UTermOf)
    , unify
    , unify_terms
    , unify_literals
    , unify_atoms
    , unify_atoms_eq
    , solve
    , fullunify
    , unify_and_apply
    , testUnif
    ) where

import Control.Monad.State hiding (fail) -- (evalStateT, runStateT, State, StateT, get)
import Prelude hiding (fail)
import Control.Monad.Fail
import Data.Bool (bool)
import Data.List as List (map)
import Data.Logic.ATP.Apply (HasApply(TermOf, PredOf), JustApply, zipApplys)
import Data.Logic.ATP.Equate (HasEquate, zipEquates)
import Data.Logic.ATP.FOL (tsubst)
import Data.Logic.ATP.Formulas (IsFormula(AtomOf))
import Data.Logic.ATP.Lib (Failing(Success, Failure))
import Data.Logic.ATP.Lit (IsLiteral, JustLiteral, zipLiterals')
import Data.Logic.ATP.Skolem (SkAtom, SkTerm)
import Data.Logic.ATP.Term (IsTerm(..), IsVariable)
import Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
-- import Data.Sequence (Seq, viewl, ViewL(EmptyL, (:<)))
import Test.HUnit hiding (State)

-- | Main unification procedure.  The result of unification is a
-- mapping of variables to terms, so although we can unify two
-- dissimilar types, they must at least have the same term type (which
-- means the variable type will also match.)  The result of unifying
-- the two arguments is added to the state, while failure is signalled
-- in the Failing monad.
--
-- One might think that Unify should take two type parameters, the
-- types of two values to be unified, but there are instances where a
-- single type contains both - for example, in template-haskell we
-- want to unify a and b in a predicate such as this: @(AppT (AppT
-- EqualityT a) b)@.
class (Monad m, IsTerm (UTermOf a), IsVariable (TVarOf (UTermOf a))) => Unify m a where
    type UTermOf a
    unify' :: a -> StateT (Map (TVarOf (UTermOf a)) (UTermOf a)) m ()

unify :: (Unify m a, Monad m) => a -> Map (TVarOf (UTermOf a)) (UTermOf a) -> m (Map (TVarOf (UTermOf a)) (UTermOf a))
unify :: forall (m :: * -> *) a.
(Unify m a, Monad m) =>
a
-> Map (TVarOf (UTermOf a)) (UTermOf a)
-> m (Map (TVarOf (UTermOf a)) (UTermOf a))
unify a
a Map (TVarOf (UTermOf a)) (UTermOf a)
mp0 = StateT (Map (TVarOf (UTermOf a)) (UTermOf a)) m ()
-> Map (TVarOf (UTermOf a)) (UTermOf a)
-> m (Map (TVarOf (UTermOf a)) (UTermOf a))
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT (a -> StateT (Map (TVarOf (UTermOf a)) (UTermOf a)) m ()
forall (m :: * -> *) a.
Unify m a =>
a -> StateT (Map (TVarOf (UTermOf a)) (UTermOf a)) m ()
unify' a
a) Map (TVarOf (UTermOf a)) (UTermOf a)
mp0

unify_terms :: (IsTerm term, v ~ TVarOf term, MonadFail m) =>
               [(term,term)] -> StateT (Map v term) m ()
unify_terms :: forall term v (m :: * -> *).
(IsTerm term, v ~ TVarOf term, MonadFail m) =>
[(term, term)] -> StateT (Map v term) m ()
unify_terms = ((term, term) -> StateT (Map v term) m ())
-> [(term, term)] -> StateT (Map v term) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((term -> term -> StateT (Map v term) m ())
-> (term, term) -> StateT (Map v term) m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry term -> term -> StateT (Map v term) m ()
forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
term -> term -> StateT (Map v term) m ()
unify_term_pair)

unify_term_pair :: forall term v f m.
                   (IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
                   term -> term -> StateT (Map v term) m ()
unify_term_pair :: forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
term -> term -> StateT (Map v term) m ()
unify_term_pair term
a term
b =
    (TVarOf term -> StateT (Map v term) m ())
-> (FunOf term -> [term] -> StateT (Map v term) m ())
-> term
-> StateT (Map v term) m ()
forall term r.
IsTerm term =>
(TVarOf term -> r) -> (FunOf term -> [term] -> r) -> term -> r
forall r.
(TVarOf term -> r) -> (FunOf term -> [term] -> r) -> term -> r
foldTerm (term -> v -> StateT (Map v term) m ()
vr term
b) (\ FunOf term
f [term]
fargs -> (TVarOf term -> StateT (Map v term) m ())
-> (FunOf term -> [term] -> StateT (Map v term) m ())
-> term
-> StateT (Map v term) m ()
forall term r.
IsTerm term =>
(TVarOf term -> r) -> (FunOf term -> [term] -> r) -> term -> r
forall r.
(TVarOf term -> r) -> (FunOf term -> [term] -> r) -> term -> r
foldTerm (term -> v -> StateT (Map v term) m ()
vr term
a) (f -> [term] -> f -> [term] -> StateT (Map v term) m ()
fn f
FunOf term
f [term]
fargs) term
b) term
a
    where
      vr :: term -> v -> StateT (Map v term) m ()
      vr :: term -> v -> StateT (Map v term) m ()
vr term
t v
x =
          (v -> Map v term -> Maybe term
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup v
x (Map v term -> Maybe term)
-> StateT (Map v term) m (Map v term)
-> StateT (Map v term) m (Maybe term)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT (Map v term) m (Map v term)
forall s (m :: * -> *). MonadState s m => m s
get) StateT (Map v term) m (Maybe term)
-> (Maybe term -> StateT (Map v term) m ())
-> StateT (Map v term) m ()
forall a b.
StateT (Map v term) m a
-> (a -> StateT (Map v term) m b) -> StateT (Map v term) m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
          StateT (Map v term) m ()
-> (term -> StateT (Map v term) m ())
-> Maybe term
-> StateT (Map v term) m ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (v -> term -> StateT (Map v term) m Bool
forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
v -> term -> StateT (Map v term) m Bool
istriv v
x term
t StateT (Map v term) m Bool
-> (Bool -> StateT (Map v term) m ()) -> StateT (Map v term) m ()
forall a b.
StateT (Map v term) m a
-> (a -> StateT (Map v term) m b) -> StateT (Map v term) m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StateT (Map v term) m ()
-> StateT (Map v term) m () -> Bool -> StateT (Map v term) m ()
forall a. a -> a -> Bool -> a
bool ((Map v term -> Map v term) -> StateT (Map v term) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (v -> term -> Map v term -> Map v term
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert v
x term
t)) (() -> StateT (Map v term) m ()
forall a. a -> StateT (Map v term) m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()))
                (\term
y -> term -> term -> StateT (Map v term) m ()
forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
term -> term -> StateT (Map v term) m ()
unify_term_pair term
y term
t)
      fn :: f -> [term] -> f -> [term] -> StateT (Map v term) m ()
      fn :: f -> [term] -> f -> [term] -> StateT (Map v term) m ()
fn f
f [term]
fargs f
g [term]
gargs =
          if f
f f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== f
g Bool -> Bool -> Bool
&& [term] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [term]
fargs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [term] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [term]
gargs
          then ((term, term) -> StateT (Map v term) m ())
-> [(term, term)] -> StateT (Map v term) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((term -> term -> StateT (Map v term) m ())
-> (term, term) -> StateT (Map v term) m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry term -> term -> StateT (Map v term) m ()
forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
term -> term -> StateT (Map v term) m ()
unify_term_pair) ([term] -> [term] -> [(term, term)]
forall a b. [a] -> [b] -> [(a, b)]
zip [term]
fargs [term]
gargs)
          else String -> StateT (Map v term) m ()
forall a. String -> StateT (Map v term) m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"impossible unification"

istriv :: forall term v f m. (IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
          v -> term -> StateT (Map v term) m Bool
istriv :: forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
v -> term -> StateT (Map v term) m Bool
istriv v
x term
t =
    (TVarOf term -> StateT (Map v term) m Bool)
-> (FunOf term -> [term] -> StateT (Map v term) m Bool)
-> term
-> StateT (Map v term) m Bool
forall term r.
IsTerm term =>
(TVarOf term -> r) -> (FunOf term -> [term] -> r) -> term -> r
forall r.
(TVarOf term -> r) -> (FunOf term -> [term] -> r) -> term -> r
foldTerm v -> StateT (Map v term) m Bool
TVarOf term -> StateT (Map v term) m Bool
vr f -> [term] -> StateT (Map v term) m Bool
FunOf term -> [term] -> StateT (Map v term) m Bool
fn term
t
    where
      vr :: v -> StateT (Map v term) m Bool
      vr :: v -> StateT (Map v term) m Bool
vr v
y | v
x v -> v -> Bool
forall a. Eq a => a -> a -> Bool
== v
y = Bool -> StateT (Map v term) m Bool
forall a. a -> StateT (Map v term) m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
      vr v
y = (v -> Map v term -> Maybe term
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup v
y (Map v term -> Maybe term)
-> StateT (Map v term) m (Map v term)
-> StateT (Map v term) m (Maybe term)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT (Map v term) m (Map v term)
forall s (m :: * -> *). MonadState s m => m s
get) StateT (Map v term) m (Maybe term)
-> (Maybe term -> StateT (Map v term) m Bool)
-> StateT (Map v term) m Bool
forall a b.
StateT (Map v term) m a
-> (a -> StateT (Map v term) m b) -> StateT (Map v term) m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Maybe term
mt :: Maybe term) -> StateT (Map v term) m Bool
-> (term -> StateT (Map v term) m Bool)
-> Maybe term
-> StateT (Map v term) m Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Bool -> StateT (Map v term) m Bool
forall a. a -> StateT (Map v term) m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False) (v -> term -> StateT (Map v term) m Bool
forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
v -> term -> StateT (Map v term) m Bool
istriv v
x) Maybe term
mt
      fn :: f -> [term] -> StateT (Map v term) m Bool
      fn :: f -> [term] -> StateT (Map v term) m Bool
fn f
_ [term]
args = (term -> StateT (Map v term) m Bool)
-> [term] -> StateT (Map v term) m [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (v -> term -> StateT (Map v term) m Bool
forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
v -> term -> StateT (Map v term) m Bool
istriv v
x) [term]
args StateT (Map v term) m [Bool]
-> ([Bool] -> StateT (Map v term) m Bool)
-> StateT (Map v term) m Bool
forall a b.
StateT (Map v term) m a
-> (a -> StateT (Map v term) m b) -> StateT (Map v term) m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StateT (Map v term) m Bool
-> StateT (Map v term) m Bool -> Bool -> StateT (Map v term) m Bool
forall a. a -> a -> Bool -> a
bool (Bool -> StateT (Map v term) m Bool
forall a. a -> StateT (Map v term) m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False) (String -> StateT (Map v term) m Bool
forall a. String -> StateT (Map v term) m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"cyclic") (Bool -> StateT (Map v term) m Bool)
-> ([Bool] -> Bool) -> [Bool] -> StateT (Map v term) m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or

-- | Solve to obtain a single instantiation.
solve :: (IsTerm term, v ~ TVarOf term) =>
         Map v term -> Map v term
solve :: forall term v.
(IsTerm term, v ~ TVarOf term) =>
Map v term -> Map v term
solve Map v term
env =
    if Map v term
env' Map v term -> Map v term -> Bool
forall a. Eq a => a -> a -> Bool
== Map v term
env then Map v term
env else Map v term -> Map v term
forall term v.
(IsTerm term, v ~ TVarOf term) =>
Map v term -> Map v term
solve Map v term
env'
    where env' :: Map v term
env' = (term -> term) -> Map v term -> Map v term
forall a b k. (a -> b) -> Map k a -> Map k b
Map.map (Map v term -> term -> term
forall term v.
(IsTerm term, v ~ TVarOf term) =>
Map v term -> term -> term
tsubst Map v term
env) Map v term
env

-- | Unification reaching a final solved form (often this isn't needed).
fullunify :: (IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
             [(term,term)] -> m (Map v term)
fullunify :: forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
[(term, term)] -> m (Map v term)
fullunify [(term, term)]
eqs = Map v term -> Map v term
forall term v.
(IsTerm term, v ~ TVarOf term) =>
Map v term -> Map v term
solve (Map v term -> Map v term) -> m (Map v term) -> m (Map v term)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT (Map v term) m () -> Map v term -> m (Map v term)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT ([(term, term)] -> StateT (Map v term) m ()
forall term v (m :: * -> *).
(IsTerm term, v ~ TVarOf term, MonadFail m) =>
[(term, term)] -> StateT (Map v term) m ()
unify_terms [(term, term)]
eqs) Map v term
forall k a. Map k a
Map.empty

-- | Examples.
unify_and_apply :: (IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
                   [(term, term)] -> m [(term, term)]
unify_and_apply :: forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
[(term, term)] -> m [(term, term)]
unify_and_apply [(term, term)]
eqs =
    [(term, term)] -> m (Map v term)
forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
[(term, term)] -> m (Map v term)
fullunify [(term, term)]
eqs m (Map v term)
-> (Map v term -> m [(term, term)]) -> m [(term, term)]
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Map v term
i -> [(term, term)] -> m [(term, term)]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(term, term)] -> m [(term, term)])
-> [(term, term)] -> m [(term, term)]
forall a b. (a -> b) -> a -> b
$ ((term, term) -> (term, term)) -> [(term, term)] -> [(term, term)]
forall a b. (a -> b) -> [a] -> [b]
List.map (\ (term
t1, term
t2) -> (Map v term -> term -> term
forall term v.
(IsTerm term, v ~ TVarOf term) =>
Map v term -> term -> term
tsubst Map v term
i term
t1, Map v term -> term -> term
forall term v.
(IsTerm term, v ~ TVarOf term) =>
Map v term -> term -> term
tsubst Map v term
i term
t2)) [(term, term)]
eqs

-- | Unify literals, perhaps of different types, but sharing term and
-- variable type.  Note that only one needs to be 'JustLiteral', if
-- the unification succeeds the other must have been too, if it fails,
-- who cares.
unify_literals :: forall lit1 lit2 atom1 atom2 v term m.
                  (IsLiteral lit1, HasApply atom1, atom1 ~ AtomOf lit1, term ~ TermOf atom1,
                   JustLiteral lit2, HasApply atom2, atom2 ~ AtomOf lit2, term ~ TermOf atom2,
                   Unify m (atom1, atom2), term ~ UTermOf (atom1, atom2), v ~ TVarOf term,
                   MonadFail m) =>
                  lit1 -> lit2 -> StateT (Map v term) m ()
unify_literals :: forall lit1 lit2 atom1 atom2 v term (m :: * -> *).
(IsLiteral lit1, HasApply atom1, atom1 ~ AtomOf lit1,
 term ~ TermOf atom1, JustLiteral lit2, HasApply atom2,
 atom2 ~ AtomOf lit2, term ~ TermOf atom2, Unify m (atom1, atom2),
 term ~ UTermOf (atom1, atom2), v ~ TVarOf term, MonadFail m) =>
lit1 -> lit2 -> StateT (Map v term) m ()
unify_literals lit1
f1 lit2
f2 =
    StateT (Map v term) m ()
-> Maybe (StateT (Map v term) m ()) -> StateT (Map v term) m ()
forall a. a -> Maybe a -> a
fromMaybe (String -> StateT (Map v term) m ()
forall a. String -> StateT (Map v term) m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Can't unify literals") ((lit1 -> lit2 -> Maybe (StateT (Map v term) m ()))
-> (lit1 -> lit2 -> Maybe (StateT (Map v term) m ()))
-> (Bool -> Bool -> Maybe (StateT (Map v term) m ()))
-> (AtomOf lit1 -> AtomOf lit2 -> Maybe (StateT (Map v term) m ()))
-> lit1
-> lit2
-> Maybe (StateT (Map v term) m ())
forall lit1 lit2 r.
(IsLiteral lit1, IsLiteral lit2) =>
(lit1 -> lit2 -> Maybe r)
-> (lit1 -> lit2 -> Maybe r)
-> (Bool -> Bool -> Maybe r)
-> (AtomOf lit1 -> AtomOf lit2 -> Maybe r)
-> lit1
-> lit2
-> Maybe r
zipLiterals' lit1 -> lit2 -> Maybe (StateT (Map v term) m ())
forall {p} {p} {a}. p -> p -> Maybe a
ho lit1 -> lit2 -> Maybe (StateT (Map v term) m ())
lit1
-> lit2
-> Maybe
     (StateT
        (Map
           (TVarOf (UTermOf (AtomOf lit1, AtomOf lit2)))
           (TermOf (AtomOf lit1)))
        m
        ())
forall {lit2} {lit1} {m :: * -> *}.
(TermOf (AtomOf lit2) ~ UTermOf (AtomOf lit1, AtomOf lit2),
 TermOf (AtomOf lit1) ~ UTermOf (AtomOf lit1, AtomOf lit2),
 JustLiteral lit2, HasApply (AtomOf lit1), HasApply (AtomOf lit2),
 Unify m (AtomOf lit1, AtomOf lit2), MonadFail m, IsLiteral lit1) =>
lit1
-> lit2
-> Maybe
     (StateT
        (Map
           (TVarOf (UTermOf (AtomOf lit1, AtomOf lit2)))
           (TermOf (AtomOf lit1)))
        m
        ())
ne Bool -> Bool -> Maybe (StateT (Map v term) m ())
tf atom1
-> atom2
-> Maybe
     (StateT
        (Map (TVarOf (UTermOf (atom1, atom2))) (UTermOf (atom1, atom2)))
        m
        ())
AtomOf lit1 -> AtomOf lit2 -> Maybe (StateT (Map v term) m ())
forall {m :: * -> *} {a} {b}.
Unify m (a, b) =>
a
-> b
-> Maybe
     (StateT (Map (TVarOf (UTermOf (a, b))) (UTermOf (a, b))) m ())
at lit1
f1 lit2
f2)
    where
      ho :: p -> p -> Maybe a
ho p
_ p
_ = Maybe a
forall a. Maybe a
Nothing
      ne :: lit1
-> lit2
-> Maybe
     (StateT
        (Map
           (TVarOf (UTermOf (AtomOf lit1, AtomOf lit2)))
           (TermOf (AtomOf lit1)))
        m
        ())
ne lit1
p lit2
q = StateT
  (Map
     (TVarOf (UTermOf (AtomOf lit1, AtomOf lit2)))
     (TermOf (AtomOf lit1)))
  m
  ()
-> Maybe
     (StateT
        (Map
           (TVarOf (UTermOf (AtomOf lit1, AtomOf lit2)))
           (TermOf (AtomOf lit1)))
        m
        ())
forall a. a -> Maybe a
Just (StateT
   (Map
      (TVarOf (UTermOf (AtomOf lit1, AtomOf lit2)))
      (TermOf (AtomOf lit1)))
   m
   ()
 -> Maybe
      (StateT
         (Map
            (TVarOf (UTermOf (AtomOf lit1, AtomOf lit2)))
            (TermOf (AtomOf lit1)))
         m
         ()))
-> StateT
     (Map
        (TVarOf (UTermOf (AtomOf lit1, AtomOf lit2)))
        (TermOf (AtomOf lit1)))
     m
     ()
-> Maybe
     (StateT
        (Map
           (TVarOf (UTermOf (AtomOf lit1, AtomOf lit2)))
           (TermOf (AtomOf lit1)))
        m
        ())
forall a b. (a -> b) -> a -> b
$ lit1
-> lit2
-> StateT
     (Map
        (TVarOf (UTermOf (AtomOf lit1, AtomOf lit2)))
        (TermOf (AtomOf lit1)))
     m
     ()
forall lit1 lit2 atom1 atom2 v term (m :: * -> *).
(IsLiteral lit1, HasApply atom1, atom1 ~ AtomOf lit1,
 term ~ TermOf atom1, JustLiteral lit2, HasApply atom2,
 atom2 ~ AtomOf lit2, term ~ TermOf atom2, Unify m (atom1, atom2),
 term ~ UTermOf (atom1, atom2), v ~ TVarOf term, MonadFail m) =>
lit1 -> lit2 -> StateT (Map v term) m ()
unify_literals lit1
p lit2
q
      -- tf :: Bool -> Bool -> Maybe (StateT (Map v term) m ())
      tf :: Bool -> Bool -> Maybe (StateT (Map v term) m ())
tf Bool
p Bool
q = if Bool
p Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
q then StateT (Map v term) m () -> Maybe (StateT (Map v term) m ())
forall a. a -> Maybe a
Just ([(term, term)] -> StateT (Map v term) m ()
forall term v (m :: * -> *).
(IsTerm term, v ~ TVarOf term, MonadFail m) =>
[(term, term)] -> StateT (Map v term) m ()
unify_terms ([] :: [(term, term)])) else Maybe (StateT (Map v term) m ())
forall a. Maybe a
Nothing
      at :: a
-> b
-> Maybe
     (StateT (Map (TVarOf (UTermOf (a, b))) (UTermOf (a, b))) m ())
at a
a1 b
a2 = StateT (Map (TVarOf (UTermOf (a, b))) (UTermOf (a, b))) m ()
-> Maybe
     (StateT (Map (TVarOf (UTermOf (a, b))) (UTermOf (a, b))) m ())
forall a. a -> Maybe a
Just ((a, b)
-> StateT (Map (TVarOf (UTermOf (a, b))) (UTermOf (a, b))) m ()
forall (m :: * -> *) a.
Unify m a =>
a -> StateT (Map (TVarOf (UTermOf a)) (UTermOf a)) m ()
unify' (a
a1, b
a2))

unify_atoms :: (JustApply atom1, term ~ TermOf atom1,
                JustApply atom2, term ~ TermOf atom2,
                v ~ TVarOf term, PredOf atom1 ~ PredOf atom2, MonadFail m) =>
               (atom1, atom2) -> StateT (Map v term) m ()
unify_atoms :: forall atom1 term atom2 v (m :: * -> *).
(JustApply atom1, term ~ TermOf atom1, JustApply atom2,
 term ~ TermOf atom2, v ~ TVarOf term, PredOf atom1 ~ PredOf atom2,
 MonadFail m) =>
(atom1, atom2) -> StateT (Map v term) m ()
unify_atoms (atom1
a1, atom2
a2) =
    StateT (Map v term) m ()
-> (StateT (Map v term) m () -> StateT (Map v term) m ())
-> Maybe (StateT (Map v term) m ())
-> StateT (Map v term) m ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> StateT (Map v term) m ()
forall a. String -> StateT (Map v term) m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"unify_atoms") StateT (Map v term) m () -> StateT (Map v term) m ()
forall a. a -> a
id ((PredOf atom2
 -> [(term, term)] -> Maybe (StateT (Map v term) m ()))
-> atom1 -> atom2 -> Maybe (StateT (Map v term) m ())
forall atom1 term predicate atom2 r.
(JustApply atom1, term ~ TermOf atom1, predicate ~ PredOf atom1,
 JustApply atom2, term ~ TermOf atom2, predicate ~ PredOf atom2) =>
(predicate -> [(term, term)] -> Maybe r)
-> atom1 -> atom2 -> Maybe r
zipApplys (\PredOf atom2
_ [(term, term)]
tpairs -> StateT (Map v term) m () -> Maybe (StateT (Map v term) m ())
forall a. a -> Maybe a
Just ([(term, term)] -> StateT (Map v term) m ()
forall term v (m :: * -> *).
(IsTerm term, v ~ TVarOf term, MonadFail m) =>
[(term, term)] -> StateT (Map v term) m ()
unify_terms [(term, term)]
tpairs)) atom1
a1 atom2
a2)

unify_atoms_eq :: (HasEquate atom1, term ~ TermOf atom1,
                   HasEquate atom2, term ~ TermOf atom2,
                   PredOf atom1 ~ PredOf atom2, v ~ TVarOf term, MonadFail m) =>
                  atom1 -> atom2 -> StateT (Map v term) m ()
unify_atoms_eq :: forall atom1 term atom2 v (m :: * -> *).
(HasEquate atom1, term ~ TermOf atom1, HasEquate atom2,
 term ~ TermOf atom2, PredOf atom1 ~ PredOf atom2, v ~ TVarOf term,
 MonadFail m) =>
atom1 -> atom2 -> StateT (Map v term) m ()
unify_atoms_eq atom1
a1 atom2
a2 =
    StateT (Map v term) m ()
-> (StateT (Map v term) m () -> StateT (Map v term) m ())
-> Maybe (StateT (Map v term) m ())
-> StateT (Map v term) m ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> StateT (Map v term) m ()
forall a. String -> StateT (Map v term) m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"unify_atoms") StateT (Map v term) m () -> StateT (Map v term) m ()
forall a. a -> a
id ((TermOf atom1
 -> TermOf atom1
 -> TermOf atom2
 -> TermOf atom2
 -> Maybe (StateT (Map v term) m ()))
-> (PredOf atom1
    -> [(TermOf atom1, TermOf atom2)]
    -> Maybe (StateT (Map v term) m ()))
-> atom1
-> atom2
-> Maybe (StateT (Map v term) m ())
forall atom1 atom2 r.
(HasEquate atom1, HasEquate atom2, PredOf atom1 ~ PredOf atom2) =>
(TermOf atom1
 -> TermOf atom1 -> TermOf atom2 -> TermOf atom2 -> Maybe r)
-> (PredOf atom1 -> [(TermOf atom1, TermOf atom2)] -> Maybe r)
-> atom1
-> atom2
-> Maybe r
zipEquates (\TermOf atom1
l1 TermOf atom1
r1 TermOf atom2
l2 TermOf atom2
r2 -> StateT (Map v term) m () -> Maybe (StateT (Map v term) m ())
forall a. a -> Maybe a
Just ([(term, term)] -> StateT (Map v term) m ()
forall term v (m :: * -> *).
(IsTerm term, v ~ TVarOf term, MonadFail m) =>
[(term, term)] -> StateT (Map v term) m ()
unify_terms [(term
TermOf atom1
l1, term
TermOf atom2
l2), (term
TermOf atom1
r1, term
TermOf atom2
r2)]))
                                              (\PredOf atom1
_ [(TermOf atom1, TermOf atom2)]
tpairs -> StateT (Map v term) m () -> Maybe (StateT (Map v term) m ())
forall a. a -> Maybe a
Just ([(term, term)] -> StateT (Map v term) m ()
forall term v (m :: * -> *).
(IsTerm term, v ~ TVarOf term, MonadFail m) =>
[(term, term)] -> StateT (Map v term) m ()
unify_terms [(term, term)]
[(TermOf atom1, TermOf atom2)]
tpairs))
                                              atom1
a1 atom2
a2)

--unify_and_apply' :: (v ~ TVarOf term, f ~ FunOf term, IsTerm term, Monad m) => [(term, term)] -> m [(term, term)]
--unify_and_apply' eqs =
--    mapM app eqs
--        where
--          app (t1, t2) = fullunify eqs >>= \i -> return $ (tsubst i t1, tsubst i t2)

instance MonadFail m => Unify m (SkAtom, SkAtom) where
    type UTermOf (SkAtom, SkAtom) = TermOf SkAtom
    unify' :: (SkAtom, SkAtom)
-> StateT
     (Map
        (TVarOf (UTermOf (SkAtom, SkAtom))) (UTermOf (SkAtom, SkAtom)))
     m
     ()
unify' = (SkAtom -> SkAtom -> StateT (Map V SkTerm) m ())
-> (SkAtom, SkAtom) -> StateT (Map V SkTerm) m ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SkAtom -> SkAtom -> StateT (Map V SkTerm) m ()
forall atom1 term atom2 v (m :: * -> *).
(HasEquate atom1, term ~ TermOf atom1, HasEquate atom2,
 term ~ TermOf atom2, PredOf atom1 ~ PredOf atom2, v ~ TVarOf term,
 MonadFail m) =>
atom1 -> atom2 -> StateT (Map v term) m ()
unify_atoms_eq

test01, test02, test03, test04 :: Test
test01 :: Test
test01 = Assertion -> Test
TestCase (String
-> Failing [(SkTerm, SkTerm)]
-> Failing [(SkTerm, SkTerm)]
-> Assertion
forall a.
(HasCallStack, Eq a, Show a) =>
String -> a -> a -> Assertion
assertEqual String
"Unify test 1"
                     ([(SkTerm, SkTerm)] -> Failing [(SkTerm, SkTerm)]
forall a. a -> Failing a
Success [([SkTerm] -> SkTerm
f [[SkTerm] -> SkTerm
f [SkTerm
z],[SkTerm] -> SkTerm
g [SkTerm
y]],
                                [SkTerm] -> SkTerm
f [[SkTerm] -> SkTerm
f [SkTerm
z],[SkTerm] -> SkTerm
g [SkTerm
y]])]) -- expected
                     ([(SkTerm, SkTerm)] -> Failing [(SkTerm, SkTerm)]
forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
[(term, term)] -> m [(term, term)]
unify_and_apply [([SkTerm] -> SkTerm
f [SkTerm
x, [SkTerm] -> SkTerm
g [SkTerm
y]], [SkTerm] -> SkTerm
f [[SkTerm] -> SkTerm
f [SkTerm
z], SkTerm
w])]))
    where
      [[SkTerm] -> SkTerm
f, [SkTerm] -> SkTerm
g] = [FunOf SkTerm -> [SkTerm] -> SkTerm
forall term. IsTerm term => FunOf term -> [term] -> term
fApp FunOf SkTerm
"f", FunOf SkTerm -> [SkTerm] -> SkTerm
forall term. IsTerm term => FunOf term -> [term] -> term
fApp FunOf SkTerm
"g"]
      [SkTerm
w, SkTerm
x, SkTerm
y, SkTerm
z] = [TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"w", TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"x", TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"y", TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"z"] :: [SkTerm]
test02 :: Test
test02 = Assertion -> Test
TestCase (String
-> Failing [(SkTerm, SkTerm)]
-> Failing [(SkTerm, SkTerm)]
-> Assertion
forall a.
(HasCallStack, Eq a, Show a) =>
String -> a -> a -> Assertion
assertEqual String
"Unify test 2"
                     ([(SkTerm, SkTerm)] -> Failing [(SkTerm, SkTerm)]
forall a. a -> Failing a
Success [([SkTerm] -> SkTerm
f [SkTerm
y,SkTerm
y],
                                [SkTerm] -> SkTerm
f [SkTerm
y,SkTerm
y])]) -- expected
                     ([(SkTerm, SkTerm)] -> Failing [(SkTerm, SkTerm)]
forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
[(term, term)] -> m [(term, term)]
unify_and_apply [([SkTerm] -> SkTerm
f [SkTerm
x, SkTerm
y], [SkTerm] -> SkTerm
f [SkTerm
y, SkTerm
x])]))
    where
      [[SkTerm] -> SkTerm
f] = [FunOf SkTerm -> [SkTerm] -> SkTerm
forall term. IsTerm term => FunOf term -> [term] -> term
fApp FunOf SkTerm
"f"]
      [SkTerm
x, SkTerm
y] = [TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"x", TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"y"] :: [SkTerm]
test03 :: Test
test03 = Assertion -> Test
TestCase (String
-> Failing [(SkTerm, SkTerm)]
-> Failing [(SkTerm, SkTerm)]
-> Assertion
forall a.
(HasCallStack, Eq a, Show a) =>
String -> a -> a -> Assertion
assertEqual String
"Unify test 3"
                     ([String] -> Failing [(SkTerm, SkTerm)]
forall a. [String] -> Failing a
Failure [String
"cyclic"]) -- expected
                     ([(SkTerm, SkTerm)] -> Failing [(SkTerm, SkTerm)]
forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
[(term, term)] -> m [(term, term)]
unify_and_apply [([SkTerm] -> SkTerm
f [SkTerm
x, [SkTerm] -> SkTerm
g [SkTerm
y]], [SkTerm] -> SkTerm
f [SkTerm
y, SkTerm
x])]))
    where
      [[SkTerm] -> SkTerm
f, [SkTerm] -> SkTerm
g] = [FunOf SkTerm -> [SkTerm] -> SkTerm
forall term. IsTerm term => FunOf term -> [term] -> term
fApp FunOf SkTerm
"f", FunOf SkTerm -> [SkTerm] -> SkTerm
forall term. IsTerm term => FunOf term -> [term] -> term
fApp FunOf SkTerm
"g"]
      [SkTerm
x, SkTerm
y] = [TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"x", TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"y"] :: [SkTerm]
test04 :: Test
test04 = Assertion -> Test
TestCase (String
-> Failing [(SkTerm, SkTerm)]
-> Failing [(SkTerm, SkTerm)]
-> Assertion
forall a.
(HasCallStack, Eq a, Show a) =>
String -> a -> a -> Assertion
assertEqual String
"Unify test 4"
                     ([(SkTerm, SkTerm)] -> Failing [(SkTerm, SkTerm)]
forall a. a -> Failing a
Success [([SkTerm] -> SkTerm
f [[SkTerm] -> SkTerm
f [[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3],[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3]], [SkTerm] -> SkTerm
f [[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3],[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3]]],
                                [SkTerm] -> SkTerm
f [[SkTerm] -> SkTerm
f [[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3],[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3]], [SkTerm] -> SkTerm
f [[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3],[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3]]]),
                               ([SkTerm] -> SkTerm
f [[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3],[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3]],
                                [SkTerm] -> SkTerm
f [[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3],[SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3]]),
                               ([SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3],
                                [SkTerm] -> SkTerm
f [SkTerm
x_3,SkTerm
x_3])]) -- expected
                     ([(SkTerm, SkTerm)] -> Failing [(SkTerm, SkTerm)]
forall term v f (m :: * -> *).
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, MonadFail m) =>
[(term, term)] -> m [(term, term)]
unify_and_apply [(SkTerm
x_0, [SkTerm] -> SkTerm
f [SkTerm
x_1, SkTerm
x_1]),
                                       (SkTerm
x_1, [SkTerm] -> SkTerm
f [SkTerm
x_2, SkTerm
x_2]),
                                       (SkTerm
x_2, [SkTerm] -> SkTerm
f [SkTerm
x_3, SkTerm
x_3])]))

    where
      f :: [SkTerm] -> SkTerm
f = FunOf SkTerm -> [SkTerm] -> SkTerm
forall term. IsTerm term => FunOf term -> [term] -> term
fApp FunOf SkTerm
"f"
      [SkTerm
x_0, SkTerm
x_1, SkTerm
x_2, SkTerm
x_3] = [TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"x0", TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"x1", TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"x2", TVarOf SkTerm -> SkTerm
forall term. IsTerm term => TVarOf term -> term
vt TVarOf SkTerm
V
"x3"] :: [SkTerm]
{-

START_INTERACTIVE;;
unify_and_apply [<<|f(x,g(y))|>>,<<|f(f(z),w)|>>];;

unify_and_apply [<<|f(x,y)|>>,<<|f(y,x)|>>];;

(****  unify_and_apply [<<|f(x,g(y))|>>,<<|f(y,x)|>>];; *****)

unify_and_apply [<<|x_0|>>,<<|f(x_1,x_1)|>>;
                 <<|x_1|>>,<<|f(x_2,x_2)|>>;
                 <<|x_2|>>,<<|f(x_3,x_3)|>>];;
END_INTERACTIVE;;
-}

testUnif :: Test
testUnif :: Test
testUnif = String -> Test -> Test
TestLabel String
"Unif" ([Test] -> Test
TestList [Test
test01, Test
test02, Test
test03, Test
test04])