module LText.Internal.Inference where
import LText.Internal.Expr
import LText.Internal.Types
import LText.Internal.Classes
import Control.Monad.State
import Control.Monad.Except
import qualified Data.Map as Map
import qualified Data.Set as Set
newtype Context = Context (Map.Map ExprVar Prenex)
remove :: Context -> String -> Context
remove (Context env) var = Context (Map.delete var env)
instance Bindable Set.Set TypeVar Context where
fv (Context env) = fv (Map.elems env)
instance Substitutable Map.Map TypeVar Type Context where
apply s (Context env) = Context (fmap (apply s) env)
generalize :: Context -> Type -> Prenex
generalize env t = Prenex vars t
where vars = Set.toList $ fv t `difference` fv env
data TIState = TIState { tiSupply :: Int
, tiSubst :: Subst TypeVar Type
}
runTI :: ( Monad m
, MonadError String m
) => StateT TIState m a -> m a
runTI t = evalStateT t initTIState
where initTIState = TIState { tiSupply = 0
, tiSubst = nullSubst
}
newTyVar :: ( Monad m
, MonadState TIState m
, MonadError String m
) => TypeVar -> m Type
newTyVar prefix = do
s <- get
put s {tiSupply = tiSupply s + 1}
return $ TVar $ prefix ++ show (tiSupply s)
instantiate :: ( Monad m
, MonadState TIState m
, MonadError String m
) => Prenex -> m Type
instantiate (Prenex vars t) = do
nvars <- mapM newTyVar vars
return $ apply (Map.fromList $ zip vars nvars) t
mgu :: ( Monad m
, MonadState TIState m
, MonadError String m
) => Type -> Type -> m (Subst TypeVar Type)
mgu (TFun l r) (TFun l' r') = do s1 <- mgu l l'
s2 <- mgu (apply s1 r) (apply s1 r')
return (s1 `composeSubst` s2)
mgu (TVar u) t = varBind u t
mgu t (TVar u) = varBind u t
mgu TText TText = return nullSubst
mgu t1 t2 = throwError $ "Types do not unify: " ++ show t1 ++
" vs. " ++ show t2
varBind :: ( Monad m
, MonadState TIState m
, MonadError String m
) => TypeVar -> Type -> m (Subst TypeVar Type)
varBind u t | t == TVar u = return nullSubst
| u `Set.member` fv t = throwError $ "Occur check fails: " ++ u ++
" vs. " ++ show t
| otherwise = return (Map.singleton u t)
ti :: ( Monad m
, MonadState TIState m
, MonadError String m
) => Context -> Expr -> m (Subst TypeVar Type, Type)
ti (Context env) (EVar n) = case Map.lookup n env of
Nothing -> throwError $ "Unbound variable: " ++ n
Just sigma -> do t <- instantiate sigma
return (nullSubst, t)
ti env (EAbs n e) = do
tv <- newTyVar "a"
let Context env' = remove env n
env'' = Context $ env' `Map.union` Map.singleton n (Prenex [] tv)
(s1, t1) <- ti env'' e
return (s1, TFun (apply s1 tv) t1)
ti env (EApp e1 e2) = do
(s1, t1) <- ti env e1
(s2, t2) <- ti (apply s1 env) e2
tv <- newTyVar "a"
s3 <- mgu (apply s2 t1) (TFun t2 tv)
return (s3 `composeSubst` s2 `composeSubst` s1, apply s3 tv)
ti env (ELet x e1 e2) = do
(s1, t1) <- ti env e1
let Context env' = remove env x
t' = generalize (apply s1 env) t1
env'' = Context (Map.insert x t' env')
(s2, t2) <- ti (apply s1 env'') e2
return (s1 `composeSubst` s2, t2)
ti _ (EText _) = return (nullSubst, TText)
ti env (EConc e1 e2) = do
(s1, t1) <- ti env e1
(s2, t2) <- ti env e2
s3 <- mgu t1 t2
s4 <- mgu (apply s3 t1) TText
return (s4 `composeSubst` s3 `composeSubst` s2 `composeSubst` s1, TText)
typeInference :: ( Monad m
, MonadState TIState m
, MonadError String m
) => Context -> Expr -> m Type
typeInference env e = do
(s, t) <- ti env e
return (apply s t)
test :: ( Monad m
, MonadIO m
) => Expr -> m ()
test e = do
eRes <- runExceptT $ runTI $ typeInference (Context Map.empty) e
case eRes of
Left err -> liftIO $ putStrLn $ "Error: " ++ err
Right t -> let q = generalize (Context Map.empty) t in
liftIO $ putStrLn $ show e ++ " :: " ++ show q