module Language.Haskell.Modules.Imports
( cleanImports
) where
import Control.Applicative ((<$>))
import "MonadCatchIO-mtl" Control.Monad.CatchIO as IO (bracket, catch, throw)
import Control.Monad.Trans (liftIO)
import Data.Char (toLower)
import Data.Default (def, Default)
import Data.Foldable (fold)
import Data.Function (on)
import Data.List (find, groupBy, intercalate, nub, nubBy, sortBy)
import Data.Maybe (catMaybes, fromMaybe)
import Data.Monoid ((<>), mempty)
import Data.Sequence ((|>))
import Data.Set as Set (empty, member, Set, singleton, toList, union, unions)
import Language.Haskell.Exts.Annotated (ParseResult(..))
import Language.Haskell.Exts.Annotated.Simplify as S (sImportDecl, sImportSpec, sModuleName, sName)
import qualified Language.Haskell.Exts.Annotated.Syntax as A (Decl(DerivDecl), ImportDecl(..), ImportSpec(..), ImportSpecList(ImportSpecList), InstHead(..), Module(..), ModuleHead(ModuleHead), ModuleName(ModuleName), QName(..), Type(..))
import Language.Haskell.Exts.Extension (Extension(PackageImports, StandaloneDeriving, TypeSynonymInstances, FlexibleInstances))
import Language.Haskell.Exts.Pretty (defaultMode, PPHsMode(layout), PPLayout(PPInLine), prettyPrintWithMode)
import Language.Haskell.Exts.SrcLoc (SrcSpanInfo)
import qualified Language.Haskell.Exts.Syntax as S (ImportDecl(importLoc, importModule, importSpecs), ModuleName(..), Name(..))
import Language.Haskell.Modules.Common (modulePathBase, withCurrentDirectory)
import Language.Haskell.Modules.Fold (ModuleInfo, foldDecls, foldExports, foldHeader, foldImports)
import Language.Haskell.Modules.Internal (getParams, markForDelete, modifyParams, ModuleResult(..), MonadClean, Params(..), parseFile, parseFileWithComments, runMonadClean, scratchDir)
import Language.Haskell.Modules.Params (modifyTestMode)
import Language.Haskell.Modules.Util.DryIO (replaceFile, tildeBackup)
import Language.Haskell.Modules.Util.QIO (qLnPutStr, quietly)
import Language.Haskell.Modules.Util.Symbols (symbols)
import System.Cmd (system)
import System.Directory (createDirectoryIfMissing, getCurrentDirectory)
import System.Exit (ExitCode(..))
import System.FilePath ((<.>), (</>))
import System.Process (readProcessWithExitCode, showCommandForUser)
import Test.HUnit (assertEqual, Test(..))
cleanImports :: MonadClean m => FilePath -> m ModuleResult
cleanImports path =
do text <- liftIO $ readFile path
source <- parseFileWithComments path
case source of
ParseOk (m@(A.Module _ h _ imports _decls), comments) ->
do let name = case h of
Just (A.ModuleHead _ x _ _) -> sModuleName x
_ -> S.ModuleName "Main"
hiddenImports = filter isHiddenImport imports
dumpImports path >> checkImports path name (m, text, comments) hiddenImports
ParseOk (A.XmlPage {}, _) -> error "cleanImports: XmlPage"
ParseOk (A.XmlHybrid {}, _) -> error "cleanImports: XmlHybrid"
ParseFailed _loc msg -> error ("cleanImports: - parse of " ++ path ++ " failed: " ++ msg)
where
isHiddenImport (A.ImportDecl {A.importSpecs = Just (A.ImportSpecList _ True _)}) = True
isHiddenImport _ = False
dumpImports :: MonadClean m => FilePath -> m ()
dumpImports path =
do scratch <- scratchDir <$> getParams
liftIO $ createDirectoryIfMissing True scratch
let cmd = "ghc"
args <- hsFlags <$> getParams
dirs <- sourceDirs <$> getParams
exts <- extensions <$> getParams
let args' = args ++ ["--make", "-c", "-ddump-minimal-imports", "-outputdir", scratch, "-i" ++ intercalate ":" dirs, path] ++ map (("-X" ++) . show) exts
(code, _out, err) <- liftIO $ readProcessWithExitCode cmd args' ""
case code of
ExitSuccess -> quietly (qLnPutStr (showCommandForUser cmd args' ++ " -> Ok")) >> return ()
ExitFailure _ -> error ("dumpImports: compile failed\n " ++ showCommandForUser cmd args' ++ " ->\n" ++ err)
checkImports :: MonadClean m => FilePath -> S.ModuleName -> ModuleInfo -> [A.ImportDecl SrcSpanInfo] -> m ModuleResult
checkImports path name@(S.ModuleName name') m extraImports =
do let importsPath = name' <.> ".imports"
markForDelete importsPath
result <-
bracket (getParams >>= return . extensions)
(\ saved -> modifyParams (\ p -> p {extensions = saved}))
(\ saved -> modifyParams (\ p -> p {extensions = PackageImports : saved}) >>
parseFile importsPath `IO.catch` (\ (e :: IOError) -> liftIO (getCurrentDirectory >>= \ here -> throw . userError $ here ++ ": " ++ show e)))
case result of
ParseOk newImports -> updateSource path m newImports name extraImports
_ -> error ("checkImports: parse of " ++ importsPath ++ " failed - " ++ show result)
updateSource :: MonadClean m => FilePath -> ModuleInfo -> A.Module SrcSpanInfo -> S.ModuleName -> [A.ImportDecl SrcSpanInfo] -> m ModuleResult
updateSource path m@(A.Module _ _ _ oldImports _, _, _) (A.Module _ _ _ newImports _) name extraImports =
do remove <- removeEmptyImports <$> getParams
maybe (qLnPutStr ("cleanImports: no changes to " ++ path) >> return (Unchanged name))
(\ text' ->
qLnPutStr ("cleanImports: modifying " ++ path) >>
replaceFile tildeBackup path text' >>
return (Modified name text'))
(replaceImports (fixNewImports remove m oldImports (newImports ++ extraImports)) m)
updateSource _ _ _ _ _ = error "updateSource"
replaceImports :: [A.ImportDecl SrcSpanInfo] -> ModuleInfo -> Maybe String
replaceImports newImports m =
let oldPretty = fold (foldImports (\ _ pref s suff r -> r |> (pref <> s <> suff)) m mempty)
newPretty = fromMaybe "" (foldImports (\ _ pref _ _ r -> maybe (Just pref) Just r) m Nothing) <>
intercalate "\n" (map (prettyPrintWithMode (defaultMode {layout = PPInLine})) newImports) <>
foldImports (\ _ _ _ suff _ -> suff) m mempty in
if oldPretty == newPretty
then Nothing
else Just (fold (foldHeader (\ s r -> r |> s) (\ _ pref s suff r -> r |> (pref <> s <> suff))
(\ _ pref s suff r -> r |> pref <> s <> suff)
(\ _ pref s suff r -> r |> pref <> s <> suff) m mempty) ++
fold (foldExports (\ s r -> r |> s)
(\ _ pref s suff r -> r |> pref <> s <> suff)
(\ s r -> r |> s) m mempty) ++
newPretty <>
fold (foldDecls (\ _ pref s suff r -> r |> pref <> s <> suff) (\ r s -> s |> r) m mempty))
fixNewImports :: Bool
-> ModuleInfo
-> [A.ImportDecl SrcSpanInfo]
-> [A.ImportDecl SrcSpanInfo]
-> [A.ImportDecl SrcSpanInfo]
fixNewImports remove m oldImports imports =
filter importPred $ map expandSDTypes $ map mergeDecls $ groupBy (\ a b -> importMergable a b == EQ) $ sortBy importMergable imports
where
mergeDecls [] = error "mergeDecls"
mergeDecls xs@(x : _) = x {A.importSpecs = mergeSpecLists (catMaybes (map A.importSpecs xs))}
where
mergeSpecLists (A.ImportSpecList loc flag specs : ys) =
Just (A.ImportSpecList loc flag (mergeSpecs (sortBy compareSpecs (nub (concat (specs : map (\ (A.ImportSpecList _ _ specs') -> specs') ys))))))
mergeSpecLists [] = error "mergeSpecLists"
mergeSpecs ys = nubBy equalSpecs ys
expandSDTypes :: A.ImportDecl SrcSpanInfo -> A.ImportDecl SrcSpanInfo
expandSDTypes i@(A.ImportDecl {A.importSpecs = Just (A.ImportSpecList l f specs)}) =
i {A.importSpecs = Just (A.ImportSpecList l f (map (expandSpec i) specs))}
expandSDTypes i = i
expandSpec i s =
if not (A.importQualified i) && member (Nothing, sName n) sdTypes ||
maybe False (\ mn -> (member (Just (sModuleName mn), sName n) sdTypes)) (A.importAs i) ||
member (Just (sModuleName (A.importModule i)), sName n) sdTypes
then s'
else s
where
n = case s of
(A.IVar _ x) -> x
(A.IAbs _ x) -> x
(A.IThingAll _ x) -> x
(A.IThingWith _ x _) -> x
s' = case s of
(A.IVar l x) -> A.IThingAll l x
(A.IAbs l x) -> A.IThingAll l x
(A.IThingWith l x _) -> A.IThingAll l x
(A.IThingAll _ _) -> s
importPred (A.ImportDecl _ mn _ _ _ _ (Just (A.ImportSpecList _ _ []))) =
not remove || maybe False (isEmptyImport . A.importSpecs) (find ((== (unModuleName mn)) . unModuleName . A.importModule) oldImports)
where
isEmptyImport (Just (A.ImportSpecList _ _ [])) = True
isEmptyImport _ = False
importPred _ = True
sdTypes :: Set (Maybe S.ModuleName, S.Name)
sdTypes = standaloneDerivingTypes m
standaloneDerivingTypes :: ModuleInfo -> Set (Maybe S.ModuleName, S.Name)
standaloneDerivingTypes (A.XmlPage _ _ _ _ _ _ _, _, _) = error "standaloneDerivingTypes A.XmlPage"
standaloneDerivingTypes (A.XmlHybrid _ _ _ _ _ _ _ _ _, _, _) = error "standaloneDerivingTypes A.XmlHybrid"
standaloneDerivingTypes (A.Module _ _ _ _ decls, _, _) =
unions (map derivDeclTypes decls)
where
derivDeclTypes (A.DerivDecl _ _ (A.IHead _ _ xs)) = unions (map derivDeclTypes' xs)
derivDeclTypes (A.DerivDecl a b (A.IHParen _ x)) = derivDeclTypes (A.DerivDecl a b x)
derivDeclTypes (A.DerivDecl _ _ (A.IHInfix _ x _op y)) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes _ = empty
derivDeclTypes' (A.TyForall _ _ _ x) = derivDeclTypes' x
derivDeclTypes' (A.TyFun _ x y) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes' (A.TyTuple _ _ xs) = unions (map derivDeclTypes' xs)
derivDeclTypes' (A.TyList _ x) = derivDeclTypes' x
derivDeclTypes' (A.TyApp _ x y) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes' (A.TyVar _ _) = empty
derivDeclTypes' (A.TyCon _ (A.Qual _ m n)) = singleton (Just (sModuleName m), sName n)
derivDeclTypes' (A.TyCon _ (A.UnQual _ n)) = singleton (Nothing, sName n)
derivDeclTypes' (A.TyCon _ _) = empty
derivDeclTypes' (A.TyParen _ x) = derivDeclTypes' x
derivDeclTypes' (A.TyInfix _ x _op y) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes' (A.TyKind _ x _) = derivDeclTypes' x
importMergable :: A.ImportDecl SrcSpanInfo -> A.ImportDecl SrcSpanInfo -> Ordering
importMergable a b =
case (compare `on` noSpecs) a' b' of
EQ -> EQ
specOrdering ->
case (compare `on` S.importModule) a' b' of
EQ -> specOrdering
moduleNameOrdering -> moduleNameOrdering
where
a' = sImportDecl a
b' = sImportDecl b
noSpecs :: S.ImportDecl -> S.ImportDecl
noSpecs x = x { S.importLoc = def,
S.importSpecs = case S.importSpecs x of
Just (True, _) -> Just (True, [])
Just (False, _) -> Nothing
Nothing -> Nothing }
unModuleName :: A.ModuleName SrcSpanInfo -> String
unModuleName (A.ModuleName _ x) = x
compareSpecs :: A.ImportSpec SrcSpanInfo -> A.ImportSpec SrcSpanInfo -> Ordering
compareSpecs a b =
case compare (map (map toLower . nameString) $ catMaybes $ toList $ symbols a) (map (map toLower . nameString) $ catMaybes $ toList $ symbols b) of
EQ -> compare (sImportSpec a) (sImportSpec b)
x -> x
equalSpecs :: A.ImportSpec SrcSpanInfo -> A.ImportSpec SrcSpanInfo -> Bool
equalSpecs a b = compareSpecs a b == EQ
nameString :: S.Name -> String
nameString (S.Ident s) = s
nameString (S.Symbol s) = s