{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}

-- | A formatter for Haskell source code.
module Ormolu
  ( ormolu,
    ormoluFile,
    ormoluStdin,
    Config (..),
    ColorMode (..),
    RegionIndices (..),
    SourceType (..),
    defaultConfig,
    detectSourceType,
    DynOption (..),
    PrinterOpts (..),
    PrinterOptsPartial,
    PrinterOptsTotal,
    defaultPrinterOpts,
    loadConfigFile,
    ConfigFileLoadResult (..),
    configFileName,
    fillMissingPrinterOpts,
    OrmoluException (..),
    withPrettyOrmoluExceptions,
  )
where

import Control.Exception
import Control.Monad
import Control.Monad.IO.Class (MonadIO (..))
import Data.Text (Text)
import qualified Data.Text as T
import Debug.Trace
import qualified GHC.Driver.CmdLine as GHC
import qualified GHC.Types.SrcLoc as GHC
import Ormolu.Config
import Ormolu.Diff.ParseResult
import Ormolu.Diff.Text
import Ormolu.Exception
import Ormolu.Fixity
import Ormolu.Parser
import Ormolu.Parser.CommentStream (showCommentStream)
import Ormolu.Parser.Result
import Ormolu.Printer
import Ormolu.Utils (showOutputable)
import Ormolu.Utils.IO
import System.FilePath

-- | Format a 'String', return formatted version as 'Text'.
--
-- The function
--
--     * Takes 'String' because that's what GHC parser accepts.
--     * Needs 'IO' because some functions from GHC that are necessary to
--       setup parsing context require 'IO'. There should be no visible
--       side-effects though.
--     * Takes file name just to use it in parse error messages.
--     * Throws 'OrmoluException'.
--
-- __NOTE__: The caller is responsible for setting the appropriate value in
-- the 'cfgSourceType' field. Autodetection of source type won't happen
-- here, see 'detectSourceType'.
ormolu ::
  MonadIO m =>
  -- | Ormolu configuration
  Config RegionIndices ->
  -- | Location of source file
  FilePath ->
  -- | Input to format
  String ->
  m Text
ormolu :: forall (m :: * -> *).
MonadIO m =>
Config RegionIndices -> String -> String -> m Text
ormolu Config RegionIndices
cfgWithIndices String
path String
originalInput = do
  let totalLines :: Int
totalLines = forall (t :: * -> *) a. Foldable t => t a -> Int
length (String -> [String]
lines String
originalInput)
      cfg :: Config RegionDeltas
cfg = Int -> RegionIndices -> RegionDeltas
regionIndicesToDeltas Int
totalLines forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Config RegionIndices
cfgWithIndices
      fixityMap :: LazyFixityMap
fixityMap =
        -- It is important to keep all arguments (but last) of
        -- 'buildFixityMap' constant (such as 'defaultStrategyThreshold'),
        -- otherwise it is going to break memoization.
        Float -> Set String -> LazyFixityMap
buildFixityMap
          Float
defaultStrategyThreshold
          (forall region. Config region -> Set String
cfgDependencies Config RegionDeltas
cfg) -- memoized on the set of dependencies
  ([Warn]
warnings, [SourceSnippet]
result0) <-
    forall (m :: * -> *).
MonadIO m =>
Config RegionDeltas
-> LazyFixityMap
-> (SrcSpan -> String -> OrmoluException)
-> String
-> String
-> m ([Warn], [SourceSnippet])
parseModule' Config RegionDeltas
cfg LazyFixityMap
fixityMap SrcSpan -> String -> OrmoluException
OrmoluParsingFailed String
path String
originalInput
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall region. Config region -> Bool
cfgDebug Config RegionDeltas
cfg) forall a b. (a -> b) -> a -> b
$ do
    forall (f :: * -> *). Applicative f => String -> f ()
traceM String
"warnings:\n"
    forall (f :: * -> *). Applicative f => String -> f ()
traceM (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Warn -> String
showWarn [Warn]
warnings)
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SourceSnippet]
result0 forall a b. (a -> b) -> a -> b
$ \case
      ParsedSnippet ParseResult
