{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Haskell indenter.
module HIndent
  ( -- * Formatting functions.
    reformat
  , prettyPrint
  , -- * Testing
    testAst
  ) where

import Control.Monad.State.Strict
import Control.Monad.Trans.Maybe
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as S
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Internal as S
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Char8 as L8
import qualified Data.ByteString.UTF8 as UTF8
import qualified Data.ByteString.Unsafe as S
import Data.Char
import Data.Either
import Data.Function
import Data.Functor.Identity
import Data.List hiding (stripPrefix)
import Data.Maybe
import Data.Monoid
import GHC.Hs
import GHC.Parser.Lexer hiding (buffer)
import GHC.Types.SrcLoc
import HIndent.CodeBlock
import HIndent.Config
import HIndent.LanguageExtension
import qualified HIndent.LanguageExtension.Conversion as CE
import HIndent.LanguageExtension.Types
import HIndent.ModulePreprocessing
import HIndent.Parse
import HIndent.Pretty
import HIndent.Printer
import Prelude

-- | Format the given source.
reformat ::
     Config
  -> Maybe [Extension]
  -> Maybe FilePath
  -> ByteString
  -> Either String Builder
reformat :: Config
-> Maybe [Extension]
-> Maybe [Char]
-> ByteString
-> Either [Char] Builder
reformat Config
config Maybe [Extension]
mexts Maybe [Char]
mfilepath =
  (ByteString -> Either [Char] Builder)
-> ByteString -> Either [Char] Builder
forall {m :: * -> *}.
Monad m =>
(ByteString -> m Builder) -> ByteString -> m Builder
preserveTrailingNewline
    (([Builder] -> Builder)
-> Either [Char] [Builder] -> Either [Char] Builder
forall a b. (a -> b) -> Either [Char] a -> Either [Char] b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Builder] -> Builder
forall a. Monoid a => [a] -> a
mconcat ([Builder] -> Builder)
-> ([Builder] -> [Builder]) -> [Builder] -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> [Builder] -> [Builder]
forall a. a -> [a] -> [a]
intersperse Builder
"\n") (Either [Char] [Builder] -> Either [Char] Builder)
-> (ByteString -> Either [Char] [Builder])
-> ByteString
-> Either [Char] Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CodeBlock -> Either [Char] Builder)
-> [CodeBlock] -> Either [Char] [Builder]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM CodeBlock -> Either [Char] Builder
processBlock ([CodeBlock] -> Either [Char] [Builder])
-> (ByteString -> [CodeBlock])
-> ByteString
-> Either [Char] [Builder]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [CodeBlock]
cppSplitBlocks)
  where
    processBlock :: CodeBlock -> Either String Builder
    processBlock :: CodeBlock -> Either [Char] Builder
processBlock (Shebang ByteString
text) = Builder -> Either [Char] Builder
forall a b. b -> Either a b
Right (Builder -> Either [Char] Builder)
-> Builder -> Either [Char] Builder
forall a b. (a -> b) -> a -> b
$ ByteString -> Builder
S.byteString ByteString
text
    processBlock (CPPDirectives ByteString
text) = Builder -> Either [Char] Builder
forall a b. b -> Either a b
Right (Builder -> Either [Char] Builder)
-> Builder -> Either [Char] Builder
forall a b. (a -> b) -> a -> b
$ ByteString -> Builder
S.byteString ByteString
text
    processBlock (HaskellSource Int
yPos ByteString
text) =
      let ls :: [ByteString]
ls = ByteString -> [ByteString]
S8.lines ByteString
text
          prefix :: ByteString
prefix = [ByteString] -> ByteString
findPrefix [ByteString]
ls
          code :: ByteString
code = [ByteString] -> ByteString
unlines' ((ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (ByteString -> ByteString -> ByteString
stripPrefix ByteString
prefix) [ByteString]
ls)
          allExts :: [Extension]
allExts =
            [Extension] -> Maybe [Extension] -> [Extension]
forall a. a -> Maybe a -> a
fromMaybe [Extension]
allExtensions Maybe [Extension]
mexts [Extension] -> [Extension] -> [Extension]
forall a. [a] -> [a] -> [a]
++
            Config -> [Extension]
