module Language.Haskell.Tools.Daemon.Session where
import Control.Monad.State.Strict
import Control.Reference
import Data.Function (on)
import Data.IORef (writeIORef, readIORef)
import qualified Data.List as List
import Data.List.Split (splitOn)
import qualified Data.Map as Map
import Data.Maybe
import System.Directory (doesFileExist)
import System.FilePath
import Digraph as GHC (flattenSCCs)
import DynFlags (DynFlags(..), xopt)
import Exception (gtry)
import GHC
import GHCi (purgeLookupSymbolCache)
import GhcMonad (modifySession)
import HscTypes
import Language.Haskell.TH.LanguageExtensions as Exts (Extension(..))
import Linker (unload)
import Module
import NameCache (NameCache(..))
import Packages (initPackages)
import Language.Haskell.Tools.Daemon.GetModules (getAllModules, setupLoadFlags)
import Language.Haskell.Tools.Daemon.ModuleGraph (supportingModules, dependentModules)
import Language.Haskell.Tools.Daemon.Representation
import Language.Haskell.Tools.Daemon.State
import Language.Haskell.Tools.Daemon.Utils
import Language.Haskell.Tools.Refactor hiding (ModuleName)
type DaemonSession a = StateT DaemonSessionState Ghc a
loadPackagesFrom :: (ModSummary -> IO ())
                      -> ([ModSummary] -> IO ())
                      -> (DaemonSessionState -> FilePath -> IO [FilePath])
                      -> [FilePath]
                      -> DaemonSession [SourceError]
loadPackagesFrom report loadCallback additionalSrcDirs packages =
  do 
     modColls <- liftIO $ getAllModules packages
     st <- get
     moreSrcDirs <- liftIO $ mapM (additionalSrcDirs st) packages
     lift $ useDirs ((modColls ^? traversal & mcSourceDirs & traversal) ++ concat moreSrcDirs)
     mcs' <- liftIO (traversal !~ locateModules $ modColls)
     modify' (refSessMCs .- (++ mcs'))
     mcs <- gets (^. refSessMCs)
     let alreadyLoadedFiles
           = concatMap (map (^. sfkFileName) . Map.keys . Map.filter (isJust . (^? typedRecModule)) . (^. mcModules))
                       (filter (\mc -> (mc ^. mcRoot) `notElem` packages) mcs)
     currentTargets <- map targetId <$> (lift getTargets)
     lift $ mapM_ (\t -> when (targetId t `notElem` currentTargets) (addTarget t))
                  (map makeTarget $ List.nubBy ((==) `on` (^. sfkFileName))
                                  $ List.sort $ concatMap getExposedModules mcs')
     loadRes <- gtry (loadModules mcs alreadyLoadedFiles)
     case loadRes of
       Right mods -> do 
         modify (refSessMCs & traversal & filtered (\mc -> (mc ^. mcId) `elem` map (^. mcId) modColls) & mcLoadDone .= True)
         compileModules report mods
       Left err -> return [err]
  where getExposedModules :: ModuleCollection k -> [k]
        getExposedModules
          = Map.keys . Map.filter (\v -> fromMaybe True (v ^? recModuleExposed)) . (^. mcModules)
        locateModules :: ModuleCollection ModuleNameStr -> IO (ModuleCollection SourceFileKey)
        locateModules mc
          = mcModules !~ ((Map.fromList <$>)
                            . mapM (locateModule (mc ^. mcSourceDirs) (mc ^. mcModuleFiles))
                            . Map.assocs) $ mc
        locateModule :: [FilePath] -> [(ModuleNameStr, FilePath)]
                          -> (ModuleNameStr, ModuleRecord) -> IO (SourceFileKey,ModuleRecord)
        locateModule srcDirs modMaps (modName, record)
          = do candidate <- createTargetCandidate srcDirs modMaps modName
               return (SourceFileKey (either (const "") id candidate) modName, record)
        
        
        createTargetCandidate :: [FilePath] -> [(ModuleNameStr, FilePath)] -> ModuleNameStr
                                    -> IO (Either ModuleName FilePath)
        createTargetCandidate srcFolders mapping modName
          = wrapEither <$> filterM doesFileExist
                             (map (</> toFileName modName) srcFolders)
          where toFileName modName
                  = case lookup modName mapping of
                      Just fileName -> fileName
                      Nothing -> List.intercalate [pathSeparator] (splitOn "." modName) <.> "hs"
                wrapEither [] = Left (GHC.mkModuleName modName)
                wrapEither (fn:_) = Right fn
        makeTarget (SourceFileKey "" modName) = Target (TargetModule (GHC.mkModuleName modName)) True Nothing
        makeTarget (SourceFileKey filePath _) = Target (TargetFile filePath Nothing) True Nothing
        loadModules mcs alreadyLoaded = do
          mods <- withLoadFlagsForModules mcs $ do
            loadVisiblePackages 
            modsForColls <- lift $ depanal [] True
            let modsToParse = flattenSCCs $ topSortModuleGraph False modsForColls Nothing
                actuallyCompiled = filter (\ms -> getModSumOrig ms `notElem` alreadyLoaded) modsToParse
            modify' (refSessMCs .- foldl (.) id (map (insertIfMissing . keyFromMS) actuallyCompiled))
            return actuallyCompiled
          liftIO $ loadCallback mods
          return mods
        compileModules report mods = do
            checkEvaluatedMods mods
            compileWhileOk mods
          where compileWhileOk [] = return []
                compileWhileOk (mod:mods) 
                  = do res <- gtry (reloadModule report mod)
                       case res of
                          Left err -> do dependents <- lift $ dependentModules (return . (ms_mod mod ==) . ms_mod)
                                         (err :) <$> compileWhileOk (filter ((`notElem` map ms_mod dependents) . ms_mod) mods)
                          Right _ -> compileWhileOk mods
        
loadVisiblePackages :: DaemonSession ()
loadVisiblePackages = do
  dfs <- getSessionDynFlags
  (dfs', _) <- liftIO $ initPackages dfs
  setSessionDynFlags dfs' 
  modify' (pkgDbFlags .= \dfs -> dfs { pkgDatabase = pkgDatabase dfs'
                                     , pkgState = pkgState dfs'
                                     }) 