r -> forall (f :: * -> *). Applicative f => String -> f ()
traceM forall b c a. (b -> c) -> (a -> b) -> a -> c
. CommentStream -> String
showCommentStream forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParseResult -> CommentStream
prCommentStream forall a b. (a -> b) -> a -> b
$ ParseResult
r
      SourceSnippet
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  -- We're forcing 'formattedText' here because otherwise errors (such as
  -- messages about not-yet-supported functionality) will be thrown later
  -- when we try to parse the rendered code back, inside of GHC monad
  -- wrapper which will lead to error messages presenting the exceptions as
  -- GHC bugs.
  let !formattedText :: Text
formattedText = [SourceSnippet] -> PrinterOptsTotal -> Text
printSnippets [SourceSnippet]
result0 forall a b. (a -> b) -> a -> b
$ forall region. Config region -> PrinterOptsTotal
cfgPrinterOpts Config RegionIndices
cfgWithIndices
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (forall region. Config region -> Bool
cfgUnsafe Config RegionDeltas
cfg) Bool -> Bool -> Bool
|| forall region. Config region -> Bool
cfgCheckIdempotence Config RegionDeltas
cfg) forall a b. (a -> b) -> a -> b
$ do
    -- Parse the result of pretty-printing again and make sure that AST
    -- is the same as AST of original snippet module span positions.
    ([Warn]
_, [SourceSnippet]
result1) <-
      forall (m :: * -> *).
MonadIO m =>
Config RegionDeltas
-> LazyFixityMap
-> (SrcSpan -> String -> OrmoluException)
-> String
-> String
-> m ([Warn], [SourceSnippet])
parseModule'
        Config RegionDeltas
cfg
        LazyFixityMap
fixityMap
        SrcSpan -> String -> OrmoluException
OrmoluOutputParsingFailed
        String
path
        (Text -> String
T.unpack Text
formattedText)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall region. Config region -> Bool
cfgUnsafe Config RegionDeltas
cfg) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
      let diff :: TextDiff
diff = case Text -> Text -> String -> Maybe TextDiff
diffText (String -> Text
T.pack String
originalInput) Text
formattedText String
path of
            Maybe TextDiff
Nothing -> forall a. HasCallStack => String -> a
error String
"AST differs, yet no changes have been introduced"
            Just TextDiff
x -> TextDiff
x
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SourceSnippet]
result0 forall a. Eq a => a -> a -> Bool
/= forall (t :: * -> *) a. Foldable t => t a -> Int
length [SourceSnippet]
result1) forall a b. (a -> b) -> a -> b
$
        forall e a. Exception e => e -> IO a
throwIO (TextDiff -> [RealSrcSpan] -> OrmoluException
OrmoluASTDiffers TextDiff
diff [])
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SourceSnippet]
result0 forall a b. [a] -> [b] -> [(a, b)]
`zip` [SourceSnippet]
result1) forall a b. (a -> b) -> a -> b
$ \case
        (ParsedSnippet ParseResult
s, ParsedSnippet ParseResult
s') -> case ParseResult -> ParseResult -> ParseResultDiff
diffParseResult ParseResult
s ParseResult
s' of
          ParseResultDiff
Same -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Different [RealSrcSpan]
ss -> forall e a. Exception e => e -> IO a
throwIO (TextDiff -> [RealSrcSpan] -> OrmoluException
OrmoluASTDiffers ([RealSrcSpan] -> TextDiff -> TextDiff
selectSpans [RealSrcSpan]
ss TextDiff
diff) [RealSrcSpan]
ss)
        (RawSnippet {}, RawSnippet {}) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        (SourceSnippet, SourceSnippet)
_ -> forall e a. Exception e => e -> IO a
throwIO (TextDiff -> [RealSrcSpan] -> OrmoluException
OrmoluASTDiffers TextDiff
diff [])
    -- Try re-formatting the formatted result to check if we get exactly
    -- the same output.
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall region. Config region -> Bool
cfgCheckIdempotence Config RegionDeltas
cfg) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$
      let reformattedText :: Text
reformattedText = [SourceSnippet] -> PrinterOptsTotal -> Text
printSnippets [SourceSnippet]
result1 forall a b. (a -> b) -> a -> b
$ forall region. Config region -> PrinterOptsTotal
cfgPrinterOpts Config RegionIndices
cfgWithIndices
       in case Text -> Text -> String -> Maybe TextDiff
diffText Text
formattedText Text
reformattedText String
path of
            Maybe TextDiff
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just TextDiff
diff -> forall e a. Exception e => e -> IO a
throwIO (TextDiff -> OrmoluException
OrmoluNonIdempotentOutput TextDiff
diff)
  forall (m :: * -> *) a. Monad m => a -> m a
