{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Database.Beam.AutoMigrate.Postgres
  ( getSchema,
  )
where

import Control.Monad.State
import Data.Bits (shiftR, (.&.))
import Data.ByteString (ByteString)
import Data.Foldable (asum, foldlM)
import Data.Map (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe)
import Data.Set (Set)
import qualified Data.Set as S
import Data.String
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.Vector as V
import Database.Beam.AutoMigrate.Types
import Database.Beam.Backend.SQL hiding (tableName)
import qualified Database.Beam.Backend.SQL.AST as AST
import qualified Database.PostgreSQL.Simple as Pg
import Database.PostgreSQL.Simple.FromField (FromField (..), fromField, returnError)
import Database.PostgreSQL.Simple.FromRow (FromRow (..), field)
import qualified Database.PostgreSQL.Simple.TypeInfo.Static as Pg
import qualified Database.PostgreSQL.Simple.Types as Pg

--
-- Necessary types to make working with the underlying raw SQL a bit more pleasant
--

data SqlRawOtherConstraintType
  = SQL_raw_pk
  | SQL_raw_unique
  deriving (Show, Eq)

data SqlOtherConstraint = SqlOtherConstraint
  { sqlCon_name :: Text,
    sqlCon_constraint_type :: SqlRawOtherConstraintType,
    sqlCon_table :: TableName,
    sqlCon_fk_colums :: V.Vector ColumnName
  }
  deriving (Show, Eq)

instance Pg.FromRow SqlOtherConstraint where
  fromRow =
    SqlOtherConstraint <$> field
      <*> field
      <*> fmap TableName field
      <*> fmap (V.map ColumnName) field

data SqlForeignConstraint = SqlForeignConstraint
  { sqlFk_foreign_table :: TableName,
    sqlFk_primary_table :: TableName,
    -- | The columns in the /foreign/ table.
    sqlFk_fk_columns :: V.Vector ColumnName,
    -- | The columns in the /current/ table.
    sqlFk_pk_columns :: V.Vector ColumnName,
    sqlFk_name :: Text
  }
  deriving (Show, Eq)

instance Pg.FromRow SqlForeignConstraint where
  fromRow =
    SqlForeignConstraint <$> fmap TableName field
      <*> fmap TableName field
      <*> fmap (V.map ColumnName) field
      <*> fmap (V.map ColumnName) field
      <*> field

instance FromField TableName where
  fromField f dat = TableName <$> fromField f dat

instance FromField ColumnName where
  fromField f dat = ColumnName <$> fromField f dat

instance FromField SqlRawOtherConstraintType where
  fromField f dat = do
    t :: String <- fromField f dat
    case t of
      "p" -> pure SQL_raw_pk
      "u" -> pure SQL_raw_unique
      _ -> returnError Pg.ConversionFailed f t

--
-- Postgres queries to extract the schema out of the DB
--

-- | A SQL query to select all user's queries, skipping any beam-related tables (i.e. leftovers from
-- beam-migrate, for example).
userTablesQ :: Pg.Query
userTablesQ =
  fromString $
    unlines
      [ "SELECT cl.oid, relname FROM pg_catalog.pg_class \"cl\" join pg_catalog.pg_namespace \"ns\" ",
        "on (ns.oid = relnamespace) where nspname = any (current_schemas(false)) and relkind='r' ",
        "and relname NOT LIKE 'beam_%'"
      ]

-- | Get information about default values for /all/ tables.
defaultsQ :: Pg.Query
defaultsQ =
  fromString $
    unlines
      [ "SELECT col.table_name::text, col.column_name::text, col.column_default::text, col.data_type::text ",
        "FROM information_schema.columns col ",
        "WHERE col.column_default IS NOT NULL ",
        "AND col.table_schema NOT IN('information_schema', 'pg_catalog') ",
        "ORDER BY col.table_name"
      ]

