{-# OPTIONS_GHC -fno-warn-name-shadowing #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE OverloadedStrings, ScopedTypeVariables, PatternGuards #-} -- | Haskell indenter. module HIndent (-- * Formatting functions. reformat ,prettyPrint ,parseMode -- * Testing ,test ,testFile ,testAst ,testFileAst ,defaultExtensions ,getExtensions ) where import Control.Monad.State.Strict import Control.Monad.Trans.Maybe import Data.ByteString (ByteString) import qualified Data.ByteString as S import Data.ByteString.Builder (Builder) import qualified Data.ByteString.Builder as S import qualified Data.ByteString.Char8 as S8 import qualified Data.ByteString.Internal as S import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy.Char8 as L8 import qualified Data.ByteString.UTF8 as UTF8 import qualified Data.ByteString.Unsafe as S import Data.Char import Data.Foldable (foldr') import Data.Either import Data.Function import Data.Functor.Identity import Data.List import Data.Maybe import Data.Monoid import Data.Text (Text) import qualified Data.Text as T import Data.Traversable hiding (mapM) import HIndent.CodeBlock import HIndent.Pretty import HIndent.Types import qualified Language.Haskell.Exts as Exts import Language.Haskell.Exts hiding (Style, prettyPrint, Pretty, style, parse) import Prelude -- | Format the given source. reformat :: Config -> Maybe [Extension] -> Maybe FilePath -> ByteString -> Either String Builder reformat config mexts mfilepath = preserveTrailingNewline (fmap (mconcat . intersperse "\n") . mapM processBlock . cppSplitBlocks) where processBlock :: CodeBlock -> Either String Builder processBlock (Shebang text) = Right $ S.byteString text processBlock (CPPDirectives text) = Right $ S.byteString text processBlock (HaskellSource line text) = let ls = S8.lines text prefix = findPrefix ls code = unlines' (map (stripPrefix prefix) ls) exts = readExtensions (UTF8.toString code) mode'' = case exts of Nothing -> mode' Just (Nothing, exts') -> mode' { extensions = exts' ++ configExtensions config ++ extensions mode' } Just (Just lang, exts') -> mode' { baseLanguage = lang , extensions = exts' ++ configExtensions config ++ extensions mode' } in case parseModuleWithComments mode'' (UTF8.toString code) of ParseOk (m, comments) -> fmap (S.lazyByteString . addPrefix prefix . S.toLazyByteString) (prettyPrint config m comments) ParseFailed loc e -> Left (Exts.prettyPrint (loc {srcLine = srcLine loc + line}) ++ ": " ++ e) unlines' = S.concat . intersperse "\n" unlines'' = L.concat . intersperse "\n" addPrefix :: ByteString -> L8.ByteString -> L8.ByteString addPrefix prefix = unlines'' . map (L8.fromStrict prefix <>) . L8.lines stripPrefix :: ByteString -> ByteString -> ByteString stripPrefix prefix line = if S.null (S8.dropWhile (== '\n') line) then line else fromMaybe (error "Missing expected prefix") . s8_stripPrefix prefix $ line findPrefix :: [ByteString] -> ByteString findPrefix = takePrefix False . findSmallestPrefix . dropNewlines dropNewlines :: [ByteString] -> [ByteString] dropNewlines = filter (not . S.null . S8.dropWhile (== '\n')) takePrefix :: Bool -> ByteString -> ByteString takePrefix bracketUsed txt = case S8.uncons txt of Nothing -> "" Just ('>', txt') -> if not bracketUsed then S8.cons '>' (takePrefix True txt') else "" Just (c, txt') -> if c == ' ' || c == '\t' then S8.cons c (takePrefix bracketUsed txt') else "" findSmallestPrefix :: [ByteString] -> ByteString findSmallestPrefix [] = "" findSmallestPrefix ("":_) = "" findSmallestPrefix (p:ps) = let first = S8.head p startsWithChar c x = S8.length x > 0 && S8.head x == c in if all (startsWithChar first) ps then S8.cons first (findSmallestPrefix (S.tail p : map S.tail ps)) else "" mode' = let m = case mexts of Just exts -> parseMode { extensions = exts } Nothing -> parseMode in m { parseFilename = fromMaybe "" mfilepath } preserveTrailingNewline f x = if S8.null x || S8.all isSpace x then return mempty else if hasTrailingLine x || configTrailingNewline config then fmap (\x' -> if hasTrailingLine (L.toStrict (S.toLazyByteString x')) then x' else x' <> "\n") (f x) else f x -- | Does the strict bytestring have a trailing newline? hasTrailingLine :: ByteString -> Bool hasTrailingLine xs = if S8.null xs then False else S8.last xs == '\n' -- | Print the module. prettyPrint :: Config -> Module SrcSpanInfo -> [Comment] -> Either a Builder prettyPrint config m comments = let ast = evalState (collectAllComments (fromMaybe m (applyFixities baseFixities m))) comments in Right (runPrinterStyle config (pretty ast)) -- | Pretty print the given printable thing. runPrinterStyle :: Config -> Printer () -> Builder runPrinterStyle config m = maybe (error "Printer failed with mzero call.") psOutput (runIdentity (runMaybeT (execStateT (runPrinter m) (PrintState { psIndentLevel = 0 , psOutput = mempty , psNewline = False , psColumn = 0 , psLine = 1 , psConfig = config , psInsideCase = False , psFitOnOneLine = False , psEolComment = False })))) -- | Parse mode, includes all extensions, doesn't assume any fixities. parseMode :: ParseMode parseMode = defaultParseMode {extensions = allExtensions ,fixities = Nothing} where allExtensions = filter isDisabledExtension knownExtensions isDisabledExtension (DisableExtension _) = False isDisabledExtension _ = True -- | Test the given file. testFile :: FilePath -> IO () testFile fp = S.readFile fp >>= test -- | Test the given file. testFileAst :: FilePath -> IO () testFileAst fp = S.readFile fp >>= print . testAst -- | Test with the given style, prints to stdout. test :: ByteString -> IO () test = either error (L8.putStrLn . S.toLazyByteString) . reformat defaultConfig Nothing Nothing -- | Parse the source and annotate it with comments, yielding the resulting AST. testAst :: ByteString -> Either String (Module NodeInfo) testAst x = case parseModuleWithComments parseMode (UTF8.toString x) of ParseOk (m,comments) -> Right (let ast = evalState (collectAllComments (fromMaybe m (applyFixities baseFixities m))) comments in ast) ParseFailed _ e -> Left e -- | Default extensions. defaultExtensions :: [Extension] defaultExtensions = [ e | e@EnableExtension {} <- knownExtensions ] \\ map EnableExtension badExtensions -- | Extensions which steal too much syntax. badExtensions :: [KnownExtension] badExtensions = [Arrows -- steals proc ,TransformListComp -- steals the group keyword ,XmlSyntax, RegularPatterns -- steals a-b ,UnboxedTuples -- breaks (#) lens operator -- ,QuasiQuotes -- breaks [x| ...], making whitespace free list comps break ,PatternSynonyms -- steals the pattern keyword ,RecursiveDo -- steals the rec keyword ,DoRec -- same ,TypeApplications -- since GHC 8 and haskell-src-exts-1.19 ] s8_stripPrefix :: ByteString -> ByteString -> Maybe ByteString s8_stripPrefix bs1@(S.PS _ _ l1) bs2 | bs1 `S.isPrefixOf` bs2 = Just (S.unsafeDrop l1 bs2) | otherwise = Nothing -------------------------------------------------------------------------------- -- Extensions stuff stolen from hlint -- | Consume an extensions list from arguments. getExtensions :: [Text] -> [Extension] getExtensions = foldl f defaultExtensions . map T.unpack where f _ "Haskell98" = [] f a ('N':'o':x) | Just x' <- readExtension x = delete x' a f a x | Just x' <- readExtension x = x' : delete x' a f _ x = error $ "Unknown extension: " ++ x -------------------------------------------------------------------------------- -- Comments -- | Traverse the structure backwards. traverseInOrder :: (Monad m, Traversable t, Functor m) => (b -> b -> Ordering) -> (b -> m b) -> t b -> m (t b) traverseInOrder cmp f ast = do indexed <- fmap (zip [0 :: Integer ..] . reverse) (execStateT (traverse (modify . (:)) ast) []) let sorted = sortBy (\(_,x) (_,y) -> cmp x y) indexed results <- mapM (\(i,m) -> do v <- f m return (i, v)) sorted evalStateT (traverse (const (do i <- gets head modify tail case lookup i results of Nothing -> error "traverseInOrder" Just x -> return x)) ast) [0 ..] -- | Collect all comments in the module by traversing the tree. Read -- this from bottom to top. collectAllComments :: Module SrcSpanInfo -> State [Comment] (Module NodeInfo) collectAllComments = shortCircuit (traverseBackwards -- Finally, collect backwards comments which come after each node. (collectCommentsBy CommentAfterLine (\nodeSpan commentSpan -> fst (srcSpanStart commentSpan) >= fst (srcSpanEnd nodeSpan)))) <=< shortCircuit addCommentsToTopLevelWhereClauses <=< shortCircuit (traverse -- Collect forwards comments which start at the end line of a -- node: Does the start line of the comment match the end-line -- of the node? (collectCommentsBy CommentSameLine (\nodeSpan commentSpan -> fst (srcSpanStart commentSpan) == fst (srcSpanEnd nodeSpan)))) <=< shortCircuit (traverseBackwards -- Collect backwards comments which are on the same line as a -- node: Does the start line & end line of the comment match -- that of the node? (collectCommentsBy CommentSameLine (\nodeSpan commentSpan -> fst (srcSpanStart commentSpan) == fst (srcSpanStart nodeSpan) && fst (srcSpanStart commentSpan) == fst (srcSpanEnd nodeSpan)))) <=< shortCircuit (traverse -- First, collect forwards comments for declarations which both -- start on column 1 and occur before the declaration. (collectCommentsBy CommentBeforeLine (\nodeSpan commentSpan -> (snd (srcSpanStart nodeSpan) == 1 && snd (srcSpanStart commentSpan) == 1) && fst (srcSpanStart commentSpan) < fst (srcSpanStart nodeSpan)))) . fmap nodify where nodify s = NodeInfo s mempty -- Sort the comments by their end position. traverseBackwards = traverseInOrder (\x y -> on (flip compare) (srcSpanEnd . srcInfoSpan . nodeInfoSpan) x y) -- Stop traversing if all comments have been consumed. shortCircuit m v = do comments <- get if null comments then return v else m v -- | Collect comments by satisfying the given predicate, to collect a -- comment means to remove it from the pool of available comments in -- the State. This allows for a multiple pass approach. collectCommentsBy :: (SrcSpan -> SomeComment -> NodeComment) -> (SrcSpan -> SrcSpan -> Bool) -> NodeInfo -> State [Comment] NodeInfo collectCommentsBy cons predicate nodeInfo@(NodeInfo (SrcSpanInfo nodeSpan _) _) = do comments <- get let (others, mine) = partitionEithers (map (\comment@(Comment _ commentSpan _) -> if predicate nodeSpan commentSpan then Right comment else Left comment) comments) put others return $ addCommentsToNode cons mine nodeInfo -- | Reintroduce comments which were immediately above declarations in where clauses. -- Affects where clauses of top level declarations only. addCommentsToTopLevelWhereClauses :: Module NodeInfo -> State [Comment] (Module NodeInfo) addCommentsToTopLevelWhereClauses (Module x x' x'' x''' topLevelDecls) = Module x x' x'' x''' <$> traverse addCommentsToWhereClauses topLevelDecls where addCommentsToWhereClauses :: Decl NodeInfo -> State [Comment] (Decl NodeInfo) addCommentsToWhereClauses (PatBind x x' x'' (Just (BDecls x''' whereDecls))) = do newWhereDecls <- traverse addCommentsToPatBind whereDecls return $ PatBind x x' x'' (Just (BDecls x''' newWhereDecls)) addCommentsToWhereClauses other = return other addCommentsToPatBind :: Decl NodeInfo -> State [Comment] (Decl NodeInfo) addCommentsToPatBind (PatBind bindInfo (PVar x (Ident declNodeInfo declString)) x' x'') = do bindInfoWithComments <- addCommentsBeforeNode bindInfo return $ PatBind bindInfoWithComments (PVar x (Ident declNodeInfo declString)) x' x'' addCommentsToPatBind other = return other addCommentsBeforeNode :: NodeInfo -> State [Comment] NodeInfo addCommentsBeforeNode nodeInfo = do comments <- get let (notAbove, above) = partitionAboveNotAbove comments nodeInfo put notAbove return $ addCommentsToNode CommentBeforeLine above nodeInfo partitionAboveNotAbove :: [Comment] -> NodeInfo -> ([Comment], [Comment]) partitionAboveNotAbove cs (NodeInfo (SrcSpanInfo nodeSpan _) _) = fst $ foldr' (\comment@(Comment _ commentSpan _) ((ls, rs), lastSpan) -> if comment `isAbove` lastSpan then ((ls, comment : rs), commentSpan) else ((comment : ls, rs), lastSpan)) (([], []), nodeSpan) cs isAbove :: Comment -> SrcSpan -> Bool isAbove (Comment _ commentSpan _) span = let (_, commentColStart) = srcSpanStart commentSpan (commentLnEnd, _) = srcSpanEnd commentSpan (lnStart, colStart) = srcSpanStart span in commentColStart == colStart && commentLnEnd + 1 == lnStart addCommentsToTopLevelWhereClauses other = return other addCommentsToNode :: (SrcSpan -> SomeComment -> NodeComment) -> [Comment] -> NodeInfo -> NodeInfo addCommentsToNode mkNodeComment newComments nodeInfo@(NodeInfo (SrcSpanInfo _ _) existingComments) = nodeInfo {nodeInfoComments = existingComments <> map mkBeforeNodeComment newComments} where mkBeforeNodeComment :: Comment -> NodeComment mkBeforeNodeComment (Comment multiLine commentSpan commentString) = mkNodeComment commentSpan ((if multiLine then MultiLine else EndOfLine) commentString)