{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, TypeSynonymInstances #-}
module CSPM.TypeChecker.Decl (typeCheckDecls) where

import Control.Monad
import Control.Monad.Trans
import Data.Graph.Wrapper
import qualified Data.Map as M
import qualified Data.Set as S
import Data.List (nub, intersect, (\\), sortBy)

import CSPM.DataStructures.Names
import CSPM.DataStructures.Syntax hiding (getType)
import CSPM.DataStructures.Types
import CSPM.PrettyPrinter
import CSPM.TypeChecker.BuiltInFunctions
import CSPM.TypeChecker.Common
import CSPM.TypeChecker.Dependencies
import CSPM.TypeChecker.Exceptions
import CSPM.TypeChecker.Expr
import CSPM.TypeChecker.Monad
import CSPM.TypeChecker.Pat
import CSPM.TypeChecker.Unification
import Util.Annotated
import Util.List
import Util.Monad
import Util.PartialFunctions
import Util.PrettyPrint

-- | Type check a list of possibly mutually recursive functions
typeCheckDecls :: [TCDecl] -> TypeCheckMonad ()
typeCheckDecls decls = do
    namesBoundByDecls <- mapM (\ decl -> do
        namesBound <- namesBoundByDecl decl
        return (decl, namesBound)) decls
        -- | Map from declarations to integer identifiers
        declMap = zip decls [0..]
        invDeclMap = invert declMap
        -- | Map from names to the identifier of the declaration that it is
        -- defined by.
        varToDeclIdMap = 
            [(n, apply declMap d) | (d, ns) <- namesBoundByDecls, n <- ns]
        boundVars = map fst varToDeclIdMap
        namesToLocations = [(n, loc d) | (d, ns) <- namesBoundByDecls, n <- ns]

    -- Throw an error if a name is defined multiple times
    when (not (noDups boundVars)) $ panic "Duplicates found after renaming."

    -- We prebind the datatypes and channels as they can be matched on in 
    -- patterns (and thus, given a var in a pattern we can't decide if it
    -- is free or a dependency otherwise).
    mapM_ registerChannelsAndDataTypes (map unAnnotate decls)

    -- Map from decl id -> [decl id] meaning decl id depends on the list of
    -- ids
    declDeps <- mapM (\ decl -> do
            deps <- dependencies decl
            let depsInThisGroup = intersect deps boundVars
            return (apply declMap decl, mapPF varToDeclIdMap depsInThisGroup)
        ) decls

        -- | Edge from n -> n' iff n uses n'
        declGraph :: Graph Int Int
        declGraph = fromListSimple [(id, deps) | (id, deps) <- declDeps]
        -- | The graph of strongly connected components, with an edge
        -- from scc i to scc j if j depends on i, but i does not depend
        -- on j.
        sccgraph :: Graph (S.Set Int) (M.Map Int Int)
        sccgraph = transpose (sccGraph declGraph)
        -- | The strongly connected components themselves, topologically sorted
        sccs :: [S.Set Int]
        sccs = topologicalSort sccgraph
        -- | Get the declarations corresponding to certain ids
        typeInferenceGroup = mapPF invDeclMap

        -- When an error occurs continue type checking, but only
        -- type check groups that do not depend on the failed group.
        -- failM is called at the end if any error has occured.
        typeCheckGroups [] b = if b then failM else return ()
        typeCheckGroups (g:gs) b = do
            err <- tryAndRecover (do
                typeCheckMutualyRecursiveGroup (typeInferenceGroup (S.toList g))
                return False
                ) (return True)
            if not err then typeCheckGroups gs b
            -- Else, continue type checking but remove all declaration groups
            -- that are reachable from this group. Also, set the flag to be 
            -- True to indicate that an error has occured so that failM is 
            -- called at the end.
            else typeCheckGroups (gs \\ (reachableVertices sccgraph g)) True
    -- Start type checking the groups
    typeCheckGroups sccs False

-- This method heavily affects the DataType clause of typeCheckDecl.
-- If any changes are made here changes will need to be made to typeCheckDecl
-- too

-- We have to prebind all datatype clauses and channel names so
-- that we can identify when a particular pattern uses these clauses and
-- channels. We do this by injecting them into the symbol table earlier
-- than normal.
registerChannelsAndDataTypes :: Decl Name -> TypeCheckMonad ()
registerChannelsAndDataTypes (DataType n cs) = do
    mapM_ (\ c -> case unAnnotate c of
            DataTypeClause n' _ -> addDataTypeOrChannel n'
        ) cs
registerChannelsAndDataTypes (Channel ns _) = 
    mapM_ addDataTypeOrChannel ns
registerChannelsAndDataTypes _ = return ()

-- | Type checks a group of certainly mutually recursive functions. Only 
-- functions that are mutually recursive should be included otherwise the
-- types could end up being less general.
typeCheckMutualyRecursiveGroup :: [TCDecl] -> TypeCheckMonad ()
typeCheckMutualyRecursiveGroup ds' = do
    -- TODO: fix temporary hack
        cmp x y = case (unAnnotate x, unAnnotate y) of
            (DataType _ _, DataType _ _) -> EQ
            (DataType _ _, _) -> LT
            (_, DataType _ _) -> GT
            (_, _) -> EQ
        ds = sortBy cmp ds'
    fvs <- liftM nub (concatMapM namesBoundByDecl ds)

    ftvs <- replicateM (length fvs) freshTypeVar
    zipWithM setType fvs (map (ForAll []) ftvs)
    -- The list of all variables bound by these declaration
    fvs <- liftM nub (concatMapM namesBoundByDecl ds)

    -- Type check each declaration then generalise the types
    nts <- generaliseGroup fvs (map typeCheck ds)
    -- Add the type of each declaration (if one exists to each declaration)
    zipWithM annotate nts ds

    -- Compress all the types we have inferred here (they should never be 
    -- touched again)
    mapM_ (\ n -> do
        t <- getType n
        t' <- compressTypeScheme t
        setType n t') fvs
        annotate nts (An _ psymbtable _) = setPSymbolTable (snd psymbtable) nts

-- | Takes a type and returns the inner type, i.e. the type that this
-- is a set of. For example TSet t1 -> t, TTuple [TSet t1, TSet t2] -> (t1, t2).
-- The type that is returned is guaranteed to satisfy Eq since, at the
-- recursion only bottoms out on reaching something that is of type TSet.
evalTypeExpression :: Type -> TypeCheckMonad Type
evalTypeExpression (TTuple ts) = do
    -- TTuple [TSet t1,...] = TSet (TTuple [t1,...])
    ts' <- mapM evalTypeExpression ts
    return $ TTuple ts'
evalTypeExpression (TDot t1 t2) = do
    -- TDot (TSet t1) (TSet t2) = TSet (TDot t1 t2)
    t1' <- evalTypeExpression t1
    t2' <- evalTypeExpression t2
    return $ TDot t1' t2'
-- Otherwise, it must be a set.
evalTypeExpression t = do
    fv <- freshTypeVar
    unify t (TSet fv)
    return fv

instance TypeCheckable TCDecl [(Name, Type)] where
    errorContext an = Nothing
    typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)

instance TypeCheckable (Decl Name) [(Name, Type)] where
    errorContext (FunBind n ms) = Just $ 
        -- This will only be helpful if the equations don't match in
        -- type
        text "In the declaration of:" <+> prettyPrint n
    errorContext (p@(PatBind pat exp)) = Just $
        hang (text "In a pattern binding:") tabWidth (prettyPrint p)
    errorContext (DataType n cs) = Just $
        text "In the declaration of:" <+> prettyPrint n
    errorContext (NameType n e) = Just $
        text "In the declaration of:" <+> prettyPrint n
    errorContext (Channel ns es) = Just $
        text "In the declaration of:" <+> list (map prettyPrint ns)
    errorContext (Assert a) = Just $
        text "In the assertion:" <+> prettyPrint a
    errorContext (Transparent ns) = Nothing
    errorContext (External ns) = Nothing
    typeCheck' (FunBind n ms) = do
        ts <- mapM (\ m -> addErrorContext (matchCtxt m) $ typeCheck m) ms
        ForAll [] t <- getType n
        -- This unification also ensures that each equation has the same number
        -- of arguments.
        (t' @ (TFunction tsargs _)) <- unifyAll (t:ts)
        return [(n, t')]
            matchCtxt an = 
                hang (text "In an equation for" <+> prettyPrint n <> colon) 
                    tabWidth (prettyPrintMatch n an)
    typeCheck' (PatBind pat exp) = do
        tpat <- typeCheck pat
        texp <- typeCheck exp
        -- We evaluate the dots to implement the 'longest match' rule. For
        -- example, suppose we have the following declaration:
        --   datatype A = B.Integers.Integers
        --   f(B.x) = x
        -- Then we make the decision that x should be of type Int.Int.
        tpat <- evaluateDots tpat
        texp <- evaluateDots texp
        -- We must disallow symmetric unification here as we don't want
        -- to allow patterns such as:
        --     x.y = B
        disallowSymmetricUnification (unify texp tpat)
        ns <- namesBoundByDecl' (PatBind pat exp)
        ts <- mapM getType ns
        return $ zip ns [t | ForAll _ t <- ts]
    -- The following two clauses rely on the fact that they have been 
    -- prebound.
    typeCheck' (Channel ns Nothing) = do
        -- We now unify the each type to be a TEvent
        mapM (\ n -> do
            ForAll [] t <- getType n
            unify TEvent t) ns
        -- (Now getType n for any n in ns will return TEvent)
        return [(n, TEvent) | n <- ns]
    typeCheck' (Channel ns (Just e)) = do
        t <- typeCheck e
        valueType <- evalTypeExpression t
        dotList <- typeToDotList valueType
        let t = foldr TDotable TEvent dotList
        mapM (\ n -> do
            ForAll [] t' <- getType n
            unify t' t) ns
        return $ [(n, t) | n <- ns]
    typeCheck' (DataType n clauses) = do
        nts <- mapM (\ clause -> do
                n' = case unAnnotate clause of
                        DataTypeClause x _ -> x
            ForAll [] t <- getType n'
            (n', ts) <- typeCheck clause
            let texp = foldr TDotable (TDatatype n) ts
            t <- unify texp t
            return (n', t)
            ) clauses
        ForAll [] t <- getType n
        t' <- unify t (TSet (TDatatype n))
        return $ (n, t'):nts
    typeCheck' (NameType n e) = do
        t <- typeCheck e
        valueType <- evalTypeExpression t
        return [(n, TSet valueType)]
    typeCheck' (Transparent ns) = return []
    typeCheck' (External ns) = return []
    typeCheck' (Assert a) = typeCheck a >> return []

instance TypeCheckable (Assertion Name) () where
    errorContext a = Just $ 
        hang (text "In the assertion" <> colon) tabWidth (prettyPrint a)
    typeCheck' (PropertyCheck e1 p m) = do
        ensureIsProc e1
        return ()
    typeCheck' (Refinement e1 m e2 opts) = do
        ensureIsProc e1
        ensureIsProc e2
        mapM_ typeCheck opts

instance TypeCheckable (ModelOption Name) () where
    errorContext a = Nothing
    typeCheck' (TauPriority e) = do
        typeCheckExpect e (TSet TEvent)
        return ()

instance TypeCheckable TCDataTypeClause (Name, [Type]) where
    errorContext an = Nothing
    typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)

instance TypeCheckable (DataTypeClause Name) (Name, [Type]) where
    errorContext c = Just $
        hang (text "In the data type clause" <> colon) tabWidth 
            (prettyPrint c)
    typeCheck' (DataTypeClause n' Nothing) = do
        return (n', [])
    typeCheck' (DataTypeClause n' (Just e)) = do
        t <- typeCheck e
        valueType <- evalTypeExpression t
        dotList <- typeToDotList valueType
        return (n', dotList)

instance TypeCheckable TCMatch Type where
    errorContext an = Nothing
    typeCheck' an = setSrcSpan (loc an) $ typeCheck (inner an)
instance TypeCheckable (Match Name) Type where
    -- We create the error context in FunBind as that has access
    -- to the name
    errorContext (Match groups exp) = Nothing
    typeCheck' (Match groups exp) = do
        -- Introduce free variables for all the parameters
        fvs <- liftM concat (mapM freeVars groups)
        local fvs $ do
            tgroups <- mapM (\ pats -> mapM (\ pat -> 
                    -- We evaluate the dots here to implment the longest 
                    -- match rule
                    typeCheck pat >>= evaluateDots
                ) pats) groups
            -- We evaluate the dots here to implment the longest match rule
            tr <- typeCheck exp >>= evaluateDots
            -- We need to evaluate the dots in the patterns twice just in case 
            -- the type inferences on the RHS have resulted in extra dots on 
            -- the left being able to be removed.
            tgroups <- mapM (\ pats -> mapM (\ pat -> 
                    typeCheck pat >>= evaluateDots
                ) pats) groups

            return $ foldr (\ targs tr -> TFunction targs tr) tr tgroups