{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE TemplateHaskellQuotes #-}

module Language.Nanopass.LangDef
  ( Define
  , runDefine
  , defineLang
  , reifyLang
  , runModify
  ) where

import Nanopass.Internal.Representation

import Control.Monad (forM,forM_,foldM,when)
import Nanopass.Internal.Extend (extendLang)
import Control.Monad.State (StateT,gets,modify,evalStateT)
import Data.Bifunctor (second)
import Data.Functor ((<&>))
import Data.List (nub,(\\),stripPrefix)
import Data.List.NonEmpty (NonEmpty)
import Data.Map (Map)
import Data.Maybe (fromMaybe)
import Language.Haskell.TH (Q, Dec)

import qualified Control.Monad.Trans as M
import qualified Data.Map as Map
import qualified Data.Text.Lazy as LT
import qualified Language.Haskell.TH as TH
import qualified Language.Haskell.TH.Syntax as TH
import qualified Text.Pretty.Simple as PP

---------------------------------
------ Language Definition ------
---------------------------------

type Define a = StateT DefState Q a

data DefState = DefState
  { langTyvars :: [TH.Name]
  , nontermNames :: Map UpName TH.Name
  }

runDefine :: Define a -> Q a
runDefine = flip evalStateT st0
  where
  st0 = DefState
    { langTyvars = errorWithoutStackTrace "internal nanopass error: uninitialized langTyVars"
    , nontermNames = Map.empty
    }

defineLang :: Language 'Valid UpName -> Define [Dec]
defineLang l = do
  -- initialize language type variables
  let duplicateParams = l.langInfo.langParams \\ nub l.langInfo.langParams
  if not (null duplicateParams)
    then fail $ concat
      [ "in a nanopass language definition: "
      , "duplicate language parameter names "
      , show (nub duplicateParams)
      ]
    else modify $ \st -> st{ langTyvars = (.th) <$> l.langInfo.langParams }
  -- initialize nontermNames
  forM_ l.langInfo.nonterms $ \nonterm -> do
    knownNames <- gets nontermNames
    case Map.lookup nonterm.nontermName.name knownNames of
      Nothing -> modify $ \st ->
        st{nontermNames = Map.insert nonterm.nontermName.name nonterm.nontermName.th knownNames}
      Just _ -> fail $ concat [ "in a nanopass language definition: "
                              , "duplicate non-terminal (terminal/nonterminal) name "
                              , fromUpName nonterm.nontermName.name
                              ]
  -- define a type with one nullary ctor for every grammatical type
  langInfo <- defineLangHeader l
  -- define every nonterminal type
  params <- gets langTyvars <&> \tvs -> TH.plainTV <$> tvs
  nontermTypeDecs <- forM (Map.elems l.langInfo.nonterms) $ \nonterm -> do
    M.lift $ TH.addModFinalizer $ TH.putDoc (TH.DeclDoc nonterm.nontermName.th) $
      "This type is a non-terminal of the t'" ++ fromUpName l.langName.name ++ "' language."
    prodCtors <- defineProduction `mapM` Map.elems nonterm.productions
    pure $ TH.DataD [] nonterm.nontermName.th params Nothing
            prodCtors
            []
  pure $ langInfo : nontermTypeDecs

defineLangHeader :: Language 'Valid UpName -> Define Dec
defineLangHeader l = do
  nontermNames <- gets $ Map.toAscList . nontermNames
  ctors <- forM nontermNames $ \(nontermName, _) -> do
    let ctorName = TH.mkName $ fromUpName l.langName.name ++ "_" ++ fromUpName nontermName
    M.lift $ TH.addModFinalizer $ TH.putDoc (TH.DeclDoc ctorName) $
      "Serves as a reference to the non-terminal of t'" ++ fromUpName nontermName ++ "'s."
    pure $ TH.NormalC ctorName []
  let thName = l.langName.th
  M.lift $ TH.addModFinalizer $ TH.putDoc (TH.DeclDoc thName) $ concat
    [ unlines
      [ "This type is generated by nanopass."
      , "It serves as a reference to the types of syntactic categories in the language."
      , "Nanopass itself uses types like these to read back in a full language that was defined in a separate splice/quasiquote."
      ]
    , case (l.langInfo.baseDefdLang, l.langInfo.originalProgram) of
      (Just baseLang, Just origProg) -> unlines
        [ ""
        , "This language was generated based on the langauge t'" ++ show baseLang.langName.th ++ "'"
        , "using the following 'Language.Nanopass.deflang' program:"
        , ""
        , unlines . fmap ("> " ++) . lines $ origProg
        ]
      (Just baseLang, Nothing) -> unlines
        [ ""
        , "This language was generated based on the langauge t'" ++ show baseLang.langName.th ++ "'."
        ]
      (Nothing, Just origProg) -> unlines
        [ ""
        , "This language was generated from the following 'Language.Nanopass.deflang' program:"
        , ""
        , unlines . fmap ("> " ++) . lines $ origProg
        ]
      (Nothing, Nothing) -> ""
    ]
  -- I'm not sure I need these singe this type is just a glorified set of pointers, but here they are for reference
  -- dShow = TH.DerivClause Nothing [TH.ConT ''Show]
  -- dRead = TH.DerivClause Nothing [TH.ConT ''Read]
  pure $ TH.DataD [] thName [] Nothing ctors []

defineProduction :: Production 'Valid -> Define TH.Con
defineProduction production = do
  fields <- defineSubterm `mapM` production.subterms
  pure $ TH.NormalC production.prodName.th fields

defineSubterm :: TypeDesc 'Valid -> Define TH.BangType
defineSubterm typeDesc = do
  ty <- subtermType typeDesc
  pure (noBang, ty)

subtermType :: TypeDesc 'Valid -> Define TH.Type
subtermType (RecursiveType nontermName) =
  gets (Map.lookup nontermName . nontermNames) >>= \case
    Just thName -> do
      let grammarCtor = TH.ConT thName
      params <- gets $ fmap TH.VarT . langTyvars
      pure $ foldl TH.AppT grammarCtor params
      -- pure $ TH.AppT grammarCtor params
    Nothing -> fail $ concat ["in a nanopass language definition: unknown metavariable ", fromUpName nontermName]
subtermType (VarType vName) =
  gets ((vName.th `elem`) . langTyvars) >>= \case
    True -> do
      pure $ TH.VarT vName.th
    False -> fail $ concat ["in a nanopass language definition: unknown langauge parameter ", show vName]
subtermType (CtorType cName argDescs) = do
  args <- subtermType `mapM` argDescs
  pure $ foldl TH.AppT (TH.ConT cName.th) args
subtermType (ListType argDesc) = do
  arg <- subtermType argDesc
  pure $ TH.AppT TH.ListT arg
subtermType (NonEmptyType argDesc) = do
  neType <- M.lift [t|NonEmpty|]
  arg <- subtermType argDesc
  pure $ TH.AppT neType arg
subtermType (MaybeType argDesc) = do
  maybeType <- M.lift [t|Maybe|]
  arg <- subtermType argDesc
  pure $ TH.AppT maybeType arg
subtermType UnitType = pure $ TH.TupleT 0
subtermType (TupleType t1 t2 ts) = do
  let tupLen = 2 + length ts
      thTup = TH.TupleT tupLen
  tys <- subtermType `mapM` (t1:t2:ts)
  pure $ foldl TH.AppT thTup tys

----------------------------------
------ Language Reification ------
----------------------------------

-- given a string, we need to find the language info with that name in scope,
-- then decode each of the info's constructors into the names of grammar types,
-- then decode each grammar type
reifyLang :: UpDotName -> Q (Language 'Valid UpDotName)
reifyLang lName = do
  (langNameTH, nontermPtrs) <- findLangInfo
  -- determine the language's grammar types
  thNonterms <- findRecursiveType `mapM` nontermPtrs
  let sNames = thNonterms <&> \(sName, _, _, _) -> sName
  nontermTypeList <- forM thNonterms $ \(nontermName, nontermNameTH, paramNames, thCtors) -> do
    ctorList <- decodeCtor sNames paramNames `mapM` thCtors
    let prodNames = (.prodName) <$> ctorList
        duplicatePNames = prodNames \\ nub prodNames
    case duplicatePNames of
      [] -> pure Nonterm
        { nontermName = ValidName nontermName nontermNameTH
        , productions = Map.fromList (ctorList <&> \ctor -> (ctor.prodName.name, ctor))
        }
      _ -> fail $ "corrupt language has duplicate production names: " ++ show (nub duplicatePNames)
  -- disallowing duplicates here allows `decodeType.recurse` to produce `RecursiveType`s easily
  let nontermTypes = nontermTypeList <&> \t -> (t.nontermName.name, t)
      nontermNames = fst <$> nontermTypes
      duplicateSNames = nontermNames \\ nub nontermNames
  when (not $ null duplicateSNames) $ fail $
    "corrupt language has duplicate non-terminal names: " ++ show (nub duplicateSNames)
  -- determine the language's type parameters
  langParams <- do
    let f Nothing (_, _, tvs, _) = pure (Just $ fixup <$> tvs)
        f (Just tvs) (_, _, tvs', _)
          | tvs == (fixup <$> tvs') = pure (Just tvs)
          | otherwise = fail $ concat
            [ "corrupt language has differing paramaters between syntactic categories. expected:\n"
            , "  " ++ show tvs ++ "\n"
            , "got:\n"
            , "  " ++ show (fixup <$> tvs')
            ]
    rawTvs <- fromMaybe [] <$> foldM f Nothing thNonterms
    forM rawTvs $ \rawTv -> case toLowName rawTv of
      Just tv -> pure $ ValidName tv (TH.mkName $ fromLowName tv)
      Nothing -> fail $ concat
        [ "corrupt language has non-lowercase type parameter: ", show rawTv ]
  -- and we're done
  pure $ Language
    { langName = ValidName lName langNameTH
    , langInfo = LanguageInfo
      { langParams
      , nonterms = Map.fromList nontermTypes
      , originalProgram = Nothing
      , baseDefdLang = Nothing
      }
    }
  where
  -- this is here because TH will add a bunch of garbage on the end of a type variable to ensure it doesn't capture,
  -- but in this case I _want_ it to capture, so I can check name equality across different types
  fixup :: TH.Name -> String
  fixup = reverse . loop . reverse . show
    where
    loop (c:rest)
      | c == '_' = rest
      | '0' <= c && c <= '9' = loop rest
    loop other = other
  decodeCtor :: [UpName] -> [TH.Name] -> TH.Con -> Q (Production 'Valid)
  decodeCtor sNames paramNames (TH.NormalC prodNameTH thSubterms) = do
    prodName <- case toUpName (TH.nameBase prodNameTH) of
      Just x -> pure $ ValidName x prodNameTH
      Nothing -> fail $ "corrupt language has illegal production name: " ++ show prodNameTH
    subterms <- forM thSubterms $ \(_, thSubtermType) ->
      decodeType sNames paramNames thSubtermType
    pure $ Production{prodName,subterms}
  decodeCtor _ _ otherCtor = fail $ "corrupt production type:\n" ++ show otherCtor
  decodeType :: [UpName] -> [TH.Name] -> TH.Type -> Q (TypeDesc 'Valid)
  decodeType sNames paramNames type0 = recurse type0
    where
    tvs = TH.VarT <$> paramNames
    recurse tuple | Just (t1:t2:ts) <- fromTuple tuple = do
      t1Desc <- recurse t1
      t2Desc <- recurse t2
      tDescs <- recurse `mapM` ts
      pure $ TupleType t1Desc t2Desc tDescs
    recurse (TH.AppT (TH.ConT special) a)
      | special == ''Maybe = MaybeType <$> recurse a
      | special == ''NonEmpty = NonEmptyType <$> recurse a
    recurse (TH.AppT TH.ListT a) = ListType <$> recurse a
    recurse appType
      | (TH.ConT thName, args) <- fromApps appType
      , Just sName <- toUpName (TH.nameBase thName)
      , sName `elem` sNames && args == tvs
        = pure $ RecursiveType sName
      | (TH.ConT thName, args) <- fromApps appType
      , Just cName <- toUpDotName (TH.nameBase thName) = do
        decodedArgs <- recurse `mapM` args
        pure $ CtorType (ValidName cName thName) decodedArgs
    recurse (TH.VarT thName)
      | Just tvName <- toLowName (TH.nameBase thName)
        = pure $ VarType (ValidName tvName thName)
    recurse otherType = fail $ "corrupt subterm type:\n" ++ show otherType ++ "\n in type:\n" ++ show type0
    fromTuple :: TH.Type -> Maybe [TH.Type]
    fromTuple t0 = case loop t0 of
      Just (0, ts) -> Just (reverse ts)
      _ -> Nothing
      where
      loop (TH.TupleT n) = Just (n, [])
      loop (TH.AppT f t)
        | Just (n, ts) <- loop f = Just (n - 1, t:ts)
      loop _ = Nothing
    fromApps :: TH.Type -> (TH.Type, [TH.Type])
    fromApps = second reverse . loop
      where
      loop (TH.AppT inner lastArg) = second (lastArg:) (loop inner)
      loop t = (t, [])
  findLangInfo :: Q (TH.Name, [TH.Con]) -- name and constructors of the info type
  findLangInfo = TH.lookupTypeName (fromUpDotName lName) >>= \case
    Nothing -> fail $ "in a nanopass language extension: could not find base language " ++ fromUpDotName lName
    Just langNameTH -> TH.reify langNameTH >>= \case
      TH.TyConI (TH.DataD [] qualThLangName [] Nothing nontermNames _) -> pure (qualThLangName, nontermNames)
      otherInfo -> fail $ concat
        [ "in a nanopass language extension: base name " ++ show langNameTH ++ " does not identify a language: "
        , "  expecting language name to identify data definition, but got this type:\n"
        , "  " ++ show otherInfo
        ]
  findRecursiveType :: TH.Con -> Q (UpName, TH.Name, [TH.Name], [TH.Con])
  findRecursiveType (TH.NormalC thTypePtr []) = do
    let enumPrefix = (fromUpName . upDotBase) lName ++ "_"
    typePtrBase <- case stripPrefix enumPrefix (TH.nameBase thTypePtr) of
      Just base | Just it <- toUpName base -> pure it
        | otherwise -> fail $ concat
          [ "in a nanopass language extension: base name " ++ (fromUpName . upDotBase) lName ++ " is illegal: "
          , "  it must be an UpperCaseName, but got: " ++ base
          ]
      Nothing -> fail $ concat
        [ "in a nanopass language extension: base name " ++ (fromUpName . upDotBase) lName ++ " does not identify a language:\n"
        , "  expecting language info enum ctors to start with " ++ enumPrefix ++ ", but got name: "
        , "  " ++ TH.nameBase thTypePtr
        ]
    let typePtr = TH.mkName $ fromUpDotName $ upDotChBase lName typePtrBase
    TH.reify typePtr >>= \case
      TH.TyConI (TH.DataD [] nontermNameTH thParams _ ctors _) -> do
        nontermName <- case toUpName $ TH.nameBase nontermNameTH of
          Just x -> pure x
          Nothing -> fail $ "corrupt language has illegal non-terminal name: " ++ show nontermNameTH
        let thParamNames = thParams <&> \case { TH.PlainTV it _ -> it ; TH.KindedTV it _ _ -> it }
        pure (nontermName, nontermNameTH, thParamNames, ctors)
      otherType -> fail $ "corrupt language non-terminal type:\n" ++ show otherType
  findRecursiveType otherCtor = fail $ concat
    [ "in a nanopass language extension: base name " ++ (fromUpName . upDotBase) lName ++ " does not identify a language: "
    , "  expecting language name to identify an enum, but got this constructor:\n"
    , "  " ++ show otherCtor
    ]

--------------------------------
------ Language Extension ------
--------------------------------

runModify :: LangMod -> Q [Dec]
runModify lMod = do
  oldLang <- reifyLang lMod.baseLang
  lang' <- case extendLang oldLang lMod of
    Right ok -> pure ok
    Left err -> fail $ (LT.unpack . PP.pShow) err -- TODO
  runDefine $ defineLang lang'

------------------------
------ TH Helpers ------
------------------------

noBang :: TH.Bang
noBang = TH.Bang TH.NoSourceUnpackedness TH.NoSourceStrictness
