{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TemplateHaskell #-}

module Raft.Log.PostgreSQL (
  RaftPostgresT(..),
  runRaftPostgresT,
  runRaftPostgresM,

  raftDatabaseName,
  raftDatabaseConnInfo,
  initConnInfo,

  setupDB,
  deleteDB

) where

import Protolude hiding (Handler, catches, bracket)

import Control.Concurrent.STM.TMVar

import Control.Monad.Catch
import Control.Monad.Fail
import Control.Monad.Trans.Class

import Data.FileEmbed
import qualified Data.Sequence as Seq
import Data.Serialize (Serialize)

import Database.PostgreSQL.Simple
import Database.PostgreSQL.Simple.Types (Identifier(..))

import Raft.Client
import Raft.RPC
import Raft.Monad
import Raft.Log
import Raft.StateMachine
import Raft.Persistent
import Raft.Types

data RaftPostgresEnv = RaftPostgresEnv
  { raftPostgresConnInfo :: ConnectInfo
  , raftPostgresConn :: TMVar Connection
  }

-- | A single threaded PostgreSQL storage monad transformer
newtype RaftPostgresT m a = RaftPostgresT { unRaftPostgresT :: ReaderT RaftPostgresEnv m a }
  deriving newtype (Functor, Applicative, Monad, MonadIO, MonadFail, MonadReader RaftPostgresEnv, Alternative, MonadPlus, MonadTrans)

deriving newtype instance MonadCatch m => MonadCatch (RaftPostgresT m)
deriving newtype instance MonadThrow m => MonadThrow (RaftPostgresT m)
deriving newtype instance MonadMask m => MonadMask (RaftPostgresT m)

initRaftPostgresEnv :: MonadIO m => ConnectInfo -> m RaftPostgresEnv
initRaftPostgresEnv connInfo =
  RaftPostgresEnv connInfo <$> liftIO (atomically newEmptyTMVar)

-- | Run a RaftPostgresT computation by supplying the database connection info
runRaftPostgresT :: MonadIO m => ConnectInfo -> RaftPostgresT m a -> m a
runRaftPostgresT connInfo m = do
  connTMVar <- liftIO $ atomically newEmptyTMVar
  let postgresEnv = RaftPostgresEnv connInfo connTMVar
  runReaderT (unRaftPostgresT m) postgresEnv

runRaftPostgresT' :: MonadIO m => RaftPostgresEnv -> RaftPostgresT m a -> m a
runRaftPostgresT' raftPostgresEnv =
  flip runReaderT raftPostgresEnv . unRaftPostgresT

type RaftPostgresM = RaftPostgresT IO

runRaftPostgresM :: ConnectInfo -> RaftPostgresM a -> IO a
runRaftPostgresM = runRaftPostgresT

data RaftPostgresError
  = RaftPostgresError PGError
  | RaftPostgresFailedToConnect PGError
  | RaftPostgresSerializeError Text
  deriving (Show)

instance Exception (RaftPostgresError)

-- | Helper function for RaftXLog typeclasses such that we can easily catch all
-- @PGError@s and connect to the DB if we haven't already
withRaftPostgresConn
  :: MonadIO m
  => (Connection -> IO (Either RaftPostgresError a))
  -> RaftPostgresT m (Either RaftPostgresError a)
withRaftPostgresConn f = do
  mconn <- liftIO . atomically . tryReadTMVar =<< asks raftPostgresConn
  eConn <-
    case mconn of
      -- If there is no existing connection, attempt to connect and store the
      -- connection in the TMVar
      Nothing -> do
        eConn <- liftIO . tryPG . connect =<< asks raftPostgresConnInfo
        case eConn of
          Left err -> pure (Left (RaftPostgresFailedToConnect err))
          Right conn -> do
            connTMVar <- asks raftPostgresConn
            liftIO . atomically $ putTMVar connTMVar conn
            pure (Right conn)
      Just conn -> pure (Right conn)
  case eConn of
    Left err -> pure (Left err)
    Right conn -> do
      eRes <- liftIO (tryPG (f conn))
      case eRes of
        Left err -> pure (Left (RaftPostgresError err))
        Right (Left err) -> pure (Left err)
        Right (Right a) -> pure (Right a)

--------------------------------------------------------------------------------
-- Raft instances
--------------------------------------------------------------------------------

instance (MonadIO m) => RaftInitLog (RaftPostgresT m) v where
  type RaftInitLogError (RaftPostgresT m) = RaftPostgresError
  initializeLog _ = do
    eConn <- liftIO . setupDB =<< asks raftPostgresConnInfo
    case eConn of
      Left err -> pure (Left (RaftPostgresError err))
      Right conn -> do
        connTMVar <- asks raftPostgresConn
        fmap Right . liftIO . atomically $ putTMVar connTMVar conn

instance (Typeable v, Serialize v, MonadIO m) => RaftReadLog (RaftPostgresT m) v where
  type RaftReadLogError (RaftPostgresT m) = RaftPostgresError
  readLogEntry idx =
    withRaftPostgresConn $ \conn ->
      Right . fmap rowTypeToEntry . listToMaybe <$>
        query conn "select * from entries where entryIndex = ?" (Only idx)

  readLogEntriesFrom idx =
    withRaftPostgresConn $ \conn ->
      Right . Seq.fromList . map rowTypeToEntry <$>
        query conn "select * from entries where entryIndex >= ?" (Only idx)

  readLastLogEntry =
    withRaftPostgresConn $ \conn ->
      Right . fmap rowTypeToEntry . listToMaybe <$>
        query_ conn "SELECT * FROM entries ORDER BY entryIndex DESC LIMIT 1"

instance (Serialize v, MonadIO m) => RaftWriteLog (RaftPostgresT m) v where
  type RaftWriteLogError (RaftPostgresT m) = RaftPostgresError
  writeLogEntries entries =
    withRaftPostgresConn $ \conn ->
      fmap Right . void $
        executeMany conn "INSERT INTO entries VALUES (?,?,?,?,?,?)" (map entryToRowType (toList entries))

instance (Serialize v, MonadIO m) => RaftDeleteLog (RaftPostgresT m) v where
  type RaftDeleteLogError (RaftPostgresT m) = RaftPostgresError
  deleteLogEntriesFrom idx =
    withRaftPostgresConn $ \conn ->
      fmap (const $ Right DeleteSuccess) . void $
        execute conn "DELETE FROM entries WHERE entryIndex >= ?" (Only idx)

--------------------------------------------------------------------------------
-- Inherited Raft instances
--------------------------------------------------------------------------------

instance RaftPersist m => RaftPersist (RaftPostgresT m) where
  type RaftPersistError (RaftPostgresT m) = RaftPersistError m
  initializePersistentState = lift initializePersistentState
  readPersistentState = lift readPersistentState
  writePersistentState = lift . writePersistentState

instance (Monad m, RaftSendRPC m v) => RaftSendRPC (RaftPostgresT m) v where
  sendRPC nid msg = lift (sendRPC nid msg)

instance (Monad m, RaftRecvRPC m v) => RaftRecvRPC (RaftPostgresT m) v where
  type RaftRecvRPCError (RaftPostgresT m) v = RaftRecvRPCError m v
  receiveRPC = lift receiveRPC

instance (Monad m, RaftSendClient m sm v) => RaftSendClient (RaftPostgresT m) sm v where
  sendClient clientId = lift . sendClient clientId

instance (Monad m, RaftRecvClient m v) => RaftRecvClient (RaftPostgresT m) v where
  type RaftRecvClientError (RaftPostgresT m) v = RaftRecvClientError m v
  receiveClient = lift receiveClient

instance RaftStateMachine m sm v => RaftStateMachine (RaftPostgresT m) sm v where
  validateCmd = lift . validateCmd
  askRaftStateMachinePureCtx = lift askRaftStateMachinePureCtx

instance MonadRaftChan v m => MonadRaftChan v (RaftPostgresT m) where
  type RaftEventChan v (RaftPostgresT m) = RaftEventChan v m
  readRaftChan = lift . readRaftChan
  writeRaftChan chan = lift . writeRaftChan chan
  newRaftChan = lift (newRaftChan @v @m)

instance (MonadIO m, MonadRaftFork m) => MonadRaftFork (RaftPostgresT m) where
  type RaftThreadId (RaftPostgresT m) = RaftThreadId m
  raftFork r m = do
    raftPostgresEnv <- ask
    lift $ raftFork r (runRaftPostgresT' raftPostgresEnv m)

--------------------------------------------------------------------------------

data EntryRow v = EntryRow
  { entryRowIndex :: Index
  , entryRowTerm :: Term
  , entryRowValueHash :: EntryHash
  , entryRowValue :: EntryValue v
  , entryRowIssuer :: EntryIssuer
  , entryRowPrevHash :: EntryHash
  } deriving (Show, Generic, ToRow, FromRow)

entryToRowType :: Serialize v => Entry v -> EntryRow v
entryToRowType entry@Entry{..} =
  EntryRow
    { entryRowIndex = entryIndex
    , entryRowTerm = entryTerm
    , entryRowValueHash = hashEntry entry
    , entryRowValue = entryValue
    , entryRowIssuer = entryIssuer
    , entryRowPrevHash = entryPrevHash
    }

rowTypeToEntry :: Serialize v => EntryRow v -> Entry v
rowTypeToEntry EntryRow{..} = Entry
  { entryIndex = entryRowIndex
  , entryTerm = entryRowTerm
  , entryValue = entryRowValue
  , entryIssuer = entryRowIssuer
  , entryPrevHash = entryRowPrevHash
  }

--------------------------------------------------------------------------------
-- Create Database for Entry Values
--------------------------------------------------------------------------------

data PGError
  = PGSqlError SqlError
  | PGFormatError FormatError
  | PGQueryError QueryError
  | PGResultError ResultError
  | PGUnexpectedError Text
  deriving (Show)

raftDatabasePrefix :: [Char]
raftDatabasePrefix = "libraft_"

raftDatabaseName :: [Char] -> [Char]
raftDatabaseName suffix = raftDatabasePrefix ++ suffix

raftDatabaseConnInfo :: [Char] -> [Char] -> [Char] -> ConnectInfo
raftDatabaseConnInfo usrnm pswd dbSuffix =
  defaultConnectInfo {
      connectUser = usrnm
    , connectPassword = pswd
    , connectDatabase = raftDatabaseName dbSuffix
    }

initConnInfo :: ConnectInfo
initConnInfo = defaultConnectInfo
  { connectDatabase = "postgres"
  , connectUser = "libraft_test"
  , connectPassword = "libraft_test"
  }

tryPG :: IO a -> IO (Either PGError a)
tryPG action =
    catches (Right <$> action)
      [ catchSqlError, catchFmtError, catchQueryError, catchResultError, catchAllError ]
  where
    handler :: Exception e => (e -> IO a) -> Handler IO a
    handler = Handler

    catchSqlError    = handler (pure . Left . PGSqlError)
    catchFmtError    = handler (pure . Left . PGFormatError)
    catchQueryError  = handler (pure . Left . PGQueryError)
    catchResultError = handler (pure . Left . PGResultError)
    catchAllError    = handler (pure . Left . PGUnexpectedError . (show :: SomeException -> Text))

createDB :: [Char] -> Connection -> IO Int64
createDB dbName conn =
  execute conn "CREATE DATABASE ?" (Only $ Identifier (toS dbName))

deleteDB :: [Char] -> Connection -> IO Int64
deleteDB dbName conn =
  execute conn "DROP DATABASE IF EXISTS ?" (Only $ Identifier (toS dbName))

-- | Load the raft schema at compile time using TH
raftSchema :: IsString a => a
raftSchema = $(makeRelativeToProject "postgres/entries.sql" >>= embedStringFile)

-- | Execute DB schema to build the raft database
createEntriesTable :: Connection -> IO Int64
createEntriesTable conn = execute_ conn raftSchema

-- | Create the libraft database to store all log entries
setupDB :: ConnectInfo -> IO (Either PGError Connection)
setupDB connInfo = tryPG $ do
    -- Create the database with the "postgres" user
    bracket (connect initConnInfo) close $ \conn ->
      createDB dbName conn
    -- Connect to the DB with the provided connection info
    conn <- connect connInfo
    -- Create the log entries table with the resulting connection
    createEntriesTable conn
    pure conn
  where
    dbName = connectDatabase connInfo