return Text
formattedText

-- | Load a file and format it. The file stays intact and the rendered
-- version is returned as 'Text'.
--
-- __NOTE__: The caller is responsible for setting the appropriate value in
-- the 'cfgSourceType' field. Autodetection of source type won't happen
-- here, see 'detectSourceType'.
ormoluFile ::
  MonadIO m =>
  -- | Ormolu configuration
  Config RegionIndices ->
  -- | Location of source file
  FilePath ->
  -- | Resulting rendition
  m Text
ormoluFile :: forall (m :: * -> *).
MonadIO m =>
Config RegionIndices -> String -> m Text
ormoluFile Config RegionIndices
cfg String
path =
  forall (m :: * -> *). MonadIO m => String -> m Text
readFileUtf8 String
path forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *).
MonadIO m =>
Config RegionIndices -> String -> String -> m Text
ormolu Config RegionIndices
cfg String
path forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack

-- | Read input from stdin and format it.
--
-- __NOTE__: The caller is responsible for setting the appropriate value in
-- the 'cfgSourceType' field. Autodetection of source type won't happen
-- here, see 'detectSourceType'.
ormoluStdin ::
  MonadIO m =>
  -- | Ormolu configuration
  Config RegionIndices ->
  -- | Resulting rendition
  m Text
ormoluStdin :: forall (m :: * -> *). MonadIO m => Config RegionIndices -> m Text
ormoluStdin Config RegionIndices
cfg =
  forall (m :: * -> *). MonadIO m => m Text
getContentsUtf8 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *).
MonadIO m =>
Config RegionIndices -> String -> String -> m Text
ormolu Config RegionIndices
cfg String
"<stdin>" forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack

----------------------------------------------------------------------------
-- Helpers

-- | A wrapper around 'parseModule'.
parseModule' ::
  MonadIO m =>
  -- | Ormolu configuration
  Config RegionDeltas ->
  -- | Fixity Map for operators
  LazyFixityMap ->
  -- | How to obtain 'OrmoluException' to throw when parsing fails
  (GHC.SrcSpan -> String -> OrmoluException) ->
  -- | File name to use in errors
  FilePath ->
  -- | Actual input for the parser
  String ->
  m ([GHC.Warn], [SourceSnippet])
parseModule' :: forall (m :: * -> *).
MonadIO m =>
Config RegionDeltas
-> LazyFixityMap
-> (SrcSpan -> String -> OrmoluException)
-> String
-> String
-> m ([Warn], [SourceSnippet])
parseModule' Config RegionDeltas
cfg LazyFixityMap
fixityMap SrcSpan -> String -> OrmoluException
mkException String
path String
str = do
  ([Warn]
warnings, Either (SrcSpan, String) [SourceSnippet]
r) <- forall (m :: * -> *).
MonadIO m =>
Config RegionDeltas
-> LazyFixityMap
-> String
-> String
-> m ([Warn], Either (SrcSpan, String) [SourceSnippet])
parseModule Config RegionDeltas
cfg LazyFixityMap
fixityMap String
path String
str
  case Either (SrcSpan, String) [SourceSnippet]
r of
    Left (SrcSpan
spn, String
err) -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => e -> IO a
throwIO (SrcSpan -> String -> OrmoluException
mkException SrcSpan
spn String
err)
    Right [SourceSnippet]
x -> forall (m :: * -> *) a. Monad m => a -> m a
return ([Warn]
warnings, [SourceSnippet]
x)

-- | Pretty-print a 'GHC.Warn'.
showWarn :: GHC.Warn -> String
showWarn :: Warn -> String
showWarn (GHC.Warn WarnReason
reason Located String
l) =
  [String] -> String
unlines
    [ forall o. Outputable o => o -> String
showOutputable WarnReason
reason,
      forall o. Outputable o => o -> String
showOutputable Located String
l
    ]

-- | Detect 'SourceType' based on the file extension.
detectSourceType :: FilePath -> SourceType
detectSourceType :: String -> SourceType
detectSourceType String
mpath =
  if String -> String
takeExtension String
mpath forall a. Eq a => a -> a -> Bool
== String
".hsig"
    then SourceType
SignatureSource
    else SourceType
ModuleSource