{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}

module Database.Beam.AutoMigrate.Diff
  ( Diffable (..),
    Diff,
    Priority (..),
    WithPriority (..),

    -- * Reference implementation, for model-testing purposes
    diffColumnReferenceImplementation,
    diffTablesReferenceImplementation,
    diffTableReferenceImplementation,
    diffReferenceImplementation,

    -- * Hopefully-efficient implementation
    diffColumn,
    diffTables,
    diffTable,
    sortEdits,
  )
where

import Control.Exception (assert)
import Control.Monad
import Data.DList (DList)
import qualified Data.DList as D
import Data.Foldable (foldlM)
import Data.List (foldl', (\\))
import qualified Data.List as L
import Data.Map.Merge.Strict
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Data.Text (Text)
import Data.Word (Word8)
import Database.Beam.AutoMigrate.Types

--
-- Simple typeclass to diff things
--

-- | Some notion of 'Priority'. The lower the value, the higher the priority.
newtype Priority = Priority Word8 deriving (Show, Eq, Ord)

newtype WithPriority a = WithPriority {unPriority :: (a, Priority)} deriving (Show, Eq, Ord)

editPriority :: EditAction -> Priority
editPriority = \case
  -- Operations that create tables, sequences or enums have top priority
  EnumTypeAdded {} -> Priority 0
  SequenceAdded {} -> Priority 1
  TableAdded {} -> Priority 2
  -- We cannot create a column if the relevant table (or enum type) is not there.
  ColumnAdded {} -> Priority 3
  -- Operations that set constraints or change the shape of a type have lower priority
  ColumnTypeChanged {} -> Priority 4
  EnumTypeValueAdded {} -> Priority 5
  -- foreign keys need to go last, as the referenced columns needs to be either UNIQUE or have PKs.
  TableConstraintAdded _ Unique {} -> Priority 6
  TableConstraintAdded _ PrimaryKey {} -> Priority 7
  TableConstraintAdded _ ForeignKey {} -> Priority 8
  ColumnConstraintAdded {} -> Priority 9
  TableConstraintRemoved {} -> Priority 10
  ColumnConstraintRemoved {} -> Priority 11
  -- Destructive operations go last
  ColumnRemoved {} -> Priority 12
  TableRemoved {} -> Priority 13
  EnumTypeRemoved {} -> Priority 14
  SequenceRemoved {} -> Priority 15

-- TODO: This needs to support adding conditional queries.
mkEdit :: EditAction -> WithPriority Edit
mkEdit e = WithPriority (defMkEdit e, editPriority e)

-- | Sort edits according to their execution order, to make sure they don't reference
-- something which hasn't been created yet.
sortEdits :: [WithPriority Edit] -> [WithPriority Edit]
sortEdits = L.sortOn (snd . unPriority)

type DiffA t = Either DiffError (t (WithPriority Edit))

type Diff = DiffA []

-- NOTE(adn) Accumulate all the errors independently instead of short circuiting?
class Diffable a where
  diff :: a -> a -> Diff

-- | Computes the diff between two 'Schema's, either failing with a 'DiffError'
-- or returning the list of 'Edit's necessary to turn the first into the second.
instance Diffable Schema where
  diff hsSchema dbSchema = do
    tableDiffs <- diff (schemaTables hsSchema) (schemaTables dbSchema)
    enumDiffs <- diff (schemaEnumerations hsSchema) (schemaEnumerations dbSchema)
    sequenceDiffs <- diff (schemaSequences hsSchema) (schemaSequences dbSchema)
    pure $ tableDiffs <> enumDiffs <> sequenceDiffs

instance Diffable Tables where
  diff t1 = fmap D.toList . diffTables t1

instance Diffable Enumerations where
  diff e1 = fmap D.toList . diffEnums e1

instance Diffable Sequences where
  diff s1 = fmap D.toList . diffSequences s1

--
-- Reference implementation
--

diffReferenceImplementation :: Schema -> Schema -> Diff
diffReferenceImplementation hsSchema = diff (schemaTables hsSchema) . schemaTables

-- | A slow but hopefully correct implementation of the diffing algorithm, for QuickCheck comparison with
-- more sophisticated ones.
diffTablesReferenceImplementation :: Tables -> Tables -> Diff
diffTablesReferenceImplementation hsTables dbTables = do
  let tablesAdded = M.difference hsTables dbTables
      tablesRemoved = M.difference dbTables hsTables
      diffableTables = M.intersection hsTables dbTables
      diffableTables' = M.intersection dbTables hsTables
  whenBoth <- foldlM go mempty (zip (M.toList diffableTables) (M.toList diffableTables'))
  pure $ whenAdded tablesAdded <> whenRemoved tablesRemoved <> whenBoth
  where
    whenAdded :: Tables -> [WithPriority Edit]
    whenAdded = concatMap (addEdit TableAdded TableConstraintAdded tableConstraints) . M.toList

    whenRemoved :: Tables -> [WithPriority Edit]
    whenRemoved =
      concatMap (addEdit (\k _ -> TableRemoved k) TableConstraintRemoved tableConstraints) . M.toList

    go :: [WithPriority Edit] -> ((TableName, Table), (TableName, Table)) -> Diff
    go e ((hsName, hsTable), (dbName, dbTable)) = assert (hsName == dbName) $ do
      d <- diffTableReferenceImplementation hsName hsTable dbTable
      pure $ e <> d

addEdit ::
  (k -> v -> EditAction) ->
  (k -> c -> EditAction) ->
  (v -> S.Set c) ->
  (k, v) ->
  [WithPriority Edit]
addEdit onValue onConstr getConstr (k, v) =
  mkEdit (onValue k v) : map (mkEdit . onConstr k) (S.toList $ getConstr v)

diffTableReferenceImplementation :: TableName -> Table -> Table -> Diff
diffTableReferenceImplementation tName hsTable dbTable = do
  let constraintsAdded = S.difference (tableConstraints hsTable) (tableConstraints dbTable)
      constraintsRemoved = S.difference (tableConstraints dbTable) (tableConstraints hsTable)
      columnsAdded = M.difference (tableColumns hsTable) (tableColumns dbTable)
      columnsRemoved = M.difference (tableColumns dbTable) (tableColumns hsTable)
      diffableColumns = M.intersection (tableColumns hsTable) (tableColumns dbTable)
      diffableColumns' = M.intersection (tableColumns dbTable) (tableColumns hsTable)
  whenBoth <- foldlM go mempty (zip (M.toList diffableColumns) (M.toList diffableColumns'))
  let tblConstraintsAdded = do
        guard (not $ S.null constraintsAdded)
        pure $ map (mkEdit . TableConstraintAdded tName) (S.toList constraintsAdded)
  let tblConstraintsRemoved = do
        guard (not $ S.null constraintsRemoved)
        pure $ map (mkEdit . TableConstraintRemoved tName) (S.toList constraintsRemoved)
  let colsAdded = whenAdded columnsAdded
  let colsRemoved = whenRemoved columnsRemoved
  pure $
    join (catMaybes [tblConstraintsAdded, tblConstraintsRemoved])
      <> colsAdded
      <> colsRemoved
      <> whenBoth
  where
    go :: [WithPriority Edit] -> ((ColumnName, Column), (ColumnName, Column)) -> Diff
    go e ((hsName, hsCol), (dbName, dbCol)) = assert (hsName == dbName) $ do
      d <- diffColumnReferenceImplementation tName hsName hsCol dbCol
      pure $ e <> d

    whenAdded :: Columns -> [WithPriority Edit]
    whenAdded =
      concatMap (addEdit (ColumnAdded tName) (ColumnConstraintAdded tName) columnConstraints) . M.toList

    whenRemoved :: Columns -> [WithPriority Edit]
    whenRemoved =
      concatMap (addEdit (\k _ -> ColumnRemoved tName k) (ColumnConstraintRemoved tName) columnConstraints) . M.toList

diffColumnReferenceImplementation :: TableName -> ColumnName -> Column -> Column -> Diff
diffColumnReferenceImplementation tName colName hsColumn dbColumn = do
  let constraintsAdded = S.difference (columnConstraints hsColumn) (columnConstraints dbColumn)
      constraintsRemoved = S.difference (columnConstraints dbColumn) (columnConstraints hsColumn)
  let colConstraintsAdded = do
        guard (not $ S.null constraintsAdded)
        pure $ map (mkEdit . ColumnConstraintAdded tName colName) (S.toList constraintsAdded)
  let colConstraintsRemoved = do
        guard (not $ S.null constraintsRemoved)
        pure $ map (mkEdit . ColumnConstraintRemoved tName colName) (S.toList constraintsRemoved)
  let typeChanged = do
        guard (columnType hsColumn /= columnType dbColumn)
        pure [mkEdit $ ColumnTypeChanged tName colName (columnType dbColumn) (columnType hsColumn)]
  pure $ join $ catMaybes [colConstraintsAdded, colConstraintsRemoved, typeChanged]

--
-- Actual implementation
--

--
-- Diffing enums together
--

diffEnums :: Enumerations -> Enumerations -> DiffA DList
diffEnums hsEnums dbEnums =
  M.foldl' D.append mempty <$> mergeA whenEnumsAdded whenEnumsRemoved whenBoth hsEnums dbEnums
  where
    whenEnumsAdded :: WhenMissing (Either DiffError) EnumerationName Enumeration (DList (WithPriority Edit))
    whenEnumsAdded = traverseMissing (\k v -> Right . D.singleton . mkEdit $ EnumTypeAdded k v)

    whenEnumsRemoved :: WhenMissing (Either DiffError) EnumerationName Enumeration (DList (WithPriority Edit))
    whenEnumsRemoved = traverseMissing (\k _ -> Right . D.singleton . mkEdit $ EnumTypeRemoved k)

    whenBoth :: WhenMatched (Either DiffError) EnumerationName Enumeration Enumeration (DList (WithPriority Edit))
    whenBoth = zipWithAMatched diffEnumeration

diffEnumeration :: EnumerationName -> Enumeration -> Enumeration -> DiffA DList
diffEnumeration eName (Enumeration hsEnum) (Enumeration dbEnum) = do
  let valuesRemoved = dbEnum \\ hsEnum
  if L.null valuesRemoved then Right $ D.fromList (computeEnumEdit eName hsEnum dbEnum) else Left $ ValuesRemovedFromEnum eName valuesRemoved

computeEnumEdit :: EnumerationName -> [Text] -> [Text] -> [WithPriority Edit]
computeEnumEdit _ [] [] = mempty
computeEnumEdit _ [] (_ : _) = mempty
computeEnumEdit eName (x : xs) [] = appendAfter eName xs x
computeEnumEdit eName (x : xs) [y] =
  if x == y
    then appendAfter eName xs y
    else mkEdit (EnumTypeValueAdded eName x Before y) : computeEnumEdit eName xs [y]
computeEnumEdit eName (x : xs) (y : ys) =
  if x == y
    then computeEnumEdit eName xs ys
    else mkEdit (EnumTypeValueAdded eName x Before y) : computeEnumEdit eName xs (y : ys)

appendAfter :: EnumerationName -> [Text] -> Text -> [WithPriority Edit]
appendAfter _ [] _ = mempty
appendAfter eName [l] z = [mkEdit $ EnumTypeValueAdded eName l After z]
appendAfter eName (l : ls) z = mkEdit (EnumTypeValueAdded eName l After z) : appendAfter eName ls l

--
-- Diffing sequences together
--

diffSequences :: Sequences -> Sequences -> DiffA DList
diffSequences hsSeqs dbSeqs =
  M.foldl' D.append mempty <$> mergeA whenSeqsAdded whenSeqsRemoved whenBoth hsSeqs dbSeqs
  where
    whenSeqsAdded :: WhenMissing (Either DiffError) SequenceName Sequence (DList (WithPriority Edit))
    whenSeqsAdded = traverseMissing (\k v -> Right . D.singleton . mkEdit $ SequenceAdded k v)

    whenSeqsRemoved :: WhenMissing (Either DiffError) SequenceName Sequence (DList (WithPriority Edit))
    whenSeqsRemoved = traverseMissing (\k _ -> Right . D.singleton . mkEdit $ SequenceRemoved k)

    -- Currently a 'Sequence' doesn't carry any extra information, so diffing two 'Sequence's is
    -- a no-op, basically.
    whenBoth :: WhenMatched (Either DiffError) SequenceName Sequence Sequence (DList (WithPriority Edit))
    whenBoth = zipWithAMatched (\_ (Sequence _ _) (Sequence _ _) -> Right mempty)

--
-- Diffing tables together
--

diffTables :: Tables -> Tables -> DiffA DList
diffTables hsTables dbTables =
  M.foldl' D.append mempty <$> mergeA whenTablesAdded whenTablesRemoved whenBoth hsTables dbTables
  where
    whenTablesAdded :: WhenMissing (Either DiffError) TableName Table (DList (WithPriority Edit))
    whenTablesAdded =
      traverseMissing
        ( \k v -> do
            let created = mkEdit $ TableAdded k v
            let constraintsAdded = map (mkEdit . TableConstraintAdded k) (S.toList $ tableConstraints v)
            pure $ D.fromList (created : constraintsAdded)
        )

    whenTablesRemoved :: WhenMissing (Either DiffError) TableName Table (DList (WithPriority Edit))
    whenTablesRemoved =
      traverseMissing
        ( \k v -> do
            let removed = mkEdit $ TableRemoved k
            let constraintsRemoved = map (mkEdit . TableConstraintRemoved k) (S.toList $ tableConstraints v)
            pure $ D.fromList (removed : constraintsRemoved)
        )

    whenBoth :: WhenMatched (Either DiffError) TableName Table Table (DList (WithPriority Edit))
    whenBoth = zipWithAMatched diffTable

diffTable :: TableName -> Table -> Table -> DiffA DList
diffTable tName hsTable dbTable = do
  let constraintsAdded = S.difference (tableConstraints hsTable) (tableConstraints dbTable)
      constraintsRemoved = S.difference (tableConstraints dbTable) (tableConstraints hsTable)
      tblConstraintsAdded = do
        guard (not $ S.null constraintsAdded)
        pure $ D.map (mkEdit . TableConstraintAdded tName) (D.fromList . S.toList $ constraintsAdded)
      tblConstraintsRemoved = do
        guard (not $ S.null constraintsRemoved)
        pure $ D.map (mkEdit . TableConstraintRemoved tName) (D.fromList . S.toList $ constraintsRemoved)
  diffs <-
    M.foldl' D.append mempty
      <$> mergeA whenColumnAdded whenColumnRemoved whenBoth (tableColumns hsTable) (tableColumns dbTable)
  pure $ foldl' D.append D.empty (catMaybes [tblConstraintsAdded, tblConstraintsRemoved]) <> diffs
  where
    whenColumnAdded :: WhenMissing (Either DiffError) ColumnName Column (DList (WithPriority Edit))
    whenColumnAdded =
      traverseMissing
        ( \k v -> do
            let added = mkEdit $ ColumnAdded tName k v
            let constraintsAdded = map (mkEdit . ColumnConstraintAdded tName k) (S.toList $ columnConstraints v)
            pure $ D.fromList (added : constraintsAdded)
        )

    whenColumnRemoved :: WhenMissing (Either DiffError) ColumnName Column (DList (WithPriority Edit))
    whenColumnRemoved =
      traverseMissing
        ( \k v -> do
            let removed = mkEdit $ ColumnRemoved tName k
            let constraintsRemoved = map (mkEdit . ColumnConstraintRemoved tName k) (S.toList $ columnConstraints v)
            pure $ D.fromList (removed : constraintsRemoved)
        )

    whenBoth :: WhenMatched (Either DiffError) ColumnName Column Column (DList (WithPriority Edit))
    whenBoth = zipWithAMatched (diffColumn tName)

diffColumn :: TableName -> ColumnName -> Column -> Column -> DiffA DList
diffColumn tName colName hsColumn dbColumn = do
  let constraintsAdded = S.difference (columnConstraints hsColumn) (columnConstraints dbColumn)
      constraintsRemoved = S.difference (columnConstraints dbColumn) (columnConstraints hsColumn)
  let colConstraintsAdded = do
        guard (not $ S.null constraintsAdded)
        pure $ D.map (mkEdit . ColumnConstraintAdded tName colName) (D.fromList . S.toList $ constraintsAdded)
  let colConstraintsRemoved = do
        guard (not $ S.null constraintsRemoved)
        pure $ D.map (mkEdit . ColumnConstraintRemoved tName colName) (D.fromList . S.toList $ constraintsRemoved)
  let typeChanged = do
        guard (columnType hsColumn /= columnType dbColumn)
        pure $ D.singleton (mkEdit $ ColumnTypeChanged tName colName (columnType dbColumn) (columnType hsColumn))
  pure $ foldl' D.append D.empty $ catMaybes [colConstraintsAdded, colConstraintsRemoved, typeChanged]