module Language.Haskell.Tools.Refactor.Session where
import Control.Exception
import Control.Monad.State.Strict
import Control.Reference
import qualified Data.List as List
import qualified Data.Map as Map
import Data.Maybe
import System.FilePath
import Data.IntSet (member)
import Digraph as GHC
import ErrUtils
import Exception (ExceptionMonad)
import FastString as GHC
import GHC
import HscTypes as GHC
import Language.Haskell.TH.LanguageExtensions
import Outputable
import Language.Haskell.Tools.AST (IdDom)
import Language.Haskell.Tools.Refactor.GetModules
import Language.Haskell.Tools.Refactor.Prepare
import Language.Haskell.Tools.Refactor.RefactorBase
data RefactorSessionState
= RefactorSessionState { __refSessMCs :: [ModuleCollection]
}
makeReferences ''RefactorSessionState
class IsRefactSessionState st where
refSessMCs :: Simple Lens st [ModuleCollection]
initSession :: st
instance IsRefactSessionState RefactorSessionState where
refSessMCs = _refSessMCs
initSession = RefactorSessionState []
loadPackagesFrom :: IsRefactSessionState st => (ModSummary -> IO a) -> ([ModSummary] -> IO ()) -> (st -> FilePath -> IO [FilePath]) -> [FilePath] -> StateT st Ghc (Either RefactorException ([a], [String]))
loadPackagesFrom report loadCallback additionalSrcDirs packages =
do modColls <- liftIO $ getAllModules packages
modify $ refSessMCs .- (++ modColls)
allModColls <- gets (^. refSessMCs)
st <- get
moreSrcDirs <- liftIO $ mapM (additionalSrcDirs st) packages
lift $ useDirs ((modColls ^? traversal & mcSourceDirs & traversal) ++ concat moreSrcDirs)
let (ignored, modNames) = extractDuplicates $ map (^. sfkModuleName) $ concat $ map Map.keys $ modColls ^? traversal & mcModules
alreadyExistingMods = concatMap (map (^. sfkModuleName) . Map.keys . (^. mcModules)) (allModColls List.\\ modColls)
lift $ mapM_ addTarget $ map (\mod -> (Target (TargetModule (GHC.mkModuleName mod)) True Nothing)) modNames
handleErrors $ withAlteredDynFlags (liftIO . setupLoadFlags allModColls) $ do
modsForColls <- lift $ depanal [] True
let modsToParse = flattenSCCs $ topSortModuleGraph False modsForColls Nothing
actuallyCompiled = filter (\ms -> modSumName ms `notElem` alreadyExistingMods) modsToParse
liftIO $ loadCallback actuallyCompiled
void $ checkEvaluatedMods (\_ -> return ()) modsToParse
mods <- mapM (loadModule report) actuallyCompiled
return (mods, ignored)
where extractDuplicates :: Eq a => [a] -> ([a],[a])
extractDuplicates (a:rest)
= case extractDuplicates rest of (repl, orig) -> if a `elem` orig then (a:repl, orig) else (repl, a:orig)
extractDuplicates [] = ([],[])
loadModule :: IsRefactSessionState st => (ModSummary -> IO a) -> ModSummary -> StateT st Ghc a
loadModule report ms = do
needsCodeGen <- gets (needsGeneratedCode (keyFromMS ms) . (^. refSessMCs))
reloadModule report (if needsCodeGen then forceCodeGen ms else ms)
handleErrors :: ExceptionMonad m => m a -> m (Either RefactorException a)
handleErrors action = handleSourceError (return . Left . SourceCodeProblem . srcErrorMessages) (Right <$> action)
`gcatch` (return . Left)
keyFromMS :: ModSummary -> SourceFileKey
keyFromMS ms = SourceFileKey (case ms_hsc_src ms of HsSrcFile -> NormalHs; _ -> IsHsBoot) (modSumName ms)
getMods :: (Monad m, IsRefactSessionState st)
=> Maybe SourceFileKey -> StateT st m ( Maybe (SourceFileKey, UnnamedModule IdDom)
, [(SourceFileKey, UnnamedModule IdDom)] )
getMods actMod
= do mcs <- gets (^. refSessMCs)
return $ ( (_2 !~ (^? typedRecModule)) =<< flip lookupModInSCs mcs =<< actMod
, filter ((actMod /=) . Just . fst) $ concatMap (catMaybes . map (_2 !~ (^? typedRecModule)) . Map.assocs . (^. mcModules)) mcs )
getFileMods :: (GhcMonad m, IsRefactSessionState st)
=> FilePath -> StateT st m ( Maybe (SourceFileKey, UnnamedModule IdDom)
, [(SourceFileKey, UnnamedModule IdDom)] )
getFileMods fname
= do mcs <- gets (^. refSessMCs)
let mods = map (\(k,m) -> (fromJust $ m ^? modRecMS, k))
(concatMap Map.assocs $ (mcs ^? traversal & mcModules :: [Map.Map SourceFileKey ModuleRecord]))
let sfs = catMaybes $ map (\(ms,k) -> if Just fname == fmap normalise (ml_hs_file (ms_location ms)) then Just k else Nothing) mods
case sfs of sf:_ -> getMods (Just sf)
[] -> getMods Nothing
reloadChangedModules :: IsRefactSessionState st => (ModSummary -> IO a) -> ([ModSummary] -> IO ()) -> (ModSummary -> Bool) -> StateT st Ghc (Either RefactorException [a])
reloadChangedModules report loadCallback isChanged = handleErrors $ do
reachable <- getReachableModules loadCallback isChanged
void $ checkEvaluatedMods report reachable
mapM (reloadModule report) reachable
getReachableModules :: IsRefactSessionState st => ([ModSummary] -> IO ()) -> (ModSummary -> Bool) -> StateT st Ghc [ModSummary]
getReachableModules loadCallback selected = do
allModColls <- gets (^. refSessMCs)
withAlteredDynFlags (liftIO . setupLoadFlags allModColls) $ do
allMods <- lift $ depanal [] True
let (allModsGraph, lookup) = moduleGraphNodes False allMods
changedMods = catMaybes $ map (\ms -> lookup (ms_hsc_src ms) (moduleName $ ms_mod ms))
$ filter selected allMods
recompMods = map (ms_mod . getModFromNode) $ reachablesG (transposeG allModsGraph) changedMods
sortedMods = reverse $ topologicalSortG allModsGraph
sortedRecompMods = filter ((`elem` recompMods) . ms_mod) $ map getModFromNode sortedMods
liftIO $ loadCallback sortedRecompMods
return sortedRecompMods
reloadModule :: IsRefactSessionState st => (ModSummary -> IO a) -> ModSummary -> StateT st Ghc a
reloadModule report ms = do
mcs <- gets (^. refSessMCs)
let modName = modSumName ms
codeGen = needsGeneratedCode (keyFromMS ms) mcs
case lookupModuleColl modName mcs of
Just mc -> do
let dfs = ms_hspp_opts ms
dfs' <- liftIO $ compileInContext mc mcs dfs
let ms' = ms { ms_hspp_opts = dfs' }
newm <- lift $ withAlteredDynFlags (liftIO . compileInContext mc mcs) $
parseTyped (if codeGen then forceCodeGen ms' else ms')
modify $ refSessMCs & traversal & filtered (\mc' -> (mc' ^. mcRoot) == (mc ^. mcRoot)) & mcModules
.- Map.insert (keyFromMS ms) ((if codeGen then ModuleCodeGenerated else ModuleTypeChecked) newm ms)
liftIO $ report ms
Nothing -> liftIO $ throwIO $ ModuleNotInPackage modName
checkEvaluatedMods :: IsRefactSessionState st => (ModSummary -> IO a) -> [ModSummary] -> StateT st Ghc [a]
checkEvaluatedMods report mods = do
modsNeedCode <- lift (getEvaluatedMods mods)
mcs <- gets (^. refSessMCs)
res <- forM modsNeedCode $ \ms -> reloadIfNeeded ms mcs
return $ catMaybes res
where reloadIfNeeded ms mcs
= let key = keyFromMS ms
in if not (hasGeneratedCode key mcs)
then do md <- gets (^. refSessMCs)
modify $ refSessMCs .- codeGeneratedFor key
md <- gets (^. refSessMCs)
if (isAlreadyLoaded key mcs) then
Just <$> lift (codeGenForModule report (codeGeneratedFor key mcs) ms)
else return Nothing
else return Nothing
codeGenForModule :: (ModSummary -> IO a) -> [ModuleCollection] -> ModSummary -> Ghc a
codeGenForModule report mcs ms
= let modName = modSumName ms
mc = fromMaybe (error $ "codeGenForModule: The following module is not found: " ++ modName) $ lookupModuleColl modName mcs
in
do withAlteredDynFlags (liftIO . compileInContext mc mcs)
$ void $ parseTyped (forceCodeGen ms)
liftIO $ report ms
getEvaluatedMods :: [ModSummary] -> Ghc [GHC.ModSummary]
getEvaluatedMods mods
= do allMods <- getModuleGraph
let (allModsGraph, lookup) = moduleGraphNodes False allMods
modsWithTH = catMaybes $ map (\ms -> lookup (ms_hsc_src ms) (moduleName $ ms_mod ms)) $ filter isTH mods
recompMods = map (moduleName . ms_mod . getModFromNode) $ reachablesG allModsGraph modsWithTH
sortedMods = map getModFromNode $ reverse $ topologicalSortG allModsGraph
sortedTHMods = filter ((`elem` recompMods) . moduleName . ms_mod) sortedMods
return sortedTHMods
where isTH mod = fromEnum TemplateHaskell `member` extensionFlags (ms_hspp_opts mod)
modSumName :: ModSummary -> String
modSumName = GHC.moduleNameString . moduleName . ms_mod
type NodeKey = (ModuleName, IsBoot)
type NodeMap a = Map.Map NodeKey a
type SummaryNode = (ModSummary, Int, [Int])
getModFromNode :: SummaryNode -> ModSummary
getModFromNode (ms, _, _) = ms
moduleGraphNodes :: Bool -> [ModSummary]
-> (Graph SummaryNode, HscSource -> ModuleName -> Maybe SummaryNode)
moduleGraphNodes drop_hs_boot_nodes summaries = (graphFromEdgedVertices nodes, lookup_node)
where
numbered_summaries = zip summaries [1..]
lookup_node :: HscSource -> ModuleName -> Maybe SummaryNode
lookup_node hs_src mod = Map.lookup (mod, hscSourceToIsBoot hs_src) node_map
lookup_key :: HscSource -> ModuleName -> Maybe Int
lookup_key hs_src mod = fmap summaryNodeKey (lookup_node hs_src mod)
node_map :: NodeMap SummaryNode
node_map = Map.fromList [ ((moduleName (ms_mod s),
hscSourceToIsBoot (ms_hsc_src s)), node)
| node@(s, _, _) <- nodes ]
nodes :: [SummaryNode]
nodes = [ (s, key, out_keys)
| (s, key) <- numbered_summaries
, not (isBootSummary s && drop_hs_boot_nodes)
, let out_keys = out_edge_keys hs_boot_key (map unLoc (ms_home_srcimps s)) ++
out_edge_keys HsSrcFile (map unLoc (ms_home_imps s)) ++
(
if drop_hs_boot_nodes || ms_hsc_src s == HsBootFile
then []
else case lookup_key HsBootFile (ms_mod_name s) of
Nothing -> []
Just k -> [k]) ]
hs_boot_key | drop_hs_boot_nodes = HsSrcFile
| otherwise = HsBootFile
out_edge_keys :: HscSource -> [ModuleName] -> [Int]
out_edge_keys hi_boot ms = mapMaybe (lookup_key hi_boot) ms
hscSourceToIsBoot :: HscSource -> IsBoot
hscSourceToIsBoot HsBootFile = IsHsBoot
hscSourceToIsBoot _ = NormalHs
summaryNodeKey :: SummaryNode -> Int
summaryNodeKey (_, k, _) = k
ms_home_imps :: ModSummary -> [Located ModuleName]
ms_home_imps = home_imps . ms_imps
ms_home_srcimps :: ModSummary -> [Located ModuleName]
ms_home_srcimps = home_imps . ms_srcimps
home_imps :: [(Maybe FastString, Located ModuleName)] -> [Located ModuleName]
home_imps imps = [ lmodname | (mb_pkg, lmodname) <- imps,
isLocal mb_pkg ]
where isLocal Nothing = True
isLocal (Just pkg) | pkg == fsLit "this" = True
isLocal _ = False