{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}

module Convex.Schema.Parser
  ( parseSchema,
    ParsedFile (..),
    Schema (..),
    Table (..),
    Index (..),
    Field (..),
    ConvexType (..),
    ParserState (..),
    initialState,
    getLiteralString,
    isLiteral,
  )
where

import Control.Monad
import Data.Map (Map)
import qualified Data.Map as Map
import Text.Parsec
import qualified Text.Parsec.Language as Token
import qualified Text.Parsec.Token as Token

type SchemaParser a = ParsecT String ParserState IO a

data ParserState = ParserState
  { psConstants :: Map String ConvexType
  }
  deriving (Show, Eq)

initialState :: ParserState
initialState = ParserState {psConstants = Map.empty}

data ParsedFile = ParsedFile
  { parsedConstants :: Map String ConvexType,
    parsedSchema :: Schema
  }
  deriving (Show, Eq)

newtype Schema = Schema {getTables :: [Table]}
  deriving (Show, Eq)

data Index = Index
  { indexName :: String,
    indexFields :: [String]
  }
  deriving (Show, Eq)

data Table = Table
  { tableName :: String,
    tableFields :: [Field],
    tableIndexes :: [Index]
  }
  deriving (Show, Eq)

data Field = Field
  { fieldName :: String,
    fieldType :: ConvexType
  }
  deriving (Show, Eq)

data ConvexType
  = VString
  | VNumber
  | VInt64
  | VFloat64
  | VBoolean
  | VBytes
  | VNull
  | VAny
  | VId String
  | VArray ConvexType
  | VObject [(String, ConvexType)]
  | VOptional ConvexType
  | VUnion [ConvexType]
  | VLiteral String
  | VReference String
  | VVoid
  deriving (Show, Eq, Ord)

getLiteralString :: ConvexType -> String
getLiteralString (VLiteral str) = str
getLiteralString _ = error "Expected a literal type"

isLiteral :: ConvexType -> Bool
isLiteral (VLiteral _) = True
isLiteral _ = False

langDef :: Token.GenLanguageDef String ParserState IO
langDef =
  Token.LanguageDef
    { Token.commentStart = "/*",
      Token.nestedComments = True,
      Token.commentEnd = "*/",
      Token.commentLine = "//",
      Token.opStart = oneOf ":!#$%&*+./<=>?@\\^|-~",
      Token.opLetter = oneOf ":!#$%&*+./<=>?@\\^|-~",
      Token.reservedOpNames = [],
      Token.identStart = letter <|> char '_',
      Token.identLetter = alphaNum <|> char '_',
      Token.reservedNames =
        [ "defineSchema",
          "defineSchema(",
          "defineTable",
          "v",
          "export",
          "default",
          "import",
          "from",
          "const",
          "type",
          "keyof",
          "typeof"
        ],
      Token.caseSensitive = True
    }

lexer :: Token.GenTokenParser String ParserState IO
lexer = Token.makeTokenParser langDef

whiteSpace :: SchemaParser ()
whiteSpace = Token.whiteSpace lexer

lexeme :: SchemaParser a -> SchemaParser a
lexeme = Token.lexeme lexer

identifier :: SchemaParser String
identifier = Token.identifier lexer

stringLiteral :: SchemaParser String
stringLiteral = Token.stringLiteral lexer

reserved :: String -> SchemaParser ()
reserved = Token.reserved lexer

parens :: SchemaParser a -> SchemaParser a
parens = Token.parens lexer

braces :: SchemaParser a -> SchemaParser a
braces = Token.braces lexer

brackets :: SchemaParser a -> SchemaParser a
brackets = Token.brackets lexer

topLevelStatementEnd :: SchemaParser ()
topLevelStatementEnd = void (optional (lexeme (char ';'))) *> whiteSpace

itemEnd :: SchemaParser ()
itemEnd = do
  optional (lexeme (char ','))
  optional (lexeme (char ';'))
  whiteSpace

fieldToTuple :: Field -> (String, ConvexType)
fieldToTuple (Field name typ) = (name, typ)

fieldParser :: SchemaParser Field
fieldParser = lexeme $ do
  key <- identifier <|> stringLiteral
  void $ lexeme $ char ':'
  value <- convexTypeParser
  return $ Field key value

indexParser :: SchemaParser Index
indexParser = lexeme $ do
  void $ char '.'
  reserved "index"
  (iName, iFields) <- parens $ do
    name <- stringLiteral
    void $ lexeme $ char ','
    fields <- brackets $ sepEndBy stringLiteral (lexeme $ char ',')
    return (name, fields)
  return $ Index iName iFields

