{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

module Axel.Macros where

import Axel.AST
  ( Identifier
  , MacroDefinition
  , Statement(SDataDeclaration, SFunctionDefinition, SMacroDefinition,
          SModuleDeclaration, SPragma, SQualifiedImport, SRestrictedImport,
          STopLevel, STypeSignature, STypeSynonym, STypeclassInstance,
          SUnrestrictedImport)
  , ToHaskell(toHaskell)
  , functionDefinition
  , name
  , statements
  )
import Axel.Denormalize (denormalizeStatement)
import Axel.Error (Error(MacroError))
import Axel.Haskell.GHC (ghcInterpret)
import Axel.Haskell.Prettify (prettifyHaskell)
import Axel.Monad.FileSystem (MonadFileSystem)
import qualified Axel.Monad.FileSystem as FS
  ( MonadFileSystem(createDirectoryIfMissing, writeFile)
  , withCurrentDirectory
  , withTemporaryDirectory
  )
import Axel.Monad.Process (MonadProcess)
import Axel.Monad.Resource (MonadResource, readResource)
import qualified Axel.Monad.Resource as Res
  ( astDefinition
  , macroDefinitionAndEnvironmentFooter
  , macroDefinitionAndEnvironmentHeader
  , macroScaffold
  )
import Axel.Normalize (normalizeStatement)
import qualified Axel.Parse as Parse
  ( Expression(LiteralChar, LiteralInt, LiteralString, SExpression,
           Symbol)
  , parseMultiple
  , programToTopLevelExpressions
  , topLevelExpressionsToProgram
  )
import Axel.Utils.Display (Delimiter(Newlines), delimit, isOperator)
import Axel.Utils.Function (uncurry3)
import Axel.Utils.Recursion
  ( Recursive(bottomUpFmap, bottomUpTraverse)
  , exhaustM
  )
import Axel.Utils.String (replace)

import Control.Lens.Cons (snoc)
import Control.Lens.Operators ((%~), (^.))
import Control.Lens.Tuple (_1, _2)
import Control.Monad (foldM)
import Control.Monad.Except (MonadError, catchError, runExceptT, throwError)

import Data.Function ((&))
import Data.List.NonEmpty (NonEmpty, nonEmpty)
import qualified Data.List.NonEmpty as NE (head, toList)
import Data.Semigroup ((<>))

import System.FilePath ((</>))

generateMacroProgram ::
     (MonadError Error m, MonadFileSystem m, MonadResource m)
  => NonEmpty MacroDefinition
  -> [Statement]
  -> [Parse.Expression]
  -> m (String, String, String)
generateMacroProgram macroDefs env applicationArgs = do
  astDef <- readResource Res.astDefinition
  scaffold <- getScaffold
  macroDefAndEnv <- (<>) <$> getMacroDefAndEnvHeader <*> getMacroDefAndEnvFooter
  pure (astDef, scaffold, macroDefAndEnv)
  where
    insertDefName =
      let defNamePlaceholder = "%%%MACRO_NAME%%%"
       in replace defNamePlaceholder newMacroName
    oldMacroName = NE.head macroDefs ^. functionDefinition . name
    newMacroName =
      oldMacroName <>
      if isOperator oldMacroName
        then "%%%%%%%%%%"
        else "_AXEL_AUTOGENERATED_MACRO_DEFINITION"
    getMacroDefAndEnvHeader =
      insertDefName <$> readResource Res.macroDefinitionAndEnvironmentHeader
    getMacroDefAndEnvFooter = do
      hygenicMacroDefs <-
        traverse
          (replaceName oldMacroName newMacroName . SMacroDefinition)
          macroDefs
      let source =
            prettifyHaskell $ delimit Newlines $
            map toHaskell (env <> NE.toList hygenicMacroDefs)
      footer <-
        insertDefName <$> readResource Res.macroDefinitionAndEnvironmentFooter
      pure (unlines [source, footer])
    getScaffold :: (Monad m, MonadFileSystem m, MonadResource m) => m String
    getScaffold =
      let insertApplicationArgs =
            let applicationArgsPlaceholder = "%%%ARGUMENTS%%%"
             in replace applicationArgsPlaceholder (show applicationArgs)
       in prettifyHaskell . insertApplicationArgs . insertDefName <$>
          readResource Res.macroScaffold

expansionPass ::
     (MonadError Error m, MonadFileSystem m, MonadProcess m, MonadResource m)
  => Parse.Expression
  -> m Parse.Expression
expansionPass programExpr =
  Parse.topLevelExpressionsToProgram . map denormalizeStatement <$>
  expandMacros (Parse.programToTopLevelExpressions programExpr)

programToTopLevelExpressions :: Parse.Expression -> [Parse.Expression]
programToTopLevelExpressions (Parse.SExpression (Parse.Symbol "begin":stmts)) =
  stmts
programToTopLevelExpressions _ =
  error "programToTopLevelExpressions must be passed a top-level program!"

topLevelExpressionsToProgram :: [Parse.Expression] -> Parse.Expression
topLevelExpressionsToProgram stmts =
  Parse.SExpression (Parse.Symbol "begin" : stmts)

exhaustivelyExpandMacros ::
     (MonadError Error m, MonadFileSystem m, MonadProcess m, MonadResource m)
  => Parse.Expression
  -> m Parse.Expression
exhaustivelyExpandMacros = exhaustM expansionPass

isStatementNonconflicting :: Statement -> Bool
isStatementNonconflicting (SDataDeclaration _) = True
isStatementNonconflicting (SFunctionDefinition _) = True
isStatementNonconflicting (SPragma _) = True
isStatementNonconflicting (SMacroDefinition _) = True
isStatementNonconflicting (SModuleDeclaration _) = False
isStatementNonconflicting (SQualifiedImport _) = True
isStatementNonconflicting (SRestrictedImport _) = True
isStatementNonconflicting (STopLevel _) = False
isStatementNonconflicting (STypeclassInstance _) = True
isStatementNonconflicting (STypeSignature _) = True
isStatementNonconflicting (STypeSynonym _) = True
isStatementNonconflicting (SUnrestrictedImport _) = True

expandMacros ::
     (MonadError Error m, MonadFileSystem m, MonadProcess m, MonadResource m)
  => [Parse.Expression]
  -> m [Statement]
expandMacros topLevelExprs =
  fst <$>
  foldM
    (\acc@(stmts, macroDefs) expr -> do
       expandedExprs <- fullyExpandExpr stmts macroDefs expr
       foldM
         (\acc' expandedExpr -> do
            stmt <- normalizeStatement expandedExpr
            pure $ acc' &
              case stmt of
                SMacroDefinition macroDefinition ->
                  _2 %~ flip snoc macroDefinition
                _ -> _1 %~ flip snoc stmt)
         acc
         expandedExprs)
    ([], [])
    topLevelExprs
  where
    fullyExpandExpr ::
         ( MonadError Error m
         , MonadFileSystem m
         , MonadProcess m
         , MonadResource m
         )
      => [Statement]
      -> [MacroDefinition]
      -> Parse.Expression
      -> m [Parse.Expression]
    fullyExpandExpr stmts allMacroDefs expr = do
      let program = Parse.topLevelExpressionsToProgram [expr]
      expandedExpr <-
        exhaustM
          (bottomUpTraverse
             (\case
                Parse.SExpression xs ->
                  Parse.SExpression <$>
                  foldM
                    (\acc x ->
                       case x of
                         Parse.LiteralChar _ -> pure $ acc ++ [x]
                         Parse.LiteralInt _ -> pure $ acc ++ [x]
                         Parse.LiteralString _ -> pure $ acc ++ [x]
                         Parse.SExpression [] -> pure $ acc ++ [x]
                         Parse.SExpression (function:args) ->
                           case lookupMacroDefinitions function allMacroDefs of
                             Just macroDefs ->
                               (acc ++) <$>
                               expandMacroApplication
                                 macroDefs
                                 (filter isStatementNonconflicting stmts)
                                 args
                             Nothing -> pure $ acc ++ [x]
                         Parse.Symbol _ -> pure $ acc ++ [x])
                    []
                    xs
                x -> pure x))
          program
      pure $ Parse.programToTopLevelExpressions expandedExpr

expandMacroApplication ::
     (MonadError Error m, MonadFileSystem m, MonadProcess m, MonadResource m)
  => NonEmpty MacroDefinition
  -> [Statement]
  -> [Parse.Expression]
  -> m [Parse.Expression]
expandMacroApplication macroDefs auxEnv args = do
  macroProgram <- generateMacroProgram macroDefs auxEnv args
  newSource <- uncurry3 evalMacro macroProgram
  Parse.parseMultiple newSource

lookupMacroDefinitions ::
     Parse.Expression -> [MacroDefinition] -> Maybe (NonEmpty MacroDefinition)
lookupMacroDefinitions identifierExpr =
  nonEmpty . filter (`isMacroBeingCalled` identifierExpr)

isMacroBeingCalled :: MacroDefinition -> Parse.Expression -> Bool
isMacroBeingCalled macroDef identifierExpr =
  case identifierExpr of
    Parse.LiteralChar _ -> False
    Parse.LiteralInt _ -> False
    Parse.LiteralString _ -> False
    Parse.SExpression _ -> False
    Parse.Symbol identifier ->
      macroDef ^. functionDefinition . name == identifier

stripMacroDefinitions :: Statement -> Statement
stripMacroDefinitions =
  \case
    STopLevel topLevel ->
      STopLevel $
      (statements %~ filter (not . isMacroDefinitionStatement)) topLevel
    x -> x

isMacroDefinitionStatement :: Statement -> Bool
isMacroDefinitionStatement (SMacroDefinition _) = True
isMacroDefinitionStatement _ = False

replaceName ::
     (MonadError Error m)
  => Identifier
  -> Identifier
  -> Statement
  -> m Statement
replaceName oldName newName =
  normalize . bottomUpFmap replaceSymbol . denormalizeStatement
  where
    normalize expr =
      normalizeStatement expr `catchError` \_ ->
        throwError (MacroError $ "Invalid macro name: `" <> oldName <> "`!")
    replaceSymbol expr =
      case expr of
        Parse.Symbol identifier ->
          Parse.Symbol $
          if identifier == oldName
            then newName
            else identifier
        _ -> expr

evalMacro ::
     (MonadError Error m, MonadFileSystem m, MonadProcess m)
  => String
  -> String
  -> String
  -> m String
evalMacro astDefinition scaffold macroDefinitionAndEnvironment =
  FS.withTemporaryDirectory $ \directoryName ->
    FS.withCurrentDirectory directoryName $ do
      let astDirectoryPath = "Axel" </> "Parse"
      let macroDefinitionAndEnvironmentFileName =
            "MacroDefinitionAndEnvironment.hs"
      let scaffoldFileName = "Scaffold.hs"
      FS.createDirectoryIfMissing True astDirectoryPath
      FS.writeFile (astDirectoryPath </> "AST.hs") astDefinition
      FS.writeFile
        macroDefinitionAndEnvironmentFileName
        macroDefinitionAndEnvironment
      FS.writeFile scaffoldFileName scaffold
      runExceptT (ghcInterpret scaffoldFileName) >>= \case
        Left err ->
          throwError $
          MacroError
            ("Temporary directory: " <> directoryName <> "\n\n" <> "Error:\n" <>
             err)
        Right res -> pure res