module Control.Unification.Ranked
(
module Control.Unification.Types
, getFreeVars
, applyBindings
, freshen
, (===)
, (=~=)
, (=:=)
, equals
, equiv
, unify
, getFreeVarsAll
, applyBindingsAll
, freshenAll
) where
import Prelude
hiding (mapM, mapM_, sequence, foldr, foldr1, foldl, foldl1, all, or)
import qualified Data.IntMap as IM
import Data.Traversable
import Control.Applicative
import Control.Monad.Trans (MonadTrans(..))
import Control.Monad.Error (MonadError(..))
import Control.Monad.State (MonadState(..), evalStateT)
import Control.Monad.State.UnificationExtras
import Control.Unification.Types
import Control.Unification hiding (unify, (=:=))
(=:=)
:: ( RankedBindingMonad t v m
, MonadTrans e
, Functor (e m)
, MonadError (UnificationFailure t v) (e m)
)
=> UTerm t v
-> UTerm t v
-> e m (UTerm t v)
(=:=) = unify
infix 4 =:=, `unify`
unify
:: ( RankedBindingMonad t v m
, MonadTrans e
, Functor (e m)
, MonadError (UnificationFailure t v) (e m)
)
=> UTerm t v
-> UTerm t v
-> e m (UTerm t v)
unify tl0 tr0 = evalStateT (loop tl0 tr0) IM.empty
where
v `seenAs` t = do
seenVars <- get
case IM.lookup (getVarID v) seenVars of
Just t' -> lift . throwError $ OccursIn v t'
Nothing -> put $ IM.insert (getVarID v) t seenVars
v =: t = bindVar v t >> return t
loop tl0 tr0 = do
tl0 <- lift . lift $ semiprune tl0
tr0 <- lift . lift $ semiprune tr0
case (tl0, tr0) of
(UVar vl, UVar vr)
| vl == vr -> return tr0
| otherwise -> do
Rank rl mtl <- lift . lift $ lookupRankVar vl
Rank rr mtr <- lift . lift $ lookupRankVar vr
let cmp = compare rl rr
case (mtl, mtr) of
(Nothing, Nothing) -> lift . lift $
case cmp of
LT -> do { vl =: tr0 }
EQ -> do { incrementRank vr ; vl =: tr0 }
GT -> do { vr =: tl0 }
(Nothing, Just tr) -> lift . lift $
case cmp of
LT -> do { vl =: tr0 }
EQ -> do { incrementRank vr ; vl =: tr0 }
GT -> do { vl `bindVar` tr ; vr =: tl0 }
(Just tl, Nothing) -> lift . lift $
case cmp of
LT -> do { vr `bindVar` tl ; vl =: tr0 }
EQ -> do { incrementRank vl ; vr =: tl0 }
GT -> do { vr =: tl0 }
(Just tl, Just tr) -> do
t <- localState $ do
vl `seenAs` tl
vr `seenAs` tr
loop tl tr
lift . lift $
case cmp of
LT -> do { vr `bindVar` t ; vl =: tr0 }
EQ -> do { incrementBindVar vl t ; vr =: tl0 }
GT -> do { vl `bindVar` t ; vr =: tl0 }
(UVar vl, UTerm _) -> do
t <- do
mtl <- lift . lift $ lookupVar vl
case mtl of
Nothing -> return tr0
Just tl -> localState $ do
vl `seenAs` tl
loop tl tr0
lift . lift $ do
vl `bindVar` t
return tl0
(UTerm _, UVar vr) -> do
t <- do
mtr <- lift . lift $ lookupVar vr
case mtr of
Nothing -> return tl0
Just tr -> localState $ do
vr `seenAs` tr
loop tl0 tr
lift . lift $ do
vr `bindVar` t
return tr0
(UTerm tl, UTerm tr) ->
case zipMatch tl tr of
Nothing -> lift . throwError $ TermMismatch tl tr
Just tlr -> UTerm <$> mapM (uncurry loop) tlr