tableParser :: SchemaParser Table
tableParser = lexeme $ do
  tName <- identifier <|> stringLiteral
  void $ lexeme $ char ':'
  reserved "defineTable"
  -- First, parse the table definition itself inside the parentheses.
  fields <- parens $ do
    tableDef <- (try (VObject . map fieldToTuple <$> braces (sepEndBy fieldParser (lexeme $ char ',')))) <|> (VReference <$> identifier)
    case tableDef of
      VReference refName -> do
        st <- getState
        case Map.lookup refName (psConstants st) of
          Just (VObject fs) -> return $ map (\(n, t) -> Field n t) fs
          _ -> fail $ "Table '" ++ tName ++ "' references an unknown or non-object constant: " ++ refName
      VObject fs -> return $ map (\(n, t) -> Field n t) fs
      _ -> fail "Invalid table definition: expected an object or a reference."

  -- After parsing defineTable(...), now look for zero or more chained .index() calls.
  indexes <- many indexParser

  itemEnd
  return $ Table tName fields indexes

structParser :: SchemaParser ConvexType
structParser = do
  res <- VObject . map fieldToTuple <$> braces (sepEndBy fieldParser (lexeme $ char ','))
  itemEnd
  return res

convexTypeParser :: SchemaParser ConvexType
convexTypeParser =
  choice . map try $
    [ vParser,
      structParser,
      referenceParser
    ]
  where
    vParser = do
      void $ lexeme $ reserved "v"
      void $ lexeme $ char '.'
      typeName <- identifier
      case typeName of
        "string" -> VString <$ parens (return ())
        "number" -> VNumber <$ parens (return ())
        "boolean" -> VBoolean <$ parens (return ())
        "bytes" -> VBytes <$ parens (return ())
        "int64" -> VInt64 <$ parens (return ())
        "float64" -> VFloat64 <$ parens (return ())
        "null" -> VNull <$ parens (return ())
        "any" -> VAny <$ parens (return ())
        "id" -> VId <$> parens stringLiteral
        "array" -> VArray <$> parens convexTypeParser
        "object" -> parens structParser
        "optional" -> VOptional <$> parens convexTypeParser
        "union" -> VUnion <$> parens (sepBy convexTypeParser (lexeme $ char ','))
        "literal" -> VLiteral <$> parens stringLiteral
        _ -> fail $ "Unknown v-dot type: " ++ typeName
    referenceParser = VReference <$> identifier

topLevelConstParser :: SchemaParser ()
topLevelConstParser = lexeme $ do
  void $ optional (reserved "export")
  reserved "const"
  constName <- identifier
  void $ lexeme $ char '='
  constType <-
    try (reserved "defineTable" *> (VObject . map fieldToTuple <$> parens (braces (many fieldParser))))
      <|> convexTypeParser
  topLevelStatementEnd
  modifyState (\s -> s {psConstants = Map.insert constName constType (psConstants s)})

topLevelTypeParser :: SchemaParser ()
topLevelTypeParser = lexeme $ do
  void $ optional (reserved "export")
  reserved "type"
  typeName <- identifier
  void $ lexeme $ char '='
  optional (reserved "typeof")
  refType <- convexTypeParser
  topLevelStatementEnd
  modifyState (\s -> s {psConstants = Map.insert typeName refType (psConstants s)})

parseSchema :: String -> IO (Either ParseError ParsedFile)
parseSchema input = do
  -- First Pass: Collect all top-level definitions (consts and types).
  let definitionsPassParser = many (try topLevelConstParser <|> try topLevelTypeParser <|> (anyChar >> return ()))
  constsState <- execParser (definitionsPassParser *> getState) initialState "(schema.ts)" input

  -- Second Pass: Parse the schema, using the constants we just found.
  let schemaPassParser = do
        _ <- manyTill anyChar (lookAhead (try (reserved "defineSchema(")))
        reserved "defineSchema"
        tables <- parens $ braces $ many tableParser
        return $ Schema tables

  schemaResult <- execParser schemaPassParser constsState "(schema.ts)" input

  return $ Right (ParsedFile (psConstants constsState) schemaResult)

-- | A helper to run a parser and return the result, simplifying error handling.
execParser :: SchemaParser a -> ParserState -> SourceName -> String -> IO a
execParser p st name input = do
  result <- runParserT p st name input
  case result of
    Left err -> fail (show err)
    Right res -> return res
