module Control.Unification.Ranked
(
module Control.Unification.Types
, getFreeVars
, applyBindings
, freshen
, (===)
, (=~=)
, (=:=)
, equals
, equiv
, unify
, getFreeVarsAll
, applyBindingsAll
, freshenAll
) where
import Prelude
hiding (mapM, mapM_, sequence, foldr, foldr1, foldl, foldl1, all, or)
import qualified Data.IntMap as IM
import Data.Traversable
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
import Control.Monad.Trans (MonadTrans(..))
#if (MIN_VERSION_mtl(2,2,1))
import Control.Monad.Except (MonadError(..))
#else
import Control.Monad.Error (MonadError(..))
#endif
import Control.Monad.State (MonadState(..), StateT, evalStateT)
import Control.Monad.State.UnificationExtras
import Control.Unification.Types
import Control.Unification hiding (unify, (=:=))
(=:=)
:: ( RankedBindingMonad t v m
, Fallible t v e
, MonadTrans em
, Functor (em m)
, MonadError e (em m)
)
=> UTerm t v
-> UTerm t v
-> em m (UTerm t v)
(=:=) = unify
infix 4 =:=, `unify`
seenAs
:: ( BindingMonad t v m
, Fallible t v e
, MonadTrans em
, MonadError e (em m)
)
=> v
-> t (UTerm t v)
-> StateT (IM.IntMap (t (UTerm t v))) (em m) ()
seenAs v0 t0 = do
seenVars <- get
case IM.lookup (getVarID v0) seenVars of
Just t -> lift . throwError $ occursFailure v0 (UTerm t)
Nothing -> put $! IM.insert (getVarID v0) t0 seenVars
unify
:: ( RankedBindingMonad t v m
, Fallible t v e
, MonadTrans em
, Functor (em m)
, MonadError e (em m)
)
=> UTerm t v
-> UTerm t v
-> em m (UTerm t v)
unify tl0 tr0 = evalStateT (loop tl0 tr0) IM.empty
where
v =: t = bindVar v t >> return t
loop tl0 tr0 = do
tl0 <- lift . lift $ semiprune tl0
tr0 <- lift . lift $ semiprune tr0
case (tl0, tr0) of
(UVar vl, UVar vr)
| vl == vr -> return tr0
| otherwise -> do
Rank rl mtl <- lift . lift $ lookupRankVar vl
Rank rr mtr <- lift . lift $ lookupRankVar vr
let cmp = compare rl rr
case (mtl, mtr) of
(Nothing, Nothing) -> lift . lift $
case cmp of
LT -> do { vl =: tr0 }
EQ -> do { incrementRank vr ; vl =: tr0 }
GT -> do { vr =: tl0 }
(Nothing, Just tr) -> lift . lift $
case cmp of
LT -> do { vl =: tr0 }
EQ -> do { incrementRank vr ; vl =: tr0 }
GT -> do { vl `bindVar` tr ; vr =: tl0 }
(Just tl, Nothing) -> lift . lift $
case cmp of
LT -> do { vr `bindVar` tl ; vl =: tr0 }
EQ -> do { incrementRank vl ; vr =: tl0 }
GT -> do { vr =: tl0 }
(Just (UTerm tl), Just (UTerm tr)) -> do
t <- localState $ do
vl `seenAs` tl
vr `seenAs` tr
match tl tr
lift . lift $
case cmp of
LT -> do { vr `bindVar` t ; vl =: tr0 }
EQ -> do { incrementBindVar vl t ; vr =: tl0 }
GT -> do { vl `bindVar` t ; vr =: tl0 }
_ -> error _impossible_unify
(UVar vl, UTerm tr) -> do
t <- do
mtl <- lift . lift $ lookupVar vl
case mtl of
Nothing -> return tr0
Just (UTerm tl) -> localState $ do
vl `seenAs` tl
match tl tr
_ -> error _impossible_unify
lift . lift $ do
vl `bindVar` t
return tl0
(UTerm tl, UVar vr) -> do
t <- do
mtr <- lift . lift $ lookupVar vr
case mtr of
Nothing -> return tl0
Just (UTerm tr) -> localState $ do
vr `seenAs` tr
match tl tr
_ -> error _impossible_unify
lift . lift $ do
vr `bindVar` t
return tr0
(UTerm tl, UTerm tr) -> match tl tr
match tl tr =
case zipMatch tl tr of
Nothing -> lift . throwError $ mismatchFailure tl tr
Just tlr -> UTerm <$> mapM loop_ tlr
loop_ (Left t) = return t
loop_ (Right (tl,tr)) = loop tl tr
_impossible_unify :: String
_impossible_unify = "unify: the impossible happened"