-- | Get information about columns for this table. Due to the fact this is a query executed for /each/
-- table, is important this is as light as possible to keep the performance decent.
tableColumnsQ :: Pg.Query
tableColumnsQ =
  fromString $
    unlines
      [ "SELECT attname, atttypid, atttypmod, attnotnull, pg_catalog.format_type(atttypid, atttypmod) ",
        "FROM pg_catalog.pg_attribute att ",
        "WHERE att.attrelid=? AND att.attnum>0 AND att.attisdropped='f' "
      ]

-- | Get the enumeration data for all enum types in the database.
enumerationsQ :: Pg.Query
enumerationsQ =
  fromString $
    unlines
      [ "SELECT t.typname, t.oid, array_agg(e.enumlabel ORDER BY e.enumsortorder)",
        "FROM pg_enum e JOIN pg_type t ON t.oid = e.enumtypid",
        "GROUP BY t.typname, t.oid"
      ]

-- | Get the sequence data for all sequence types in the database.
sequencesQ :: Pg.Query
sequencesQ = fromString "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'"

-- | Return all foreign key constraints for /all/ 'Table's.
foreignKeysQ :: Pg.Query
foreignKeysQ =
  fromString $
    unlines
      [ "SELECT kcu.table_name::text as foreign_table,",
        "       rel_kcu.table_name::text as primary_table,",
        "       array_agg(kcu.column_name::text)::text[] as fk_columns,",
        "       array_agg(rel_kcu.column_name::text)::text[] as pk_columns,",
        "       kcu.constraint_name as cname",
        "FROM information_schema.table_constraints tco",
        "JOIN information_schema.key_column_usage kcu",
        "          on tco.constraint_schema = kcu.constraint_schema",
        "          and tco.constraint_name = kcu.constraint_name",
        "JOIN information_schema.referential_constraints rco",
        "          on tco.constraint_schema = rco.constraint_schema",
        "          and tco.constraint_name = rco.constraint_name",
        "JOIN information_schema.key_column_usage rel_kcu",
        "          on rco.unique_constraint_schema = rel_kcu.constraint_schema",
        "          and rco.unique_constraint_name = rel_kcu.constraint_name",
        "          and kcu.ordinal_position = rel_kcu.ordinal_position",
        "GROUP BY foreign_table, primary_table, cname"
      ]

-- | Return /all other constraints that are not FKs/ (i.e. 'PRIMARY KEY', 'UNIQUE', etc) for all the tables.
otherConstraintsQ :: Pg.Query
otherConstraintsQ =
  fromString $
    unlines
      [ "SELECT c.conname                                AS constraint_name,",
        "  c.contype                                     AS constraint_type,",
        "  tbl.relname                                   AS \"table\",",
        "  ARRAY_AGG(col.attname ORDER BY u.attposition) AS columns",
        "FROM pg_constraint c",
        "     JOIN LATERAL UNNEST(c.conkey) WITH ORDINALITY AS u(attnum, attposition) ON TRUE",
        "     JOIN pg_class tbl ON tbl.oid = c.conrelid",
        "     JOIN pg_namespace sch ON sch.oid = tbl.relnamespace",
        "     JOIN pg_attribute col ON (col.attrelid = tbl.oid AND col.attnum = u.attnum)",
        "WHERE c.contype = 'u' OR c.contype = 'p'",
        "GROUP BY constraint_name, constraint_type, \"table\"",
        "ORDER BY c.contype"
      ]

