{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, TypeSynonymInstances #-}
module CSPM.TypeChecker.Decl (typeCheckDecls) where

import Control.Monad
import Data.Graph.Wrapper
import qualified Data.Map as M
import qualified Data.Set as S
import Data.List (intersect, (\\), sortBy)

import CSPM.DataStructures.FreeVars
import CSPM.DataStructures.Names
import CSPM.DataStructures.Syntax hiding (getType)
import CSPM.DataStructures.Types
import CSPM.PrettyPrinter
import CSPM.TypeChecker.Common
import CSPM.TypeChecker.Exceptions
import CSPM.TypeChecker.Expr()
import CSPM.TypeChecker.Monad
import CSPM.TypeChecker.Pat()
import CSPM.TypeChecker.Unification
import Util.Annotated
import Util.List
import Util.PartialFunctions
import Util.PrettyPrint

-- | Type check a list of possibly mutually recursive functions
typeCheckDecls :: [TCDecl] -> TypeCheckMonad ()
typeCheckDecls decls = do

    -- Flatten the decls so that definitions in modules are also type-checked
    -- in the correct order.
    let flattenDecl :: TCDecl -> [TCDecl]
        flattenDecl (An a b (Module mn args ds1 ds2)) =
            [An a b (Module mn args
                        (concatMap flattenDecl ds1)
                        (concatMap flattenDecl ds2))]
        flattenDecl (d@(An _ _ (TimedSection _ _ ds))) =
            -- We need to type-check the function in the timed section, but we
            -- flatten the decls so that dependencies work out ok.
            d : concatMap flattenDecl ds
        flattenDecl x = [x]
        flatDecls = concatMap flattenDecl decls

        isInstance (An _ _ (ModuleInstance _ _ _ _ _)) = True
        isInstance _ = False

        instanceDecls = filter isInstance flatDecls

        -- | Map from declarations to integer identifiers
        declMap = zip flatDecls [0..]
        invDeclMap = invert declMap
    
    let
        namesBoundByDecls = concatMap (\ (decl, declId) ->
            case decl of
                An _ _ (TimedSection _ _ _) -> []
                _ -> [(declId, boundNames decl)]) declMap

        -- | Map from names to the identifier of the declaration that it is
        -- defined by.
        varToDeclIdMap = 
            [(n, declId) | (declId, ns) <- namesBoundByDecls, n <- ns]
        boundVars = map fst varToDeclIdMap

    -- Throw an error if a name is defined multiple times
    when (not (noDups boundVars)) $ panic "Duplicates found after renaming."

    -- Map from decl id -> [decl id] meaning decl id depends on the list of
    -- ids
    declDeps <- mapM (\ (decl, declId) -> do
            let deps = freeVars decl
            let depsInThisGroup = intersect deps boundVars
            return (declId, mapPF varToDeclIdMap depsInThisGroup)
        ) declMap

    let 
        -- | Edge from n -> n' iff n uses n'
        declGraph :: Graph Int Int
        declGraph = fromListSimple [(id, deps) | (id, deps) <- declDeps]
        -- | The graph of strongly connected components, with an edge
        -- from scc i to scc j if j depends on i, but i does not depend
        -- on j.
        sccgraph :: Graph (S.Set Int) (M.Map Int Int)
        sccgraph = transpose (sccGraph declGraph)
        -- | The strongly connected components themselves, topologically sorted
        sccs :: [S.Set Int]
        sccs = topologicalSort sccgraph
        
        -- | Get the declarations corresponding to certain ids
        typeInferenceGroup = mapPF invDeclMap

        -- | Checks that this SCC does not contain both a module and an
        -- instance of the module.
        checkSCCForModuleCycles :: [TCDecl] -> TypeCheckMonad ()
        checkSCCForModuleCycles decls =
            let 
                instances = [i | An _ _ (i@(ModuleInstance _ _ _ _ _)) <- decls]
                mods = [m | An _ _ (m@(Module _ _ _ _)) <- decls]

                instancesOfMod n =
                    [i | i@(ModuleInstance _ nt _ _ _) <- instances, nt == n]

                checkMod (Module nm _ _ _) = 
                    case instancesOfMod nm of
                        [] -> return ()
                        (ModuleInstance n _ _ instanceMap _ : is) -> do
                            -- Find the cycle
                            mapM_ (\ n -> do
                                raiseMessageAsError $
                                    illegalModuleInstanceCycleErrorMessage nm n 
                                        (mapPF (invert varToDeclIdMap)
                                            (pathBetweenVerticies declGraph
                                                (apply varToDeclIdMap nm)
                                                (apply varToDeclIdMap n)))
                                ) (map fst instanceMap)
                pathBetweenVerticies :: Ord i => Graph i v -> i -> i -> [i]
                pathBetweenVerticies g i i' =
                    let
                        -- Do a DFS
                        findPath visited i =
                            let scs = successors g i
                            in if i' `elem` scs then ([i'], [])
                                else check visited scs
                        check visited [] = ([], visited)
                        check visited (x:xs) =
                            let (path, visited') = findPath visited x
                            in case path of
                                [] -> check visited' xs
                                _ -> (x : path, visited')
                    in i : fst (findPath [] i)
            in mapM_ checkMod mods

        -- When an error occurs continue type checking, but only
        -- type check groups that do not depend on the failed group.
        -- failM is called at the end if any error has occured.
        typeCheckGroups [] b = if b then failM else return ()
        typeCheckGroups (g:gs) b = do
            err <- tryAndRecover True (do
                let ds = typeInferenceGroup $ S.toList g
                checkSCCForModuleCycles ds
                typeCheckMutualyRecursiveGroup ds
                return False
                ) (return True)
            if not err then typeCheckGroups gs b
            -- Else, continue type checking but remove all declaration groups
            -- that are reachable from this group. Also, set the flag to be 
            -- True to indicate that an error has occured so that failM is 
            -- called at the end.
            else typeCheckGroups (gs \\ (reachableVertices sccgraph g)) True

    -- Start type checking the groups
    typeCheckGroups sccs False

    let
        annotate (decl@(An _ psymbtable (Module _ _ ds1 ds2))) = do
            let ns = boundNames decl
            ts <- mapM getType ns
            setPSymbolTable (snd psymbtable) (zip ns ts)
            mapM_ annotate (ds1++ds2)
        annotate (decl@(An _ psymbtable (TimedSection _ _ ds))) = do
            let ns = boundNames decl
            ts <- mapM getType ns
            setPSymbolTable (snd psymbtable) (zip ns ts)
            mapM_ annotate ds
        annotate (decl@(An _ psymbtable _)) = do
            let ns = boundNames decl
            ts <- mapM getType ns
            setPSymbolTable (snd psymbtable) (zip ns ts)

    -- Add the type of each declaration (if one exists to each declaration)
    mapM_ annotate decls

-- | Type checks a group of certainly mutually recursive functions. Only 
-- functions that are mutually recursive should be included otherwise the
-- types could end up being less general.
typeCheckMutualyRecursiveGroup :: [TCDecl] -> TypeCheckMonad ()
typeCheckMutualyRecursiveGroup ds' = do
    -- TODO: fix temporary hack
    let 
        cmp x y = case (unAnnotate x, unAnnotate y) of
            (DataType _ _, DataType _ _) -> EQ
            (DataType _ _, _) -> LT
            (_, DataType _ _) -> GT
            (_, _) -> EQ
        ds = sortBy cmp ds'
        fvs = boundNames ds

    ftvs <- replicateM (length fvs) freshTypeVar
    zipWithM setType fvs (map (ForAll []) ftvs)

    -- Type check each declaration then generalise the types
    nts <- generaliseGroup fvs $ map (\ d -> 
        case boundNames d of
            [] -> typeCheck d
            (n:_) -> addDefinitionName n (typeCheck d)) ds

    -- Compress all the types we have inferred here (they should never be 
    -- touched again)
    mapM_ (\ n -> do
        t <- getType n
        t' <- compressTypeScheme t
        setType n t') fvs

-- | Takes a type and returns the inner type, i.e. the type that this
-- is a set of. For example TSet t1 -> t, TTuple [TSet t1, TSet t2] -> (t1, t2).
-- The type that is returned is guaranteed to satisfy Eq since, at the
-- recursion only bottoms out on reaching something that is of type TSet.
evalTypeExpression :: Type -> TypeCheckMonad Type
evalTypeExpression (TTuple ts) = do
    -- TTuple [TSet t1,...] = TSet (TTuple [t1,...])
    ts' <- mapM evalTypeExpression ts
    return $ TTuple ts'
evalTypeExpression (TDot t1 t2) = do
    -- TDot (TSet t1) (TSet t2) = TSet (TDot t1 t2)
    t1' <- evalTypeExpression t1
    t2' <- evalTypeExpression t2
    return $ TDot t1' t2'
-- Otherwise, it must be a set.
evalTypeExpression t = do
    fv <- freshTypeVar
    unify t (TSet fv)
    return fv

instance TypeCheckable TCDecl [(Name, Type)] where
    errorContext an = Nothing
    typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)

instance TypeCheckable (Decl Name) [(Name, Type)] where
    errorContext (FunBind n ms _) = Just $ 
        -- This will only be helpful if the equations don't match in
        -- type
        text "In the declaration of:" <+> prettyPrint n
    errorContext (p@(PatBind pat exp _)) = Just $
        hang (text "In a pattern binding:") tabWidth (prettyPrint p)
    errorContext (DataType n cs) = Just $
        text "In the declaration of:" <+> prettyPrint n
    errorContext (SubType n cs) = Just $
        text "In the declaration of:" <+> prettyPrint n
    errorContext (NameType n e) = Just $
        text "In the declaration of:" <+> prettyPrint n
    errorContext (Channel ns es) = Just $
        text "In the declaration of:" <+> list (map prettyPrint ns)
    errorContext (Assert a) = Just $
        text "In the assertion:" <+> prettyPrint a
    errorContext (TimedSection _ _ _) = Nothing
    errorContext (Transparent ns) = Nothing
    errorContext (External ns) = Nothing
    errorContext (ModuleInstance n _ _ _ _) = Just $
        text "In the declaration of the module instance:" <+> prettyPrint n
    errorContext (Module _ _ _ _) = Nothing
    errorContext (PrintStatement _) = Nothing
    
    typeCheck' (FunBind n ms mta) = do
        let boundTypeVars =
                case mta of
                    Just (An _ _ (STypeScheme boundNs _ _)) -> boundNs
                    _ -> []
        ts <- local boundTypeVars $ do
            mta <- case mta of
                    Just ta -> do
                        ForAll _ t <- typeCheck ta
                        return $ Just t
                    Nothing -> return Nothing
            mapM (\ m -> addErrorContext (matchCtxt m) $ 
                case mta of
                    Nothing -> typeCheck m
                    Just ta -> typeCheckExpect m ta) ms
        ForAll [] t <- getType n
        -- This unification also ensures that each equation has the same number
        -- of arguments.
        (t' @ (TFunction tsargs _)) <- unifyAll (t:ts)
        return [(n, t')]
        where
            matchCtxt an = 
                hang (text "In an equation for" <+> prettyPrint n <> colon) 
                    tabWidth (prettyPrintMatch n an)
    typeCheck' (p@(PatBind pat exp mta)) = do
        (texp, tpat) <-
            case mta of
                Nothing -> do
                    tpat <- typeCheck pat
                    texp <- typeCheck exp
                    return (tpat, texp)
                Just ta -> do
                    -- todo: check
                    ForAll _ typ <- typeCheck ta
                    tpat <- typeCheckExpect pat typ
                    texp <- typeCheckExpect exp typ
                    return (tpat, texp)
        -- We evaluate the dots to implement the 'longest match' rule. For
        -- example, suppose we have the following declaration:
        --   datatype A = B.Integers.Integers
        --   f(B.x) = x
        -- Then we make the decision that x should be of type Int.Int.
        tpat <- evaluateDots tpat
        texp <- evaluateDots texp
        -- We must disallow symmetric unification here as we don't want
        -- to allow patterns such as:
        --     x.y = B
        disallowSymmetricUnification (unify texp tpat)
        let ns = boundNames p
        ts <- mapM getType ns
        return $ zip ns [t | ForAll _ t <- ts]
    -- The following two clauses rely on the fact that they have been 
    -- prebound.
    typeCheck' (Channel ns Nothing) = do
        -- We now unify the each type to be a TEvent
        mapM (\ n -> do
            ForAll [] t <- getType n
            unify TEvent t) ns
        -- (Now getType n for any n in ns will return TEvent)
        return [(n, TEvent) | n <- ns]
    typeCheck' (Channel ns (Just e)) = do
        t <- typeCheck e
        -- Events must be comparable for equality.
        ensureHasConstraint CEq t
        valueType <- evalTypeExpression t
        dotList <- typeToDotList valueType
        let t = foldr TDotable TEvent dotList
        mapM (\ n -> do
            ForAll [] t' <- getType n
            unify t' t) ns
        return $ [(n, t) | n <- ns]
    typeCheck' (SubType n clauses) = do
        -- Get the type fromthe first clause
        parentType <- freshTypeVar
        mapM_ (\ clause -> do
                let nclause = case unAnnotate clause of
                            DataTypeClause x _ -> x
                (_, tsFields) <- typeCheck clause
                ForAll [] typeCon <- getType nclause
                (actFields, dataType) <- dotableToDotList typeCon
                -- Check that the datatype is the correct subtype.
                tvref' <- freshTypeVarRef []
                unify (TExtendable parentType tvref') dataType
                -- Check that the fields are compatible with the expected fields.
                zipWithM unify actFields tsFields
            ) clauses
        ForAll [] t <- getType n
        t' <- unify t (TSet parentType)
        return [(n, TSet parentType)]
    typeCheck' (DataType n clauses) = do
        ForAll [] t <- getType n
        unify t (TSet (TDatatype n))
        ntss <- mapM (\ clause -> do
            let 
                n' = case unAnnotate clause of
                        DataTypeClause x _ -> x
            ForAll [] t <- getType n'
            (n', ts) <- typeCheck clause
            let texp = foldr TDotable (TDatatype n) ts
            t <- unify texp t
            return ((n', t), ts)
            ) clauses
        let (nts, tcss) = unzip ntss
            tclauses = concat tcss
        tclauses <- mapM (\t -> compress t) tclauses
        -- We now need to decide if we should allow this type to be comparable
        -- for equality. Thus, we check to see if each of the fields in each of
        -- the constructors is comparable for equality.

        -- We mark the type for equality, as if the type depends only on itself
        -- (i.e. it is recursive), then it should be comparable for equality.
        markDatatypeAsComparableForEquality n
        b <- tryAndRecover False 
                (mapM_ (ensureHasConstraint CEq) tclauses >> return True)
                (return False)
        when (not b) $ unmarkDatatypeAsComparableForEquality n
        ForAll [] t <- getType n
        t' <- unify t (TSet (TDatatype n))
        return $ (n, t'):nts
    typeCheck' (NameType n e) = do
        t <- typeCheck e
        valueType <- evalTypeExpression t
        return [(n, TSet valueType)]
    typeCheck' (Transparent ns) = return []
    typeCheck' (External ns) = return []
    typeCheck' (Assert a) = typeCheck a >> return []
    typeCheck' (TimedSection (Just tn) f _) = do
        typeCheckExpect (Var tn) TEvent
        case f of
            Just f -> typeCheckExpect f (TFunction [TEvent] TInt) >> return ()
            Nothing -> return ()
        return []
    typeCheck' (Module n args pubDs privDs) = do
        let fvs = boundNames args
        local fvs $ do
            tpats <- mapM (\ pat -> typeCheck pat >>= evaluateDots) args
            typeCheckDecls (pubDs ++ privDs)
            tpats <- mapM (\ pat -> typeCheck pat >>= evaluateDots) args
            return [(n, TTuple tpats)]

    typeCheck' (ModuleInstance n nt args nm (Just mod)) = do
        -- Check instance
        stk <- getDefinitionStack
        when (nt `elem` stk) $ raiseMessageAsError $
            illegalModuleInstanceCycleErrorMessage nt n $ nt :
                reverse (takeWhile (\n -> not (n == nt)) stk)

        ts <- getType nt
        (TTuple ts, sub) <- instantiate' ts
        targs <- zipWithM typeCheckExpect args ts
        let subName n = case safeApply (invert nm) n of
                            Just n' -> n'
                            Nothing -> n
        -- Set the types of each of our arguments
        nts <- mapM (\ (ourName, theirName) -> do
                ForAll xs t <- getType theirName
                t' <- substituteTypes sub t
                -- We also need to change any datatype according to name map
                let sub (TVar tvref) = do
                        res <- readTypeRef tvref
                        case res of 
                            Left _ -> return $ TVar tvref
                            Right t -> sub t
                    sub (TSet t) = sub t >>= return . TSet
                    sub (TSeq t) = sub t >>= return . TSeq
                    sub (TDot t1 t2) = do
                        t1 <- sub t1
                        t2 <- sub t2
                        return $! TDot t1 t2
                    sub (TTuple ts) = mapM sub ts >>= return . TTuple
                    sub (TFunction ts t) = do
                        ts <- mapM sub ts
                        t <- sub t
                        return  $! TFunction ts t
                    sub (TDatatype n) = return $ TDatatype $! subName n
                    sub (TDotable t1 t2) = do
                        t1 <- sub t1
                        t2 <- sub t2
                        return $! TDotable t1 t2
                    sub (TExtendable t tvref) = do
                        t' <- sub t
                        res <- readTypeRef tvref
                        case res of
                            Left _ -> return $ TExtendable t tvref
                            Right t -> do
                                tsub' <- sub t
                                writeTypeRef tvref tsub'
                                return $ TExtendable t' tvref
                    sub TInt = return TInt
                    sub TBool = return TBool
                    sub TProc = return TProc
                    sub TEvent = return TEvent
                    sub TChar = return TChar
                    sub TExtendableEmptyDotList = return TExtendableEmptyDotList

                t' <- sub t'
                setType ourName $ ForAll xs t'
                return (ourName, t')
            ) nm

        let An _ _ (Module _ _ privDs pubDs) = mod

        -- Mark datatypes as comparable as appropriate
        mapM_ (\ d -> case unAnnotate d of
            DataType n clauses -> do
                let n' = subName n
                tclauses <- mapM (\ (An _ _ (DataTypeClause n _)) -> do
                    ForAll _ t <- getType (subName n)
                    return t) clauses
                markDatatypeAsComparableForEquality n'
                b <- tryAndRecover False 
                        (mapM_ (ensureHasConstraint CEq) tclauses >> return True)
                        (return False)
                when (not b) $ unmarkDatatypeAsComparableForEquality n'
            _ -> return ()
            ) (privDs ++ pubDs)
        return nts
    typeCheck' (PrintStatement _) = return []

instance TypeCheckable TCAssertion () where
    errorContext an = Nothing
    typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)

instance TypeCheckable (Assertion Name) () where
    errorContext a = Just $ 
        hang (text "In the assertion" <> colon) tabWidth (prettyPrint a)
    typeCheck' (PropertyCheck e1 p m) = do
        ensureIsProc e1
        return ()
    typeCheck' (Refinement e1 m e2 opts) = do
        ensureIsProc e1
        ensureIsProc e2
        mapM_ typeCheck opts
    typeCheck' (ASNot a) = typeCheck a

instance TypeCheckable (ModelOption Name) () where
    errorContext a = Nothing
    typeCheck' (TauPriority e) = do
        typeCheckExpect e (TSet TEvent)
        return ()

instance TypeCheckable TCDataTypeClause (Name, [Type]) where
    errorContext an = Nothing
    typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)

instance TypeCheckable (DataTypeClause Name) (Name, [Type]) where
    errorContext c = Just $
        hang (text "In the data type clause" <> colon) tabWidth 
            (prettyPrint c)
    typeCheck' (DataTypeClause n' Nothing) = do
        return (n', [])
    typeCheck' (DataTypeClause n' (Just e)) = do
        t <- typeCheck e
        valueType <- evalTypeExpression t
        dotList <- typeToDotList valueType
        return (n', dotList)

instance TypeCheckable TCMatch Type where
    errorContext an = Nothing
    typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
    typeCheckExpect an t = setSrcSpan (loc an) $ typeCheckExpect (inner an) t
instance TypeCheckable (Match Name) Type where
    -- We create the error context in FunBind as that has access
    -- to the name
    errorContext (Match groups exp) = Nothing
    typeCheck' (Match groups exp) = do
        -- Introduce free variables for all the parameters
        let fvs = boundNames groups
        local fvs $ do
            tgroups <- mapM (\ pats -> mapM (\ pat -> 
                    -- We evaluate the dots here to implment the longest 
                    -- match rule
                    typeCheck pat >>= evaluateDots
                ) pats) groups
    
            -- We evaluate the dots here to implment the longest match rule
            tr <- typeCheck exp >>= evaluateDots
            
            -- We need to evaluate the dots in the patterns twice just in case 
            -- the type inferences on the RHS have resulted in extra dots on 
            -- the left being able to be removed.
            tgroups <- mapM (\ pats -> mapM (\ pat -> 
                    typeCheck pat >>= evaluateDots
                ) pats) groups

            return $ foldr (\ targs tr -> TFunction targs tr) tr tgroups
    typeCheckExpect (Match groups exp) tsig = do
        -- Introduce free variables for all the parameters
        let fvs = boundNames groups
        local fvs $ do
            -- Check that the function signature is of a plausible shape
            rt <- freshTypeVar
            argts <- mapM (flip replicateM freshTypeVar) (map length groups)
            unify tsig $ foldr (\ targs tr -> TFunction targs tr) rt argts

            -- The rest of the code is as before (comments before also apply)
            tgroups <- zipWithM (\ pats argts -> zipWithM (\ pat argt -> 
                    typeCheckExpect pat argt >>= evaluateDots
                ) pats argts) groups argts    
            tr <- typeCheckExpect exp rt >>= evaluateDots
            tgroups <- mapM (\ pats -> mapM (\ pat -> 
                    typeCheck pat >>= evaluateDots
                ) pats) groups
            return $ foldr (\ targs tr -> TFunction targs tr) tr tgroups

instance TypeCheckable TCSTypeScheme TypeScheme where
    errorContext an = Nothing
    typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
instance TypeCheckable (STypeScheme Name) TypeScheme where
    errorContext _ = Nothing
    typeCheck' (STypeScheme boundNs cs t) = do
        tvs <- mapM (\ n -> do
            let ncs = map (\ (STypeConstraint c _) -> c) $
                        filter (\ (STypeConstraint _ n') -> n == n') $
                        (map unAnnotate cs)
            t@(TVar tvref) <- freshRigidTypeVarWithConstraints n ncs
            setType n (ForAll [] t)
            return (typeVar tvref, constraints tvref)
            ) boundNs
        t' <- typeCheck t
        return $ ForAll tvs t'

instance TypeCheckable TCSType Type where
    errorContext _ = Nothing
    typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
instance TypeCheckable (SType Name) Type where
    errorContext _ = Nothing
    typeCheck' (STVar var) = getType var >>= \ (ForAll [] t) -> return t
    typeCheck' (STExtendable t var) = do
        t <- typeCheck t
        TVar tvref <- getType var >>= \ (ForAll [] t) -> return t
        return $ TExtendable t tvref
    typeCheck' (STSet t) = typeCheck t >>= return . TSet
    typeCheck' (STSeq t) = typeCheck t >>= return . TSeq
    typeCheck' (STDot t1 t2) = do
        t1' <- typeCheck t1
        t2' <- typeCheck t2
        return $ TDot t1' t2'
    typeCheck' (STMap t1 t2) = do
        t1' <- typeCheck t1
        t2' <- typeCheck t2
        return $ TMap t1' t2'
    typeCheck' (STTuple ts) = mapM typeCheck ts >>= return . TTuple
    typeCheck' (STFunction args rt) = do
        targs <- mapM typeCheck args
        trt <- typeCheck rt
        return $ TFunction targs trt
    typeCheck' (STDotable t1 t2) = do
        t1' <- typeCheck t1
        t2' <- typeCheck t2
        return $ TDotable t1' t2'
    typeCheck' (STParen t) = typeCheck t

    typeCheck' (STDatatype n) = return $ TDatatype n
    typeCheck' STProc = return TProc
    typeCheck' STInt = return TInt
    typeCheck' STBool = return TBool
    typeCheck' STChar = return TChar
    typeCheck' STEvent = return TEvent