module Data.Logic.ATP.Unif
( Unify(unify', UTermOf)
, unify
, unify_terms
, unify_literals
, unify_atoms
, unify_atoms_eq
, solve
, fullunify
, unify_and_apply
, testUnif
) where
import Control.Monad.State
import Data.Bool (bool)
import Data.List as List (map)
import Data.Logic.ATP.Apply (HasApply(TermOf, PredOf), JustApply, zipApplys)
import Data.Logic.ATP.Equate (HasEquate, zipEquates)
import Data.Logic.ATP.FOL (tsubst)
import Data.Logic.ATP.Formulas (IsFormula(AtomOf))
import Data.Logic.ATP.Lib (Failing(Success, Failure))
import Data.Logic.ATP.Lit (IsLiteral, JustLiteral, zipLiterals')
import Data.Logic.ATP.Skolem (SkAtom, SkTerm)
import Data.Logic.ATP.Term (IsTerm(..), IsVariable)
import Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
import Test.HUnit hiding (State)
class (Monad m, IsTerm (UTermOf a), IsVariable (TVarOf (UTermOf a))) => Unify m a where
type UTermOf a
unify' :: a -> StateT (Map (TVarOf (UTermOf a)) (UTermOf a)) m ()
unify :: (Unify m a, Monad m) => a -> Map (TVarOf (UTermOf a)) (UTermOf a) -> m (Map (TVarOf (UTermOf a)) (UTermOf a))
unify a mp0 = execStateT (unify' a) mp0
unify_terms :: (IsTerm term, v ~ TVarOf term, Monad m) =>
[(term,term)] -> StateT (Map v term) m ()
unify_terms = mapM_ (uncurry unify_term_pair)
unify_term_pair :: forall term v f m.
(IsTerm term, v ~ TVarOf term, f ~ FunOf term, Monad m) =>
term -> term -> StateT (Map v term) m ()
unify_term_pair a b =
foldTerm (vr b) (\ f fargs -> foldTerm (vr a) (fn f fargs) b) a
where
vr :: term -> v -> StateT (Map v term) m ()
vr t x =
(Map.lookup x <$> get) >>=
maybe (istriv x t >>= bool (modify (Map.insert x t)) (return ()))
(\y -> unify_term_pair y t)
fn :: f -> [term] -> f -> [term] -> StateT (Map v term) m ()
fn f fargs g gargs =
if f == g && length fargs == length gargs
then mapM_ (uncurry unify_term_pair) (zip fargs gargs)
else fail "impossible unification"
istriv :: forall term v f m. (IsTerm term, v ~ TVarOf term, f ~ FunOf term, Monad m) =>
v -> term -> StateT (Map v term) m Bool
istriv x t =
foldTerm vr fn t
where
vr :: v -> StateT (Map v term) m Bool
vr y | x == y = return True
vr y = (Map.lookup y <$> get) >>= \(mt :: Maybe term) -> maybe (return False) (istriv x) mt
fn :: f -> [term] -> StateT (Map v term) m Bool
fn _ args = mapM (istriv x) args >>= bool (return False) (fail "cyclic") . or
solve :: (IsTerm term, v ~ TVarOf term) =>
Map v term -> Map v term
solve env =
if env' == env then env else solve env'
where env' = Map.map (tsubst env) env
fullunify :: (IsTerm term, v ~ TVarOf term, f ~ FunOf term, Monad m) =>
[(term,term)] -> m (Map v term)
fullunify eqs = solve <$> execStateT (unify_terms eqs) Map.empty
unify_and_apply :: (IsTerm term, v ~ TVarOf term, f ~ FunOf term, Monad m) =>
[(term, term)] -> m [(term, term)]
unify_and_apply eqs =
fullunify eqs >>= \i -> return $ List.map (\ (t1, t2) -> (tsubst i t1, tsubst i t2)) eqs
unify_literals :: forall lit1 lit2 atom1 atom2 v term m.
(IsLiteral lit1, HasApply atom1, atom1 ~ AtomOf lit1, term ~ TermOf atom1,
JustLiteral lit2, HasApply atom2, atom2 ~ AtomOf lit2, term ~ TermOf atom2,
Unify m (atom1, atom2), term ~ UTermOf (atom1, atom2), v ~ TVarOf term) =>
lit1 -> lit2 -> StateT (Map v term) m ()
unify_literals f1 f2 =
fromMaybe (fail "Can't unify literals") (zipLiterals' ho ne tf at f1 f2)
where
ho _ _ = Nothing
ne p q = Just $ unify_literals p q
tf p q = if p == q then Just (unify_terms ([] :: [(term, term)])) else Nothing
at a1 a2 = Just (unify' (a1, a2))
unify_atoms :: (JustApply atom1, term ~ TermOf atom1,
JustApply atom2, term ~ TermOf atom2,
v ~ TVarOf term, PredOf atom1 ~ PredOf atom2, Monad m) =>
(atom1, atom2) -> StateT (Map v term) m ()
unify_atoms (a1, a2) =
maybe (fail "unify_atoms") id (zipApplys (\_ tpairs -> Just (unify_terms tpairs)) a1 a2)
unify_atoms_eq :: (HasEquate atom1, term ~ TermOf atom1,
HasEquate atom2, term ~ TermOf atom2,
PredOf atom1 ~ PredOf atom2, v ~ TVarOf term, Monad m) =>
atom1 -> atom2 -> StateT (Map v term) m ()
unify_atoms_eq a1 a2 =
maybe (fail "unify_atoms") id (zipEquates (\l1 r1 l2 r2 -> Just (unify_terms [(l1, l2), (r1, r2)]))
(\_ tpairs -> Just (unify_terms tpairs))
a1 a2)
instance Monad m => Unify m (SkAtom, SkAtom) where
type UTermOf (SkAtom, SkAtom) = TermOf SkAtom
unify' = uncurry unify_atoms_eq
test01, test02, test03, test04 :: Test
test01 = TestCase (assertEqual "Unify test 1"
(Success [(f [f [z],g [y]],
f [f [z],g [y]])])
(unify_and_apply [(f [x, g [y]], f [f [z], w])]))
where
[f, g] = [fApp "f", fApp "g"]
[w, x, y, z] = [vt "w", vt "x", vt "y", vt "z"] :: [SkTerm]
test02 = TestCase (assertEqual "Unify test 2"
(Success [(f [y,y],
f [y,y])])
(unify_and_apply [(f [x, y], f [y, x])]))
where
[f] = [fApp "f"]
[x, y] = [vt "x", vt "y"] :: [SkTerm]
test03 = TestCase (assertEqual "Unify test 3"
(Failure ["cyclic"])
(unify_and_apply [(f [x, g [y]], f [y, x])]))
where
[f, g] = [fApp "f", fApp "g"]
[x, y] = [vt "x", vt "y"] :: [SkTerm]
test04 = TestCase (assertEqual "Unify test 4"
(Success [(f [f [f [x_3,x_3],f [x_3,x_3]], f [f [x_3,x_3],f [x_3,x_3]]],
f [f [f [x_3,x_3],f [x_3,x_3]], f [f [x_3,x_3],f [x_3,x_3]]]),
(f [f [x_3,x_3],f [x_3,x_3]],
f [f [x_3,x_3],f [x_3,x_3]]),
(f [x_3,x_3],
f [x_3,x_3])])
(unify_and_apply [(x_0, f [x_1, x_1]),
(x_1, f [x_2, x_2]),
(x_2, f [x_3, x_3])]))
where
f = fApp "f"
[x_0, x_1, x_2, x_3] = [vt "x0", vt "x1", vt "x2", vt "x3"] :: [SkTerm]
testUnif :: Test
testUnif = TestLabel "Unif" (TestList [test01, test02, test03, test04])