-- | Return all \"action types\" for /all/ the constraints.
referenceActionsQ :: Pg.Query
referenceActionsQ =
  fromString $
    unlines
      [ "SELECT c.conname, c. confdeltype, c.confupdtype FROM ",
        "(SELECT r.conrelid, r.confrelid, unnest(r.conkey) AS conkey, unnest(r.confkey) AS confkey, r.conname, r.confupdtype, r.confdeltype ",
        "FROM pg_catalog.pg_constraint r WHERE r.contype = 'f') AS c ",
        "INNER JOIN pg_attribute a_parent ON a_parent.attnum = c.confkey AND a_parent.attrelid = c.confrelid ",
        "INNER JOIN pg_class cl_parent ON cl_parent.oid = c.confrelid ",
        "INNER JOIN pg_namespace sch_parent ON sch_parent.oid = cl_parent.relnamespace ",
        "INNER JOIN pg_attribute a_child ON a_child.attnum = c.conkey AND a_child.attrelid = c.conrelid ",
        "INNER JOIN pg_class cl_child ON cl_child.oid = c.conrelid ",
        "INNER JOIN pg_namespace sch_child ON sch_child.oid = cl_child.relnamespace ",
        "WHERE sch_child.nspname = current_schema() ORDER BY c.conname "
      ]

