{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
-------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Unification
-- Copyright   :  (c) 2010-2011 Patrick Bahr
-- License     :  BSD3
-- Maintainer  :  Patrick Bahr <paba@diku.dk>
-- Stability   :  experimental
-- Portability :  non-portable (GHC Extensions)
--
-- This module implements a simple unification algorithm using compositional
-- data types.
--
--------------------------------------------------------------------------------

module Data.Comp.Unification where

import Data.Comp.Decompose
import Data.Comp.Term
import Data.Comp.Variables

import Control.Monad
import Control.Monad.Except
import Control.Monad.State

import qualified Data.Map as Map

{-| This type represents equations between terms over a specific
signature. -}

type Equation f = (Term f,Term f)

{-| This type represents list of equations. -}

type Equations f = [Equation f]

{-| This type represents errors that might occur during the
unification.  -}

data UnifError f v = FailedOccursCheck v (Term f)
                   | HeadSymbolMismatch (Term f) (Term f)
                   | UnifError String

-- | This is used in order to signal a failed occurs check during
-- unification.
failedOccursCheck :: (MonadError (UnifError f v) m) => v -> Term f -> m a
failedOccursCheck :: forall (f :: * -> *) v (m :: * -> *) a.
MonadError (UnifError f v) m =>
v -> Term f -> m a
failedOccursCheck v
v Term f
t = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) v. v -> Term f -> UnifError f v
FailedOccursCheck v
v Term f
t

-- | This is used in order to signal a head symbol mismatch during
-- unification.
headSymbolMismatch :: (MonadError (UnifError f v) m) => Term f -> Term f -> m a
headSymbolMismatch :: forall (f :: * -> *) v (m :: * -> *) a.
MonadError (UnifError f v) m =>
Term f -> Term f -> m a
headSymbolMismatch Term f
f Term f
g = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) v. Term f -> Term f -> UnifError f v
HeadSymbolMismatch Term f
f Term f
g

-- | This function applies a substitution to each term in a list of
-- equations.
appSubstEq :: (Ord v,  HasVars f v, Traversable f) =>
     Subst f v -> Equation f -> Equation f
appSubstEq :: forall v (f :: * -> *).
(Ord v, HasVars f v, Traversable f) =>
Subst f v -> Equation f -> Equation f
appSubstEq Subst f v
s (Cxt NoHole f ()
t1,Cxt NoHole f ()
t2) = (forall v t a. (Ord v, SubstVars v t a) => Map v t -> a -> a
appSubst Subst f v
s Cxt NoHole f ()
t1,forall v t a. (Ord v, SubstVars v t a) => Map v t -> a -> a
appSubst Subst f v
s Cxt NoHole f ()
t2)


{-| This function returns the most general unifier of the given
equations using the algorithm of Martelli and Montanari. -}

unify :: (MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f), Traversable f)
      => Equations f -> m (Subst f v)
unify :: forall (f :: * -> *) v (m :: * -> *).
(MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f),
 Traversable f) =>
Equations f -> m (Subst f v)
unify = forall (f :: * -> *) v (m :: * -> *) a.
MonadError (UnifError f v) m =>
UnifyM f v m a -> Equations f -> m (Subst f v)
runUnifyM forall (f :: * -> *) v (m :: * -> *).
(MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f),
 Traversable f) =>
UnifyM f v m ()
runUnify

-- | This type represents the state for the unification algorithm.
data UnifyState f v = UnifyState {forall (f :: * -> *) v. UnifyState f v -> Equations f
usEqs ::Equations f, forall (f :: * -> *) v. UnifyState f v -> Subst f v
usSubst :: Subst f v}

-- | This is the unification monad that is used to run the unification
-- algorithm.
type UnifyM f v m a = StateT (UnifyState f v) m a

-- | This function runs a unification monad with the given initial
-- list of equations.
runUnifyM :: MonadError (UnifError f v) m
          => UnifyM f v m a -> Equations f -> m (Subst f v)
runUnifyM :: forall (f :: * -> *) v (m :: * -> *) a.
MonadError (UnifError f v) m =>
UnifyM f v m a -> Equations f -> m (Subst f v)
runUnifyM UnifyM f v m a
m Equations f
eqs = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (forall (f :: * -> *) v. UnifyState f v -> Subst f v
usSubst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
                           forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT UnifyM f v m a
m UnifyState { usEqs :: Equations f
usEqs = Equations f
eqs, usSubst :: Subst f v
usSubst = forall k a. Map k a
Map.empty}

withNextEq :: Monad m
           => (Equation f -> UnifyM f v m ()) -> UnifyM f v m ()
withNextEq :: forall (m :: * -> *) (f :: * -> *) v.
Monad m =>
(Equation f -> UnifyM f v m ()) -> UnifyM f v m ()
withNextEq Equation f -> UnifyM f v m ()
m = do Equations f
eqs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall (f :: * -> *) v. UnifyState f v -> Equations f
usEqs
                  case Equations f
eqs of
                    [] -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
                    Equation f
x : Equations f
xs -> forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\UnifyState f v
s -> UnifyState f v
s {usEqs :: Equations f
usEqs = Equations f
xs})
                           forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Equation f -> UnifyM f v m ()
m Equation f
x

putEqs :: Monad m
       => Equations f -> UnifyM f v m ()
putEqs :: forall (m :: * -> *) (f :: * -> *) v.
Monad m =>
Equations f -> UnifyM f v m ()
putEqs Equations f
eqs = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall {v}. UnifyState f v -> UnifyState f v
addEqs
    where addEqs :: UnifyState f v -> UnifyState f v
