{-# LANGUAGE CPP #-} {------------------------------------------------------------------------------- Copyright: Bernie Pope 2004 Module: TypeCheck Description: Infer types for Baskell programs and expressions. Type inference is based on a simple constraint solving process. A single pass is made over the AST to generate a set of type constraints. The contraints are in the form of equalities: type1 = type2 These constraints are then passed to a solver which simplifies them as much as possible. If a constraint can't be solved it will appear in the solution. For example: Int = Bool Thus, the type you get back is really a set of constraints, rather than the (more traditional) single type or type error. This type checker never gives errors! Primary Authors: Bernie Pope -------------------------------------------------------------------------------} {- This file is part of baskell. baskell is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. baskell is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with baskell; if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA -} module TypeCheck ( typeCheckExpression , typeCheckProgram , renderConstraints , Constraint , Binding (..) , SolverType (..) ) where import AST ( Ident , Exp (..) , Lit (..) , Decl (..) , Program (..) ) import qualified Data.Map as Map ( Map , empty , fromList , union , insert , lookup ) import Pretty ( Pretty (..) , parensIf , text , (<+>) , render , vcat , Doc , parens , cat , (<>) , punctuate , comma , brackets , int , empty , ($$) ) import Data.List ( mapAccumL , find , delete ) import Depend ( depend ) import qualified Type ( Type (..) ) import Utils ( nameSupply ) import Control.Monad ( zipWithM , liftM , liftM2 , unless ) import Control.Monad.State ( runStateT , get , put , StateT , gets , modify , execStateT ) import Control.Monad.Trans ( lift , liftIO ) import Control.Monad.Reader ( ReaderT , local , ask , runReaderT ) -------------------------------------------------------------------------------- data SolverType = TypeOf Binding Ident -- type of an identifier | TVar Int | TInt | TChar | TBool | TList SolverType | TFun SolverType SolverType | TTuple [SolverType] deriving (Eq, Show) -- how an identifier is bound data Binding = Free -- not bound at all | LamBound -- bound in a lambda abstraction | LetBound -- bound in a function declaration (top level) deriving (Eq, Show) type BinderEnv = Map.Map Ident Binding -- an equality constraint on types type Constraint = (SolverType, SolverType) -- infer the type of an expression from the command line -- print out the type. An initial set of assumptions tell the -- types of functions in scope typeCheckExpression :: [Constraint] -> Exp -> IO () typeCheckExpression assumptions exp = do let initialCount = 0 initialType = TVar initialCount initialConstraint = (reservedIdent, initialType) (constraints, finalCount) <- runTC (typeExp initialType exp) (initialCount + 1) Map.empty let initialStore = Store { store_active = initialConstraint:constraints , store_solution = [] , store_assumptions = assumptions , store_count = finalCount } store <- runSolve solve initialStore putStrLn $ render $ prettyTypeOfExp $ store_solution store -- pretty print the type infered for an expression on the -- command line prettyTypeOfExp :: [Constraint] -> Doc prettyTypeOfExp cs = case find typeOfReservedIdent cs of Nothing -> empty Just c@(_typeOf, theType) -> prettyTypeSol (theType, delete c cs) where typeOfReservedIdent :: Constraint -> Bool typeOfReservedIdent (t1@(TypeOf LetBound ident), _t2) = t1 == reservedIdent typeOfReservedIdent otherConstraint = False prettyTypeSol :: (SolverType, [Constraint]) -> Doc prettyTypeSol (t, []) = prettyType t prettyTypeSol (t, cs@(_:_)) = text "if" $$ indent (vcat $ map prettyConstraint cs) $$ text "then" $$ indent (pretty t) indent :: Doc -> Doc indent doc = text " " <> doc -- infer the types for a whole program -- the decls must be sorted into dependency order typeCheckProgram :: [Constraint] -> Program -> IO [Constraint] typeCheckProgram assumptions (Program decls) = do store <- runSolve (typeDeclss (depend decls)) initialStore return $ store_solution store where initialStore = Store { store_active = [] , store_solution = [] , store_assumptions = assumptions , store_count = 0 } -------------------------------------------------------------------------------- -- infer the types of declarations in dependency order -- type solutions of earlier declarations become -- type assumptions of later declarations -- thus if f depends on g, g will be typed first -- and its type will be an assumption when f is typed typeDeclss :: [[Decl]] -> Solve () typeDeclss dss = mapM_ typeDecls dss typeDecls :: [Decl] -> Solve () typeDecls ds = do count <- gets store_count (constraints, nextCount) <- liftIO $ runTC (mapM typeDecl ds) count Map.empty modify $ \store -> store { store_count = nextCount } updateActive $ concat constraints solve solution <- gets store_solution updateAssumptions solution modify $ \store -> store { store_active = [] } -------------------------------------------------------------------------------- type TcState = Int type TC a = ReaderT BinderEnv (StateT TcState IO) a runTC :: TC a -> TcState -> BinderEnv -> IO (a, TcState) runTC action state env = runStateT (runReaderT action env) state freshVar :: TC SolverType freshVar = do count <- lift get lift $ put (count + 1) return $ TVar count extendEnv :: Ident -> Binding -> TC a -> TC a extendEnv ident binding action = local (Map.insert ident binding) action lookupIdentBinding :: Ident -> TC Binding lookupIdentBinding ident = do env <- ask return $ case Map.lookup ident env of Just bind -> bind Nothing -> Free -- type a single declaration typeDecl :: Decl -> TC [Constraint] typeDecl (Sig {}) = return [] -- XXX typeDecl (Decl ident body) = do newVar <- freshVar cs <- typeExp newVar body let c1 = (TypeOf LetBound ident, newVar) return $ c1:cs -- type expressions -- * arg1 maps vars to their binding style -- * arg2 is the expected type of this expression, -- as required by its context -- * arg3 is the expression itself typeExp :: SolverType -> Exp -> TC [Constraint] typeExp t (Var ident) = do binding <- lookupIdentBinding ident return [(TypeOf binding ident, t)] -- XXX can we avoid the need to introduce t2? typeExp t (Lam ident body) = do t1 <- freshVar t2 <- freshVar let c1 = (t, TFun (TypeOf LamBound ident) t1) c2 = (t2, TypeOf LamBound ident) csBody <- extendEnv ident LamBound $ typeExp t1 body return $ [c1, c2] ++ csBody typeExp t (LamStrict ident body) = do t1 <- freshVar t2 <- freshVar let c1 = (t, TFun (TypeOf LamBound ident) t1) c2 = (t2, TypeOf LamBound ident) csBody <- extendEnv ident LamBound $ typeExp t1 body return $ [c1, c2] ++ csBody typeExp t exp@(App e1 e2) = do t1 <- freshVar csRight <- typeExp t1 e2 csLeft <- typeExp (TFun t1 t) e1 return $ csLeft ++ csRight typeExp t (Literal lit) = typeLit t lit -- XXX delete this? typeExp t (Tuple exps) = do let dimension = length exps vars <- sequence $ replicate dimension freshVar let c1 = (t, TTuple vars) cssExps <- typeExpList vars exps return $ c1 : concat cssExps -- primitives don't give rise to constraints -- their types are already known typeExp t (Prim _name _impl) = return [] -- XXX delete this? Is it some kind of mapM, or mapAccum ? -- an list expression typeExpList :: [SolverType] -> [Exp] -> TC [[Constraint]] typeExpList ts es = zipWithM typeExp ts es -- literals -- * arg1 is the expected type of this literal, -- as required by its context -- * arg2 is the literal itself typeLit :: SolverType -> Lit -> TC [Constraint] typeLit t (LitInt _i) = return [(t, TInt)] typeLit t (LitChar _c) = return [(t, TChar)] typeLit t (LitBool _b) = return [(t, TBool)] typeLit t LitCons = do t1 <- freshVar return [(t, TFun t1 (TFun (TList t1) (TList t1)))] typeLit t LitNil = do t1 <- freshVar return [(t, TList t1)] -------------------------------------------------------------------------------- -- constraint resolution type Solve a = StateT Store IO a runSolve :: Solve () -> Store -> IO Store runSolve action store = execStateT action store -- the constraint store data Store = Store { store_active :: [Constraint] -- not yet solved , store_solution :: [Constraint] -- solved in this pass , store_assumptions :: [Constraint] -- prior assumptions , store_count :: Int -- counter for generating fresh vars } deriving (Eq, Show) -- keep reducing the store until there are no active constraints left solve :: Solve () solve = do #ifdef DEBUG -- store <- get -- liftIO $ debugPrintStore store #endif active <- gets store_active unless (null active) $ do modify $ \store -> store { store_active = tail active } applyRule $ head active solve -- eliminate a given active constraint -- there are three situations to consider, the contraint deals with: -- 1) the type of an identifier (typeOf x = Bool) -- 2) a type variable (tvar 12 = Char) -- 3) an equality between two concrete types (List (tvar 24) = List Int) applyRule :: Constraint -> Solve () applyRule c@(t1, t2) = case c of -- type of identifier constraints (TypeOf {}, _) -> typeOfRule c (_, TypeOf {}) -> typeOfRule (t2, t1) -- type variable contraints (TVar {}, _) -> substitute c (_, TVar {}) -> substitute (t2, t1) -- constraints on concrete types (_, _) -> match (t1, t2) -- resolve contraints on types of identifiers -- the variable could be: -- * free -- * let bound -- * lambda bound typeOfRule :: Constraint -> Solve () typeOfRule (t1@(TypeOf binding ident), t2) | binding == Free = freeIdent (t1, t2) | binding == LetBound = updateSolution [(t1, t2)] | binding == LamBound = substitute (t1, t2) -- a free identifier could be typed in the assumptions -- or it may be unknown. If it is in the assumptions -- then replace all occurrences of t2 with an *instance* -- of the type found in the assumptions. freeIdent :: Constraint -> Solve () freeIdent (t1@(TypeOf _binding ident), t2) = do store <- get let assumptions = store_assumptions store case lookupAssumption assumptions ident of Just scheme -> applyScheme (scheme, t2) Nothing -> do let solAndActive = store_active store ++ store_solution store newConstraints = [ (t2, t3) | t3 <- lookupFreeIdent solAndActive ident ] if null newConstraints then updateSolution [(t1, t2)] else updateActive newConstraints -- look for a type assumption for an identifier in -- a set of contraints lookupAssumption :: [Constraint] -> Ident -> Maybe SolverType lookupAssumption [] _key = Nothing lookupAssumption ((TypeOf _binding ident, t) : cs) key | ident == key = Just t | otherwise = lookupAssumption cs key lookupAssumption (_other : cs) key = lookupAssumption cs key lookupFreeIdent :: [Constraint] -> Ident -> [SolverType] lookupFreeIdent [] _key = [] lookupFreeIdent (c@(t1, t2) : cs) key = case c of (TypeOf _binding ident, _) -> if ident == key then t2 : rest else rest (_, TypeOf _binding ident) -> if ident == key then t1 : rest else rest (_, _) -> rest where rest = lookupFreeIdent cs key -- apply a type scheme to the contraint store applyScheme :: Constraint -> Solve () applyScheme (scheme, t) = do schemeInstance <- typeInstance scheme updateActive [(schemeInstance, t)] -- match two concrete types. This might generate new active -- contraints if either of the types has arguments match :: Constraint -> Solve () match (t1@(TVar i), t2@(TVar j)) | i == j = return () | otherwise = updateActive [(t1, t2)] match (TInt, TInt) = return () match (TChar, TChar) = return () match (TBool, TBool) = return () match (TList t1, TList t2) = updateActive [(t1, t2)] match (TFun t1 t2, TFun t3 t4) = updateActive [(t1, t3), (t2, t4)] match (t1@(TTuple ts1), t2@(TTuple ts2)) | length ts1 == length ts2 = updateActive (zip ts1 ts2) -- type error | otherwise = updateSolution [(t1, t2)] match (t1@(TypeOf _ _), t2@(TypeOf _ _)) | t1 == t2 = return () | otherwise = updateActive [(t1, t2)] -- type error match (t1, t2) = updateSolution [(t1, t2)] -- substitute a type variable or a typeOf with -- another type in the store substitute :: Constraint -> Solve () substitute (t1, t2) | t1 == t2 = return () -- occurs check failure, infinite type | occursInType t1 t2 = updateSolution [(t1, t2)] | otherwise = do store <- get let newActives = map (subTypeInConstraint t1 t2) (store_active store) newSolution = map (subTypeInConstraint t1 t2) (store_solution store) put $ store { store_active = newActives , store_solution = newSolution } subTypeInConstraint :: SolverType -> SolverType -> Constraint -> Constraint subTypeInConstraint t1 t2 (typeLeft, typeRight) = (newLeftType, newRightType) where newLeftType = subTypeInType t1 t2 typeLeft newRightType = subTypeInType t1 t2 typeRight subTypeInType :: SolverType -> SolverType -> SolverType -> SolverType subTypeInType old new thisType@(TVar _) | thisType == old = new | otherwise = thisType subTypeInType old new (TList t) = TList $ subTypeInType old new t subTypeInType old new (TFun t1 t2) = TFun newT1 newT2 where newT1 = subTypeInType old new t1 newT2 = subTypeInType old new t2 subTypeInType old new (TTuple ts) = TTuple newTs where newTs = map (subTypeInType old new) ts subTypeInType old new thisType@(TypeOf _ _) | thisType == old = new | otherwise = thisType subTypeInType old new otherType = otherType -- make a fresh instance of an existing type. -- instance has same shape as existing type -- but all variables are fresh type TyVarMap = Map.Map Int Int type Inst a = StateT (TyVarMap, Int) IO a lookupTyVarMap :: Int -> Inst (Maybe Int) lookupTyVarMap i = do map <- gets fst return $ Map.lookup i map freshTyVarCounter :: Inst Int freshTyVarCounter = do count <- gets snd modify $ \(map, count) -> (map, count+1) return count extendTyVarMap :: Int -> Int -> Inst () extendTyVarMap x y = modify $ \(map, count) -> (Map.insert x y map, count) typeInstance :: SolverType -> Solve SolverType typeInstance t = do count <- gets store_count (resultType, (_env, finalCount)) <- liftIO $ runStateT (mkInstance t) (Map.empty, count) modify $ \store -> store { store_count = finalCount } return resultType where mkInstance :: SolverType -> Inst SolverType mkInstance (TVar var) = do mbVar <- lookupTyVarMap var case mbVar of Nothing -> do count <- freshTyVarCounter extendTyVarMap var count return $ TVar count Just newVar -> return $ TVar newVar mkInstance (TList t) = liftM TList $ mkInstance t mkInstance (TFun t1 t2) = liftM2 TFun (mkInstance t1) (mkInstance t2) mkInstance (TTuple ts) = liftM TTuple $ mapM mkInstance ts mkInstance otherType = return otherType -- does a type var or typeOf occur within another type? occursInType :: SolverType -> SolverType -> Bool occursInType search thisType@(TVar _) = search == thisType occursInType search (TList t) = occursInType search t occursInType search (TFun t1 t2) = occursInType search t1 || occursInType search t2 occursInType search (TTuple ts) = any (occursInType search) ts occursInType search thisType@(TypeOf _ _) = search == thisType occursInType search other = False -- update the solution constraints in the store updateSolution :: [Constraint] -> Solve () updateSolution cs = do oldSolution <- gets store_solution modify $ \store -> store { store_solution = cs ++ oldSolution } -- update the active constraints in the store updateActive :: [Constraint] -> Solve () updateActive cs = do oldActive <- gets store_active modify $ \store -> store { store_active = cs ++ oldActive } -- update the active constraints in the store updateAssumptions :: [Constraint] -> Solve () updateAssumptions cs = do oldAssumps <- gets store_assumptions modify $ \store -> store { store_assumptions = cs ++ oldAssumps } -------------------------------------------------------------------------------- -- pretty printing of types and constraints debugPrintStore :: Store -> IO () debugPrintStore store = do putStrLn "---- the current store ----" putStrLn "active constraints:" putStrLn $ renderConstraintsUgly $ store_active store putStrLn "solution:" putStrLn $ renderConstraintsUgly $ store_solution store -- putStrLn "assumptions:" -- putStrLn $ renderConstraints $ store_assumptions store putStr "count: " print $ store_count store return () data PrettyState = PrettyState { prettyState_varMap :: Map.Map Int String , prettyState_nameSupply :: [String] } initPrettyState :: PrettyState initPrettyState = PrettyState { prettyState_varMap = Map.empty , prettyState_nameSupply = nameSupply } instance Pretty SolverType where pretty = prettyType -- pretty printing of types, type variables get nice names prettyType :: SolverType -> Doc prettyType = snd . prettyTypeWorker False initPrettyState prettyTypeWorker :: Bool -> PrettyState -> SolverType -> (PrettyState, Doc) prettyTypeWorker _bracks state (TVar i) = case Map.lookup i varMap of Nothing -> (newState, text newName) Just name -> (state, text name) where varMap = prettyState_varMap state nameSupply = prettyState_nameSupply state newName = head nameSupply newState = PrettyState { prettyState_varMap = Map.insert i newName varMap , prettyState_nameSupply = tail nameSupply } prettyTypeWorker _bracks state TInt = (state, text "Int") prettyTypeWorker _bracks state TChar = (state, text "Char") prettyTypeWorker _bracks state TBool = (state, text "Bool") prettyTypeWorker _bracks state (TList t) = (newState, brackets doc) where (newState, doc) = prettyTypeWorker False state t prettyTypeWorker bracks state (TFun t1 t2) = (newState, doc) where (t1State, t1Doc) = prettyTypeWorker True state t1 (newState, t2Doc) = prettyTypeWorker False t1State t2 doc = parensIf bracks (t1Doc <+> text "->" <+> t2Doc) prettyTypeWorker _bracks state (TTuple ts) = (newState, doc) where (newState, tsDoc) = mapAccumL (prettyTypeWorker False) state ts doc = parens $ cat $ punctuate comma tsDoc prettyTypeWorker _bracks state (TypeOf binding ident) = (state, doc) where doc = text "type" <> (parens $ prettyBinder binding <+> text ident) -- less pretty printing of types. Type variables do not get nice -- names, they are printed as their underlying numbers. This is -- helpful for debugging the constraint solver. uglyType :: Bool -> SolverType -> Doc uglyType _bracks (TVar i) = text "t" <> int i uglyType _bracks TInt = text "Int" uglyType _bracks TChar = text "Char" uglyType _bracks TBool = text "Bool" uglyType _bracks (TList t) = brackets $ uglyType False t uglyType bracks (TFun t1 t2) = parensIf bracks (t1Doc <+> text "->" <+> t2Doc) where t1Doc = uglyType True t1 t2Doc = uglyType False t2 uglyType _bracks (TTuple ts) = parens $ cat $ punctuate comma tsDoc where tsDoc = map (uglyType False) ts uglyType _bracks (TypeOf binding ident) = text "type" <> (parens $ prettyBinder binding <+> text ident) prettyBinder :: Binding -> Doc prettyBinder Free = text "free" prettyBinder LamBound = text "lambda-bound" prettyBinder LetBound = text "let-bound" reservedIdent :: SolverType reservedIdent = TypeOf LetBound "$" uglyConstraint :: Constraint -> Doc uglyConstraint (t1, t2) = uglyType False t1 <+> text "=" <+> uglyType False t2 prettyConstraint :: Constraint -> Doc prettyConstraint = snd . prettyConstraintWorker initPrettyState prettyConstraintWorker :: PrettyState -> Constraint -> (PrettyState, Doc) prettyConstraintWorker state (TypeOf LetBound ident, t2) = (newState, doc) where (newState, t2Doc) = prettyTypeWorker False state t2 doc = text ident <+> text "::" <+> t2Doc prettyConstraintWorker state (t1, t2) = (newState, doc) where (t1State, t1Doc) = prettyTypeWorker False state t1 (newState, t2Doc) = prettyTypeWorker False t1State t2 doc = t1Doc <+> text "=" <+> t2Doc renderConstraints :: [Constraint] -> String renderConstraints cs = render $ vcat $ map prettyConstraint cs renderConstraintsUgly :: [Constraint] -> String renderConstraintsUgly cs = render $ vcat $ map uglyConstraint cs -------------------------------------------------------------------------------- toSolverType :: Type.Type -> SolverType toSolverType (Type.TVar i) = TVar i toSolverType Type.TInt = TInt toSolverType Type.TChar = TChar toSolverType Type.TBool = TBool toSolverType (Type.TList t) = TList $ toSolverType t toSolverType (Type.TFun t1 t2) = TFun (toSolverType t1) (toSolverType t2) toSolverType (Type.TTuple ts) = TTuple $ map toSolverType ts