-- | HM unification implementations based on propositional logics, -- based on nominal type system. -- Author: Taine Zhao(thautwarm) -- Date: 2019-08-04 -- License: MIT {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE MultiParamTypeClasses #-} module RSolve.HM where import RSolve.Logic import RSolve.Solver import RSolve.MultiState import Control.Lens (Lens', view, over, makeLenses) import Control.Applicative import Control.Monad import Debug.Trace import qualified Data.List as L import qualified Data.Map as M import qualified Data.Set as S type Fix a = a -> a infixl 6 :->, :* data T = TVar Int | TFresh String | T :-> T | T :* T -- tuple | TForall (S.Set String) T | TApp T T -- type application | TNom Int -- nominal type index deriving (Eq, Ord) deConsTOp :: T -> Maybe (T -> T -> T, T, T) deConsTOp = \case a :-> b -> Just ((:->), a, b) a :* b -> Just ((:*), a, b) TApp a b -> Just (TApp, a, b) _ -> Nothing instance Show T where show = \case TVar idx -> "@" ++ show idx TFresh s -> s a :-> b -> showNest a ++ " -> " ++ show b a :* b -> showNest a ++ " * " ++ show b TForall l t -> "forall " ++ (unwords $ S.toList l) ++ ". " ++ show t TApp t1 t2 -> show t1 ++ " " ++ showNest t2 TNom i -> "@t" ++ show i where showNest s | isNest s = "(" ++ show s ++ ")" | otherwise = show s isNest s = case s of TApp _ _ -> True TForall _ s -> isNest s _ :-> _ -> True _ :* _ -> True _ -> False data Unif = Unif { lhs :: T , rhs :: T , neq :: Bool } deriving (Eq, Ord) instance Show Unif where show Unif {lhs, rhs, neq} = let op = if neq then " /= " else " == " in show lhs ++ op ++ show rhs instance AtomF Unif where notA a@Unif {neq} = [a {neq = not neq}] data TCEnv = TCEnv { _noms :: M.Map Int T -- nominal type ids , _tvars :: M.Map Int T -- type variables , _neqs :: S.Set (T, T) -- negation constraints } deriving (Show) emptyTCEnv = TCEnv M.empty M.empty S.empty makeLenses ''TCEnv newTVar :: MS TCEnv Int newTVar = do i <- getsMS $ M.size . view tvars modifyMS $ over tvars $ M.insert i (TVar i) return i newTNom :: MS TCEnv Int newTNom = do i <- getsMS $ M.size . view noms modifyMS $ over noms $ M.insert i (TNom i) return i loadTVar :: Int -> MS TCEnv T loadTVar i = getsMS $ (M.! i) . view tvars occurIn :: Int -> T -> MS TCEnv Bool occurIn l = contains where contains (deConsTOp -> Just (_, a, b)) = (||) <$> contains a <*> contains b contains (TNom _) = return False contains (TForall _ a) = contains a contains (TFresh _) = return False contains (TVar a) | a == l = return True | otherwise = do tvar <- loadTVar a case tvar of TVar a' | a' == a -> return False _ -> contains tvar free :: M.Map String T -> T -> T free m = mkFree where mkFree (deConsTOp -> Just (op, a, b)) = op (mkFree a) (mkFree b) mkFree a@(TNom i) = a mkFree (TForall n t) = TForall n $ flip free t $ M.withoutKeys m n mkFree a@(TVar _) = a mkFree a@(TFresh id) = M.findWithDefault a id m prune :: T -> MS TCEnv T prune = \case (deConsTOp -> Just (op, a, b)) -> op <$> prune a <*> prune b a@(TNom i) -> return a TVar i -> loadTVar i >>= \case a@(TVar i') | i' == i -> return a a -> do t <- prune a update i t return t a@(TFresh _) -> return a TForall a b -> TForall a <$> prune b update :: Int -> T -> MS TCEnv () update i t = modifyMS $ over tvars $ M.insert i t addNEq :: (T, T) -> MS TCEnv () addNEq t = modifyMS $ over neqs (S.insert t) unify :: Fix (Unif -> MS TCEnv ()) unify self Unif {lhs, rhs, neq=True} = addNEq (lhs, rhs) unify self Unif {lhs=TNom a, rhs=TNom b} | a == b = return () | otherwise = empty unify self Unif {lhs=TVar a, rhs = TVar b} = do recursive <- occurIn a (TVar b) if recursive then error "ill formed definition like a = a -> b" else update a (TVar b) unify self Unif {lhs=TVar id, rhs, neq} = update id rhs unify self a@Unif {lhs, rhs=rhs@(TVar _)} = self a {lhs=rhs, rhs=lhs} -- type operators are not frist class unify self Unif {lhs=l1 :-> l2, rhs= r1 :-> r2} = self Unif {lhs=l1, rhs=r1, neq=False} >> self Unif {lhs=l2, rhs=r2, neq=False} unify self Unif {lhs=l1 :* l2, rhs= r1 :* r2} = self Unif {lhs=l1, rhs=r1, neq=False} >> self Unif {lhs=l2, rhs=r2, neq=False} -- TODO: type aliases? unify self Unif {lhs=TApp l1 l2, rhs= TApp r1 r2} = self Unif {lhs=l1, rhs=r1, neq=False} >> self Unif {lhs=l2, rhs=r2, neq=False} unify self Unif {lhs=TForall freevars poly, rhs} = do pairs <- mapM freepair $ S.toList freevars let freemap = M.fromList pairs let l = free freemap poly self Unif {lhs=l, rhs=rhs, neq=False} where freepair freevar = (freevar,) . TVar <$> newTVar unify self a@Unif {lhs, rhs=rhs@(TForall _ _)} = self a {lhs=rhs, rhs=lhs} instance CtxSolver TCEnv Unif where solve = let frec = unify (pruneUnif >=> frec) in pruneUnif >=> frec where pruneUnif a@Unif {neq=True} = return a pruneUnif a@Unif {lhs, rhs} = do lhs <- prune lhs rhs <- prune rhs return $ a {lhs=lhs , rhs=rhs}