{-# LANGUAGE RecordWildCards, TypeOperators, TypeSynonymInstances, FlexibleInstances, ViewPatterns #-}
module Language.Pascal.TypeCheck
  (checkTypes,
   checkSource,
   builtinSymbols
  ) where

import Control.Monad
import Control.Monad.State
import Control.Monad.Error
import qualified Data.Map as M
import Data.Maybe
import Text.Parsec hiding (State)

import Language.Pascal.Types
import Language.Pascal.Builtin
import Language.Pascal.Parser

-- | Look up for named symbol
lookupSymbol :: Id -> SymbolTable -> Maybe Symbol
lookupSymbol name table =
  case filter isJust $ map (M.lookup name) table of
    [] -> Nothing
    (s:_) -> s

-- | Symbol table of builtin symbols
builtinSymbols ::  M.Map Id Symbol
builtinSymbols = M.fromList $ map pair builtinFunctions
  where
    pair (name, tp, _) = (name, Symbol {
                                 symbolName = name,
                                 symbolType = tp,
                                 symbolDefLine = 0,
                                 symbolDefCol = 0 })

isSubtypeOf :: Type -> Type -> Bool
isSubtypeOf TVoid TVoid = True
isSubtypeOf TVoid _ = False
isSubtypeOf _ TAny = True
isSubtypeOf (TArray _ t1) (TArray _ t2) = t1 `isSubtypeOf` t2
isSubtypeOf t1 (TField _ t2) = t1 `isSubtypeOf` t2
isSubtypeOf (TField _ t1) t2 = t1 `isSubtypeOf` t2
isSubtypeOf (TFunction a1 r1) (TFunction a2 r2) =
  (r1 `isSubtypeOf` r2) && areSubtypesOf a1 a2
isSubtypeOf t1 t2 = t1 == t2

areSubtypesOf :: [Type] -> [Type] -> Bool
areSubtypesOf ts1 ts2 =
  (length ts1 == length ts2) && and (zipWith isSubtypeOf ts1 ts2)

-- | Starting type checker state
emptyState :: CheckState
emptyState = CheckState {
  userConsts = [],
  userTypes = M.empty,
  symbolTable = [builtinSymbols],
  contexts = [],
  ckLine = 0,
  ckColumn = 0 }

class Typed a where
  typeCheck :: a :~ SrcPos -> Check (a :~ TypeAnn)

isFor :: Context -> Bool
isFor (ForLoop _ _) = True
isFor _             = False

returnT ::  Type -> Annotate node1 SrcPos -> node -> Check (Annotate node TypeAnn)
returnT t x res =
  return $ Annotate res $ TypeAnn {
             srcPos = SrcPos {
                        srcLine = srcLine (annotation x),
                        srcColumn = srcColumn (annotation x) },
             typeOf = t,
             localSymbols = M.empty}

instance Checker Check where
  enterContext c = do
    st <- get
    put $ st {contexts = c: contexts st}

  dropContext = do
    st <- get
    case contexts st of
      []  -> failCheck "Internal error in TypeCheck: dropContext on empty context!"
      (_:old) -> put $ st {contexts = old}

  failCheck msg = do
    line <- gets ckLine
    col  <- gets ckColumn
    cxs <- gets contexts
    throwError $ TError {
                  errLine    = line,
                  errColumn  = col,
                  errContext = if null cxs
                                 then Unknown
                                 else head cxs,
                  errMessage = msg }

setPos :: Annotate a SrcPos -> Check ()
setPos x = do
  st <- get
  put $ st {ckLine = srcLine (annotation x),
            ckColumn = srcColumn (annotation x)}

errorOnUserTypeSymbol :: Annotate Symbol a -> Annotate Symbol a
errorOnUserTypeSymbol (Annotate (symbolType -> TUser t) _) = error $ "Internal error (symbol): user type: " ++ t
errorOnUserTypeSymbol x = x

checkType :: Type -> Check Type
checkType (TArray sz t) = do
    t' <- checkType t
    return (TArray sz t')
checkType (TRecord pairs) = withSymbolTable $ do
    pairs' <- forM pairs $ \(n,t) -> do
                t' <- checkType t
                addSymbol $ Annotate (n # t') (SrcPos 0 0)
                return (n, t')
    return (TRecord pairs')
checkType (TUser name) = do
    types <- gets userTypes
    case M.lookup name types of
      Just t -> checkType t
      Nothing -> failCheck $ "Undefined type: " ++ name
checkType t = return t

checkSymbol :: Annotate Symbol SrcPos -> Check (Annotate Symbol TypeAnn)
checkSymbol s = do
  setPos s
  t <- checkType (symbolTypeC s)
  case t of
    TUser name -> failCheck $ "Internal error: undefined user type: " ++ name
    _ -> do
      let s' = setType s t
      addSymbol s'
      return $ s' `withType` t

getSymbol :: Id -> Check Symbol
getSymbol name = do
  table <- gets symbolTable
  case lookupSymbol name table of
    Nothing -> failCheck $ "Unknown symbol: " ++ name
    Just s  -> return s

addSymbol :: Annotate Symbol SrcPos -> Check ()
addSymbol (Annotate (Symbol {..}) (SrcPos {..})) = do
  st <- get
  (current:other) <- gets symbolTable
  case M.lookup symbolName current of
    Just s -> failCheck $ "Symbol is already defined: " ++ showSymbol s
    Nothing -> do
      let new = M.insert symbolName (Symbol symbolName symbolType srcLine srcColumn) current
      put $ st {symbolTable = (new:other)}

addSymbolTable :: Check ()
addSymbolTable = do
  st <- get
  was <- gets symbolTable
  put $ st {symbolTable = (M.empty: was)}

dropSymbolTable :: Check ()
dropSymbolTable = do
  st <- get
  was <- gets symbolTable
  case was of
    [] -> failCheck "Internal error: empty symbol table on dropSymbolTable!"
    (_:older) -> put $ st {symbolTable = older}

withSymbolTable :: Check a -> Check a
withSymbolTable check = do
  addSymbolTable
  x <- check
  dropSymbolTable
  return x

addType :: Id -> Type -> Check (Id, Type)
addType name tp = do
  st <- get
  let types = userTypes st
  case M.lookup name types of
    Just _  -> failCheck $ "Type is already defined: " ++ name
    Nothing -> do
      tp' <- checkType tp
      put $ st {userTypes = M.insert name tp' types}
      return (name, tp')

evalConst :: Expression :~ a -> Check Lit
evalConst expr = do
    case content expr of
      Variable name -> do
                       consts <- gets userConsts
                       case lookup name consts of
                         Just v -> evalConst v
                         Nothing -> failCheck $ "No such constant: " ++ name
      Literal v -> return v
      Op op x y -> do
                   x' <- evalConst x
                   y' <- evalConst y
                   return $ eval op x' y'
      x -> failCheck $ "Expression is not constant: " ++ show x
  where
    eval Add (LInteger x) (LInteger y) = LInteger (x+y)
    eval Sub (LInteger x) (LInteger y) = LInteger (x-y)
    eval Mul (LInteger x) (LInteger y) = LInteger (x*y)
    eval Div (LInteger x) (LInteger y) = LInteger (x `div` y)
    eval Pow (LInteger x) (LInteger y) = error "pow() is not supported yet"
    eval IsGT (LInteger x) (LInteger y) = LBool (x > y)
    eval IsLT (LInteger x) (LInteger y) = LBool (x < y)
    eval IsEQ (LInteger x) (LInteger y) = LBool (x == y)
    eval IsNE (LInteger x) (LInteger y) = LBool (x /= y)
    eval _ _ _ = error "Unsupported operand types in constant expression"

litType :: Lit -> Type
litType (LInteger _) = TInteger
litType (LString _)  = TString
litType (LBool _)    = TBool

addConst :: Id -> Expression :~ SrcPos -> Check (Expression :~ TypeAnn)
addConst name e = do
  st <- get
  let consts = userConsts st
  case lookup name consts of
    Just c  -> failCheck $ "Constant " ++ name ++ " was already defined as " ++ show c
    Nothing -> do
      val <- evalConst e
      let result = Annotate (Literal val) $ TypeAnn {
                     srcPos = annotation e,
                     typeOf = litType val,
                     localSymbols = M.empty }
      put $ st {userConsts = (name, result): consts}
      return result

instance Typed Program where
  typeCheck p@(content -> Program consts types vars fns body) = withSymbolTable $ do
      setPos p
      consts' <- inContext Outside $
                   forM consts $ \(n,v) -> do
                     v' <- addConst n v
                     let sym = Annotate {
                                 content = Symbol {
                                   symbolName = n,
                                   symbolType = typeOfA v',
                                   symbolDefLine = srcLine   (annotation v),
                                   symbolDefCol  = srcColumn (annotation v) },
                                 annotation = annotation v }
                     addSymbol sym
                     return (n, v')
                                    
      types' <- inContext Outside $
                  forM (M.assocs types) $ uncurry addType
      vars' <- inContext Outside $
                 forM vars checkSymbol
      fns'  <- inContext Outside $
                 forM fns $ \fn -> do
                   fn' <- typeCheck fn
                   let f = content fn'
                       tp = TFunction (argTypes f) (fnResultType f)
                       s = SrcPos {
                             srcLine = srcLine (annotation fn),
                             srcColumn = srcColumn (annotation fn) }
                   addSymbol $ Annotate (fnName f # tp) s
                   return fn'
      body' <- inContext ProgramBody $
                 forM body typeCheck
      let program = Program consts' (M.fromList types') (map errorOnUserTypeSymbol vars') fns' body'
      return $ Annotate program $ TypeAnn {
        srcPos = SrcPos 0 0,
        typeOf = TVoid,
        localSymbols = makeSymbolTable vars'}
    where
      argTypes :: Function TypeAnn -> [Type]
      argTypes (Function {..}) = map symbolTypeC fnFormalArgs

makeSymbolTable :: [Annotate Symbol TypeAnn] -> M.Map Id Symbol
makeSymbolTable xs = M.fromList $ map pair xs
  where
    pair :: Annotate Symbol TypeAnn -> (Id, Symbol)
    pair (Annotate s (TypeAnn {..})) =
      (symbolName s,
       s { symbolDefLine = srcLine srcPos,
           symbolDefCol  = srcColumn srcPos })

findField :: Id -> [(Id, Type)] -> Maybe (Int, Type)
findField name pairs = go 1 pairs
  where
    go _ [] = Nothing
    go i ((k,v):other)
         | k == name = Just (i, v)
         | otherwise = go (i+1) other

instance Typed LValue where
  typeCheck v@(content -> LVariable name) = do
    setPos v
    sym <- getSymbol name
    returnT (symbolType sym) v (LVariable name)

  typeCheck v@(content -> LArray name ix) = do
    setPos v
    sym <- getSymbol name
    case symbolType sym of
      TArray _ tp -> do
                     ix' <- typeCheck ix
                     when (typeOfA ix' /= TInteger) $
                       failCheck $ "Invalid array item lvalue: index is " ++ show (typeOfA ix') ++ ", not Integer"
                     returnT tp v (LArray name ix')
      x -> failCheck $ "Invalid lvalue: " ++ name ++ " is " ++ show x ++ ", not Array"

  typeCheck v@(content -> LField base field) = do
    setPos v
    baseSym <- getSymbol base
    case symbolType baseSym of
      TRecord pairs -> case findField field pairs of
                         Just (ix,t) -> returnT (TField ix t) v (LField base field)
                         Nothing -> failCheck $ "No such field in " ++ base ++ " record: " ++ field
      x -> failCheck $ base ++ " is " ++ show x ++ ", not Record"

instance Typed Statement where
  typeCheck x@(content -> Assign lvalue expr) = do
    setPos x
    lhs <- typeCheck lvalue
    rhs <- typeCheck expr
    let rhsType = typeOfA rhs
        lhsType = typeOfA lhs
    if (rhsType == TAny) || (rhsType `isSubtypeOf` lhsType)
      then do
           let result = Assign lhs rhs
           returnT lhsType x result
      else failCheck $ "Invalid assignment: LHS type is " ++ show lhsType ++ ", but RHS type is " ++ show rhsType

  typeCheck s@(content -> Procedure name args) = do
    setPos s
    sym <- getSymbol name
    case symbolType sym of
      TFunction formalArgTypes TVoid -> do
          args' <- mapM typeCheck args
          let actualTypes = map typeOfA args'
          if actualTypes `areSubtypesOf` formalArgTypes
            then returnT TVoid s (Procedure name args')
            else failCheck $ "Invalid types in procedure call: " ++ show actualTypes ++ " instead of " ++ show formalArgTypes
      t -> failCheck $ "Symbol " ++ name ++ " is not a procedure, but " ++ show t

  typeCheck s@(content -> Break) = do
      setPos s
      cxs <- gets contexts
      if null (filter isFor cxs)
        then failCheck "break statement not in for loop"
        else returnT TVoid s Break

  typeCheck s@(content -> Continue) = do
      setPos s
      cxs <- gets contexts
      if null (filter isFor cxs)
        then failCheck "continue statement not in for loop"
        else returnT TVoid s Continue

  typeCheck s@(content -> Exit) = do
    setPos s
    cxs <- gets contexts
    case cxs of
      (InFunction _ TVoid:_) -> returnT TVoid s Exit
      (ProgramBody:_)        -> returnT TVoid s Exit
      _                      -> failCheck "exit statement not in procedure or program body"

  typeCheck s@(content -> Return x) = do
    setPos s
    x' <- typeCheck x
    let retType = typeOfA x'
    cxs <- gets contexts
    case cxs of
      (InFunction _ TVoid:_) -> failCheck "return statement in procedure"
      (InFunction _ t:_)
          | retType `isSubtypeOf` t -> returnT (typeOfA x') s (Return x')
          | otherwise -> failCheck $ "Return value type does not match: expecting " ++ show t ++ ", got " ++ show retType
      _               -> failCheck $ "return statement not in function"

  typeCheck s@(content -> IfThenElse c a b) = do
    setPos s
    c' <- typeCheck c
    when (typeOfA c' /= TBool) $
      failCheck $ "Condition type is not Boolean: " ++ show c
    a' <- mapM typeCheck a
    b' <- mapM typeCheck b
    returnT TVoid s (IfThenElse c' a' b')

  typeCheck s@(content -> For name start end body) = inContext (ForLoop name 0) $ do
    setPos s
    sym <- getSymbol name
    when (symbolType sym /= TInteger) $
      failCheck $ "Counter variable is not Integer: " ++ name
    start' <- typeCheck start
    when (typeOfA start' /= TInteger) $
      failCheck $ "Counter start value is not Integer: " ++ show start
    end' <- typeCheck end
    when (typeOfA end' /= TInteger) $
      failCheck $ "Counter end value is not Integer: " ++ show end
    body' <- mapM typeCheck body
    returnT TVoid s (For name start' end' body')

instance Typed Function where
  typeCheck x@(content -> Function {..}) = do
    setPos x
    inContext (InFunction fnName fnResultType) $ withSymbolTable $ do
        args <- mapM checkSymbol fnFormalArgs
        vars <- mapM checkSymbol fnVars
        body <- mapM typeCheck fnBody
        let fn = Function fnName args fnResultType vars body
            tp = TFunction (map typeOfA args) fnResultType
        Annotate result ta <- returnT fnResultType x fn
        return $ Annotate result $ ta {localSymbols = makeSymbolTable vars}

instance Typed Expression where
  typeCheck e@(content -> Variable x) = do
    setPos e
    sym <- getSymbol x
    returnT (symbolType sym) e (Variable x)

  typeCheck e@(content -> ArrayItem name ix) = do
    setPos e
    sym <- getSymbol name
    case symbolType sym of
      TArray _ tp -> do
          ix' <- typeCheck ix
          when (typeOfA ix' /= TInteger) $
            failCheck $ "Array index is " ++ show (typeOfA ix') ++ ", not Integer"
          returnT tp e (ArrayItem name ix')
      x -> failCheck $ name ++ " is " ++ show x ++ ", not Array"

  typeCheck e@(content -> RecordField base field) = do
    setPos e
    baseSym <- getSymbol base
    case symbolType baseSym of
      TRecord pairs -> case findField field pairs of
                         Just (ix,t) -> returnT (TField ix t) e (RecordField base field)
                         Nothing -> failCheck $ "No such field in " ++ base ++ " record: " ++ field
      TField ix t -> returnT (TField ix t) e (RecordField base field)
      x -> failCheck $ base ++ " is " ++ show x ++ ", not Record"

  typeCheck e@(content -> Literal x) = returnT (litType x) e (Literal x)

  typeCheck e@(content -> Call name args) = do
    setPos e
    sym <- getSymbol name
    case symbolType sym of
      TFunction formalArgTypes resType -> do
          args' <- mapM typeCheck args
          let actualTypes = map typeOfA args'
          if actualTypes `areSubtypesOf` formalArgTypes
            then returnT resType e (Call name args')
            else failCheck $ "Invalid types in function call: " ++ show actualTypes ++ " instead of " ++ show formalArgTypes
      t -> failCheck $ "Symbol " ++ name ++ " is not a function, but " ++ show t

  typeCheck e@(content -> Op op x y) = do
    setPos e
    x' <- typeCheck x
    y' <- typeCheck y
    let tx = typeOfA x'
        ty = typeOfA y'
    if (TInteger `isSubtypeOf` tx) && (TInteger `isSubtypeOf` ty)
      then if op `elem` [IsEQ, IsNE, IsGT, IsLT]
             then returnT TBool    e (Op op x' y')
             else returnT TInteger e (Op op x' y')
      else failCheck $ "Invalid operand types: " ++ show tx ++ ", " ++ show ty

checkTypes :: Program :~ SrcPos -> Program :~ TypeAnn
checkTypes prog = evalState check emptyState
  where
    check :: State CheckState (Program :~ TypeAnn)
    check = do
      x <- runErrorT (runCheck $ typeCheck prog)
      case x of
        Right result -> return result
        Left  err -> fail $ "type checker: " ++ show err

checkSource :: FilePath -> IO (Program :~ TypeAnn)
checkSource path = do
  str <- readFile path
  case parse pProgram path str of
    Left err -> fail $ "parser: " ++ show err
    Right prog -> return (checkTypes prog)