module CSPM.TypeChecker.Monad (
    readTypeRef, writeTypeRef, freshTypeVar, freshTypeVarWithConstraints,
    getType, setType, 
    isTypeUnsafe, markTypeAsUnsafe, 
    replacementForDeprecatedName, isDeprecated, markAsDeprecated,
    compress, compressTypeScheme, 
    
    TypeCheckMonad, runTypeChecker, 
    newTypeInferenceState,  TypeInferenceState(..), getState,
    
    local, getEnvironment,
    ErrorContext, addErrorContext, getErrorContexts,
    getSrcSpan, setSrcSpan,
    getUnificationStack, addUnificationPair,
    symmetricUnificationAllowed, disallowSymmetricUnification,
    getInError, setInError,
    resetWarnings, getWarnings, addWarning,
    
    raiseMessageAsError, raiseMessagesAsError, panic,
    manyErrorsIfFalse, errorIfFalseM, errorIfFalse, tryAndRecover, failM,
    
    addDataTypeOrChannel, isDataTypeOrChannel,
)
where

import Control.Monad.State
import Prelude hiding (lookup)

import CSPM.DataStructures.Names
import CSPM.DataStructures.Types
import qualified CSPM.TypeChecker.Environment as Env
import CSPM.TypeChecker.Exceptions
import Util.Annotated
import Util.Exception
import Util.PrettyPrint

-- *************************************************************************
-- Type Checker Monad
-- *************************************************************************
type ErrorContext = Doc

data TypeInferenceState = TypeInferenceState {
        -- | The type environment, which is a map from names to types.
        environment :: Env.Environment,
        -- | List of names that correspond to channels and
        -- datatypes - used when detecting dependencies.
        dataTypesAndChannels :: [Name],
        -- | Location of the current AST element - used for error
        -- pretty printing
        srcSpan :: SrcSpan,
        -- | Error stack - provides context information for any
        -- errors that might be raised
        errorContexts :: [ErrorContext],
        -- | Errors that have occured
        errors :: [ErrorMessage],
        -- | List of warnings that have occured
        warnings :: [ErrorMessage],
        -- | Stack of attempted unifications - the current one
        -- is at the front. In the form (expected, actual).
        unificationStack :: [(Type, Type)],
        -- | Are we currently in an error state
        inError :: Bool,
        symUnificationAllowed :: Bool
    }
    
newTypeInferenceState :: TypeInferenceState
newTypeInferenceState = TypeInferenceState {
        environment = Env.new,
        dataTypesAndChannels = [],
        srcSpan = Unknown,
        errorContexts = [],
        errors = [],
        warnings = [],
        unificationStack = [],
        inError = False,
        symUnificationAllowed = True
    }

type TypeCheckMonad = StateT TypeInferenceState IO

-- | Runs the typechecker, starting from state 'st'. If any errors are
-- encountered then a 'SourceError' will be thrown with the relevent
-- error messages.
runTypeChecker :: TypeInferenceState -> TypeCheckMonad a -> IO a
runTypeChecker st prog = do
    (a, st) <- runStateT prog st
    let errs = errors st
    case errs of
        [] -> return a
        _ -> throwSourceError errs

getState :: TypeCheckMonad TypeInferenceState
getState = gets id

getEnvironment :: TypeCheckMonad Env.Environment
getEnvironment = gets environment

setEnvironment :: Env.Environment -> TypeCheckMonad ()
setEnvironment env = modify (\ st -> st { environment = env })

local :: [Name] -> TypeCheckMonad a -> TypeCheckMonad a
local ns m = 
    do
        env <- getEnvironment
        newArgs <- replicateM (length ns) freshTypeVar
        let symbs = map (Env.mkSymbolInformation . ForAll []) newArgs
        setEnvironment (Env.bind env (zip ns symbs))
        
        res <- m
    
        env <- getEnvironment
        setEnvironment (foldr (flip Env.delete) env ns)

        return res

addDataTypeOrChannel :: Name -> TypeCheckMonad ()
addDataTypeOrChannel n = modify (\ st -> st { 
        dataTypesAndChannels = n:dataTypesAndChannels st
    })

isDataTypeOrChannel :: Name -> TypeCheckMonad Bool
isDataTypeOrChannel n = do
    ns <- gets dataTypesAndChannels
    return $ n `elem` ns

getErrorContexts :: TypeCheckMonad [ErrorContext]
getErrorContexts = gets errorContexts

