{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
-------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Unification
-- Copyright   :  (c) 2010-2011 Patrick Bahr
-- License     :  BSD3
-- Maintainer  :  Patrick Bahr <paba@diku.dk>
-- Stability   :  experimental
-- Portability :  non-portable (GHC Extensions)
--
-- This module implements a simple unification algorithm using compositional
-- data types.
--
--------------------------------------------------------------------------------

module Data.Comp.Unification where

import Data.Comp.Decompose
import Data.Comp.Term
import Data.Comp.Variables

import Control.Monad.Except
import Control.Monad.State

import qualified Data.Map as Map

{-| This type represents equations between terms over a specific
signature. -}

type Equation f = (Term f,Term f)

{-| This type represents list of equations. -}

type Equations f = [Equation f]

{-| This type represents errors that might occur during the
unification.  -}

data UnifError f v = FailedOccursCheck v (Term f)
                   | HeadSymbolMismatch (Term f) (Term f)
                   | UnifError String

-- | This is used in order to signal a failed occurs check during
-- unification.
failedOccursCheck :: (MonadError (UnifError f v) m) => v -> Term f -> m a
failedOccursCheck :: v -> Term f -> m a
failedOccursCheck v
v Term f
t = UnifError f v -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (UnifError f v -> m a) -> UnifError f v -> m a
forall a b. (a -> b) -> a -> b
$ v -> Term f -> UnifError f v
forall (f :: * -> *) v. v -> Term f -> UnifError f v
FailedOccursCheck v
v Term f
t

-- | This is used in order to signal a head symbol mismatch during
-- unification.
headSymbolMismatch :: (MonadError (UnifError f v) m) => Term f -> Term f -> m a
headSymbolMismatch :: Term f -> Term f -> m a
headSymbolMismatch Term f
f Term f
g = UnifError f v -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (UnifError f v -> m a) -> UnifError f v -> m a
forall a b. (a -> b) -> a -> b
$ Term f -> Term f -> UnifError f v
forall (f :: * -> *) v. Term f -> Term f -> UnifError f v
HeadSymbolMismatch Term f
f Term f
g

-- | This function applies a substitution to each term in a list of
-- equations.
appSubstEq :: (Ord v,  HasVars f v, Traversable f) =>
     Subst f v -> Equation f -> Equation f
appSubstEq :: Subst f v -> Equation f -> Equation f
appSubstEq Subst f v
s (Term f
t1,Term f
t2) = (Subst f v -> Term f -> Term f
forall v t a. (Ord v, SubstVars v t a) => Map v t -> a -> a
appSubst Subst f v
s Term f
t1,Subst f v -> Term f -> Term f
forall v t a. (Ord v, SubstVars v t a) => Map v t -> a -> a
appSubst Subst f v
s Term f
t2)


{-| This function returns the most general unifier of the given
equations using the algorithm of Martelli and Montanari. -}

unify :: (MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f), Traversable f)
      => Equations f -> m (Subst f v)
unify :: Equations f -> m (Subst f v)
unify = UnifyM f v m () -> Equations f -> m (Subst f v)
forall (f :: * -> *) v (m :: * -> *) a.
MonadError (UnifError f v) m =>
UnifyM f v m a -> Equations f -> m (Subst f v)
runUnifyM UnifyM f v m ()
forall (f :: * -> *) v (m :: * -> *).
(MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f),
 Traversable f) =>
UnifyM f v m ()
runUnify

-- | This type represents the state for the unification algorithm.
data UnifyState f v = UnifyState {UnifyState f v -> Equations f
usEqs ::Equations f, UnifyState f v -> Subst f v
usSubst :: Subst f v}

-- | This is the unification monad that is used to run the unification
-- algorithm.
type UnifyM f v m a = StateT (UnifyState f v) m a

-- | This function runs a unification monad with the given initial
-- list of equations.
runUnifyM :: MonadError (UnifError f v) m
          => UnifyM f v m a -> Equations f -> m (Subst f v)
