module Language.ML.TypeCheck
(
TyVar
, Type (..)
, Scheme (..)
, Assump (..)
, TypeError (..)
, typeExpr
, typeProgram
, prettyType
, prettyScheme
) where
import Control.Monad.Error
import Control.Monad.State
import Data.List hiding (find)
import Data.Maybe (fromMaybe)
import Text.PrettyPrint
import Language.ML.Syntax
import Applicative
type TyVar = Int
data Type
= TyVar TyVar
| TyArr Type Type
| TyGen Int
deriving Eq
data Scheme = Scheme Int Type
deriving Eq
type Subst = [(TyVar, Type)]
nullSubst :: Subst
nullSubst = []
(+->) :: TyVar -> Type -> Subst
tv +-> t = [(tv, t)]
class Types t where
apply :: Subst -> t -> t
tvs :: t -> [TyVar]
instance Types Type where
apply s (TyVar tv) = fromMaybe (TyVar tv) (lookup tv s)
apply s (TyArr l r) = TyArr (apply s l) (apply s r)
apply _ t = t
tvs (TyVar tv) = [tv]
tvs (TyArr t1 t2) = tvs t1 `union` tvs t2
tvs _ = []
instance Types a => Types [a] where
apply s = map (apply s)
tvs = nub . concatMap tvs
instance Types Scheme where
apply s (Scheme i t) = Scheme i (apply s t)
tvs (Scheme _ t) = tvs t
infixr 4 @@
(@@) :: Subst -> Subst -> Subst
s1 @@ s2 = [(tv, apply s1 t) | (tv, t) <- s2] ++ s1
data TypeError
= UnificationFail Type Type
| InfiniteType TyVar Type
| UnboundVariable Id
| TypeError
deriving Show
instance Error TypeError where
noMsg = TypeError
class (MonadError TypeError m, MonadState [Int] m) => MonadInfer m
instance MonadInfer (ErrorT TypeError (State [Int]))
unify :: MonadInfer m => Type -> Type -> m Subst
unify (TyArr t1 t2) (TyArr t1' t2') =
do s1 <- unify t1 t1'
s2 <- unify (apply s1 t2) (apply s1 t2')
return (s2 @@ s1)
unify (TyVar tv) t = varBind tv t
unify t (TyVar tv) = varBind tv t
unify t1 t2 = throwError $ UnificationFail t1 t2
varBind :: MonadInfer m => TyVar -> Type -> m Subst
varBind tv t | t == TyVar tv = return nullSubst
| tv `elem` tvs t = throwError $ InfiniteType tv t
| otherwise = return (tv +-> t)
data Assump = Id :>: Scheme
instance Types Assump where
apply s (i :>: sc) = i :>: apply s sc
tvs (_ :>: sc) = tvs sc
fresh :: MonadInfer m => m Type
fresh = TyVar <$> gets head <* modify tail
freshen :: MonadInfer m => Scheme -> m Type
freshen (Scheme gens t) =
do sub <- zip [0..] <$> mapM (const fresh) [1..gens]
return $ go sub t
where
go sub (TyGen i) = fromMaybe (error "Malformed Scheme") (lookup i sub)
go sub (TyArr t1 t2) = TyArr (go sub t1) (go sub t2)
go _ t' = t'
find :: MonadInfer m => [Assump] -> Id -> m Scheme
find [] i = throwError $ UnboundVariable i
find ((i :>: sc) : ctx) i' | i == i' = return sc
| otherwise = find ctx i'
quantify :: [TyVar] -> Type -> Scheme
quantify tvs' t = Scheme len (apply sub t)
where
len = length tvs'
sub = map (\ix -> (tvs' !! ix, TyGen ix)) [0..len1]
typecheck :: MonadInfer m => [Assump] -> Expr -> m Scheme
typecheck sctx se = (\(_, t) -> quantify (tvs t) t) <$> go sctx se
where
go ctx (Var i) = (,) [] <$> (find ctx i >>= freshen)
go ctx (Lam i e) =
do t1 <- fresh
(s1, t2) <- go ((i :>: Scheme 0 t1) : ctx) e
return (s1, apply s1 (TyArr t1 t2))
go ctx (App t1 t2) =
do t3 <- fresh
(s1, t4) <- go ctx t1
(s2, t5) <- go (apply s1 ctx) t2
s3 <- unify (apply s2 t4) (TyArr t5 t3)
return (s3 @@ s2 @@ s1, apply s3 t3)
go ctx (Let v e1 e2) =
do (s1, t1) <- go ctx e1
let ctx' = apply s1 ctx
t2 = quantify (tvs t1 \\ tvs ctx') t1
(s2, t3) <- go ((v :>: t2) : ctx') e2
return (s2 @@ s1, t3)
go ctx (Fix v e) =
do t1 <- fresh
(s1, t2) <- go ((v :>: Scheme 0 t1) : ctx) e
s2 <- unify (apply s1 t1) t2
return (s2 @@ s1, apply s2 t2)
typeExpr :: [Assump] -> Expr -> Either TypeError Scheme
typeExpr ctx e = evalState (runErrorT (typecheck ctx e)) [(1::Int)..]
typeProgram :: Program -> Either TypeError ([(Id, Scheme)], Scheme)
typeProgram (Program p' e') = evalState (runErrorT (go [] p')) [(1::Int)..]
where
go ctx [] =
let ass = map (\(i :>: sc) -> (i, sc)) ctx
in (,) ass <$> typecheck ctx e'
go ctx ((i, e) : p) =
do put [1..]
sc <- typecheck ctx e
go (ctx ++ [i :>: sc]) p
pptype :: Type -> Doc
pptype (TyGen i) = int i
pptype (TyVar i) = text "v" <> int i
pptype (TyArr t1@(TyGen _) t2) = pptype t1 <+> text "->" <+> pptype t2
pptype (TyArr t1 t2) = parens (pptype t1) <+> text "->" <+> pptype t2
prettyType :: Type -> Doc
prettyType = pptype
prettyScheme :: Scheme -> Doc
prettyScheme (Scheme _ t) = prettyType t
instance Show Type where
show = render . pptype
instance Show Scheme where
show (Scheme _ t) = render . pptype $ t