addErrorContext :: ErrorContext -> TypeCheckMonad a -> TypeCheckMonad a
addErrorContext c p = do
    ctxts <- getErrorContexts
    modify (\st -> st { errorContexts = c:ctxts })
    a <- p
    modify (\st -> st { errorContexts = ctxts })
    return a
    
getErrors :: TypeCheckMonad ErrorMessages
getErrors = gets errors

addErrors :: [ErrorMessage] -> TypeCheckMonad ()
addErrors es = modify (\ st -> st { errors = es++(errors st) })

getWarnings :: TypeCheckMonad [ErrorMessage]
getWarnings = gets warnings

resetWarnings :: TypeCheckMonad ()
resetWarnings = modify (\st -> st { warnings = [] })

addWarning :: Warning -> TypeCheckMonad ()
addWarning w = do
    src <- getSrcSpan
    let m = mkWarningMessage src w
    modify (\ st -> st { warnings = m:(warnings st) })

getInError :: TypeCheckMonad Bool
getInError = gets inError

setInError :: Bool -> TypeCheckMonad a -> TypeCheckMonad a
setInError b prog = do
    o <- getInError
    modify (\st -> st { inError = b })
    a <- prog
    modify (\st -> st { inError = o })
    return a

getSrcSpan :: TypeCheckMonad SrcSpan
getSrcSpan = gets srcSpan

-- | Sets the SrcSpan only within prog.
setSrcSpan :: SrcSpan -> TypeCheckMonad a -> TypeCheckMonad a
setSrcSpan loc prog = do
    oLoc <- getSrcSpan
    modify (\st -> st { srcSpan = loc })
    a <- prog
    modify (\st -> st { srcSpan = oLoc })
    return a

getUnificationStack :: TypeCheckMonad [(Type, Type)]
getUnificationStack = gets unificationStack

addUnificationPair :: (Type, Type) -> TypeCheckMonad a -> TypeCheckMonad a
addUnificationPair tp p = do
    stk <- getUnificationStack
    modify (\st -> st { unificationStack = tp:stk })
    a <- p
    modify (\ st -> st { unificationStack = stk })
    return a

symmetricUnificationAllowed :: TypeCheckMonad Bool
symmetricUnificationAllowed = gets symUnificationAllowed

disallowSymmetricUnification :: TypeCheckMonad a -> TypeCheckMonad a
disallowSymmetricUnification prog = do
    b <- symmetricUnificationAllowed
    modify (\st -> st { symUnificationAllowed = False })
    v <- prog
    modify (\st -> st { symUnificationAllowed = b })
    return v
    
-- Error handling

-- | Report the error if first parameter is False.
errorIfFalse :: Bool -> Error -> TypeCheckMonad ()
errorIfFalse b e = manyErrorsIfFalse b [e]

manyErrorsIfFalse :: Bool -> [Error] -> TypeCheckMonad ()
manyErrorsIfFalse True es = return ()
manyErrorsIfFalse False es = raiseMessagesAsError es

errorIfFalseM :: TypeCheckMonad Bool -> Error -> TypeCheckMonad ()
errorIfFalseM m e = m >>= \res -> errorIfFalse res e

failM :: TypeCheckMonad a
failM = do
    errs <- getErrors
    throwSourceError errs

raiseMessageAsError :: Error -> TypeCheckMonad a
raiseMessageAsError msg = raiseMessagesAsError [msg]

-- | Report a message as an error
raiseMessagesAsError :: [Error] -> TypeCheckMonad a
raiseMessagesAsError msgs = do
    src <- getSrcSpan
    ctxts <- getErrorContexts
    let ctxtDocs = ctxts
    let contextMsg = trimAndRenderContexts maxContexts ctxtDocs
    addErrors [mkErrorMessage src (msg $$ contextMsg) | msg <- msgs]
    failM
    where
        -- Don't print too much context
        trimAndRenderContexts 0 _ = empty
        trimAndRenderContexts n [] = empty
        trimAndRenderContexts n (doc:docs) = 
            doc $$ trimAndRenderContexts (n-1) docs
        maxContexts = 3

tryAndRecover :: 
    TypeCheckMonad a -> TypeCheckMonad a -> TypeCheckMonad a