runUnifyM :: UnifyM f v m a -> Equations f -> m (Subst f v)
runUnifyM UnifyM f v m a
m Equations f
eqs = ((a, UnifyState f v) -> Subst f v)
-> m (a, UnifyState f v) -> m (Subst f v)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (UnifyState f v -> Subst f v
forall (f :: * -> *) v. UnifyState f v -> Subst f v
usSubst (UnifyState f v -> Subst f v)
-> ((a, UnifyState f v) -> UnifyState f v)
-> (a, UnifyState f v)
-> Subst f v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, UnifyState f v) -> UnifyState f v
forall a b. (a, b) -> b
snd) (m (a, UnifyState f v) -> m (Subst f v))
-> m (a, UnifyState f v) -> m (Subst f v)
forall a b. (a -> b) -> a -> b
$
                           UnifyM f v m a -> UnifyState f v -> m (a, UnifyState f v)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT UnifyM f v m a
m UnifyState :: forall (f :: * -> *) v. Equations f -> Subst f v -> UnifyState f v
UnifyState { usEqs :: Equations f
usEqs = Equations f
eqs, usSubst :: Subst f v
usSubst = Subst f v
forall k a. Map k a
Map.empty}

withNextEq :: Monad m
           => (Equation f -> UnifyM f v m ()) -> UnifyM f v m ()
withNextEq :: (Equation f -> UnifyM f v m ()) -> UnifyM f v m ()
withNextEq Equation f -> UnifyM f v m ()
m = do Equations f
eqs <- (UnifyState f v -> Equations f)
-> StateT (UnifyState f v) m (Equations f)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets UnifyState f v -> Equations f
forall (f :: * -> *) v. UnifyState f v -> Equations f
usEqs
                  case Equations f
eqs of
                    [] -> () -> UnifyM f v m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                    Equation f
x : Equations f
xs -> (UnifyState f v -> UnifyState f v) -> UnifyM f v m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\UnifyState f v
s -> UnifyState f v
s {usEqs :: Equations f
usEqs = Equations f
xs})
                           UnifyM f v m () -> UnifyM f v m () -> UnifyM f v m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Equation f -> UnifyM f v m ()
m Equation f
x

putEqs :: Monad m
       => Equations f -> UnifyM f v m ()
putEqs :: Equations f -> UnifyM f v m ()
putEqs Equations f
eqs = (UnifyState f v -> UnifyState f v) -> UnifyM f v m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify UnifyState f v -> UnifyState f v
forall v. UnifyState f v -> UnifyState f v
addEqs
    where addEqs :: UnifyState f v -> UnifyState f v
addEqs UnifyState f v
s = UnifyState f v
s {usEqs :: Equations f
usEqs = Equations f
eqs Equations f -> Equations f -> Equations f
forall a. [a] -> [a] -> [a]
++ UnifyState f v -> Equations f
forall (f :: * -> *) v. UnifyState f v -> Equations f
usEqs UnifyState f v
s}

putBinding :: (Monad m, Ord v, HasVars f v, Traversable f) => (v, Term f) -> UnifyM f v m ()
putBinding :: (v, Term f) -> UnifyM f v m ()
putBinding (v, Term f)
bind = (UnifyState f v -> UnifyState f v) -> UnifyM f v m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify UnifyState f v -> UnifyState f v
appSubst
    where binds :: Map v (Term f)
binds = [(v, Term f)] -> Map v (Term f)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(v, Term f)
bind]
          appSubst :: UnifyState f v -> UnifyState f v
appSubst UnifyState f v
s = UnifyState f v
s { usEqs :: Equations f
usEqs = (Equation f -> Equation f) -> Equations f -> Equations f
forall a b. (a -> b) -> [a] -> [b]
map (Map v (Term f) -> Equation f -> Equation f
forall v (f :: * -> *).
(Ord v, HasVars f v, Traversable f) =>
Subst f v -> Equation f -> Equation f
appSubstEq Map v (Term f)
binds) (UnifyState f v -> Equations f
forall (f :: * -> *) v. UnifyState f v -> Equations f
usEqs UnifyState f v
s),
                             usSubst :: Map v (Term f)
usSubst = Map v (Term f) -> Map v (Term f) -> Map v (Term f)
forall v (f :: * -> *) h a.
(Ord v, HasVars f v, Traversable f) =>
CxtSubst h a f v -> CxtSubst h a f v -> CxtSubst h a f v
compSubst Map v (Term f)
binds (UnifyState f v -> Map v (Term f)
forall (f :: * -> *) v. UnifyState f v -> Subst f v
usSubst UnifyState f v
s)}


runUnify :: (MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f), Traversable f)
         => UnifyM f v m ()