getFileMods :: String -> DaemonSession ( Maybe (SourceFileKey, UnnamedModule)
                                       , [(SourceFileKey, UnnamedModule)] )
getFileMods fnameOrModule = do
  modMaps <- gets (^? refSessMCs & traversal & mcModules)
  let modules = mapMaybe (\(k,m) -> (\ms tc -> (ms, (k,tc))) <$> (m ^? modRecMS) <*> (m ^? typedRecModule)) 
                  $ concatMap @[] Map.assocs modMaps
      (modSel, modOthers) = List.partition (\(ms,_) -> getModSumName ms == fnameOrModule
                                                         && (case ms_hsc_src ms of HsSrcFile -> True; _ -> False))
                                          modules
      maxSufLength = maximum $ map sufLength modules
      (fnSel, fnOthers) = if null modules || maxSufLength == 0
                            then ([], modules)
                            else List.partition ((== maxSufLength) . sufLength) modules
      sufLength = length . commonSuffix (splitPath fnameOrModule) . splitPath . getModSumOrig . fst
      commonSuffix l1 l2 = takeWhile (uncurry (==)) $ zip (reverse l1) (reverse l2)
      backup = case fnSel of
                 []      -> return (Nothing, map snd fnOthers)
                 [(_,m)] -> return (Just m, map snd fnOthers)
                 _:_     -> error "getFileMods: multiple modules selected"
  case modSel of
    []      -> backup
    [(_,m)] -> return (Just m, map snd modOthers)
    _:_     -> backup
reloadChangedModules :: (ModSummary -> IO a) -> ([ModSummary] -> IO ()) -> (ModSummary -> Bool)
                           -> DaemonSession [a]
reloadChangedModules report loadCallback isChanged = do
  reachable <- getReachableModules loadCallback isChanged
  checkEvaluatedMods reachable
  
  clearModules reachable
  mapM (reloadModule report) reachable
clearModules :: [ModSummary] -> DaemonSession ()
clearModules [] = return ()
clearModules mods = do
  let reachableMods = map ms_mod_name mods
      notReloaded = (`notElem` reachableMods) . GHC.moduleName . mi_module . hm_iface
  env <- getSession
  let hptStay = filterHpt notReloaded (hsc_HPT env)
  
  liftIO $ purgeLookupSymbolCache env
  
  liftIO $ unload env (mapMaybe hm_linkable (eltsHpt hptStay))
  
  nameCache <- liftIO $ readIORef $ hsc_NC env
  let nameCache' = nameCache { nsNames = delModuleEnvList (nsNames nameCache) (map ms_mod mods) }
  liftIO $ writeIORef (hsc_NC env) nameCache'
  
  lift $ modifySession (\s -> s { hsc_HPT = hptStay
                                , hsc_mod_graph = filter ((`notElem` reachableMods) . ms_mod_name) (hsc_mod_graph s)
                                })
getReachableModules :: ([ModSummary] -> IO ()) -> (ModSummary -> Bool) -> DaemonSession [ModSummary]
getReachableModules loadCallback selected = do
  mcs <- gets (^. refSessMCs)
  withLoadFlagsForModules mcs $ do
    lift $ depanal [] True
    sortedRecompMods <- lift $ dependentModules (return . selected)
    liftIO $ loadCallback sortedRecompMods
    return sortedRecompMods
