module Language.Pascal.TypeCheck
(checkTypes,
checkSource,
builtinSymbols
) where
import Control.Monad
import Control.Monad.State
import Control.Monad.Error
import qualified Data.Map as M
import Data.Maybe
import Text.Parsec hiding (State)
import Language.Pascal.Types
import Language.Pascal.Builtin
import Language.Pascal.Parser
lookupSymbol :: Id -> SymbolTable -> Maybe Symbol
lookupSymbol name table =
case filter isJust $ map (M.lookup name) table of
[] -> Nothing
(s:_) -> s
builtinSymbols :: M.Map Id Symbol
builtinSymbols = M.fromList $ map pair builtinFunctions
where
pair (name, tp, _) = (name, Symbol {
symbolName = name,
symbolType = tp,
symbolDefLine = 0,
symbolDefCol = 0 })
isSubtypeOf :: Type -> Type -> Bool
isSubtypeOf TVoid TVoid = True
isSubtypeOf TVoid _ = False
isSubtypeOf _ TAny = True
isSubtypeOf (TArray _ t1) (TArray _ t2) = t1 `isSubtypeOf` t2
isSubtypeOf t1 (TField _ t2) = t1 `isSubtypeOf` t2
isSubtypeOf (TField _ t1) t2 = t1 `isSubtypeOf` t2
isSubtypeOf (TFunction a1 r1) (TFunction a2 r2) =
(r1 `isSubtypeOf` r2) && areSubtypesOf a1 a2
isSubtypeOf t1 t2 = t1 == t2
areSubtypesOf :: [Type] -> [Type] -> Bool
areSubtypesOf ts1 ts2 =
(length ts1 == length ts2) && and (zipWith isSubtypeOf ts1 ts2)
emptyState :: CheckState
emptyState = CheckState {
userConsts = [],
userTypes = M.empty,
symbolTable = [builtinSymbols],
contexts = [],
ckLine = 0,
ckColumn = 0 }
class Typed a where
typeCheck :: a :~ SrcPos -> Check (a :~ TypeAnn)
isFor :: Context -> Bool
isFor (ForLoop _ _) = True
isFor _ = False
returnT :: Type -> Annotate node1 SrcPos -> node -> Check (Annotate node TypeAnn)
returnT t x res =
return $ Annotate res $ TypeAnn {
srcPos = SrcPos {
srcLine = srcLine (annotation x),
srcColumn = srcColumn (annotation x) },
typeOf = t,
localSymbols = M.empty}
instance Checker Check where
enterContext c = do
st <- get
put $ st {contexts = c: contexts st}
dropContext = do
st <- get
case contexts st of
[] -> failCheck "Internal error in TypeCheck: dropContext on empty context!"
(_:old) -> put $ st {contexts = old}
failCheck msg = do
line <- gets ckLine
col <- gets ckColumn
cxs <- gets contexts
throwError $ TError {
errLine = line,
errColumn = col,
errContext = if null cxs
then Unknown
else head cxs,
errMessage = msg }
setPos :: Annotate a SrcPos -> Check ()
setPos x = do
st <- get
put $ st {ckLine = srcLine (annotation x),
ckColumn = srcColumn (annotation x)}
errorOnUserTypeSymbol :: Annotate Symbol a -> Annotate Symbol a
errorOnUserTypeSymbol (Annotate (symbolType -> TUser t) _) = error $ "Internal error (symbol): user type: " ++ t
errorOnUserTypeSymbol x = x
checkType :: Type -> Check Type
checkType (TArray sz t) = do
t' <- checkType t
return (TArray sz t')
checkType (TRecord pairs) = withSymbolTable $ do
pairs' <- forM pairs $ \(n,t) -> do
t' <- checkType t
addSymbol $ Annotate (n # t') (SrcPos 0 0)
return (n, t')
return (TRecord pairs')
checkType (TUser name) = do
types <- gets userTypes
case M.lookup name types of
Just t -> checkType t
Nothing -> failCheck $ "Undefined type: " ++ name
checkType t = return t
checkSymbol :: Annotate Symbol SrcPos -> Check (Annotate Symbol TypeAnn)
checkSymbol s = do
setPos s
t <- checkType (symbolTypeC s)
case t of
TUser name -> failCheck $ "Internal error: undefined user type: " ++ name
_ -> do
let s' = setType s t
addSymbol s'
return $ s' `withType` t
getSymbol :: Id -> Check Symbol
getSymbol name = do
table <- gets symbolTable
case lookupSymbol name table of
Nothing -> failCheck $ "Unknown symbol: " ++ name
Just s -> return s
addSymbol :: Annotate Symbol SrcPos -> Check ()
addSymbol (Annotate (Symbol {..}) (SrcPos {..})) = do
st <- get
(current:other) <- gets symbolTable
case M.lookup symbolName current of
Just s -> failCheck $ "Symbol is already defined: " ++ showSymbol s
Nothing -> do
let new = M.insert symbolName (Symbol symbolName symbolType srcLine srcColumn) current
put $ st {symbolTable = (new:other)}
addSymbolTable :: Check ()
addSymbolTable = do
st <- get
was <- gets symbolTable
put $ st {symbolTable = (M.empty: was)}
dropSymbolTable :: Check ()
dropSymbolTable = do
st <- get
was <- gets symbolTable
case was of
[] -> failCheck "Internal error: empty symbol table on dropSymbolTable!"
(_:older) -> put $ st {symbolTable = older}
withSymbolTable :: Check a -> Check a
withSymbolTable check = do
addSymbolTable
x <- check
dropSymbolTable
return x
addType :: Id -> Type -> Check (Id, Type)
addType name tp = do
st <- get
let types = userTypes st
case M.lookup name types of
Just _ -> failCheck $ "Type is already defined: " ++ name
Nothing -> do
tp' <- checkType tp
put $ st {userTypes = M.insert name tp' types}
return (name, tp')
evalConst :: Expression :~ a -> Check Lit
evalConst expr = do
case content expr of
Variable name -> do
consts <- gets userConsts
case lookup name consts of
Just v -> evalConst v
Nothing -> failCheck $ "No such constant: " ++ name
Literal v -> return v
Op op x y -> do
x' <- evalConst x
y' <- evalConst y
return $ eval op x' y'
x -> failCheck $ "Expression is not constant: " ++ show x
where
eval Add (LInteger x) (LInteger y) = LInteger (x+y)
eval Sub (LInteger x) (LInteger y) = LInteger (xy)
eval Mul (LInteger x) (LInteger y) = LInteger (x*y)
eval Div (LInteger x) (LInteger y) = LInteger (x `div` y)
eval Pow (LInteger x) (LInteger y) = error "pow() is not supported yet"
eval IsGT (LInteger x) (LInteger y) = LBool (x > y)
eval IsLT (LInteger x) (LInteger y) = LBool (x < y)
eval IsEQ (LInteger x) (LInteger y) = LBool (x == y)
eval IsNE (LInteger x) (LInteger y) = LBool (x /= y)
eval _ _ _ = error "Unsupported operand types in constant expression"
litType :: Lit -> Type
litType (LInteger _) = TInteger
litType (LString _) = TString
litType (LBool _) = TBool
addConst :: Id -> Expression :~ SrcPos -> Check (Expression :~ TypeAnn)
addConst name e = do
st <- get
let consts = userConsts st
case lookup name consts of
Just c -> failCheck $ "Constant " ++ name ++ " was already defined as " ++ show c
Nothing -> do
val <- evalConst e
let result = Annotate (Literal val) $ TypeAnn {
srcPos = annotation e,
typeOf = litType val,
localSymbols = M.empty }
put $ st {userConsts = (name, result): consts}
return result
instance Typed Program where
typeCheck p@(content -> Program consts types vars fns body) = withSymbolTable $ do
setPos p
consts' <- inContext Outside $
forM consts $ \(n,v) -> do
v' <- addConst n v
let sym = Annotate {
content = Symbol {
symbolName = n,
symbolType = typeOfA v',
symbolDefLine = srcLine (annotation v),
symbolDefCol = srcColumn (annotation v) },
annotation = annotation v }
addSymbol sym
return (n, v')
types' <- inContext Outside $
forM (M.assocs types) $ uncurry addType
vars' <- inContext Outside $
forM vars checkSymbol
fns' <- inContext Outside $
forM fns $ \fn -> do
fn' <- typeCheck fn
let f = content fn'
tp = TFunction (argTypes f) (fnResultType f)
s = SrcPos {
srcLine = srcLine (annotation fn),
srcColumn = srcColumn (annotation fn) }
addSymbol $ Annotate (fnName f # tp) s
return fn'
body' <- inContext ProgramBody $
forM body typeCheck
let program = Program consts' (M.fromList types') (map errorOnUserTypeSymbol vars') fns' body'
return $ Annotate program $ TypeAnn {
srcPos = SrcPos 0 0,
typeOf = TVoid,
localSymbols = makeSymbolTable vars'}
where
argTypes :: Function TypeAnn -> [Type]
argTypes (Function {..}) = map symbolTypeC fnFormalArgs
makeSymbolTable :: [Annotate Symbol TypeAnn] -> M.Map Id Symbol
makeSymbolTable xs = M.fromList $ map pair xs
where
pair :: Annotate Symbol TypeAnn -> (Id, Symbol)
pair (Annotate s (TypeAnn {..})) =
(symbolName s,
s { symbolDefLine = srcLine srcPos,
symbolDefCol = srcColumn srcPos })
findField :: Id -> [(Id, Type)] -> Maybe (Int, Type)
findField name pairs = go 1 pairs
where
go _ [] = Nothing
go i ((k,v):other)
| k == name = Just (i, v)
| otherwise = go (i+1) other
instance Typed LValue where
typeCheck v@(content -> LVariable name) = do
setPos v
sym <- getSymbol name
returnT (symbolType sym) v (LVariable name)
typeCheck v@(content -> LArray name ix) = do
setPos v
sym <- getSymbol name
case symbolType sym of
TArray _ tp -> do
ix' <- typeCheck ix
when (typeOfA ix' /= TInteger) $
failCheck $ "Invalid array item lvalue: index is " ++ show (typeOfA ix') ++ ", not Integer"
returnT tp v (LArray name ix')
x -> failCheck $ "Invalid lvalue: " ++ name ++ " is " ++ show x ++ ", not Array"
typeCheck v@(content -> LField base field) = do
setPos v
baseSym <- getSymbol base
case symbolType baseSym of
TRecord pairs -> case findField field pairs of
Just (ix,t) -> returnT (TField ix t) v (LField base field)
Nothing -> failCheck $ "No such field in " ++ base ++ " record: " ++ field
x -> failCheck $ base ++ " is " ++ show x ++ ", not Record"
instance Typed Statement where
typeCheck x@(content -> Assign lvalue expr) = do
setPos x
lhs <- typeCheck lvalue
rhs <- typeCheck expr
let rhsType = typeOfA rhs
lhsType = typeOfA lhs
if (rhsType == TAny) || (rhsType `isSubtypeOf` lhsType)
then do
let result = Assign lhs rhs
returnT lhsType x result
else failCheck $ "Invalid assignment: LHS type is " ++ show lhsType ++ ", but RHS type is " ++ show rhsType
typeCheck s@(content -> Procedure name args) = do
setPos s
sym <- getSymbol name
case symbolType sym of
TFunction formalArgTypes TVoid -> do
args' <- mapM typeCheck args
let actualTypes = map typeOfA args'
if actualTypes `areSubtypesOf` formalArgTypes
then returnT TVoid s (Procedure name args')
else failCheck $ "Invalid types in procedure call: " ++ show actualTypes ++ " instead of " ++ show formalArgTypes
t -> failCheck $ "Symbol " ++ name ++ " is not a procedure, but " ++ show t
typeCheck s@(content -> Break) = do
setPos s
cxs <- gets contexts
if null (filter isFor cxs)
then failCheck "break statement not in for loop"
else returnT TVoid s Break
typeCheck s@(content -> Continue) = do
setPos s
cxs <- gets contexts
if null (filter isFor cxs)
then failCheck "continue statement not in for loop"
else returnT TVoid s Continue
typeCheck s@(content -> Exit) = do
setPos s
cxs <- gets contexts
case cxs of
(InFunction _ TVoid:_) -> returnT TVoid s Exit
(ProgramBody:_) -> returnT TVoid s Exit
_ -> failCheck "exit statement not in procedure or program body"
typeCheck s@(content -> Return x) = do
setPos s
x' <- typeCheck x
let retType = typeOfA x'
cxs <- gets contexts
case cxs of
(InFunction _ TVoid:_) -> failCheck "return statement in procedure"
(InFunction _ t:_)
| retType `isSubtypeOf` t -> returnT (typeOfA x') s (Return x')
| otherwise -> failCheck $ "Return value type does not match: expecting " ++ show t ++ ", got " ++ show retType
_ -> failCheck $ "return statement not in function"
typeCheck s@(content -> IfThenElse c a b) = do
setPos s
c' <- typeCheck c
when (typeOfA c' /= TBool) $
failCheck $ "Condition type is not Boolean: " ++ show c
a' <- mapM typeCheck a
b' <- mapM typeCheck b
returnT TVoid s (IfThenElse c' a' b')
typeCheck s@(content -> For name start end body) = inContext (ForLoop name 0) $ do
setPos s
sym <- getSymbol name
when (symbolType sym /= TInteger) $
failCheck $ "Counter variable is not Integer: " ++ name
start' <- typeCheck start
when (typeOfA start' /= TInteger) $
failCheck $ "Counter start value is not Integer: " ++ show start
end' <- typeCheck end
when (typeOfA end' /= TInteger) $
failCheck $ "Counter end value is not Integer: " ++ show end
body' <- mapM typeCheck body
returnT TVoid s (For name start' end' body')
instance Typed Function where
typeCheck x@(content -> Function {..}) = do
setPos x
inContext (InFunction fnName fnResultType) $ withSymbolTable $ do
args <- mapM checkSymbol fnFormalArgs
vars <- mapM checkSymbol fnVars
body <- mapM typeCheck fnBody
let fn = Function fnName args fnResultType vars body
tp = TFunction (map typeOfA args) fnResultType
Annotate result ta <- returnT fnResultType x fn
return $ Annotate result $ ta {localSymbols = makeSymbolTable vars}
instance Typed Expression where
typeCheck e@(content -> Variable x) = do
setPos e
sym <- getSymbol x
returnT (symbolType sym) e (Variable x)
typeCheck e@(content -> ArrayItem name ix) = do
setPos e
sym <- getSymbol name
case symbolType sym of
TArray _ tp -> do
ix' <- typeCheck ix
when (typeOfA ix' /= TInteger) $
failCheck $ "Array index is " ++ show (typeOfA ix') ++ ", not Integer"
returnT tp e (ArrayItem name ix')
x -> failCheck $ name ++ " is " ++ show x ++ ", not Array"
typeCheck e@(content -> RecordField base field) = do
setPos e
baseSym <- getSymbol base
case symbolType baseSym of
TRecord pairs -> case findField field pairs of
Just (ix,t) -> returnT (TField ix t) e (RecordField base field)
Nothing -> failCheck $ "No such field in " ++ base ++ " record: " ++ field
TField ix t -> returnT (TField ix t) e (RecordField base field)
x -> failCheck $ base ++ " is " ++ show x ++ ", not Record"
typeCheck e@(content -> Literal x) = returnT (litType x) e (Literal x)
typeCheck e@(content -> Call name args) = do
setPos e
sym <- getSymbol name
case symbolType sym of
TFunction formalArgTypes resType -> do
args' <- mapM typeCheck args
let actualTypes = map typeOfA args'
if actualTypes `areSubtypesOf` formalArgTypes
then returnT resType e (Call name args')
else failCheck $ "Invalid types in function call: " ++ show actualTypes ++ " instead of " ++ show formalArgTypes
t -> failCheck $ "Symbol " ++ name ++ " is not a function, but " ++ show t
typeCheck e@(content -> Op op x y) = do
setPos e
x' <- typeCheck x
y' <- typeCheck y
let tx = typeOfA x'
ty = typeOfA y'
if (TInteger `isSubtypeOf` tx) && (TInteger `isSubtypeOf` ty)
then if op `elem` [IsEQ, IsNE, IsGT, IsLT]
then returnT TBool e (Op op x' y')
else returnT TInteger e (Op op x' y')
else failCheck $ "Invalid operand types: " ++ show tx ++ ", " ++ show ty
checkTypes :: Program :~ SrcPos -> Program :~ TypeAnn
checkTypes prog = evalState check emptyState
where
check :: State CheckState (Program :~ TypeAnn)
check = do
x <- runErrorT (runCheck $ typeCheck prog)
case x of
Right result -> return result
Left err -> fail $ "type checker: " ++ show err
checkSource :: FilePath -> IO (Program :~ TypeAnn)
checkSource path = do
str <- readFile path
case parse pProgram path str of
Left err -> fail $ "parser: " ++ show err
Right prog -> return (checkTypes prog)