{-# LANGUAGE TypeFamilies, ScopedTypeVariables #-}
module Language.Haskell.Names.Recursive
  ( resolve
  , annotate
  ) where
import Data.Foldable (traverse_)
import Data.Graph(stronglyConnComp, flattenSCC)
import Data.Data (Data)
import Control.Monad (forM, forM_, unless)
import qualified Data.Map as Map (insert)
import Control.Monad.State.Strict (State, execState, get, modify)
import Language.Haskell.Exts
import Language.Haskell.Names.Types
import Language.Haskell.Names.SyntaxUtils
import Language.Haskell.Names.ScopeUtils
import Language.Haskell.Names.ModuleSymbols
import Language.Haskell.Names.Exports
import Language.Haskell.Names.Imports
import Language.Haskell.Names.Open.Base
import Language.Haskell.Names.Annotated
resolve :: (Data l, Eq l) => [Module l] -> Environment -> Environment
resolve modules environment = updatedEnvironment where
  moduleSCCs = groupModules modules
  updatedEnvironment = execState (traverse_ findFixPoint moduleSCCs) environment
groupModules :: [Module l] -> [[Module l]]
groupModules modules =
  map flattenSCC (stronglyConnComp (map moduleNode modules))
moduleNode :: Module l -> (Module l, ModuleName (), [ModuleName ()])
moduleNode modul =
  ( modul
  , dropAnn (getModuleName modul)
  , map (dropAnn . importModule) (getImports modul)
  )
findFixPoint :: (Data l, Eq l) => [Module l] -> State Environment ()
findFixPoint modules = loop (replicate (length modules) []) where
  loop modulesSymbols = do
    forM_ (zip modules modulesSymbols) (\(modul, symbols) -> do
      modify (Map.insert (dropAnn (getModuleName modul)) symbols))
    environment <- get
    modulesSymbols' <- forM modules (\modul -> do
      let globalTable = moduleTable (importTable environment modul) modul
      return (exportedSymbols globalTable modul))
    unless (modulesSymbols == modulesSymbols') (loop modulesSymbols')
annotate :: (Data l, Eq l, SrcInfo l) => Environment -> Module l -> Module (Scoped l)
annotate environment modul@(Module _ _ _ _ _) =
  Module l' maybeModuleHead' modulePragmas' importDecls' decls' where
    Module l maybeModuleHead modulePragmas importDecls decls = modul
    l' = none l
    maybeModuleHead' = case maybeModuleHead of
      Nothing -> Nothing
      Just (ModuleHead lh moduleName maybeWarning maybeExports) ->
        Just (ModuleHead lh' moduleName' maybeWarning' maybeExports') where
          lh'= none lh
          moduleName' = noScope moduleName
          maybeWarning' = fmap noScope maybeWarning
          maybeExports' = fmap (annotateExportSpecList globalTable) maybeExports
    modulePragmas' = fmap noScope modulePragmas
    importDecls' = annotateImportDecls moduleName environment importDecls
    decls' = map (annotateDecl (initialScope (dropAnn moduleName) globalTable)) decls
    globalTable = moduleTable (importTable environment modul) modul
    moduleName = getModuleName modul
annotate _ _ = error "annotateModule: non-standard modules are not supported"