reloadModule :: (ModSummary -> IO a) -> ModSummary -> DaemonSession a
reloadModule report ms = do
  mcs <- gets (^. refSessMCs)
  ghcfl <- gets (^. ghcFlagsSet)
  let codeGen = needsGeneratedCode (keyFromMS ms) mcs
      mc = decideMC ms mcs
  newm <- withFlagsForModule mc $ lift $ do
    dfs <- liftIO $ fmap ghcfl $ mc ^. mcFlagSetup $ ms_hspp_opts ms
    let ms' = ms { ms_hspp_opts = dfs }
    
    parseTyped (case codeGen of NoCodeGen -> ms'
                                InterpretedCode -> forceCodeGen ms'
                                GeneratedCode -> forceAsmGen ms')
  
  modify' $ refSessMCs & traversal & filtered (\c -> (c ^. mcId) == (mc ^. mcId)) & mcModules
              .- Map.insert (keyFromMS ms) (ModuleTypeChecked newm ms codeGen)
                   . removeModuleMS ms
  liftIO $ report ms
decideMC :: ModSummary -> [ModuleCollection SourceFileKey] -> ModuleCollection SourceFileKey
decideMC ms mcs =
  case lookupModuleCollection ms mcs of
    Just mc -> mc
    Nothing -> case filter (\mc -> (mc ^. mcRoot) `List.isPrefixOf` fileName) mcs of
                 mc:_ -> mc
                 _ -> case mcs of mc:_ -> mc
                                  []   -> error "reloadModule: module collections empty"
  where fileName = getModSumOrig ms
withFlagsForModule :: ModuleCollection SourceFileKey -> DaemonSession a -> DaemonSession a
withFlagsForModule mc action = do
  ghcfl <- gets (^. ghcFlagsSet)
  dbFlags <- gets (^. pkgDbFlags)
  
  
  
  withAlteredDynFlags (liftIO . fmap (dbFlags . ghcfl) . ((mc ^. mcFlagSetup) <=< (mc ^. mcLoadFlagSetup))) action
withLoadFlagsForModules :: [ModuleCollection SourceFileKey] -> DaemonSession a -> DaemonSession a
withLoadFlagsForModules mcs action = do
  ghcfl <- gets (^. ghcFlagsSet)
  dbFlags <- gets (^. pkgDbFlags)
  withAlteredDynFlags (liftIO . fmap (dbFlags . ghcfl)
                              . setupLoadFlags (mcs ^? traversal & mcId) (mcs ^? traversal & mcRoot)
                                               (mcs ^? traversal & mcDependencies & traversal)
                                               (foldl @[] (>=>) return (mcs ^? traversal & mcLoadFlagSetup))) action
checkEvaluatedMods :: [ModSummary] -> DaemonSession ()
checkEvaluatedMods changed = do
    mcs <- gets (^. refSessMCs)
    
    
    
    let lookupFlags ms = maybe return (^. mcFlagSetup) mc $ ms_hspp_opts ms
          where mc = lookupModuleCollection ms mcs
    (modsNeedCode, modsNeedAsm) <- lift (getEvaluatedMods changed lookupFlags)
    
    forM_ modsNeedCode (\ms -> modify $ refSessMCs .- codeGeneratedFor (keyFromMS ms) InterpretedCode)
    forM_ modsNeedAsm (\ms -> modify $ refSessMCs .- codeGeneratedFor (keyFromMS ms) GeneratedCode)
    let interpreted = filter (\ms -> isAlreadyLoaded (keyFromMS ms) InterpretedCode mcs) 
                             modsNeedCode
        codeGenerated = filter (\ms -> isAlreadyLoaded (keyFromMS ms) GeneratedCode mcs) modsNeedAsm
    clearModules (interpreted ++ codeGenerated)
    
    forM_ interpreted (codeGenForModule mcs InterpretedCode)
    forM_ codeGenerated (codeGenForModule mcs GeneratedCode)
codeGenForModule :: [ModuleCollection SourceFileKey] -> CodeGenPolicy -> ModSummary -> DaemonSession ()
codeGenForModule mcs codeGen ms
  = withFlagsForModule mc $ lift $ void $ parseTyped (case codeGen of InterpretedCode -> forceCodeGen ms
                                                                      GeneratedCode -> forceAsmGen ms
                                                                      _ -> ms)
  where mc = fromMaybe (error $ "codeGenForModule: The following module is not found: " ++ getModSumName ms)
               $ lookupModuleCollection ms mcs
getEvaluatedMods :: [ModSummary] -> (ModSummary -> IO DynFlags) -> Ghc ([ModSummary],[ModSummary])
getEvaluatedMods changed additionalFlags
  = do let changedModulePathes = map getModSumOrig changed
       
       eval <- supportingModules (\ms -> (\flags -> getModSumOrig ms `elem` changedModulePathes && TemplateHaskell `xopt` flags)
                                           <$> liftIO (additionalFlags ms))
       asm <- supportingModules (\ms -> (\flags -> getModSumOrig ms `elem` changedModulePathes 
                                                     && (StaticPointers `xopt` flags || UnboxedTuples `xopt` flags || UnboxedSums `xopt` flags))
                                           <$> liftIO (additionalFlags ms))
       let asmOrigs = map getModSumOrig asm
       return (filter (\ms -> getModSumOrig ms `notElem` asmOrigs) eval, asm)