module Unifier ( Unifier, Substitution , unify, unify_with_occurs_check , apply, (+++) ) where import Control.Monad (MonadPlus, mzero) import Control.Arrow (second) import Data.Function (fix) import Data.Generics (everything, mkQ) import Syntax type Unifier = [Substitution] type Substitution = (VariableName, Term) unify, unify_with_occurs_check :: MonadPlus m => Term -> Term -> m Unifier unify = fix unify' unify_with_occurs_check = fix $ \self t1 t2 -> if (t1 `occursIn` t2 || t2 `occursIn` t1) then fail "occurs check" else unify' self t1 t2 where occursIn t = everything (||) (mkQ False (==t)) unify' _ Wildcard _ = return [] unify' _ _ Wildcard = return [] unify' _ (Var v) t = return [(v,t)] unify' _ t (Var v) = return [(v,t)] unify' self (Struct a1 ts1) (Struct a2 ts2) | a1 == a2 && same length ts1 ts2 = unifyList self (zip ts1 ts2) unify' _ _ _ = mzero same :: Eq b => (a -> b) -> a -> a -> Bool same f x y = f x == f y unifyList :: Monad m => (Term -> Term -> m Unifier) -> [(Term, Term)] -> m Unifier unifyList _ [] = return [] unifyList unify ((x,y):xys) = do u <- unify x y u' <- unifyList unify (map (both (apply u)) xys) return (u++u') both f (x,y) = (f x, f y) u1 +++ u2 = simplify $ u1 ++ u2 simplify :: Unifier -> Unifier simplify u = map (second (apply u)) u apply :: Unifier -> Term -> Term apply = flip $ foldl $ flip substitute where substitute (v,t) (Var v') | v == v' = t substitute s (Struct a ts) = Struct a (map (substitute s) ts) substitute _ t = t