module Language.Qux.Annotated.TypeChecker (
Evaluation, Check,
Context, Locals,
context,
check,
checkProgram, checkDecl, checkStmt, checkExpr, checkValue
) where
import Control.Applicative
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Function (on)
import Data.List ((\\), nubBy)
import Data.Map (Map, (!))
import qualified Data.Map as Map
import Data.Maybe (fromJust, isNothing)
import Language.Qux.Annotated.Exception
import Language.Qux.Annotated.Parser (SourcePos)
import Language.Qux.Annotated.Simplify
import qualified Language.Qux.Annotated.Syntax as Ann
import Language.Qux.Syntax
type Evaluation = StateT Locals (Reader Context)
type Check = ExceptT TypeException Evaluation
data Context = Context {
functions :: Map Id [Type]
}
type Locals = Map Id Type
context :: Program -> Context
context (Program decls) = Context { functions = Map.fromList $ map (\d -> (name d, types d)) decls }
retrieve :: Id -> Evaluation (Maybe [Type])
retrieve name = do
maybeLocal <- gets $ (fmap (:[])) . (Map.lookup name)
maybeDef <- asks $ (Map.lookup name) . functions
return $ maybeLocal <|> maybeDef
once :: Monad m => MonadState s m => (s -> s) -> m a -> m a
once f m = get >>= \save -> modify f >> m <* put save
check :: Ann.Program SourcePos -> Except TypeException ()
check program = mapExceptT
(return
. flip runReader (context $ sProgram program)
. flip evalStateT Map.empty)
(checkProgram program)
checkProgram :: Ann.Program SourcePos -> Check ()
checkProgram (Ann.Program _ decls) = do
when (not $ null duplicates) (throwError $ duplicateFunctionName (head duplicates))
mapM_ checkDecl decls
where
duplicates = decls \\ nubBy ((==) `on` sId . Ann.name) decls
checkDecl :: Ann.Decl SourcePos -> Check ()
checkDecl (Ann.FunctionDecl pos name parameters stmts) = do
when (not $ null duplicates) (throwError $ duplicateParameterName (head $ map snd duplicates))
once (Map.union $ Map.fromList (map (\(t, p) -> (sId p, sType t)) parameters)) (checkBlock stmts)
where
duplicates = parameters \\ nubBy ((==) `on` sId . snd) parameters
checkBlock :: [Ann.Stmt SourcePos] -> Check ()
checkBlock = mapM_ checkStmt
checkStmt :: Ann.Stmt SourcePos -> Check ()
checkStmt (Ann.IfStmt _ condition trueStmts falseStmts) = do
expectExpr condition [BoolType]
checkBlock trueStmts
checkBlock falseStmts
checkStmt (Ann.ReturnStmt _ expr) = do
expected <- gets (! "@")
void $ expectExpr expr [expected]
checkStmt (Ann.WhileStmt _ condition stmts) = do
expectExpr condition [BoolType]
checkBlock stmts
checkExpr :: Ann.Expr SourcePos -> Check Type
checkExpr e@(Ann.ApplicationExpr pos name arguments) = do
maybeTypes <- lift $ retrieve (sId name)
when (isNothing maybeTypes) (throwError $ undefinedFunctionCall e)
let expected = init $ fromJust maybeTypes
case length expected == length arguments of
True -> zipWithM expectExpr arguments $ map (:[]) expected
False -> throwError $ invalidArgumentsCount e (length expected)
return $ last (fromJust maybeTypes)
checkExpr (Ann.BinaryExpr _ op lhs rhs)
| op `elem` [Acc] = do
list <- expectExpr lhs [ListType undefined]
let (ListType inner) = list in
expectExpr rhs [IntType] >> return inner
| op `elem` [Mul, Div, Mod] = expectExpr lhs [IntType] >> expectExpr rhs [IntType]
| op `elem` [Add, Sub] = expectExpr lhs [IntType, ListType undefined] >>= (expectExpr rhs) . (:[])
| op `elem` [Lt, Lte, Gt, Gte] = expectExpr lhs [IntType] >> expectExpr rhs [IntType] >> return BoolType
| op `elem` [Eq, Neq] = ((:[]) <$> checkExpr lhs >>= expectExpr rhs) >> return BoolType
checkExpr (Ann.ListExpr _ []) = return $ ListType undefined
checkExpr (Ann.ListExpr _ elements) = do
expected <- checkExpr $ head elements
mapM_ (flip expectExpr [expected]) (tail elements) >> return expected
checkExpr (Ann.UnaryExpr _ op expr)
| op `elem` [Len] = expectExpr expr [ListType undefined] >> return IntType
| op `elem` [Neg] = expectExpr expr [IntType]
checkExpr (Ann.ValueExpr _ value) = checkValue value
checkValue :: Value -> Check Type
checkValue (BoolValue _) = return BoolType
checkValue (IntValue _) = return IntType
checkValue NilValue = return NilType
expectExpr :: Ann.Expr SourcePos -> [Type] -> Check Type
expectExpr expr expects = (attach (Ann.ann expr) <$> checkExpr expr) >>= flip expectType expects
expectType :: Ann.Type SourcePos -> [Type] -> Check Type
expectType received expects
| sType received `elem` expects = return $ sType received
| otherwise = throwError $ mismatchedType received expects
attach :: SourcePos -> Type -> Ann.Type SourcePos
attach pos BoolType = Ann.BoolType pos
attach pos IntType = Ann.IntType pos
attach pos (ListType inner) = Ann.ListType pos (attach undefined inner)
attach pos NilType = Ann.NilType pos