{-|
Module      : Language.Qux.Annotated.TypeChecker
Description : Type checking functions to verify that a 'Program' is type-safe.

Copyright   : (c) Henry J. Wylde, 2015
License     : BSD3
Maintainer  : public@hjwylde.com

Type checking functions to verify that a 'Program' is type-safe.

These functions only verify that types are used correctly.
They don't verify other properties such as definite assignment.
-}

module Language.Qux.Annotated.TypeChecker (
    -- * Environment
    Evaluation, Check,

    -- * Contexts
    Context, Locals,
    context,

    -- * Type checking

    -- ** Program checking
    check,

    -- ** Other node checking
    checkProgram, checkDecl, checkStmt, checkExpr, checkValue
) where

import Control.Applicative
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State

import Data.Function (on)
import Data.List ((\\), nubBy)
import Data.Map (Map, (!))
import qualified Data.Map as Map
import Data.Maybe (fromJust, isNothing)

import Language.Qux.Annotated.Exception
import Language.Qux.Annotated.Parser (SourcePos)
import Language.Qux.Annotated.Simplify
import qualified Language.Qux.Annotated.Syntax as Ann
import Language.Qux.Syntax


-- |    An environment that holds the global types (@Reader Context@) and the local types
--      (@Locals@).
type Evaluation = StateT Locals (Reader Context)

-- |    Either a 'TypeException' or an @a@.
--      Contains an underlying 'Evaluation' in the monad transformer.
type Check = ExceptT TypeException Evaluation


-- |    Global context that holds function definition types.
--      The function name and parameter types are held.
data Context = Context {
    functions :: Map Id [Type]
    }

-- | Local context.
type Locals = Map Id Type


-- | Returns a context for the given program.
context :: Program -> Context
context (Program decls) = Context { functions = Map.fromList $ map (\d -> (name d, types d)) decls }

retrieve :: Id -> Evaluation (Maybe [Type])
retrieve name = do
    maybeLocal  <- gets $ (fmap (:[])) . (Map.lookup name)
    maybeDef    <- asks $ (Map.lookup name) . functions

    return $ maybeLocal <|> maybeDef

once :: Monad m => MonadState s m => (s -> s) -> m a -> m a
once f m = get >>= \save -> modify f >> m <* put save


-- |    Type checks the program.
--      If an exception occurs then the result will be a 'TypeException', otherwise 'Nothing'.
--      This function wraps 'checkProgram' by building and evaluating the environment under
--      the hood.
check :: Ann.Program SourcePos -> Except TypeException ()
check program = mapExceptT
    (return
        . flip runReader (context $ sProgram program)
        . flip evalStateT Map.empty)
    (checkProgram program)

-- | Type checks a program.
checkProgram :: Ann.Program SourcePos -> Check ()
checkProgram (Ann.Program _ decls) = do
    when (not $ null duplicates) (throwError $ duplicateFunctionName (head duplicates))

    mapM_ checkDecl decls
    where
        duplicates = decls \\ nubBy ((==) `on` sId . Ann.name) decls

-- | Type checks a declaration.
checkDecl :: Ann.Decl SourcePos -> Check ()
checkDecl (Ann.FunctionDecl pos name parameters stmts) = do
    when (not $ null duplicates) (throwError $ duplicateParameterName (head $ map snd duplicates))

    once (Map.union $ Map.fromList (map (\(t, p) -> (sId p, sType t)) parameters)) (checkBlock stmts)
    where
        duplicates = parameters \\ nubBy ((==) `on` sId . snd) parameters

checkBlock :: [Ann.Stmt SourcePos] -> Check ()
checkBlock = mapM_ checkStmt

-- -- | Type checks a statement.
checkStmt :: Ann.Stmt SourcePos -> Check ()
checkStmt (Ann.IfStmt _ condition trueStmts falseStmts)   = do
    expectExpr condition [BoolType]

    checkBlock trueStmts
    checkBlock falseStmts
checkStmt (Ann.ReturnStmt _ expr)                         = do
    expected <- gets (! "@")

    void $ expectExpr expr [expected]
checkStmt (Ann.WhileStmt _ condition stmts)               = do
    expectExpr condition [BoolType]

    checkBlock stmts

-- | Type checks an expression.
checkExpr :: Ann.Expr SourcePos -> Check Type
checkExpr e@(Ann.ApplicationExpr pos name arguments)    = do
    maybeTypes <- lift $ retrieve (sId name)
    when (isNothing maybeTypes) (throwError $ undefinedFunctionCall e)

    let expected = init $ fromJust maybeTypes

    case length expected == length arguments of
        True    -> zipWithM expectExpr arguments $ map (:[]) expected
        False   -> throwError $ invalidArgumentsCount e (length expected)

    return $ last (fromJust maybeTypes)
checkExpr (Ann.BinaryExpr _ op lhs rhs)
    | op `elem` [Acc]               = do
    list <- expectExpr lhs [ListType undefined]
    let (ListType inner) = list in
        expectExpr rhs [IntType] >> return inner
    | op `elem` [Mul, Div, Mod]     = expectExpr lhs [IntType] >> expectExpr rhs [IntType]
    | op `elem` [Add, Sub]          = expectExpr lhs [IntType, ListType undefined] >>= (expectExpr rhs) . (:[])
    | op `elem` [Lt, Lte, Gt, Gte]  = expectExpr lhs [IntType] >> expectExpr rhs [IntType] >> return BoolType
    | op `elem` [Eq, Neq]           = ((:[]) <$> checkExpr lhs >>= expectExpr rhs) >> return BoolType
checkExpr (Ann.ListExpr _ [])                           = return $ ListType undefined
checkExpr (Ann.ListExpr _ elements)                     = do
    expected <- checkExpr $ head elements

    mapM_ (flip expectExpr [expected]) (tail elements) >> return expected
checkExpr (Ann.UnaryExpr _ op expr)
    | op `elem` [Len]               = expectExpr expr [ListType undefined] >> return IntType
    | op `elem` [Neg]               = expectExpr expr [IntType]
checkExpr (Ann.ValueExpr _ value)                       = checkValue value

-- -- | Type checks a value.
checkValue :: Value -> Check Type
checkValue (BoolValue _)        = return BoolType
checkValue (IntValue _)         = return IntType
checkValue NilValue             = return NilType


expectExpr :: Ann.Expr SourcePos -> [Type] -> Check Type
expectExpr expr expects = (attach (Ann.ann expr) <$> checkExpr expr) >>= flip expectType expects

expectType :: Ann.Type SourcePos -> [Type] -> Check Type
expectType received expects
    | sType received `elem` expects = return $ sType received
    | otherwise                     = throwError $ mismatchedType received expects

attach :: SourcePos -> Type -> Ann.Type SourcePos
attach pos BoolType         = Ann.BoolType pos
attach pos IntType          = Ann.IntType pos
attach pos (ListType inner) = Ann.ListType pos (attach undefined inner)
attach pos NilType          = Ann.NilType pos