module Control.Unification
(
MutTerm(..)
, freeze
, unfreeze
, UnificationFailure(..)
, Unifiable(..)
, Variable(..)
, BindingMonad(..)
, getFreeVars
, applyBindings
, freshen
, (===)
, (=~=)
, (=:=)
, (<:=)
, equals
, equiv
, unify
, unifyOccurs
, subsumes
, fullprune
, semiprune
, occursIn
) where
import Prelude
hiding (mapM, mapM_, sequence, foldr, foldr1, foldl, foldl1, all, and, or)
import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import Data.Foldable
import Data.Traversable
import Control.Applicative
import Control.Monad (MonadPlus(..))
import Control.Monad.Trans (MonadTrans(..))
import Control.Monad.Error (MonadError(..))
import Control.Monad.State (MonadState(..), StateT, evalStateT, execStateT)
import Control.Monad.MaybeK
import Control.Monad.State.UnificationExtras
import Control.Unification.Types
fullprune :: (BindingMonad v t m) => MutTerm v t -> m (MutTerm v t)
fullprune t0 =
case t0 of
MutTerm _ -> return t0
MutVar v -> do
mb <- lookupVar v
case mb of
Nothing -> return t0
Just t -> do
finalTerm <- fullprune t
v `bindVar` finalTerm
return finalTerm
semiprune :: (BindingMonad v t m) => MutTerm v t -> m (MutTerm v t)
semiprune =
\t0 ->
case t0 of
MutTerm _ -> return t0
MutVar v0 -> loop t0 v0
where
loop t v = do
mb <- lookupVar v
case mb of
Nothing -> return t
Just t' ->
case t' of
MutTerm _ -> return t
MutVar v' -> do
finalVar <- loop t' v'
v `bindVar` finalVar
return finalVar
occursIn :: (BindingMonad v t m) => v (MutTerm v t) -> MutTerm v t -> m Bool
occursIn v t0 = do
t <- fullprune t0
case t of
MutTerm t' -> or <$> mapM (v `occursIn`) t'
MutVar v' -> return $! v `eqVar` v'
seenAs
:: ( BindingMonad v t m
, MonadTrans e
, MonadError (UnificationFailure v t) (e m)
)
=> v (MutTerm v t)
-> MutTerm v t
-> StateT (IM.IntMap (MutTerm v t)) (e m) ()
seenAs v t = do
seenVars <- get
case IM.lookup (getVarID v) seenVars of
Just t' -> lift . throwError $ OccursIn v t'
Nothing -> put $! IM.insert (getVarID v) t seenVars
getFreeVars :: (BindingMonad v t m) => MutTerm v t -> m [v (MutTerm v t)]
getFreeVars =
\t -> IM.elems <$> evalStateT (loop t) IS.empty
where
loop t0 = do
t1 <- lift $ semiprune t0
case t1 of
MutTerm t -> fold <$> mapM loop t
MutVar v -> do
seenVars <- get
let i = getVarID v
if IS.member i seenVars
then return IM.empty
else do
put $! IS.insert i seenVars
mb <- lift $ lookupVar v
case mb of
Just t' -> loop t'
Nothing -> return $ IM.singleton i v
applyBindings
:: ( BindingMonad v t m
, MonadTrans e
, Functor (e m)
, MonadError (UnificationFailure v t) (e m)
)
=> MutTerm v t
-> e m (MutTerm v t)
applyBindings =
\t -> evalStateT (loop t) IM.empty
where
loop t0 = do
t1 <- lift . lift $ semiprune t0
case t1 of
MutTerm t -> MutTerm <$> mapM loop t
MutVar v -> do
let i = getVarID v
mb <- IM.lookup i <$> get
case mb of
Just (Right t) -> return t
Just (Left t) -> lift . throwError $ OccursIn v t
Nothing -> do
mb' <- lift . lift $ lookupVar v
case mb' of
Nothing -> return t1
Just t -> do
modify' . IM.insert i $ Left t
t' <- loop t
modify' . IM.insert i $ Right t'
return t'
freshen
:: ( BindingMonad v t m
, MonadTrans e
, Functor (e m)
, MonadError (UnificationFailure v t) (e m)
)
=> MutTerm v t
-> e m (MutTerm v t)
freshen =
\t -> evalStateT (loop t) IM.empty
where
loop t0 = do
t1 <- lift . lift $ semiprune t0
case t1 of
MutTerm t -> MutTerm <$> mapM loop t
MutVar v -> do
let i = getVarID v
seenVars <- get
case IM.lookup i seenVars of
Just (Right t) -> return t
Just (Left t) -> lift . throwError $ OccursIn v t
Nothing -> do
mb <- lift . lift $ lookupVar v
case mb of
Nothing -> do
v' <- lift . lift $ MutVar <$> freeVar
put $! IM.insert i (Right v') seenVars
return v'
Just t -> do
put $! IM.insert i (Left t) seenVars
t' <- loop t
v' <- lift . lift $ MutVar <$> newVar t'
modify' $ IM.insert i (Right v')
return v'
(===)
:: (BindingMonad v t m)
=> MutTerm v t
-> MutTerm v t
-> m Bool
(===) = equals
infix 4 ===, `equals`
(=~=)
:: (BindingMonad v t m)
=> MutTerm v t
-> MutTerm v t
-> m (Maybe (IM.IntMap Int))
(=~=) = equiv
infix 4 =~=, `equiv`
(=:=)
:: ( BindingMonad v t m
, MonadTrans e
, Functor (e m)
, MonadError (UnificationFailure v t) (e m)
)
=> MutTerm v t
-> MutTerm v t
-> e m (MutTerm v t)
(=:=) = unify
infix 4 =:=, `unify`
(<:=)
:: ( BindingMonad v t m
, MonadTrans e
, Functor (e m)
, MonadError (UnificationFailure v t) (e m)
)
=> MutTerm v t
-> MutTerm v t
-> e m Bool
(<:=) = subsumes
infix 4 <:=, `subsumes`
equals
:: (BindingMonad v t m)
=> MutTerm v t
-> MutTerm v t
-> m Bool
equals =
\tl tr -> do
mb <- runMaybeKT (loop tl tr)
case mb of
Nothing -> return False
Just () -> return True
where
loop tl0 tr0 = do
tl <- lift $ semiprune tl0
tr <- lift $ semiprune tr0
case (tl, tr) of
(MutVar vl', MutVar vr')
| vl' `eqVar` vr' -> return ()
| otherwise -> do
mtl <- lift $ lookupVar vl'
mtr <- lift $ lookupVar vr'
case (mtl, mtr) of
(Nothing, Nothing ) -> mzero
(Nothing, Just _ ) -> mzero
(Just _, Nothing ) -> mzero
(Just tl', Just tr') -> loop tl' tr'
(MutVar _, MutTerm _ ) -> mzero
(MutTerm _, MutVar _ ) -> mzero
(MutTerm tl', MutTerm tr') ->
case zipMatch tl' tr' of
Nothing -> mzero
Just tlr -> mapM_ (uncurry loop) tlr
equiv
:: (BindingMonad v t m)
=> MutTerm v t
-> MutTerm v t
-> m (Maybe (IM.IntMap Int))
equiv =
\tl tr -> runMaybeKT (execStateT (loop tl tr) IM.empty)
where
loop tl0 tr0 = do
tl <- lift . lift $ fullprune tl0
tr <- lift . lift $ fullprune tr0
case (tl, tr) of
(MutVar vl', MutVar vr') -> do
let il = getVarID vl'
let ir = getVarID vr'
xs <- get
case IM.lookup il xs of
Just x
| x == ir -> return ()
| otherwise -> lift mzero
Nothing -> put $! IM.insert il ir xs
(MutVar _, MutTerm _ ) -> lift mzero
(MutTerm _, MutVar _ ) -> lift mzero
(MutTerm tl', MutTerm tr') ->
case zipMatch tl' tr' of
Nothing -> lift mzero
Just tlr -> mapM_ (uncurry loop) tlr
unifyOccurs
:: ( BindingMonad v t m
, MonadTrans e
, Functor (e m)
, MonadError (UnificationFailure v t) (e m)
)
=> MutTerm v t
-> MutTerm v t
-> e m (MutTerm v t)
unifyOccurs = loop
where
v =: t = lift $ v `bindVar` t
acyclicBindVar v t = do
b <- lift $ v `occursIn` t
if b
then throwError $ OccursIn v t
else v =: t
loop tl0 tr0 = do
tl <- lift $ semiprune tl0
tr <- lift $ semiprune tr0
case (tl, tr) of
(MutVar vl', MutVar vr')
| vl' `eqVar` vr' -> return tr
| otherwise -> do
mtl <- lift $ lookupVar vl'
mtr <- lift $ lookupVar vr'
case (mtl, mtr) of
(Nothing, Nothing ) -> do
vl' =: tr
return tr
(Nothing, Just _ ) -> do
vl' `acyclicBindVar` tr
return tr
(Just _ , Nothing ) -> do
vr' `acyclicBindVar` tl
return tl
(Just tl', Just tr') -> do
t <- loop tl' tr'
vr' =: t
vl' =: tr
return tr
(MutVar vl', MutTerm _) -> do
mtl <- lift $ lookupVar vl'
case mtl of
Nothing -> do
vl' `acyclicBindVar` tr
return tl
Just tl' -> do
t <- loop tl' tr
vl' =: t
return tl
(MutTerm _, MutVar vr') -> do
mtr <- lift $ lookupVar vr'
case mtr of
Nothing -> do
vr' `acyclicBindVar` tl
return tr
Just tr' -> do
t <- loop tl tr'
vr' =: t
return tr
(MutTerm tl', MutTerm tr') ->
case zipMatch tl' tr' of
Nothing -> throwError $ TermMismatch tl' tr'
Just tlr -> MutTerm <$> mapM (uncurry loop) tlr
unify
:: ( BindingMonad v t m
, MonadTrans e
, Functor (e m)
, MonadError (UnificationFailure v t) (e m)
)
=> MutTerm v t
-> MutTerm v t
-> e m (MutTerm v t)
unify =
\tl tr -> evalStateT (loop tl tr) IM.empty
where
v =: t = lift . lift $ v `bindVar` t
loop tl0 tr0 = do
tl <- lift . lift $ semiprune tl0
tr <- lift . lift $ semiprune tr0
case (tl, tr) of
(MutVar vl', MutVar vr')
| vl' `eqVar` vr' -> return tr
| otherwise -> do
mtl <- lift . lift $ lookupVar vl'
mtr <- lift . lift $ lookupVar vr'
case (mtl, mtr) of
(Nothing, Nothing ) -> do vl' =: tr ; return tr
(Nothing, Just _ ) -> do vl' =: tr ; return tr
(Just _ , Nothing ) -> do vr' =: tl ; return tl
(Just tl', Just tr') -> do
t <- localState $ do
vl' `seenAs` tl'
vr' `seenAs` tr'
loop tl' tr'
vr' =: t
vl' =: tr
return tr
(MutVar vl', MutTerm _) -> do
t <- do
mtl <- lift . lift $ lookupVar vl'
case mtl of
Nothing -> return tr
Just tl' -> localState $ do
vl' `seenAs` tl'
loop tl' tr
vl' =: t
return tl
(MutTerm _, MutVar vr') -> do
t <- do
mtr <- lift . lift $ lookupVar vr'
case mtr of
Nothing -> return tl
Just tr' -> localState $ do
vr' `seenAs` tr'
loop tl tr'
vr' =: t
return tr
(MutTerm tl', MutTerm tr') ->
case zipMatch tl' tr' of
Nothing -> lift . throwError $ TermMismatch tl' tr'
Just tlr -> MutTerm <$> mapM (uncurry loop) tlr
subsumes
:: ( BindingMonad v t m
, MonadTrans e
, Functor (e m)
, MonadError (UnificationFailure v t) (e m)
)
=> MutTerm v t
-> MutTerm v t
-> e m Bool
subsumes =
\tl tr -> evalStateT (loop tl tr) IM.empty
where
v =: t = lift . lift $ do v `bindVar` t ; return True
loop tl0 tr0 = do
tl <- lift . lift $ semiprune tl0
tr <- lift . lift $ semiprune tr0
case (tl, tr) of
(MutVar vl', MutVar vr')
| vl' `eqVar` vr' -> return True
| otherwise -> do
mtl <- lift . lift $ lookupVar vl'
mtr <- lift . lift $ lookupVar vr'
case (mtl, mtr) of
(Nothing, Nothing ) -> vl' =: tr
(Nothing, Just _ ) -> vl' =: tr
(Just _ , Nothing ) -> return False
(Just tl', Just tr') ->
localState $ do
vl' `seenAs` tl'
vr' `seenAs` tr'
loop tl' tr'
(MutVar vl', MutTerm _ ) -> do
mtl <- lift . lift $ lookupVar vl'
case mtl of
Nothing -> vl' =: tr
Just tl' -> localState $ do
vl' `seenAs` tl'
loop tl' tr
(MutTerm _, MutVar _ ) -> return False
(MutTerm tl', MutTerm tr') ->
case zipMatch tl' tr' of
Nothing -> return False
Just tlr -> and <$> mapM (uncurry loop) tlr