module Language.Haskell.Modules.Imports
( cleanImports
, tests
) where
import Control.Applicative ((<$>))
import "MonadCatchIO-mtl" Control.Monad.CatchIO as IO (bracket, catch, throw)
import Control.Monad.Trans (liftIO)
import Data.Char (toLower)
import Data.Default (def, Default)
import Data.Function (on)
import Data.List (find, groupBy, intercalate, nub, nubBy, sortBy)
import Data.Maybe (catMaybes, fromMaybe)
import Data.Monoid ((<>))
import Data.Set as Set (empty, member, Set, singleton, toList, union, unions)
import Language.Haskell.Exts.Annotated (ParseResult(..))
import Language.Haskell.Exts.Annotated.Simplify as S (sImportDecl, sImportSpec, sModuleName, sName)
import qualified Language.Haskell.Exts.Annotated.Syntax as A (Decl(DerivDecl), ImportDecl(..), ImportSpec(..), ImportSpecList(ImportSpecList), InstHead(..), Module(..), ModuleHead(ModuleHead), ModuleName(ModuleName), QName(..), Type(..))
import Language.Haskell.Exts.Extension (Extension(PackageImports, StandaloneDeriving, TypeSynonymInstances, FlexibleInstances))
import Language.Haskell.Exts.Pretty (defaultMode, PPHsMode(layout), PPLayout(PPInLine), prettyPrintWithMode)
import Language.Haskell.Exts.SrcLoc (SrcSpanInfo)
import qualified Language.Haskell.Exts.Syntax as S (ImportDecl(importLoc, importModule, importSpecs), ModuleName(..), Name(..))
import Language.Haskell.Modules.Common (modulePathBase, withCurrentDirectory)
import Language.Haskell.Modules.Fold (foldDecls, foldExports, foldHeader, foldImports)
import Language.Haskell.Modules.Internal (getParams, markForDelete, modifyParams, ModuleResult(..), MonadClean, Params(..), parseFile, runMonadClean, scratchDir)
import Language.Haskell.Modules.Util.DryIO (replaceFile, tildeBackup)
import Language.Haskell.Modules.Util.QIO (qPutStrLn, quietly)
import Language.Haskell.Modules.Util.Symbols (symbols)
import System.Cmd (system)
import System.Directory (createDirectoryIfMissing, getCurrentDirectory)
import System.Exit (ExitCode(..))
import System.FilePath ((<.>), (</>))
import System.Process (readProcessWithExitCode, showCommandForUser)
import Test.HUnit (assertEqual, Test(..))
cleanImports :: MonadClean m => FilePath -> m ModuleResult
cleanImports path =
do source <- parseFile path
case source of
ParseOk (m@(A.Module _ h _ imports _decls)) ->
do let name = case h of
Just (A.ModuleHead _ x _ _) -> sModuleName x
_ -> S.ModuleName "Main"
hiddenImports = filter isHiddenImport imports
dumpImports path >> checkImports path name m hiddenImports
ParseOk (A.XmlPage {}) -> error "cleanImports: XmlPage"
ParseOk (A.XmlHybrid {}) -> error "cleanImports: XmlHybrid"
ParseFailed _loc msg -> error ("cleanImports: - parse of " ++ path ++ " failed: " ++ msg)
where
isHiddenImport (A.ImportDecl {A.importSpecs = Just (A.ImportSpecList _ True _)}) = True
isHiddenImport _ = False
dumpImports :: MonadClean m => FilePath -> m ()
dumpImports path =
do scratch <- scratchDir <$> getParams
liftIO $ createDirectoryIfMissing True scratch
let cmd = "ghc"
args <- hsFlags <$> getParams
dirs <- sourceDirs <$> getParams
exts <- extensions <$> getParams
let args' = args ++ ["--make", "-c", "-ddump-minimal-imports", "-outputdir", scratch, "-i" ++ intercalate ":" dirs, path] ++ map (("-X" ++) . show) exts
(code, _out, err) <- liftIO $ readProcessWithExitCode cmd args' ""
case code of
ExitSuccess -> quietly (qPutStrLn (showCommandForUser cmd args' ++ " -> Ok")) >> return ()
ExitFailure _ -> error ("dumpImports: compile failed\n " ++ showCommandForUser cmd args' ++ " ->\n" ++ err)
checkImports :: MonadClean m => FilePath -> S.ModuleName -> A.Module SrcSpanInfo -> [A.ImportDecl SrcSpanInfo] -> m ModuleResult
checkImports path name@(S.ModuleName name') m extraImports =
do let importsPath = name' <.> ".imports"
markForDelete importsPath
result <-
bracket (getParams >>= return . extensions)
(\ saved -> modifyParams (\ p -> p {extensions = saved}))
(\ saved -> modifyParams (\ p -> p {extensions = PackageImports : saved}) >>
parseFile importsPath `catch` (\ (e :: IOError) -> liftIO (getCurrentDirectory >>= \ here -> throw . userError $ here ++ ": " ++ show e)))
case result of
ParseOk newImports -> updateSource path m newImports name extraImports
_ -> error ("checkImports: parse of " ++ importsPath ++ " failed - " ++ show result)
updateSource :: MonadClean m => FilePath -> A.Module SrcSpanInfo -> A.Module SrcSpanInfo -> S.ModuleName -> [A.ImportDecl SrcSpanInfo] -> m ModuleResult
updateSource path (m@(A.Module _ _ _ oldImports _)) (A.Module _ _ _ newImports _) name extraImports =
do remove <- removeEmptyImports <$> getParams
text <- liftIO $ readFile path
maybe (qPutStrLn ("cleanImports: no changes to " ++ path) >> return (Unchanged name))
(\ text' ->
qPutStrLn ("cleanImports: modifying " ++ path) >>
replaceFile tildeBackup path text' >>
return (Modified name text'))
(replaceImports (fixNewImports remove m oldImports (newImports ++ extraImports)) m text)
updateSource _ _ _ _ _ = error "updateSource"
replaceImports :: [A.ImportDecl SrcSpanInfo] -> A.Module SrcSpanInfo -> String -> Maybe String
replaceImports newImports m sourceText =
let oldPretty = foldImports (\ _ pref s suff r -> r <> pref <> s <> suff) m sourceText ""
newPretty = fromMaybe "" (foldImports (\ _ pref _ _ r -> maybe (Just pref) Just r) m sourceText Nothing) <>
intercalate "\n" (map (prettyPrintWithMode (defaultMode {layout = PPInLine})) newImports) <>
foldImports (\ _ _ _ suff _ -> suff) m sourceText "" in
if oldPretty == newPretty
then Nothing
else Just (foldHeader (\ s r -> r <> s) (\ _ pref s suff r -> r <> pref <> s <> suff) (\ _ pref s suff r -> r <> pref <> s <> suff) (\ _ pref s suff r -> r <> pref <> s <> suff) m sourceText "" ++
foldExports (\ s r -> r <> s) (\ _ pref s suff r -> r <> pref <> s <> suff) (\ s r -> r <> s) m sourceText "" ++
newPretty <>
foldDecls (\ _ pref s suff r -> r <> pref <> s <> suff) (\ r s -> r <> s) m sourceText "")
fixNewImports :: Bool
-> A.Module SrcSpanInfo
-> [A.ImportDecl SrcSpanInfo]
-> [A.ImportDecl SrcSpanInfo]
-> [A.ImportDecl SrcSpanInfo]
fixNewImports remove m oldImports imports =
filter importPred $ map expandSDTypes $ map mergeDecls $ groupBy (\ a b -> importMergable a b == EQ) $ sortBy importMergable imports
where
mergeDecls [] = error "mergeDecls"
mergeDecls xs@(x : _) = x {A.importSpecs = mergeSpecLists (catMaybes (map A.importSpecs xs))}
where
mergeSpecLists (A.ImportSpecList loc flag specs : ys) =
Just (A.ImportSpecList loc flag (mergeSpecs (sortBy compareSpecs (nub (concat (specs : map (\ (A.ImportSpecList _ _ specs') -> specs') ys))))))
mergeSpecLists [] = error "mergeSpecLists"
mergeSpecs ys = nubBy equalSpecs ys
expandSDTypes :: A.ImportDecl SrcSpanInfo -> A.ImportDecl SrcSpanInfo
expandSDTypes i@(A.ImportDecl {A.importSpecs = Just (A.ImportSpecList l f specs)}) =
i {A.importSpecs = Just (A.ImportSpecList l f (map (expandSpec i) specs))}
expandSDTypes i = i
expandSpec i s =
if not (A.importQualified i) && member (Nothing, sName n) sdTypes ||
maybe False (\ mn -> (member (Just (sModuleName mn), sName n) sdTypes)) (A.importAs i) ||
member (Just (sModuleName (A.importModule i)), sName n) sdTypes
then s'
else s
where
n = case s of
(A.IVar _ x) -> x
(A.IAbs _ x) -> x
(A.IThingAll _ x) -> x
(A.IThingWith _ x _) -> x
s' = case s of
(A.IVar l x) -> A.IThingAll l x
(A.IAbs l x) -> A.IThingAll l x
(A.IThingWith l x _) -> A.IThingAll l x
(A.IThingAll _ _) -> s
importPred (A.ImportDecl _ mn _ _ _ _ (Just (A.ImportSpecList _ _ []))) =
not remove || maybe False (isEmptyImport . A.importSpecs) (find ((== (unModuleName mn)) . unModuleName . A.importModule) oldImports)
where
isEmptyImport (Just (A.ImportSpecList _ _ [])) = True
isEmptyImport _ = False
importPred _ = True
sdTypes :: Set (Maybe S.ModuleName, S.Name)
sdTypes = standaloneDerivingTypes m
standaloneDerivingTypes :: A.Module SrcSpanInfo -> Set (Maybe S.ModuleName, S.Name)
standaloneDerivingTypes (A.XmlPage _ _ _ _ _ _ _) = error "standaloneDerivingTypes A.XmlPage"
standaloneDerivingTypes (A.XmlHybrid _ _ _ _ _ _ _ _ _) = error "standaloneDerivingTypes A.XmlHybrid"
standaloneDerivingTypes (A.Module _ _ _ _ decls) =
unions (map derivDeclTypes decls)
where
derivDeclTypes (A.DerivDecl _ _ (A.IHead _ _ xs)) = unions (map derivDeclTypes' xs)
derivDeclTypes (A.DerivDecl a b (A.IHParen _ x)) = derivDeclTypes (A.DerivDecl a b x)
derivDeclTypes (A.DerivDecl _ _ (A.IHInfix _ x _op y)) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes _ = empty
derivDeclTypes' (A.TyForall _ _ _ x) = derivDeclTypes' x
derivDeclTypes' (A.TyFun _ x y) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes' (A.TyTuple _ _ xs) = unions (map derivDeclTypes' xs)
derivDeclTypes' (A.TyList _ x) = derivDeclTypes' x
derivDeclTypes' (A.TyApp _ x y) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes' (A.TyVar _ _) = empty
derivDeclTypes' (A.TyCon _ (A.Qual _ m n)) = singleton (Just (sModuleName m), sName n)
derivDeclTypes' (A.TyCon _ (A.UnQual _ n)) = singleton (Nothing, sName n)
derivDeclTypes' (A.TyCon _ _) = empty
derivDeclTypes' (A.TyParen _ x) = derivDeclTypes' x
derivDeclTypes' (A.TyInfix _ x _op y) = union (derivDeclTypes' x) (derivDeclTypes' y)
derivDeclTypes' (A.TyKind _ x _) = derivDeclTypes' x
importMergable :: A.ImportDecl SrcSpanInfo -> A.ImportDecl SrcSpanInfo -> Ordering
importMergable a b =
case (compare `on` noSpecs) a' b' of
EQ -> EQ
specOrdering ->
case (compare `on` S.importModule) a' b' of
EQ -> specOrdering
moduleNameOrdering -> moduleNameOrdering
where
a' = sImportDecl a
b' = sImportDecl b
noSpecs :: S.ImportDecl -> S.ImportDecl
noSpecs x = x { S.importLoc = def,
S.importSpecs = case S.importSpecs x of
Just (True, _) -> Just (True, [])
Just (False, _) -> Nothing
Nothing -> Nothing }
unModuleName :: A.ModuleName SrcSpanInfo -> String
unModuleName (A.ModuleName _ x) = x
compareSpecs :: A.ImportSpec SrcSpanInfo -> A.ImportSpec SrcSpanInfo -> Ordering
compareSpecs a b =
case compare (map (map toLower . nameString) $ catMaybes $ toList $ symbols a) (map (map toLower . nameString) $ catMaybes $ toList $ symbols b) of
EQ -> compare (sImportSpec a) (sImportSpec b)
x -> x
equalSpecs :: A.ImportSpec SrcSpanInfo -> A.ImportSpec SrcSpanInfo -> Bool
equalSpecs a b = compareSpecs a b == EQ
nameString :: S.Name -> String
nameString (S.Ident s) = s
nameString (S.Symbol s) = s
tests :: Test
tests = TestLabel "Clean" (TestList [test1, test2, test3, test4, test5])
test1 :: Test
test1 =
TestLabel "Imports.test1" $ TestCase
(do _ <- system "rsync -aHxS --delete testdata/original/ testdata/copy"
let name = S.ModuleName "Debian.Repo.Types.PackageIndex"
let base = modulePathBase name
_ <- withCurrentDirectory "testdata/copy" (runMonadClean (cleanImports base))
(code, diff, err) <- readProcessWithExitCode "diff" ["-ru", "testdata/original" </> base, "testdata/copy" </> base] ""
assertEqual "cleanImports"
(ExitFailure 1,
["@@ -22,13 +22,13 @@",
" , prettyPkgVersion",
" ) where",
" ",
"-import Data.Text (Text, map)",
"+import Data.Text (Text)",
" import Debian.Arch (Arch(..))",
" import qualified Debian.Control.Text as T (Paragraph)",
" import Debian.Relation (BinPkgName(..), SrcPkgName(..))",
" import qualified Debian.Relation as B (PkgName, Relations)",
" import Debian.Release (Section(..))",
"-import Debian.Repo.Orphans ({- instances -})",
"+import Debian.Repo.Orphans ()",
" import Debian.Version (DebianVersion, prettyDebianVersion)",
" import System.Posix.Types (FileOffset)",
" import Text.PrettyPrint.ANSI.Leijen ((<>), Doc, Pretty(pretty), text)"],
"")
(code, drop 2 (lines diff), err))
test2 :: Test
test2 =
TestLabel "Imports.test2" $ TestCase
(do _ <- system "rsync -aHxS --delete testdata/original/ testdata/copy"
let name = S.ModuleName "Debian.Repo.PackageIndex"
base = modulePathBase name
_ <- withCurrentDirectory "testdata/copy" (runMonadClean (cleanImports base))
(code, diff, err) <- readProcessWithExitCode "diff" ["-ru", "testdata/original" </> base, "testdata/copy" </> base] ""
assertEqual "cleanImports" (ExitSuccess, "", "") (code, diff, err))
test3 :: Test
test3 =
TestLabel "Imports.test3" $ TestCase
(runMonadClean (modifyParams (\ p -> p {sourceDirs = ["testdata"]}) >> cleanImports "testdata/NotMain.hs") >>
assertEqual "module name" () ())
test4 :: Test
test4 =
TestLabel "Imports.test4" $ TestCase
(system "cp testdata/HidingOrig.hs testdata/Hiding.hs" >>
runMonadClean (modifyParams (\ p -> p {sourceDirs = ["testdata"]}) >> cleanImports "testdata/Hiding.hs") >>
assertEqual "module name" () ())
test5 :: Test
test5 =
TestLabel "Imports.test5" $ TestCase
(do _ <- system "cp testdata/DerivingOrig.hs testdata/Deriving.hs"
_ <- runMonadClean (modifyParams (\ p -> p {extensions = extensions p ++ [StandaloneDeriving, TypeSynonymInstances, FlexibleInstances],
sourceDirs = ["testdata"]}) >>
cleanImports "testdata/Deriving.hs")
(code, diff, err) <- readProcessWithExitCode "diff" ["-ru", "testdata/DerivingOrig.hs", "testdata/Deriving.hs"] ""
assertEqual "standalone deriving"
(ExitFailure 1,
(unlines
["@@ -1,7 +1,6 @@",
" module Deriving where",
" ",
"-import Data.Text (Text)",
"-import Debian.Control (Paragraph(..), Paragraph'(..), Field'(..))",
"+import Debian.Control (Field'(..), Paragraph(..))",
" ",
" deriving instance Show (Field' String)",
" deriving instance Show Paragraph"]),
"")
(code, unlines (drop 2 (lines diff)), err))