addEqs UnifyState f v
s = UnifyState f v
s {usEqs :: Equations f
usEqs = Equations f
eqs forall a. [a] -> [a] -> [a]
++ forall (f :: * -> *) v. UnifyState f v -> Equations f
usEqs UnifyState f v
s}

putBinding :: (Monad m, Ord v, HasVars f v, Traversable f) => (v, Term f) -> UnifyM f v m ()
putBinding :: forall (m :: * -> *) v (f :: * -> *).
(Monad m, Ord v, HasVars f v, Traversable f) =>
(v, Term f) -> UnifyM f v m ()
putBinding (v, Term f)
bind = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify UnifyState f v -> UnifyState f v
appSubst
    where binds :: Map v (Term f)
binds = forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(v, Term f)
bind]
          appSubst :: UnifyState f v -> UnifyState f v
appSubst UnifyState f v
s = UnifyState f v
s { usEqs :: Equations f
usEqs = forall a b. (a -> b) -> [a] -> [b]
map (forall v (f :: * -> *).
(Ord v, HasVars f v, Traversable f) =>
Subst f v -> Equation f -> Equation f
appSubstEq Map v (Term f)
binds) (forall (f :: * -> *) v. UnifyState f v -> Equations f
usEqs UnifyState f v
s),
                             usSubst :: Map v (Term f)
usSubst = forall v (f :: * -> *) h a.
(Ord v, HasVars f v, Traversable f) =>
CxtSubst h a f v -> CxtSubst h a f v -> CxtSubst h a f v
compSubst Map v (Term f)
binds (forall (f :: * -> *) v. UnifyState f v -> Subst f v
usSubst UnifyState f v
s)}


runUnify :: (MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f), Traversable f)
         => UnifyM f v m ()
runUnify :: forall (f :: * -> *) v (m :: * -> *).
(MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f),
 Traversable f) =>
UnifyM f v m ()
runUnify = forall (m :: * -> *) (f :: * -> *) v.
Monad m =>
(Equation f -> UnifyM f v m ()) -> UnifyM f v m ()
withNextEq (\ Equation f
e -> forall (f :: * -> *) v (m :: * -> *).
(MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f),
 Traversable f) =>
Equation f -> UnifyM f v m ()
unifyStep Equation f
e forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) v (m :: * -> *).
(MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f),
 Traversable f) =>
UnifyM f v m ()
runUnify)

unifyStep :: (MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f), Traversable f)
          => Equation f -> UnifyM f v m ()
unifyStep :: forall (f :: * -> *) v (m :: * -> *).
(MonadError (UnifError f v) m, Decompose f v, Ord v, Eq (Const f),
 Traversable f) =>
Equation f -> UnifyM f v m ()
unifyStep (Term f
s,Term f
t) = case forall (f :: * -> *) v. Decompose f v => Term f -> DecompTerm f v
decompose Term f
s of
                    Var v
v1 -> case forall (f :: * -> *) v. Decompose f v => Term f -> DecompTerm f v
decompose Term f
t of
                                 Var v
v2 -> forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (v
v1 forall a. Eq a => a -> a -> Bool
== v
v2) forall a b. (a -> b) -> a -> b
$
                                             forall (m :: * -> *) v (f :: * -> *).
(Monad m, Ord v, HasVars f v, Traversable f) =>
(v, Term f) -> UnifyM f v m ()
putBinding (v
v1, Term f
t)
                                 Decomp f v (Term f)
_ -> if forall v (f :: * -> *) h a.
(Eq v, HasVars f v, Traversable f, Ord v) =>
v -> Cxt h f a -> Bool
containsVar v
v1 Term f
t
                                      then forall (f :: * -> *) v (m :: * -> *) a.
MonadError (UnifError f v) m =>
v -> Term f -> m a
failedOccursCheck v
v1 Term f
t
                                      else forall (m :: * -> *) v (f :: * -> *).
(Monad m, Ord v, HasVars f v, Traversable f) =>
(v, Term f) -> UnifyM f v m ()
putBinding (v
v1,Term f
t)
                    Fun Const f
s1 [Term f]
args1 -> case forall (f :: * -> *) v. Decompose f v => Term f -> DecompTerm f v
decompose Term f
t of
                                       Var v
v -> if forall v (f :: * -> *) h a.
(Eq v, HasVars f v, Traversable f, Ord v) =>
v -> Cxt h f a -> Bool
containsVar v
v Term f
s
                                                 then forall (f :: * -> *) v (m :: * -> *) a.
MonadError (UnifError f v) m =>
v -> Term f -> m a
failedOccursCheck v
v Term f
s
                                                 else forall (m :: * -> *) v (f :: * -> *).
(Monad m, Ord v, HasVars f v, Traversable f) =>
(v, Term f) -> UnifyM f v m ()
putBinding (v
v,Term f
s)
                                       Fun Const f
s2 [Term f]
args2 -> if Const f
s1 forall a. Eq a => a -> a -> Bool
== Const f
s2
                                                        then forall (m :: * -> *) (f :: * -> *) v.
Monad m =>
Equations f -> UnifyM f v m ()
putEqs forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Term f]
args1 [Term f]
args2
                                                        else forall (f :: * -> *) v (m :: * -> *) a.
MonadError (UnifError f v) m =>
Term f -> Term f -> m a
headSymbolMismatch Term f
s Term f
t