configExtensions Config
config [Extension] -> [Extension] -> [Extension]
forall a. [a] -> [a] -> [a]
++
            [Char] -> [Extension]
collectLanguageExtensionsFromSource (ByteString -> [Char]
UTF8.toString ByteString
code)
          exts :: [Extension]
exts = [Extension] -> [Extension]
CE.uniqueExtensions [Extension]
allExts
       in case Maybe [Char] -> [Extension] -> [Char] -> ParseResult HsModule
parseModule Maybe [Char]
mfilepath [Extension]
exts (ByteString -> [Char]
UTF8.toString ByteString
code) of
            POk PState
_ HsModule
m ->
              Builder -> Either [Char] Builder
forall a b. b -> Either a b
Right (Builder -> Either [Char] Builder)
-> Builder -> Either [Char] Builder
forall a b. (a -> b) -> a -> b
$
              ByteString -> Builder
S.lazyByteString (ByteString -> Builder) -> ByteString -> Builder
forall a b. (a -> b) -> a -> b
$
              ByteString -> ByteString -> ByteString
addPrefix ByteString
prefix (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
S.toLazyByteString (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$ Config -> HsModule -> Builder
prettyPrint Config
config HsModule
m
            PFailed PState
st ->
              let rawErrLoc :: RealSrcLoc
rawErrLoc = PsLoc -> RealSrcLoc
psRealLoc (PsLoc -> RealSrcLoc) -> PsLoc -> RealSrcLoc
forall a b. (a -> b) -> a -> b
$ PState -> PsLoc
loc PState
st
                  adjustedLoc :: (Int, Int)
adjustedLoc =
                    (RealSrcLoc -> Int
srcLocLine RealSrcLoc
rawErrLoc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
yPos, RealSrcLoc -> Int
srcLocCol RealSrcLoc
rawErrLoc)
               in [Char] -> Either [Char] Builder
forall a b. a -> Either a b
Left ([Char] -> Either [Char] Builder)
-> [Char] -> Either [Char] Builder
forall a b. (a -> b) -> a -> b
$ [Char]
"Parse failed near " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (Int, Int)
adjustedLoc
    unlines' :: [ByteString] -> ByteString
unlines' = [ByteString] -> ByteString
S.concat ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
intersperse ByteString
"\n"
    unlines'' :: [ByteString] -> ByteString
unlines'' = [ByteString] -> ByteString
L.concat ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
intersperse ByteString
"\n"
    addPrefix :: ByteString -> L8.ByteString -> L8.ByteString
    addPrefix :: ByteString -> ByteString -> ByteString
addPrefix ByteString
prefix = [ByteString] -> ByteString
unlines'' ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (ByteString -> ByteString
L8.fromStrict ByteString
prefix ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
L8.lines
    stripPrefix :: ByteString -> ByteString -> ByteString
    stripPrefix :: ByteString -> ByteString -> ByteString
stripPrefix ByteString
prefix ByteString
line =
      if ByteString -> Bool
S.null ((Char -> Bool) -> ByteString -> ByteString
S8.dropWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\n') ByteString
line)
        then ByteString
line
        else ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> ByteString
forall a. HasCallStack => [Char] -> a
error [Char]
"Missing expected prefix") (Maybe ByteString -> ByteString)
-> (ByteString -> Maybe ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString -> Maybe ByteString
s8_stripPrefix ByteString
prefix (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
             ByteString
line
    findPrefix :: [ByteString] -> ByteString
    findPrefix :: [ByteString] -> ByteString
findPrefix = Bool -> ByteString -> ByteString
takePrefix Bool
False (ByteString -> ByteString)
-> ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
findSmallestPrefix ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> [ByteString]
dropNewlines
    dropNewlines :: [ByteString] -> [ByteString]
    dropNewlines :: [ByteString] -> [ByteString]
dropNewlines = (ByteString -> Bool) -> [ByteString] -> [ByteString]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (ByteString -> Bool) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Bool
S.null (ByteString -> Bool)
-> (ByteString -> ByteString) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> ByteString -> ByteString
S8.dropWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\n'))
    takePrefix :: Bool -> ByteString -> ByteString
    takePrefix :: Bool -> ByteString -> ByteString
takePrefix Bool
bracketUsed ByteString
txt =
      case ByteString -> Maybe (Char, ByteString)
S8.uncons ByteString
txt of
        Maybe (Char, ByteString)
Nothing -> ByteString
""
        Just (Char
'>', ByteString
txt') ->
          if Bool -> Bool
not Bool
bracketUsed
            then Char -> ByteString -> ByteString
S8.cons Char
'>' (Bool -> ByteString -> ByteString
takePrefix Bool
True ByteString
txt')
            else ByteString
""
        Just (Char
c, ByteString
txt') ->
          if Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
' ' Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\t'
            then Char -> ByteString -> ByteString
S8.cons Char
c (Bool -> ByteString -> ByteString
takePrefix Bool
bracketUsed ByteString
txt')
            else ByteString
""
    findSmallestPrefix :: [ByteString] -> ByteString
    findSmallestPrefix :: [ByteString] -> ByteString
findSmallestPrefix [] = ByteString
""
    findSmallestPrefix (ByteString
"":[ByteString]
_) = ByteString
""
    findSmallestPrefix (ByteString
p:[ByteString]
ps) =
      let first :: Char
first = ByteString -> Char
S8.head ByteString
p
          startsWithChar :: Char -> ByteString -> Bool
startsWithChar Char
c ByteString
x = ByteString -> Int
S8.length ByteString
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& ByteString -> Char
S8.head ByteString
x Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
c
       in if (ByteString -> Bool) -> [ByteString] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Char -> ByteString -> Bool
startsWithChar Char
first) [ByteString]
ps
            then Char -> ByteString -> ByteString
S8.cons Char
first ([ByteString] -> ByteString
findSmallestPrefix (HasCallStack => ByteString -> ByteString
ByteString -> ByteString
S.tail ByteString
p ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map HasCallStack => ByteString -> ByteString
ByteString -> ByteString
S.tail [ByteString]
ps))
            else ByteString
""
    preserveTrailingNewline :: (ByteString -> m Builder) -> ByteString -> m Builder
preserveTrailingNewline ByteString -> m Builder
f ByteString
x
      | ByteString -> Bool
S8.null ByteString
x Bool -> Bool -> Bool
|| (Char -> Bool) -> ByteString -> Bool
S8.all Char -> Bool
isSpace ByteString
x = Builder -> m Builder
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Builder
forall a. Monoid a => a
mempty
      | ByteString -> Bool
hasTrailingLine ByteString
x Bool -> Bool -> Bool
|| Config -> Bool
configTrailingNewline Config
config =
        (Builder -> Builder) -> m Builder -> m Builder
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
          (\Builder
x' ->
             if ByteString -> Bool
hasTrailingLine (ByteString -> ByteString
L.toStrict (Builder -> ByteString
S.toLazyByteString Builder
x'))
               then Builder
x'
               else Builder
x' Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
"\n")
          (ByteString -> m Builder
f ByteString
x)
      | Bool
otherwise = ByteString -> m Builder
f ByteString
x

-- | Generate an AST from the given module for debugging.
testAst :: ByteString -> Either String HsModule
testAst :: ByteString -> Either [Char] HsModule
testAst ByteString
x =
  case Maybe [Char] -> [Extension] -> [Char] -> ParseResult HsModule
parseModule Maybe [Char]
forall a. Maybe a
Nothing [Extension]
exts (ByteString -> [Char]
UTF8.toString ByteString
x) of
    POk PState
_ HsModule
m -> HsModule -> Either [Char] HsModule
forall a b. b -> Either a b
Right (HsModule -> Either [Char] HsModule)
-> HsModule -> Either [Char] HsModule
forall a b. (a -> b) -> a -> b
$ HsModule -> HsModule
modifyASTForPrettyPrinting HsModule
m
    PFailed PState
st ->
      [Char] -> Either [Char] HsModule
forall a b. a -> Either a b
Left ([Char] -> Either [Char] HsModule)
-> [Char] -> Either [Char] HsModule
forall a b. (a -> b) -> a -> b
$
      [Char]
"Parse failed near " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
      (Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show ((,) (Int -> Int -> (Int, Int))
-> (RealSrcLoc -> Int) -> RealSrcLoc -> Int -> (Int, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RealSrcLoc -> Int
srcLocLine (RealSrcLoc -> Int -> (Int, Int))
-> (RealSrcLoc -> Int) -> RealSrcLoc -> (Int, Int)
forall a b.
(RealSrcLoc -> a -> b) -> (RealSrcLoc -> a) -> RealSrcLoc -> b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> RealSrcLoc -> Int
srcLocCol (RealSrcLoc -> (Int, Int)) -> RealSrcLoc -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ PsLoc -> RealSrcLoc
psRealLoc (PsLoc -> RealSrcLoc) -> PsLoc -> RealSrcLoc
forall a b. (a -> b) -> a -> b
$ PState -> PsLoc
loc PState
st)
  where
    exts :: [Extension]
exts =
      [Extension] -> [Extension]
CE.uniqueExtensions ([Extension] -> [Extension]) -> [Extension] -> [Extension]
forall a b. (a -> b) -> a -> b
$
      [Char] -> [Extension]
collectLanguageExtensionsFromSource ([Char] -> [Extension]) -> [Char] -> [Extension]
forall a b. (a -> b) -> a -> b
$ ByteString -> [Char]
UTF8.toString ByteString
x

-- | Does the strict bytestring have a trailing newline?
hasTrailingLine :: ByteString -> Bool
hasTrailingLine :: ByteString -> Bool
hasTrailingLine ByteString
xs = Bool -> Bool
not (ByteString -> Bool
S8.null ByteString
xs) Bool -> Bool -> Bool
&& ByteString -> Char
S8.last ByteString
xs Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\n'

-- | Print the module.
prettyPrint :: Config -> HsModule -> Builder
prettyPrint :: Config -> HsModule -> Builder
prettyPrint Config
config HsModule
m =
  Config -> Printer () -> Builder
runPrinterStyle Config
config (HsModule -> Printer ()
forall a. Pretty a => a -> Printer ()
pretty (HsModule -> Printer ()) -> HsModule -> Printer ()
forall a b. (a -> b) -> a -> b
$ HsModule -> HsModule
modifyASTForPrettyPrinting HsModule
m)

-- | Pretty print the given printable thing.
runPrinterStyle :: Config -> Printer () -> Builder
runPrinterStyle :: Config -> Printer () -> Builder
runPrinterStyle Config
config Printer ()
m =
  Builder -> (PrintState -> Builder) -> Maybe PrintState -> Builder
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
    ([Char] -> Builder
forall a. HasCallStack => [Char] -> a
error [Char]
"Printer failed with mzero call.")
    PrintState -> Builder
psOutput
    (Identity (Maybe PrintState) -> Maybe PrintState
forall a. Identity a -> a
runIdentity
       (MaybeT Identity PrintState -> Identity (Maybe PrintState)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT
          (StateT PrintState (MaybeT Identity) ()
-> PrintState -> MaybeT Identity PrintState
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT
             (Printer () -> StateT PrintState (MaybeT Identity) ()
forall a. Printer a -> StateT PrintState (MaybeT Identity) a
runPrinter Printer ()
m)
             (PrintState
                { psIndentLevel :: Int64
psIndentLevel = Int64
0
                , psOutput :: Builder
psOutput = Builder
forall a. Monoid a => a
mempty
                , psNewline :: Bool
psNewline = Bool
False
                , psColumn :: Int64
psColumn = Int64
0
                , psLine :: Int64
psLine = Int64
1
                , psConfig :: Config
psConfig = Config
config
                , psFitOnOneLine :: Bool
psFitOnOneLine = Bool
False
                , psEolComment :: Bool
psEolComment = Bool
False
                }))))

s8_stripPrefix :: ByteString -> ByteString -> Maybe ByteString
s8_stripPrefix :: ByteString -> ByteString -> Maybe ByteString
s8_stripPrefix bs1 :: ByteString
bs1@(S.PS ForeignPtr Word8
_ Int
_ Int
l1) ByteString
bs2
  | ByteString
bs1 ByteString -> ByteString -> Bool
`S.isPrefixOf` ByteString
bs2 = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (Int -> ByteString -> ByteString
S.unsafeDrop Int
l1 ByteString
bs2)
  | Bool
otherwise = Maybe ByteString
forall a. Maybe a
Nothing