runUnify :: UnifyM f v m ()
runUnify = (Equation f -> UnifyM f v m ()) -> UnifyM f v m ()
forall (m :: * -> *) (f :: * -> *) v.
Monad m =>
(Equation f -> UnifyM f v m ()) -> UnifyM f v m ()
withNextEq (\ Equation f
e -> Equation f -> UnifyM f v m ()
forall (f :: * -> *) v (m :: * -> *).
(MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f),
 Traversable f) =>
Equation f -> UnifyM f v m ()
unifyStep Equation f
e UnifyM f v m () -> UnifyM f v m () -> UnifyM f v m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> UnifyM f v m ()
forall (f :: * -> *) v (m :: * -> *).
(MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f),
 Traversable f) =>
UnifyM f v m ()
runUnify)

unifyStep :: (MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f), Traversable f)
          => Equation f -> UnifyM f v m ()
unifyStep :: Equation f -> UnifyM f v m ()
unifyStep (Term f
s,Term f
t) = case Term f -> DecompTerm f v
forall (f :: * -> *) v. Decompose f v => Term f -> DecompTerm f v
decompose Term f
s of
                    Var v
v1 -> case Term f -> DecompTerm f v
forall (f :: * -> *) v. Decompose f v => Term f -> DecompTerm f v
decompose Term f
t of
                                 Var v
v2 -> Bool -> UnifyM f v m () -> UnifyM f v m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (v
v1 v -> v -> Bool
forall a. Eq a => a -> a -> Bool
== v
v2) (UnifyM f v m () -> UnifyM f v m ())
-> UnifyM f v m () -> UnifyM f v m ()
forall a b. (a -> b) -> a -> b
$
                                             (v, Term f) -> UnifyM f v m ()
forall (m :: * -> *) v (f :: * -> *).
(Monad m, Ord v, HasVars f v, Traversable f) =>
(v, Term f) -> UnifyM f v m ()
putBinding (v
v1, Term f
t)
                                 DecompTerm f v
_ -> if v -> Term f -> Bool
forall v (f :: * -> *) h a.
(Eq v, HasVars f v, Traversable f, Ord v) =>
v -> Cxt h f a -> Bool
containsVar v
v1 Term f
t
                                      then v -> Term f -> UnifyM f v m ()
forall (f :: * -> *) v (m :: * -> *) a.
MonadError (UnifError f v) m =>
v -> Term f -> m a
failedOccursCheck v
v1 Term f
t
                                      else (v, Term f) -> UnifyM f v m ()
forall (m :: * -> *) v (f :: * -> *).
(Monad m, Ord v, HasVars f v, Traversable f) =>
(v, Term f) -> UnifyM f v m ()
putBinding (v
v1,Term f
t)
                    Fun Const f
s1 [Term f]
args1 -> case Term f -> DecompTerm f v
forall (f :: * -> *) v. Decompose f v => Term f -> DecompTerm f v
decompose Term f
t of
                                       Var v
v -> if v -> Term f -> Bool
forall v (f :: * -> *) h a.
(Eq v, HasVars f v, Traversable f, Ord v) =>
v -> Cxt h f a -> Bool
containsVar v
v Term f
s
                                                 then v -> Term f -> UnifyM f v m ()
forall (f :: * -> *) v (m :: * -> *) a.
MonadError (UnifError f v) m =>
v -> Term f -> m a
failedOccursCheck v
v Term f
s
                                                 else (v, Term f) -> UnifyM f v m ()
forall (m :: * -> *) v (f :: * -> *).
(Monad m, Ord v, HasVars f v, Traversable f) =>
(v, Term f) -> UnifyM f v m ()
putBinding (v
v,Term f
s)
                                       Fun Const f
s2 [Term f]
args2 -> if Const f
s1 Const f -> Const f -> Bool
forall a. Eq a => a -> a -> Bool
== Const f
s2
                                                        then Equations f -> UnifyM f v m ()
forall (m :: * -> *) (f :: * -> *) v.
Monad m =>
Equations f -> UnifyM f v m ()
putEqs (Equations f -> UnifyM f v m ()) -> Equations f -> UnifyM f v m ()
forall a b. (a -> b) -> a -> b
$ [Term f] -> [Term f] -> Equations f
forall a b. [a] -> [b] -> [(a, b)]
zip [Term f]
args1 [Term f]
args2
                                                        else Term f -> Term f -> UnifyM f v m ()
forall (f :: * -> *) v (m :: * -> *) a.
MonadError (UnifError f v) m =>
Term f -> Term f -> m a
headSymbolMismatch Term f
s Term f
t