module Language.Haskell.Modules.Imports
( cleanImports
, cleanResults
) where
import Control.Applicative ((<$>))
import Control.Exception.Lifted as IO (bracket, catch, throw)
import Control.Monad.Trans (liftIO)
import Data.Char (toLower)
import Data.Foldable (fold)
import Data.Function (on)
import Data.List (find, groupBy, intercalate, nub, sortBy)
import Data.Maybe (catMaybes, fromMaybe)
import Data.Monoid ((<>), mempty)
import Data.Sequence ((|>))
import Data.Set as Set (empty, fromList, member, Set, singleton, toList, union, unions)
import Language.Haskell.Exts.Annotated.Simplify as S (sImportDecl, sImportSpec, sModuleName, sName)
import qualified Language.Haskell.Exts.Annotated.Syntax as A (Decl(DerivDecl), ImportDecl(..), ImportSpec(..), ImportSpecList(ImportSpecList), InstHead(..), Module(..), ModuleHead(..), ModuleName(..), QName(..), Type(..))
#if MIN_VERSION_haskell_src_exts(1,16,0)
import qualified Language.Haskell.Exts.Annotated.Syntax as A (InstRule(..))
#endif
import Language.Haskell.Exts.Pretty (defaultMode, PPHsMode(layout), PPLayout(PPInLine), prettyPrintWithMode)
import Language.Haskell.Exts.SrcLoc (SrcLoc(..), SrcSpanInfo)
import qualified Language.Haskell.Exts.Syntax as S (ImportDecl(importLoc, importModule, importSpecs), ModuleName(..), Name(..))
import Language.Haskell.Modules.Common (ModuleResult(..))
import Language.Haskell.Modules.Fold (foldDecls, foldExports, foldHeader, foldImports)
import Language.Haskell.Modules.ModuVerse (findModule, getExtensions, loadModule, ModuleInfo(..), moduleName, parseModule)
import Language.Haskell.Modules.Params (markForDelete, MonadClean(getParams), Params(hsFlags, removeEmptyImports, scratchDir, testMode))
import Language.Haskell.Modules.SourceDirs (modifyDirs, pathKey, APath(..), PathKey(..), PathKey(unPathKey), SourceDirs(getDirs, putDirs))
import Language.Haskell.Modules.Util.DryIO (replaceFile, tildeBackup)
import Language.Haskell.Modules.Util.QIO (qLnPutStr, quietly)
import Language.Haskell.Modules.Util.SrcLoc (srcLoc)
import Language.Haskell.Modules.Util.Symbols (symbols)
import System.Directory (createDirectoryIfMissing, getCurrentDirectory)
import System.Exit (ExitCode(..))
import System.FilePath ((</>))
import System.Process (readProcessWithExitCode, showCommandForUser)
#if MIN_VERSION_haskell_src_exts(1,14,0)
import Language.Haskell.Exts.Extension (Extension(..))
#endif
cleanImports :: MonadClean m => [FilePath] -> m [ModuleResult]
cleanImports paths =
do keys <- mapM (pathKey . APath) paths >>= return . fromList
dumpImports keys
mapM (\ key -> parseModule key >>= checkImports) (toList keys)
cleanResults :: MonadClean m => [ModuleResult] -> m [ModuleResult]
cleanResults results =
do mode <- getParams >>= return . testMode
if mode then return results else (dump >> clean)
where
dump =
mapM (\ x -> case x of
JustRemoved _ _ -> return Nothing
Unchanged _ _ -> return Nothing
JustModified _ key -> return (Just key)
JustCreated name _ -> findModule name >>= return . fmap key_
_ -> error $ "cleanResults - unexpected ModuleResult " ++ show x) results >>=
dumpImports . fromList . catMaybes
clean =
mapM (\ x -> case x of
JustRemoved _ _ -> return x
Unchanged _ _ -> return x
JustModified _name key -> doModule key
JustCreated _name key -> doModule key >>= return . toCreated
_ -> error $ "cleanResults - unexpected ModuleResult " ++ show x) results
toCreated (JustModified name key) = JustCreated name key
toCreated x@(JustCreated {}) = x
toCreated _ = error "toCreated"
doModule key =
do info <- loadModule key
checkImports info
dumpImports :: MonadClean m => Set PathKey -> m ()
dumpImports keys =
do scratch <- scratchDir <$> getParams
liftIO $ createDirectoryIfMissing True scratch
let cmd = "ghc"
args <- hsFlags <$> getParams
dirs <- getDirs
exts <- getExtensions
let args' = args ++
["--make", "-c", "-ddump-minimal-imports", "-outputdir", scratch, "-i" ++ intercalate ":" dirs] ++
concatMap ppExtension exts ++
map unPathKey (toList keys)
(code, _out, err) <- liftIO $ readProcessWithExitCode cmd args' ""
case code of
ExitSuccess -> quietly (qLnPutStr (showCommandForUser cmd args' ++ " -> Ok")) >> return ()
ExitFailure _ -> error ("dumpImports: compile failed\n " ++ showCommandForUser cmd args' ++ " ->\n" ++ err)
where
#if MIN_VERSION_haskell_src_exts(1,14,0)
ppExtension (EnableExtension x) = ["-X"++ show x]
ppExtension _ = []
#else
ppExtension = (:[]) . ("-X" ++) . show
#endif
checkImports :: MonadClean m => ModuleInfo -> m ModuleResult
checkImports info@(ModuleInfo (A.Module _ mh _ imports _) _ _ _) =
do
#if __GLASGOW_HASKELL__ >= 708
scratch <- scratchDir <$> getParams
#else
let scratch = "."
#endif
let importsPath = scratch </> maybe "Main" (\ (A.ModuleHead _ (A.ModuleName _ s) _ _) -> s) mh ++ ".imports"
markForDelete importsPath
(ModuleInfo newImports _ _ _) <-
withDot $
(parseModule (PathKey importsPath)
`IO.catch` (\ (e :: IOError) -> liftIO (getCurrentDirectory >>= \ here ->
throw . userError $ here ++ ": " ++ show e)))
updateSource info newImports extraImports
where
extraImports = filter isHiddenImport imports
isHiddenImport (A.ImportDecl {A.importSpecs = Just (A.ImportSpecList _ True _)}) = True
isHiddenImport _ = False
checkImports _ = error "Unsupported module type"
withDot :: MonadClean m => m a -> m a
withDot a =
bracket (getDirs)
(modifyDirs . const)
(\ _ -> putDirs ["."] >> a)
updateSource :: MonadClean m => ModuleInfo -> A.Module SrcSpanInfo -> [A.ImportDecl SrcSpanInfo] -> m ModuleResult
updateSource m@(ModuleInfo (A.Module _ _ _ oldImports _) _ _ key) (A.Module _ _ _ newImports _) extraImports =
do remove <- removeEmptyImports <$> getParams
maybe (qLnPutStr ("cleanImports: no changes to " ++ show key) >> return (Unchanged (moduleName m) key))
(\ text' ->
qLnPutStr ("cleanImports: modifying " ++ show key) >>
replaceFile tildeBackup (unPathKey key) text' >>
return (JustModified (moduleName m) key))
(replaceImports (fixNewImports remove m oldImports (newImports ++ extraImports)) m)
updateSource _ _ _ = error "updateSource"
replaceImports :: [A.ImportDecl SrcSpanInfo] -> ModuleInfo -> Maybe String
replaceImports newImports m =
let oldPretty = fold (foldImports (\ _ pref s suff r -> r |> (pref <> s <> suff)) m mempty)
newPretty = fromMaybe "" (foldImports (\ _ pref _ _ r -> maybe (Just pref) Just r) m Nothing) <>
intercalate "\n" (map (prettyPrintWithMode (defaultMode {layout = PPInLine})) newImports) <>
foldImports (\ _ _ _ suff _ -> suff) m mempty in
if oldPretty == newPretty
then Nothing
else Just (fold (foldHeader (\ s r -> r |> s) (\ _ pref s suff r -> r |> (pref <> s <> suff))
(\ _ pref s suff r -> r |> pref <> s <> suff)
(\ _ pref s suff r -> r |> pref <> s <> suff) m mempty) ++
fold (foldExports (\ s r -> r |> s)
(\ _ pref s suff r -> r |> pref <> s <> suff)
(\ s r -> r |> s) m mempty) ++
newPretty <>
fold (foldDecls (\ _ pref s suff r -> r |> pref <> s <> suff) (\ r s -> s |> r) m mempty))
fixNewImports :: Bool
-> ModuleInfo
-> [A.ImportDecl SrcSpanInfo]
-> [A.ImportDecl SrcSpanInfo]
-> [A.ImportDecl SrcSpanInfo]
fixNewImports remove m oldImports imports =
filter importPred $ map expandSDTypes $ map mergeDecls $ groupBy (\ a b -> importMergable a b == EQ) $ sortBy importMergable imports
where
mergeDecls [] = error "mergeDecls"
mergeDecls xs@(x : _) = x {A.importSpecs = mergeSpecLists (catMaybes (map A.importSpecs xs))}
where
mergeSpecLists :: [A.ImportSpecList SrcSpanInfo] -> Maybe (A.ImportSpecList SrcSpanInfo)
mergeSpecLists (A.ImportSpecList loc flag specs : ys) =
Just (A.ImportSpecList loc flag (mergeSpecs (sortBy compareSpecs (nub (concat (specs : map (\ (A.ImportSpecList _ _ specs') -> specs') ys))))))
mergeSpecLists [] = error "mergeSpecLists"
expandSDTypes :: A.ImportDecl SrcSpanInfo -> A.ImportDecl SrcSpanInfo
expandSDTypes i@(A.ImportDecl {A.importSpecs = Just (A.ImportSpecList l f specs)}) =
i {A.importSpecs = Just (A.ImportSpecList l f (map (expandSpec i) specs))}
expandSDTypes i = i
expandSpec i s =
if not (A.importQualified i) && member (Nothing, sName n) sdTypes ||
maybe False (\ mn -> (member (Just (sModuleName mn), sName n) sdTypes)) (A.importAs i) ||
member (Just (sModuleName (A.importModule i)), sName n) sdTypes
then s'
else s
where
n = case s of
#if MIN_VERSION_haskell_src_exts(1,16,0)
(A.IVar _ _ x) -> x
#else
(A.IVar _ x) -> x
#endif
(A.IAbs _ x) -> x
(A.IThingAll _ x) -> x
(A.IThingWith _ x _) -> x
s' = case s of
#if MIN_VERSION_haskell_src_exts(1,16,0)
(A.IVar l _ x) -> A.IThingAll l x
#else
(A.IVar l x) -> A.IThingAll l x
#endif
(A.IAbs l x) -> A.IThingAll l x
(A.IThingWith l x _) -> A.IThingAll l x
(A.IThingAll _ _) -> s
#if MIN_VERSION_haskell_src_exts(1,16,0)
importPred (A.ImportDecl _ mn _ _ _ _ _ (Just (A.ImportSpecList _ _ []))) =
#else
importPred (A.ImportDecl _ mn _ _ _ _ (Just (A.ImportSpecList _ _ []))) =
#endif
not remove || maybe False (isEmptyImport . A.importSpecs) (find ((== (unModuleName mn)) . unModuleName . A.importModule) oldImports)
where
isEmptyImport (Just (A.ImportSpecList _ _ [])) = True
isEmptyImport _ = False
importPred _ = True
sdTypes :: Set (Maybe S.ModuleName, S.Name)
sdTypes = standaloneDerivingTypes m
standaloneDerivingTypes :: ModuleInfo -> Set (Maybe S.ModuleName, S.Name)
standaloneDerivingTypes (ModuleInfo (A.XmlPage _ _ _ _ _ _ _) _ _ _) = error "standaloneDerivingTypes A.XmlPage"
standaloneDerivingTypes (ModuleInfo (A.XmlHybrid _ _ _ _ _ _ _ _ _) _ _ _) = error "standaloneDerivingTypes A.XmlHybrid"
standaloneDerivingTypes (ModuleInfo (A.Module _ _ _ _ decls) _ _ _) =
unions (map derivDeclTypes decls)
#if MIN_VERSION_haskell_src_exts(1,16,0)
class DerivDeclTypes a where
derivDeclTypes :: a -> Set (Maybe S.ModuleName, S.Name)
instance DerivDeclTypes (A.Decl l) where
derivDeclTypes (A.DerivDecl _ _ x) = derivDeclTypes x
derivDeclTypes _ = empty
instance DerivDeclTypes (A.InstRule l) where
derivDeclTypes (A.IRule _ _ _ x) = derivDeclTypes x
derivDeclTypes (A.IParen _ x) = derivDeclTypes x
instance DerivDeclTypes (A.InstHead l) where
derivDeclTypes (A.IHCon _ _) = empty
derivDeclTypes (A.IHParen _ x) = derivDeclTypes x
derivDeclTypes (A.IHInfix _ x _op) = derivDeclTypes x
derivDeclTypes (A.IHApp _ x y) = union (derivDeclTypes x) (derivDeclTypes y)
instance DerivDeclTypes (A.Type l) where
derivDeclTypes (A.TyForall _ _ _ x) = derivDeclTypes x
derivDeclTypes (A.TyFun _ x y) = union (derivDeclTypes x) (derivDeclTypes y)
derivDeclTypes (A.TyTuple _ _ xs) = unions (map derivDeclTypes xs)
derivDeclTypes (A.TyList _ x) = derivDeclTypes x
derivDeclTypes (A.TyApp _ x y) = union (derivDeclTypes x) (derivDeclTypes y)
derivDeclTypes (A.TyVar _ _) = empty
derivDeclTypes (A.TyCon _ (A.Qual _ m n)) = singleton (Just (sModuleName m), sName n)
derivDeclTypes (A.TyCon _ (A.UnQual _ n)) = singleton (Nothing, sName n)
derivDeclTypes (A.TyCon _ _) = empty
derivDeclTypes (A.TyParen _ x) = derivDeclTypes x
derivDeclTypes (A.TyInfix _ x _op y) = union (derivDeclTypes x) (derivDeclTypes y)
derivDeclTypes (A.TyKind _ x _) = derivDeclTypes x
derivDeclTypes (A.TyParArray _ x) = derivDeclTypes x
derivDeclTypes (A.TyPromoted _ _) = empty
derivDeclTypes (A.TyEquals _ _ _) = empty
derivDeclTypes (A.TySplice _ _) = empty
derivDeclTypes (A.TyBang _ _ x) = derivDeclTypes x
#else
where
derivDeclTypes (A.DerivDecl _ _ (A.IHead _ _ xs)) = unions (map derivDeclTypes' xs)
derivDeclTypes (A.DerivDecl a b (A.IHParen _ x)) = derivDeclTypes (A.DerivDecl a b x)
derivDeclTypes (A.DerivDecl _ _ (A.IHInfix _ x _op y)) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes _ = empty
derivDeclTypes' (A.TyForall _ _ _ x) = derivDeclTypes' x
derivDeclTypes' (A.TyFun _ x y) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes' (A.TyTuple _ _ xs) = unions (map derivDeclTypes' xs)
derivDeclTypes' (A.TyList _ x) = derivDeclTypes' x
derivDeclTypes' (A.TyApp _ x y) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes' (A.TyVar _ _) = empty
derivDeclTypes' (A.TyCon _ (A.Qual _ m n)) = singleton (Just (sModuleName m), sName n)
derivDeclTypes' (A.TyCon _ (A.UnQual _ n)) = singleton (Nothing, sName n)
derivDeclTypes' (A.TyCon _ _) = empty
derivDeclTypes' (A.TyParen _ x) = derivDeclTypes' x
derivDeclTypes' (A.TyInfix _ x _op y) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes' (A.TyKind _ x _) = derivDeclTypes' x
#endif
importMergable :: A.ImportDecl SrcSpanInfo -> A.ImportDecl SrcSpanInfo -> Ordering
importMergable a b =
case (compare `on` noSpecs) a' b' of
EQ -> EQ
specOrdering ->
case (compare `on` S.importModule) a' b' of
EQ -> specOrdering
moduleNameOrdering -> moduleNameOrdering
where
a' = sImportDecl a
b' = sImportDecl b
SrcLoc path _ _ = srcLoc a
noSpecs :: S.ImportDecl -> S.ImportDecl
noSpecs x = x { S.importLoc = SrcLoc path 1 1,
S.importSpecs = case S.importSpecs x of
Just (True, _) -> Just (True, [])
Just (False, _) -> Nothing
Nothing -> Nothing }
unModuleName :: A.ModuleName SrcSpanInfo -> String
unModuleName (A.ModuleName _ x) = x
compareSpecs :: A.ImportSpec SrcSpanInfo -> A.ImportSpec SrcSpanInfo -> Ordering
compareSpecs a b =
case compare (map (map toLower . nameString) $ catMaybes $ toList $ symbols a) (map (map toLower . nameString) $ catMaybes $ toList $ symbols b) of
EQ -> compare (sImportSpec a) (sImportSpec b)
x -> x
equalSpecs :: A.ImportSpec SrcSpanInfo -> A.ImportSpec SrcSpanInfo -> Bool
equalSpecs a b = compareSpecs a b == EQ
mergeSpecs :: [A.ImportSpec SrcSpanInfo] -> [A.ImportSpec SrcSpanInfo]
mergeSpecs [] = []
mergeSpecs [x] = [x]
mergeSpecs xs = xs
nameString :: S.Name -> String
nameString (S.Ident s) = s
nameString (S.Symbol s) = s