-- | Connects to a running PostgreSQL database and extract the relevant 'Schema' out of it.
getSchema :: Pg.Connection -> IO Schema
getSchema conn = do
  allTableConstraints <- getAllConstraints conn
  allDefaults <- getAllDefaults conn
  enumerationData <- Pg.fold_ conn enumerationsQ mempty getEnumeration
  sequences <- Pg.fold_ conn sequencesQ mempty getSequence
  tables <-
    Pg.fold_ conn userTablesQ mempty (getTable allDefaults enumerationData allTableConstraints)
  pure $ Schema tables (M.fromList $ M.elems enumerationData) sequences
  where
    getEnumeration ::
      Map Pg.Oid (EnumerationName, Enumeration) ->
      (Text, Pg.Oid, V.Vector Text) ->
      IO (Map Pg.Oid (EnumerationName, Enumeration))
    getEnumeration allEnums (enumName, oid, V.toList -> vals) =
      pure $ M.insert oid (EnumerationName enumName, Enumeration vals) allEnums

    getSequence ::
      Sequences ->
      Pg.Only Text ->
      IO Sequences
    getSequence allSeqs (Pg.Only seqName) =
      case T.splitOn "___" seqName of
        [tName, cName, "seq"] ->
          pure $ M.insert (SequenceName seqName) (Sequence (TableName tName) (ColumnName cName)) allSeqs
        _ -> pure allSeqs

    getTable ::
      AllDefaults ->
      Map Pg.Oid (EnumerationName, Enumeration) ->
      AllTableConstraints ->
      Tables ->
      (Pg.Oid, Text) ->
      IO Tables
    getTable allDefaults enumData allTableConstraints allTables (oid, TableName -> tName) = do
      pgColumns <- Pg.query conn tableColumnsQ (Pg.Only oid)
      newTable <-
        Table (fromMaybe noTableConstraints (M.lookup tName allTableConstraints))
          <$> foldlM (getColumns tName enumData allDefaults) mempty pgColumns
      pure $ M.insert tName newTable allTables

    getColumns ::
      TableName ->
      Map Pg.Oid (EnumerationName, Enumeration) ->
      AllDefaults ->
      Columns ->
      (ByteString, Pg.Oid, Int, Bool, ByteString) ->
      IO Columns
    getColumns tName enumData defaultData c (attname, atttypid, atttypmod, attnotnull, format_type) = do
      -- /NOTA BENE(adn)/: The atttypmod - 4 was originally taken from 'beam-migrate'
      -- (see: https://github.com/tathougies/beam/blob/d87120b58373df53f075d92ce12037a98ca709ab/beam-postgres/Database/Beam/Postgres/Migrate.hs#L343)
      -- but there are cases where this is not correct, for example in the case of bitstrings.
      -- See for example: https://stackoverflow.com/questions/52376045/why-does-atttypmod-differ-from-character-maximum-length
      let mbPrecision =
            if
                | atttypmod == -1 -> Nothing
                | Pg.typoid Pg.bit == atttypid -> Just atttypmod
                | Pg.typoid Pg.varbit == atttypid -> Just atttypmod
                | otherwise -> Just (atttypmod - 4)

      let columnName = ColumnName (TE.decodeUtf8 attname)

      let mbDefault = do
            x <- M.lookup tName defaultData
            M.lookup columnName x

      case asum
        [ pgSerialTyColumnType atttypid mbDefault,
          pgTypeToColumnType atttypid mbPrecision,
          pgEnumTypeToColumnType enumData atttypid
        ] of
        Just cType -> do
          let nullConstraint = if attnotnull then S.fromList [NotNull] else mempty
          let inferredConstraints = nullConstraint <> fromMaybe mempty (S.singleton <$> mbDefault)
          let newColumn = Column cType inferredConstraints
          pure $ M.insert columnName newColumn c
        Nothing ->
          fail $
            "Couldn't convert pgType "
              <> show format_type
              <> " of field "
              <> show attname
              <> " into a valid ColumnType."

--
-- Postgres type mapping
--

pgEnumTypeToColumnType ::
  Map Pg.Oid (EnumerationName, Enumeration) ->
  Pg.Oid ->
  Maybe ColumnType
pgEnumTypeToColumnType enumData oid =
  (\(n, _) -> PgSpecificType (PgEnumeration n)) <$> M.lookup oid enumData

pgSerialTyColumnType ::
  Pg.Oid ->
  Maybe ColumnConstraint ->
  Maybe ColumnType
pgSerialTyColumnType oid (Just (Default d)) = do
  guard $ (Pg.typoid Pg.int4 == oid && "nextval" `T.isInfixOf` d && "seq" `T.isInfixOf` d)
  pure $ SqlStdType intType
pgSerialTyColumnType _ _ = Nothing

-- | Tries to convert from a Postgres' 'Oid' into 'ColumnType'.
-- Mostly taken from [beam-migrate](Database.Beam.Postgres.Migrate).
pgTypeToColumnType :: Pg.Oid -> Maybe Int -> Maybe ColumnType
pgTypeToColumnType oid width
  | Pg.typoid Pg.int2 == oid =
    Just (SqlStdType smallIntType)
  | Pg.typoid Pg.int4 == oid =
    Just (SqlStdType intType)
  | Pg.typoid Pg.int8 == oid =
    Just (SqlStdType bigIntType)
  | Pg.typoid Pg.bpchar == oid =
    Just (SqlStdType $ charType (fromIntegral <$> width) Nothing)
  | Pg.typoid Pg.varchar == oid =
    Just (SqlStdType $ varCharType (fromIntegral <$> width) Nothing)
  | Pg.typoid Pg.bit == oid =
    Just (SqlStdType $ bitType (fromIntegral <$> width))
  | Pg.typoid Pg.varbit == oid =
    Just (SqlStdType $ varBitType (fromIntegral <$> width))
  | Pg.typoid Pg.numeric == oid =
    let decimals = fromMaybe 0 width .&. 0xFFFF
        prec = (fromMaybe 0 width `shiftR` 16) .&. 0xFFFF
     in case (prec, decimals) of
          (0, 0) -> Just (SqlStdType $ numericType Nothing)
          (p, 0) -> Just (SqlStdType $ numericType $ Just (fromIntegral p, Nothing))
          _ -> Just (SqlStdType $ numericType (Just (fromIntegral prec, Just (fromIntegral decimals))))
  | Pg.typoid Pg.float4 == oid =
    Just (SqlStdType realType)
  | Pg.typoid Pg.float8 == oid =
    Just (SqlStdType doubleType)
  | Pg.typoid Pg.date == oid =
    Just (SqlStdType dateType)
  | Pg.typoid Pg.text == oid =
    Just (SqlStdType characterLargeObjectType)
  -- I am not sure if this is a bug in beam-core, but both 'characterLargeObjectType' and 'binaryLargeObjectType'
  -- get mapped into 'AST.DataTypeCharacterLargeObject', which yields TEXT, whereas we want the latter to
  -- yield bytea.
  | Pg.typoid Pg.bytea == oid =
    Just (SqlStdType AST.DataTypeBinaryLargeObject)
  | Pg.typoid Pg.bool == oid =
    Just (SqlStdType booleanType)
  | Pg.typoid Pg.time == oid =
    Just (SqlStdType $ timeType Nothing False)
  | Pg.typoid Pg.timestamp == oid =
    Just (SqlStdType $timestampType Nothing False)
  | Pg.typoid Pg.timestamptz == oid =
    Just (SqlStdType $ timestampType Nothing True)
  | Pg.typoid Pg.json == oid =
    -- json types
    Just (PgSpecificType PgJson)
  | Pg.typoid Pg.jsonb == oid =
    Just (PgSpecificType PgJsonB)
  -- range types
  | Pg.typoid Pg.int4range == oid =
    Just (PgSpecificType PgRangeInt4)
  | Pg.typoid Pg.int8range == oid =
    Just (PgSpecificType PgRangeInt8)
  | Pg.typoid Pg.numrange == oid =
    Just (PgSpecificType PgRangeNum)
  | Pg.typoid Pg.tsrange == oid =
    Just (PgSpecificType PgRangeTs)
  | Pg.typoid Pg.tstzrange == oid =
    Just (PgSpecificType PgRangeTsTz)
  | Pg.typoid Pg.daterange == oid =
    Just (PgSpecificType PgRangeDate)
  | Pg.typoid Pg.uuid == oid =
    Just (PgSpecificType PgUuid)
  | otherwise =
    Nothing

--
-- Constraints discovery
--

type AllTableConstraints = Map TableName (Set TableConstraint)

type AllDefaults = Map TableName Defaults

type Defaults = Map ColumnName ColumnConstraint

-- Get all defaults values for /all/ the columns.
-- FIXME(adn) __IMPORTANT:__ This function currently __always_ attach an explicit type annotation to the
-- default value, by reading its 'date_type' field, to resolve potential ambiguities.
-- The reason for this is that we cannot reliably guarantee a convertion between default values are read
-- by postgres and values we infer on the Schema side (using the 'beam-core' machinery). In theory we
-- wouldn't need to explicitly annotate the types before generating a 'Default' constraint on the 'Schema'
-- side, but this doesn't always work. For example, if we **always** specify a \"::numeric\" annotation for
-- an 'Int', Postgres might yield \"-1::integer\" for non-positive values and simply \"-1\" for all the rest.
-- To complicate the situation /even if/ we explicitly specify the cast
-- (i.e. \"SET DEFAULT '?::character varying'), Postgres will ignore this when reading the default back.
-- What we do here is obviously not optimal, but on the other hand it's not clear to me how to solve this
-- in a meaningful and non-invasive way, for a number of reasons:
--

-- * For example \"beam-migrate"\ seems to resort to be using explicit serialisation for the types, although

--   I couldn't find explicit trace if that applies for defaults explicitly.
--   (cfr. the \"Database.Beam.AutoMigrate.Serialization\" module in \"beam-migrate\").
--

-- * Another big problem is __rounding__: For example if we insert as \"double precision\" the following:

--   Default "'-0.22030397057804563'" , Postgres will round the value and return Default "'-0.220303970578046'".
--   Again, it's not clear to me how to prevent the users from shooting themselves here.
--

-- * Another quirk is with dates: \"beam\" renders a date like \'1864-05-10\' (note the single quotes) but

--   Postgres strip those when reading the default value back.
--

-- * Range types are also tricky to infer. 'beam-core' escapes the range type name when rendering its default

--   value, whereas Postgres annotates each individual field and yield the unquoted identifier. Compare:
--   1. Beam:     \""numrange"(0, 2, '[)')\"
--   2. Postgres: \"numrange((0)::numeric, (2)::numeric, '[)'::text)\"
--
getAllDefaults :: Pg.Connection -> IO AllDefaults
getAllDefaults conn = Pg.fold_ conn defaultsQ mempty (\acc -> pure . addDefault acc)
  where
    addDefault :: AllDefaults -> (TableName, ColumnName, Text, Text) -> AllDefaults
    addDefault m (tName, colName, defValue, dataType) =
      let cleanedDefault = case T.breakOn "::" defValue of
            (uncasted, defMb)
              | T.null defMb ->
                "'" <> T.dropAround ((==) '\'') uncasted <> "'::" <> dataType
            _ -> defValue
          entry = M.singleton colName (Default cleanedDefault)
       in M.alter
            ( \case
                Nothing -> Just entry
                Just ss -> Just $ ss <> entry
            )
            tName
            m

getAllConstraints :: Pg.Connection -> IO AllTableConstraints
getAllConstraints conn = do
  allActions <- mkActions <$> Pg.query_ conn referenceActionsQ
  allForeignKeys <- Pg.fold_ conn foreignKeysQ mempty (\acc -> pure . addFkConstraint allActions acc)
  Pg.fold_ conn otherConstraintsQ allForeignKeys (\acc -> pure . addOtherConstraint acc)
  where
    addFkConstraint ::
      ReferenceActions ->
      AllTableConstraints ->
      SqlForeignConstraint ->
      AllTableConstraints
    addFkConstraint actions st SqlForeignConstraint {..} = flip execState st $ do
      let currentTable = sqlFk_foreign_table
      let columnSet = S.fromList $ zip (V.toList sqlFk_fk_columns) (V.toList sqlFk_pk_columns)
      let (onDelete, onUpdate) =
            case M.lookup sqlFk_name (getActions actions) of
              Nothing -> (NoAction, NoAction)
              Just a -> (actionOnDelete a, actionOnUpdate a)
      addTableConstraint currentTable (ForeignKey sqlFk_name sqlFk_primary_table columnSet onDelete onUpdate)

    addOtherConstraint ::
      AllTableConstraints ->
      SqlOtherConstraint ->
      AllTableConstraints
    addOtherConstraint st SqlOtherConstraint {..} = flip execState st $ do
      let currentTable = sqlCon_table
      let columnSet = S.fromList . V.toList $ sqlCon_fk_colums
      case sqlCon_constraint_type of
        SQL_raw_unique -> addTableConstraint currentTable (Unique sqlCon_name columnSet)
        SQL_raw_pk -> addTableConstraint currentTable (PrimaryKey sqlCon_name columnSet)

newtype ReferenceActions = ReferenceActions {getActions :: Map Text Actions}

newtype RefEntry = RefEntry {unRefEntry :: (Text, ReferenceAction, ReferenceAction)}

mkActions :: [RefEntry] -> ReferenceActions
mkActions = ReferenceActions . M.fromList . map ((\(a, b, c) -> (a, Actions b c)) . unRefEntry)

instance Pg.FromRow RefEntry where
  fromRow =
    fmap
      RefEntry
      ( (,,) <$> field
          <*> fmap mkAction field
          <*> fmap mkAction field
      )

data Actions = Actions
  { actionOnDelete :: ReferenceAction,
    actionOnUpdate :: ReferenceAction
  }

mkAction :: Text -> ReferenceAction
mkAction c = case c of
  "a" -> NoAction
  "r" -> Restrict
  "c" -> Cascade
  "n" -> SetNull
  "d" -> SetDefault
  _ -> error . T.unpack $ "unknown reference action type: " <> c

--
-- Useful combinators to add constraints for a column or table if already there.
--

addTableConstraint ::
  TableName ->
  TableConstraint ->
  State AllTableConstraints ()
addTableConstraint tName cns =
  modify'
    ( M.alter
        ( \case
            Nothing -> Just $ S.singleton cns
            Just ss -> Just $ S.insert cns ss
        )
        tName
    )