----------------------------------------------------------------------------- -- TIMain: Type Inference Algorithm -- -- Part of `Typing Haskell in Haskell', version of November 23, 2000 -- Copyright (c) Mark P Jones and the Oregon Graduate Institute -- of Science and Technology, 1999-2000 -- -- This program is distributed as Free Software under the terms -- in the file "License" that is included in the distribution -- of this software, copies of which may be obtained from: -- http://www.cse.ogi.edu/~mpj/thih/ -- ----------------------------------------------------------------------------- module TIMain where import Data.List( (\\), intersect, union, partition ) import Id import Kind import Type import Subst import Pred import Scheme import Assump import TIMonad import Infer import Lit import Pat import StaticPrelude ----------------------------------------------------------------------------- data Expr = Var Id | Lit Literal | Const Assump | Ap Expr Expr | Let BindGroup Expr | Lam Alt | If Expr Expr Expr | Case Expr [(Pat,Expr)] ----------------------------------------------------------------------------- -- The following helper functions are used to construct sample programs; if we -- change the representation of Expr above, then we need only change the -- definitions of the following combinators to match, and do not need to -- rewrite all the test code. ap = foldl1 Ap evar v = (Var v) elit l = (Lit l) econst c = (Const c) elet e f = foldr Let f (map toBg e) toBg :: [(Id, Maybe Scheme, [Alt])] -> BindGroup toBg g = ([(v, t, alts) | (v, Just t, alts) <- g ], filter (not . null) [[(v,alts) | (v,Nothing,alts) <- g]]) pNil = PCon nilCfun [] pCons x y = PCon consCfun [x,y] eNil = econst nilCfun eCons x y = ap [ econst consCfun, x, y ] {- ecase = Case elambda = Lam eif = If -} ecase d as = elet [[ ("_case", Nothing, [([p],e) | (p,e) <- as]) ]] (ap [evar "_case", d]) eif c t f = ecase c [(PCon trueCfun [], t),(PCon falseCfun [], f)] elambda alt = elet [[ ("_lambda", Nothing, [alt]) ]] (evar "_lambda") eguarded = foldr (\(c,t) e -> eif c t e) efail efail = Const ("FAIL" :>: Forall [Star] ([] :=> TGen 0)) esign e t = elet [[ ("_val", Just t, [([],e)]) ]] (evar "_val") eCompFrom p e c = ap [ econst mbindMfun, e, elambda ([p],c) ] eCompGuard e c = eif e c eNil eCompLet bgs c = elet bgs c eListRet e = eCons e eNil ----------------------------------------------------------------------------- tiExpr :: Infer Expr Type tiExpr ce as (Var i) = do sc <- find i as (ps :=> t) <- freshInst sc return (ps, t) tiExpr ce as (Const (i:>:sc)) = do (ps :=> t) <- freshInst sc return (ps, t) tiExpr ce as (Lit l) = do (ps,t) <- tiLit l return (ps, t) tiExpr ce as (Ap e f) = do (ps,te) <- tiExpr ce as e (qs,tf) <- tiExpr ce as f t <- newTVar Star unify (tf `fn` t) te return (ps++qs, t) tiExpr ce as (Let bg e) = do (ps, as') <- tiBindGroup ce as bg (qs, t) <- tiExpr ce (as' ++ as) e return (ps ++ qs, t) tiExpr ce as (Lam alt) = tiAlt ce as alt tiExpr ce as (If e e1 e2) = do (ps,t) <- tiExpr ce as e unify t tBool (ps1,t1) <- tiExpr ce as e1 (ps2,t2) <- tiExpr ce as e2 unify t1 t2 return (ps++ps1++ps2, t1) tiExpr ce as (Case e branches) = do (ps, t) <- tiExpr ce as e v <- newTVar Star let tiBr (pat, f) = do (ps, as',t') <- tiPat pat unify t t' (qs, t'') <- tiExpr ce (as'++as) f unify v t'' return (ps++qs) pss <- mapM tiBr branches return (ps++concat pss, v) ----------------------------------------------------------------------------- type Alt = ([Pat], Expr) tiAlt :: Infer Alt Type tiAlt ce as (pats, e) = do (ps, as', ts) <- tiPats pats (qs,t) <- tiExpr ce (as'++as) e return (ps++qs, foldr fn t ts) tiAlts :: ClassEnv -> [Assump] -> [Alt] -> Type -> TI [Pred] tiAlts ce as alts t = do psts <- mapM (tiAlt ce as) alts mapM (unify t) (map snd psts) return (concat (map fst psts)) ----------------------------------------------------------------------------- split :: Monad m => ClassEnv -> [Tyvar] -> [Tyvar] -> [Pred] -> m ([Pred], [Pred]) split ce fs gs ps = do let ps' = reduce ce ps (ds, rs) = partition (all (`elem` fs) . tv) ps' rs' <- defaultedPreds ce (fs++gs) rs return (ds, rs \\ rs') type Ambiguity = (Tyvar, [Pred]) ambiguities :: ClassEnv -> [Tyvar] -> [Pred] -> [Ambiguity] ambiguities ce vs ps = [ (v, filter (elem v . tv) ps) | v <- tv ps \\ vs ] numClasses :: [Id] numClasses = ["Num", "Integral", "Floating", "Fractional", "Real", "RealFloat", "RealFrac"] stdClasses :: [Id] stdClasses = ["Eq", "Ord", "Show", "Read", "Bounded", "Enum", "Ix", "Functor", "Monad", "MonadPlus"] ++ numClasses candidates :: ClassEnv -> Ambiguity -> [Type] candidates ce (v, qs) = [ t' | let is = [ i | IsIn i t <- qs ] ts = [ t | IsIn i t <- qs ], all ([TVar v]==) ts, any (`elem` numClasses) is, all (`elem` stdClasses) is, t' <- defaults ce, all (entail ce []) [ IsIn i [t'] | i <- is ] ] withDefaults :: Monad m => ([Ambiguity] -> [Type] -> a) -> ClassEnv -> [Tyvar] -> [Pred] -> m a withDefaults f ce vs ps | any null tss = fail "cannot resolve ambiguity" | otherwise = return (f vps (map head tss)) where vps = ambiguities ce vs ps tss = map (candidates ce) vps defaultedPreds :: Monad m => ClassEnv -> [Tyvar] -> [Pred] -> m [Pred] defaultedPreds = withDefaults (\vps ts -> concat (map snd vps)) defaultSubst :: Monad m => ClassEnv -> [Tyvar] -> [Pred] -> m Subst defaultSubst = withDefaults (\vps ts -> zip (map fst vps) ts) ----------------------------------------------------------------------------- type Expl = (Id, Scheme, [Alt]) tiExpl :: ClassEnv -> [Assump] -> Expl -> TI [Pred] tiExpl ce as (i, sc, alts) = do (qs :=> t) <- freshInst sc ps <- tiAlts ce as alts t s <- getSubst let qs' = apply s qs t' = apply s t fs = tv (apply s as) gs = tv t' \\ fs sc' = quantify gs (qs':=>t') ps' = filter (not . entail ce qs') (apply s ps) (ds,rs) <- split ce fs gs ps' if sc /= sc' then fail "signature too general" else if not (null rs) then fail "context too weak" else return ds ----------------------------------------------------------------------------- type Impl = (Id, [Alt]) restricted :: [Impl] -> Bool restricted bs = any simple bs where simple (i,alts) = any (null . fst) alts tiImpls :: Infer [Impl] [Assump] tiImpls ce as bs = do ts <- mapM (\_ -> newTVar Star) bs let is = map fst bs scs = map toScheme ts as' = zipWith (:>:) is scs ++ as altss = map snd bs pss <- sequence (zipWith (tiAlts ce as') altss ts) s <- getSubst let ps' = apply s (concat pss) ts' = apply s ts fs = tv (apply s as) vss = map tv ts' gs = foldr1 union vss \\ fs (ds,rs) <- split ce fs (foldr1 intersect vss) ps' if restricted bs then let gs' = gs \\ tv rs scs' = map (quantify gs' . ([]:=>)) ts' in return (ds++rs, zipWith (:>:) is scs') else let scs' = map (quantify gs . (rs:=>)) ts' in return (ds, zipWith (:>:) is scs') ----------------------------------------------------------------------------- type BindGroup = ([Expl], [[Impl]]) tiBindGroup :: Infer BindGroup [Assump] tiBindGroup ce as (es,iss) = do let as' = [ v:>:sc | (v,sc,alts) <- es ] (ps, as'') <- tiSeq tiImpls ce (as'++as) iss qss <- mapM (tiExpl ce (as''++as'++as)) es return (ps++concat qss, as''++as') tiSeq :: Infer bg [Assump] -> Infer [bg] [Assump] tiSeq ti ce as [] = return ([],[]) tiSeq ti ce as (bs:bss) = do (ps,as') <- ti ce as bs (qs,as'') <- tiSeq ti ce (as'++as) bss return (ps++qs, as''++as') -----------------------------------------------------------------------------