tryAndRecover prog handler = tryM prog >>= \x ->
    case x of
        Right a -> return a
        Left ex -> case ex of
            SourceError msgs -> do
                -- The errors will be discarded as the errors propogate up.
                -- Hence, we reset them.
                setErrors msgs
                handler
            UserError -> handler
            -- An exception type we don't want to interfer with
            e           -> throwException e
    where
        setErrors :: ErrorMessages -> TypeCheckMonad ()
        setErrors es = modify (\ st -> st { errors = es })  


-- *************************************************************************
-- Type Operations
-- *************************************************************************
readTypeRef :: TypeVarRef -> TypeCheckMonad (Either (TypeVar, [Constraint]) Type)
readTypeRef (TypeVarRef tv cs ioref) = 
    do
        mtyp <- readPType ioref
        case mtyp of
            Just t  -> return (Right t)
            Nothing -> return (Left (tv, cs))

writeTypeRef :: TypeVarRef -> Type -> TypeCheckMonad ()
writeTypeRef (TypeVarRef tv cs ioref) t = setPType ioref t

getSymbolInformation :: Name -> TypeCheckMonad Env.SymbolInformation
getSymbolInformation n = do
    env <- gets environment
    -- Force evaluation of lookup n
    -- If we don't do this the error is deferred until later
    case Env.maybeLookup env n of
        Just symb -> return symb
        Nothing -> panic "Name not found after renaming."

-- | Get the type of 'n' and throw an exception if it doesn't exist.
getType :: Name -> TypeCheckMonad TypeScheme
getType n = do
    symb <- getSymbolInformation n
    return $ Env.typeScheme symb
    
-- | Sets the type of n to be t in the current scope only. No unification is 
-- performed.
setType :: Name -> TypeScheme -> TypeCheckMonad ()
setType n t = do
    env <- getEnvironment
    case Env.maybeLookup env n of
        Just symb -> setSymbolInformation n (symb { Env.typeScheme = t })
        Nothing -> setSymbolInformation n (Env.mkSymbolInformation t)

setSymbolInformation :: Name -> Env.SymbolInformation -> TypeCheckMonad ()
setSymbolInformation n symb = do
    env <- getEnvironment
    setEnvironment (Env.update env n symb)

isDeprecated :: Name -> TypeCheckMonad Bool
isDeprecated n = do
    symb <- getSymbolInformation n
    return $ Env.isDeprecated symb

isTypeUnsafe :: Name -> TypeCheckMonad Bool
isTypeUnsafe n = do
    symb <- getSymbolInformation n
    return $ Env.isTypeUnsafe symb

markAsDeprecated :: Name -> Maybe Name -> TypeCheckMonad ()
markAsDeprecated n repl = do
    symb <- getSymbolInformation n
    setSymbolInformation n (symb { 
        Env.isDeprecated = True, 
        Env.deprecationReplacement = repl
    })
markTypeAsUnsafe :: Name -> TypeCheckMonad ()
markTypeAsUnsafe n = do
    symb <- getSymbolInformation n
    setSymbolInformation n (symb { Env.isTypeUnsafe = True })

replacementForDeprecatedName :: Name -> TypeCheckMonad (Maybe Name)
replacementForDeprecatedName n = do
    symb <- getSymbolInformation n
    return $ Env.deprecationReplacement symb

-- | Apply compress to the type of a type scheme.
compressTypeScheme :: TypeScheme -> TypeCheckMonad TypeScheme
compressTypeScheme (ForAll ts t) = 
    do
        t' <- compress t
        return $ ForAll ts t'

-- | Takes a type and compresses the type by reading all type variables and
-- if they point to another type, it returns that type instead.
compress :: Type -> TypeCheckMonad Type
compress (tr @ (TVar typeRef)) = do
    res <- readTypeRef typeRef
    case res of
        Left tv -> return tr
        Right t -> compress t
compress (TFunction targs tr) = do
    targs' <- mapM compress targs
    tr' <- compress tr
    return $ TFunction targs' tr'
compress (TSeq t) = compress t >>= return . TSeq
compress (TSet t) = compress t >>= return . TSet
compress (TTuple ts) = mapM compress ts >>= return . TTuple
compress (TDotable t1 t2)= do
    t1' <- compress t1
    t2' <- compress t2
    return $ TDotable t1' t2'
compress (TDatatype n) = return $ TDatatype n
compress (TDot t1 t2) = do
    t1' <- compress t1
    t2' <- compress t2
    return $ TDot t1' t2'
compress t = return t