-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Retrie.CPP
  ( CPP(..)
  , addImportsCPP
  , parseCPPFile
  , parseCPP
  , printCPP
    -- ** Internal interface exported for tests
  , cppFork
  ) where

import Data.Char (isSpace)
import Data.Function (on)
import Data.Functor.Identity
import Data.List (nubBy, sortOn)
import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.IO as Text
import Debug.Trace
import Retrie.ExactPrint
import Retrie.GHC
import Retrie.Replace

-- Note [CPP]
-- We can't just run the pre-processor on files and then rewrite them, because
-- the rewrites will apply to a module that never exists as code! Exactprint
-- has no support for roundtripping CPP, because the GHC parser doesn't
-- actually parse it (it looks for the pragma and then delegates to the
-- pre-processor).
--
-- To solve this, we instead generate all possible versions of the module
-- (exponential in the number of #if directives :-P). We then apply rewrites
-- to all versions, and collect all the 'Replacement's that they generate.
-- We can then use these to splice results back into the original file.
--
-- Suprisingly, this works. It depends on a few observations:
--
-- * We don't need to actually evaluate any CPP directives. This is because
--   we want all versions of the file.
--
-- * Since we don't need to evaluate, we can simply replace all CPP directives
--   with blank lines and the locations of all AST elements in each version of
--   the module will be exactly the same as in the original module. This is the
--   key to splicing properly.
--
-- * Replacements can be spliced in directly with no smarts about binders, etc,
--   because retrie did the instantiation during matching.
--

-- The CPP Type ----------------------------------------------------------------

data CPP a
  = NoCPP a
  | CPP Text [AnnotatedImports] [a]

instance Functor CPP where
  fmap f (NoCPP x) = NoCPP (f x)
  fmap f (CPP orig is xs) = CPP orig is (map f xs)

instance Foldable CPP where
  foldMap f (NoCPP x) = f x
  foldMap f (CPP _ _ xs) = foldMap f xs

instance Traversable CPP where
  traverse f (NoCPP x) = NoCPP <$> f x
  traverse f (CPP orig is xs) = CPP orig is <$> traverse f xs

addImportsCPP
  :: [AnnotatedImports]
  -> CPP AnnotatedModule
  -> CPP AnnotatedModule
addImportsCPP is (NoCPP m) =
  NoCPP $ runIdentity $ transformA m $ insertImports is
addImportsCPP is (CPP orig is' ms) = CPP orig (is++is') ms

-- Parsing a CPP Module --------------------------------------------------------

parseCPPFile
  :: (FilePath -> String -> IO AnnotatedModule)
  -> FilePath
  -> IO (CPP AnnotatedModule)
parseCPPFile p fp =
  -- read file strictly
  Text.readFile fp >>= parseCPP (p fp)

parseCPP
  :: Monad m
  => (String -> m AnnotatedModule)
  -> Text -> m (CPP AnnotatedModule)
parseCPP p orig
  | any isCPP (Text.lines orig) =
    CPP orig [] <$> mapM (p . Text.unpack) (cppFork orig)
  | otherwise = NoCPP <$> p (Text.unpack orig)

-- Printing a CPP Module -------------------------------------------------------

printCPP :: [Replacement] -> CPP AnnotatedModule -> String
printCPP _ (NoCPP m) = printA m
printCPP repls (CPP orig is ms) = Text.unpack $ Text.unlines $
  case is of
    [] -> splice "" 1 1 sorted origLines
    _ ->
      splice
        (Text.unlines newHeader)
        (length revHeader + 1)
        1
        sorted
        (reverse revDecls)
  where
    sorted = sortOn fst
      [ (r, replReplacement)
      | Replacement{..} <- repls
      , RealSrcSpan r <- [replLocation]
      ]

    origLines = Text.lines orig
    mbName = unLoc <$> hsmodName (unLoc $ astA $ head ms)
    importLines = runIdentity $ fmap astA $ transformA (filterAndFlatten mbName is) $
      mapM $ fmap (Text.pack . dropWhile isSpace . printA) . pruneA

    p t = isImport t || isModule t || isPragma t
    (revDecls, revHeader) = break p (reverse origLines)
    newHeader = reverse revHeader ++ importLines

splice :: Text -> Int -> Int -> [(RealSrcSpan, String)] -> [Text] -> [Text]
splice _ _ _ _ [] = []
splice prefix _ _ [] (t:ts) = prefix <> t : ts
splice prefix l c rs@((r, repl):rs') ts@(t:ts')
  | srcSpanStartLine r > l =
      -- Next rewrite is not on this line. Output line.
      prefix <> t : splice "" (l+1) 1 rs ts'
  | srcSpanStartLine r < l || srcSpanStartCol r < c =
      -- Next rewrite starts before current position. This happens when
      -- the same rewrite is made in multiple versions of the CPP'd module.
      -- Drop the duplicate rewrite and keep going.
      splice prefix l c rs' ts
  | (old, ln:lns) <- splitAt (srcSpanEndLine r - l) ts =
      -- The next rewrite starts on this line.
      let
        start = srcSpanStartCol r
        end = srcSpanEndCol r

        prefix' = prefix <> Text.take (start - c) t <> Text.pack repl
        ln' = Text.drop (end - c) ln

        -- For an example of how this can happen, see the CPPConflict test.
        errMsg = unlines
          [ "Refusing to rewrite across CPP directives."
          , ""
          , "Location: " ++ locStr
          , ""
          , "Original:"
          , ""
          , Text.unpack orig
          , ""
          , "Replacement:"
          , ""
          , repl
          ]
        orig =
          Text.unlines $ (prefix <> t : drop 1 old) ++ [Text.take (end - c) ln]
        locStr = unpackFS (srcSpanFile r) ++ ":" ++ show l ++ ":" ++ show start
      in
        if any isCPP old
        then trace errMsg $ splice prefix l c rs' ts
        else splice prefix' (srcSpanEndLine r) end rs' (ln':lns)
  | otherwise = error "printCPP: impossible replacement past end of file"

-- Forking the module ----------------------------------------------------------

cppFork :: Text -> [Text]
cppFork = cppTreeToList . mkCPPTree

-- | Tree representing the module. Each #endif becomes a Node.
data CPPTree
  = Node [Text] CPPTree CPPTree
  | Leaf [Text]

-- | Stack type used to keep track of how many #ifs we are nested into.
-- Controls whether we emit lines into each version of the module.
data CPPBranch
  = CPPTrue -- print until an 'else'
  | CPPFalse -- print blanks until an 'else' or 'endif'
  | CPPOmit -- print blanks until an 'endif'

-- | Build CPPTree from lines of the module.
mkCPPTree :: Text -> CPPTree
mkCPPTree = go False [] [] . reverse . Text.lines
  -- We reverse the lines once up front, then process the module from bottom
  -- to top, branching at #endifs. If we were to process from top to bottom,
  -- we'd have to reverse each version later, rather than reversing the original
  -- once. This also makes it easy to spot import statements and stop branching
  -- since we don't care about differences in imports.
  where
    go :: Bool -> [CPPBranch] -> [Text] -> [Text] -> CPPTree
    go _ _ suffix [] = Leaf suffix
    go True [] suffix ls =
      Leaf (blankifyAndReverse suffix ls) -- See Note [Imports]
    go seenImport st suffix (l:ls) =
      case extractCPPCond l of
        Just If -> -- pops from stack
          case st of
            (_:st') -> emptyLine st'
            [] -> error "mkCPPTree: if with empty stack"
        Just ElIf -> -- stack same size
          case st of
            (CPPOmit:_) -> emptyLine st
            (CPPFalse:st') -> emptyLine (CPPOmit:st')
            (CPPTrue:st') -> -- See Note [ElIf]
              let
                omittedSuffix = replicate (length suffix) ""
              in
                Node
                  []
                  (emptyLine (CPPOmit:st'))
                  (go seenImport (CPPTrue:st') ("":omittedSuffix) ls)
            [] -> error "mkCPPTree: else with empty stack"
        Just Else -> -- stack same size
          case st of
            (CPPOmit:_) -> emptyLine st
            (CPPTrue:st') -> emptyLine (CPPFalse:st')
            (CPPFalse:st') -> emptyLine (CPPTrue:st')
            [] -> error "mkCPPTree: else with empty stack"
        Just EndIf -> -- push to stack
          case st of
            (CPPOmit:_) -> emptyLine (CPPOmit:st)
            (CPPFalse:_) -> emptyLine (CPPOmit:st)
            _ ->
              Node
                suffix
                (go seenImport (CPPTrue:st) [""] ls)
                (go seenImport (CPPFalse:st) [""] ls)
        Nothing -> -- stack same size
          case st of
            (CPPOmit:_) -> go seenImport' st ("":suffix) ls
            (CPPFalse:_) -> go seenImport' st ("":suffix) ls
            _ -> go seenImport' st (blankCPP l:suffix) ls
      where
        emptyLine st' = go seenImport st' ("":suffix) ls
        seenImport' = seenImport || isImport l

    blankifyAndReverse :: [Text] -> [Text] -> [Text]
    blankifyAndReverse suffix [] = suffix
    blankifyAndReverse suffix (l:ls) = blankifyAndReverse (blankCPP l:suffix) ls

-- Note [Imports]
-- If we have seen an import statement, and have an empty stack, that means all
-- conditionals above this point only control imports/exports, etc. Retrie
-- doesn't match in those places anyway, and the imports don't matter because
-- we only parse, no renaming. As a micro-optimization, we can stop branching.
-- This saves forking the module in the common case that CPP is used to choose
-- imports. We have to wait for stack to be empty because we might have seen an
-- import in one branch, but there is a decl in the other branch.

-- Note [ElIf]
-- The way we handle #elif is pretty subtle. Some observations:
-- If we're on the CPPOmit branch, keep omitting up to the next #if, like usual.
-- If we're on the CPPFalse branch, we didn't show the #elif, but either we
-- showed the #else, or this whole #if might not output anything. So either way,
-- we need to omit up to the next #if.
-- If we're on the CPPTrue branch, we definitely showed the #elif, so we need to
-- fork with a Node. One side of the branch omits up to the next #if. The other
-- side is as if we have omitted everything from the last #endif, and we
-- continue showing up from here. This will show whatever is above the #elif.
-- It is crucial we do this branching on the CPPTrue branch, so any #elif
-- above this point is also handled correctly.

-- | Expand CPPTree into 2^h-1 versions of the module.
cppTreeToList :: CPPTree -> [Text]
cppTreeToList t = go [] t []
  where
    go rest (Leaf suffix) = (Text.unlines (suffix ++ rest) :)
    go rest (Node suffix l r) =
      let rest' = suffix ++ rest -- right-nested
      in go rest' l . go rest' r

-- Spotting CPP directives -----------------------------------------------------

data CPPCond = If | ElIf | Else | EndIf

extractCPPCond :: Text -> Maybe CPPCond
extractCPPCond t
  | Just ('#',t') <- Text.uncons t =
    case Text.words t' of
      ("if":_) -> Just If
      ("else":_) -> Just Else
      ("elif":_) -> Just ElIf
      ("endif":_) -> Just EndIf
      _ -> Nothing
  | otherwise = Nothing

blankCPP :: Text -> Text
blankCPP t
  | isCPP t = ""
  | otherwise = t

isCPP :: Text -> Bool
isCPP = Text.isPrefixOf "#"

isImport :: Text -> Bool
isImport = Text.isPrefixOf "import"

isModule :: Text -> Bool
isModule = Text.isPrefixOf "module"

isPragma :: Text -> Bool
isPragma = Text.isPrefixOf "{-#"

-------------------------------------------------------------------------------
-- This would make more sense in Retrie.Expr, but that creates an import cycle.
-- Ironic, I know.

insertImports
  :: Monad m
  => [AnnotatedImports]   -- ^ imports and their annotations
  -> Located (HsModule GhcPs)    -- ^ target module
  -> TransformT m (Located (HsModule GhcPs))
insertImports is (L l m) = do
  imps <- graftA $ filterAndFlatten (unLoc <$> hsmodName m) is
  let
    deduped = nubBy (eqImportDecl `on` unLoc) $ hsmodImports m ++ imps
  return $ L l m { hsmodImports = deduped }

filterAndFlatten :: Maybe ModuleName -> [AnnotatedImports] -> AnnotatedImports
filterAndFlatten mbName is =
  runIdentity $ transformA (mconcat is) $ return . externalImps mbName
  where
    externalImps :: Maybe ModuleName -> [LImportDecl GhcPs] -> [LImportDecl GhcPs]
    externalImps (Just mn) = filter ((/= mn) . unLoc . ideclName . unLoc)
    externalImps _ = id

eqImportDecl :: ImportDecl GhcPs -> ImportDecl GhcPs -> Bool
eqImportDecl x y =
  ((==) `on` unLoc . ideclName) x y
  && ((==) `on` ideclQualified) x y
  && ((==) `on` ideclAs) x y
  && ((==) `on` ideclHiding) x y
  && ((==) `on` ideclPkgQual) x y
  && ((==) `on` ideclSource) x y
  && ((==) `on` ideclSafe) x y
  -- intentionally leave out ideclImplicit and ideclSourceSrc
  -- former doesn't matter for this check, latter is prone to whitespace issues