module Unify where import Constrain import Control.Arrow (second) import Control.Monad (liftM) import Data.List (foldl') import qualified Data.Set as Set import Guid import Types import Control.DeepSeq (NFData (..), deepseq) unify hints expr = run $ do cs <- constrain hints expr solver cs [] solver [] subs = return $ Right subs -------- Destruct Type-constructors -------- solver ((t1@(ADT n1 ts1) :=: t2@(ADT n2 ts2)) : cs) subs = if n1 /= n2 then uniError t1 t2 else solver (zipWith (:=:) ts1 ts2 ++ cs) subs solver ((LambdaT t1 t2 :=: LambdaT t1' t2') : cs) subs = solver ([ t1 :=: t1', t2 :=: t2' ] ++ cs) subs -------- Type-equality -------- solver ((VarT x :=: t) : cs) subs = solver (map (cSub x t) cs) . map (second $ tSub x t) $ (x,t):subs solver ((t :=: VarT x) : cs) subs = solver (map (cSub x t) cs) . map (second $ tSub x t) $ (x,t):subs solver ((t1 :=: t2) : cs) subs = if t1 /= t2 then uniError t1 t2 else solver cs subs -------- subtypes -------- solver ((t1 :<: t2) : cs) subs = do let f x = do y <- guid ; return (x,VarT y) pairs <- mapM f . Set.toList $ getVars t2 let t2' = foldr (uncurry tSub) t2 pairs solver ((t1 :=: t2') : cs) subs cSub k v (t1 :=: t2) = force $ tSub k v t1 :=: tSub k v t2 cSub k v (t1 :<: t2) = force $ tSub k v t1 :<: tSub k v t2 tSub k v t@(VarT x) = if k == x then v else t tSub k v (LambdaT t1 t2) = force $ LambdaT (tSub k v t1) (tSub k v t2) tSub k v (ADT name ts) = ADT name (map (force . tSub k v) ts) tSub _ _ t = t getVars (VarT x) = Set.singleton x getVars (LambdaT t1 t2) = Set.union (getVars t1) (getVars t2) getVars (ADT name ts) = Set.unions $ map getVars ts getVars _ = Set.empty uniError t1 t2 = return . Left $ "Type error: " ++ show t1 ++ " is not equal to " ++ show t2 force x = x `deepseq` x instance NFData Constraint where rnf (t1 :=: t2) = t1 `deepseq` t2 `deepseq` () rnf (t1 :<: t2) = t1 `deepseq` t2 `deepseq` () instance NFData Type where rnf (LambdaT t1 t2) = t1 `deepseq` t2 `deepseq` () rnf (ADT _ ts) = foldl' (\acc x -> x `deepseq` acc) () ts rnf t = t `seq` ()