{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TemplateHaskell #-}
module Raft.Log.PostgreSQL (
RaftPostgresT(..),
runRaftPostgresT,
runRaftPostgresM,
raftDatabaseName,
raftDatabaseConnInfo,
initConnInfo,
setupDB,
deleteDB
) where
import Protolude hiding (Handler, catches, bracket)
import Control.Concurrent.STM.TMVar
import Control.Monad.Catch
import Control.Monad.Fail
import Control.Monad.Trans.Class
import Data.FileEmbed
import qualified Data.Sequence as Seq
import Data.Serialize (Serialize)
import Database.PostgreSQL.Simple
import Database.PostgreSQL.Simple.Types (Identifier(..))
import Raft.Client
import Raft.RPC
import Raft.Monad
import Raft.Log
import Raft.StateMachine
import Raft.Persistent
import Raft.Types
data RaftPostgresEnv = RaftPostgresEnv
{ raftPostgresConnInfo :: ConnectInfo
, raftPostgresConn :: TMVar Connection
}
newtype RaftPostgresT m a = RaftPostgresT { unRaftPostgresT :: ReaderT RaftPostgresEnv m a }
deriving newtype (Functor, Applicative, Monad, MonadIO, MonadFail, MonadReader RaftPostgresEnv, Alternative, MonadPlus, MonadTrans)
deriving newtype instance MonadCatch m => MonadCatch (RaftPostgresT m)
deriving newtype instance MonadThrow m => MonadThrow (RaftPostgresT m)
deriving newtype instance MonadMask m => MonadMask (RaftPostgresT m)
initRaftPostgresEnv :: MonadIO m => ConnectInfo -> m RaftPostgresEnv
initRaftPostgresEnv connInfo =
RaftPostgresEnv connInfo <$> liftIO (atomically newEmptyTMVar)
runRaftPostgresT :: MonadIO m => ConnectInfo -> RaftPostgresT m a -> m a
runRaftPostgresT connInfo m = do
connTMVar <- liftIO $ atomically newEmptyTMVar
let postgresEnv = RaftPostgresEnv connInfo connTMVar
runReaderT (unRaftPostgresT m) postgresEnv
runRaftPostgresT' :: MonadIO m => RaftPostgresEnv -> RaftPostgresT m a -> m a
runRaftPostgresT' raftPostgresEnv =
flip runReaderT raftPostgresEnv . unRaftPostgresT
type RaftPostgresM = RaftPostgresT IO
runRaftPostgresM :: ConnectInfo -> RaftPostgresM a -> IO a
runRaftPostgresM = runRaftPostgresT
data RaftPostgresError
= RaftPostgresError PGError
| RaftPostgresFailedToConnect PGError
| RaftPostgresSerializeError Text
deriving (Show)
instance Exception (RaftPostgresError)
withRaftPostgresConn
:: MonadIO m
=> (Connection -> IO (Either RaftPostgresError a))
-> RaftPostgresT m (Either RaftPostgresError a)
withRaftPostgresConn f = do
mconn <- liftIO . atomically . tryReadTMVar =<< asks raftPostgresConn
eConn <-
case mconn of
Nothing -> do
eConn <- liftIO . tryPG . connect =<< asks raftPostgresConnInfo
case eConn of
Left err -> pure (Left (RaftPostgresFailedToConnect err))
Right conn -> do
connTMVar <- asks raftPostgresConn
liftIO . atomically $ putTMVar connTMVar conn
pure (Right conn)
Just conn -> pure (Right conn)
case eConn of
Left err -> pure (Left err)
Right conn -> do
eRes <- liftIO (tryPG (f conn))
case eRes of
Left err -> pure (Left (RaftPostgresError err))
Right (Left err) -> pure (Left err)
Right (Right a) -> pure (Right a)
instance (MonadIO m) => RaftInitLog (RaftPostgresT m) v where
type RaftInitLogError (RaftPostgresT m) = RaftPostgresError
initializeLog _ = do
eConn <- liftIO . setupDB =<< asks raftPostgresConnInfo
case eConn of
Left err -> pure (Left (RaftPostgresError err))
Right conn -> do
connTMVar <- asks raftPostgresConn
fmap Right . liftIO . atomically $ putTMVar connTMVar conn
instance (Typeable v, Serialize v, MonadIO m) => RaftReadLog (RaftPostgresT m) v where
type RaftReadLogError (RaftPostgresT m) = RaftPostgresError
readLogEntry idx =
withRaftPostgresConn $ \conn ->
Right . fmap rowTypeToEntry . listToMaybe <$>
query conn "select * from entries where entryIndex = ?" (Only idx)
readLogEntriesFrom idx =
withRaftPostgresConn $ \conn ->
Right . Seq.fromList . map rowTypeToEntry <$>
query conn "select * from entries where entryIndex >= ?" (Only idx)
readLastLogEntry =
withRaftPostgresConn $ \conn ->
Right . fmap rowTypeToEntry . listToMaybe <$>
query_ conn "SELECT * FROM entries ORDER BY entryIndex DESC LIMIT 1"
instance (Serialize v, MonadIO m) => RaftWriteLog (RaftPostgresT m) v where
type RaftWriteLogError (RaftPostgresT m) = RaftPostgresError
writeLogEntries entries =
withRaftPostgresConn $ \conn ->
fmap Right . void $
executeMany conn "INSERT INTO entries VALUES (?,?,?,?,?,?)" (map entryToRowType (toList entries))
instance (Serialize v, MonadIO m) => RaftDeleteLog (RaftPostgresT m) v where
type RaftDeleteLogError (RaftPostgresT m) = RaftPostgresError
deleteLogEntriesFrom idx =
withRaftPostgresConn $ \conn ->
fmap (const $ Right DeleteSuccess) . void $
execute conn "DELETE FROM entries WHERE entryIndex >= ?" (Only idx)
instance RaftPersist m => RaftPersist (RaftPostgresT m) where
type RaftPersistError (RaftPostgresT m) = RaftPersistError m
initializePersistentState = lift initializePersistentState
readPersistentState = lift readPersistentState
writePersistentState = lift . writePersistentState
instance (Monad m, RaftSendRPC m v) => RaftSendRPC (RaftPostgresT m) v where
sendRPC nid msg = lift (sendRPC nid msg)
instance (Monad m, RaftRecvRPC m v) => RaftRecvRPC (RaftPostgresT m) v where
type RaftRecvRPCError (RaftPostgresT m) v = RaftRecvRPCError m v
receiveRPC = lift receiveRPC
instance (Monad m, RaftSendClient m sm v) => RaftSendClient (RaftPostgresT m) sm v where
sendClient clientId = lift . sendClient clientId
instance (Monad m, RaftRecvClient m v) => RaftRecvClient (RaftPostgresT m) v where
type RaftRecvClientError (RaftPostgresT m) v = RaftRecvClientError m v
receiveClient = lift receiveClient
instance RaftStateMachine m sm v => RaftStateMachine (RaftPostgresT m) sm v where
validateCmd = lift . validateCmd
askRaftStateMachinePureCtx = lift askRaftStateMachinePureCtx
instance MonadRaftChan v m => MonadRaftChan v (RaftPostgresT m) where
type RaftEventChan v (RaftPostgresT m) = RaftEventChan v m
readRaftChan = lift . readRaftChan
writeRaftChan chan = lift . writeRaftChan chan
newRaftChan = lift (newRaftChan @v @m)
instance (MonadIO m, MonadRaftFork m) => MonadRaftFork (RaftPostgresT m) where
type RaftThreadId (RaftPostgresT m) = RaftThreadId m
raftFork r m = do
raftPostgresEnv <- ask
lift $ raftFork r (runRaftPostgresT' raftPostgresEnv m)
data EntryRow v = EntryRow
{ entryRowIndex :: Index
, entryRowTerm :: Term
, entryRowValueHash :: EntryHash
, entryRowValue :: EntryValue v
, entryRowIssuer :: EntryIssuer
, entryRowPrevHash :: EntryHash
} deriving (Show, Generic, ToRow, FromRow)
entryToRowType :: Serialize v => Entry v -> EntryRow v
entryToRowType entry@Entry{..} =
EntryRow
{ entryRowIndex = entryIndex
, entryRowTerm = entryTerm
, entryRowValueHash = hashEntry entry
, entryRowValue = entryValue
, entryRowIssuer = entryIssuer
, entryRowPrevHash = entryPrevHash
}
rowTypeToEntry :: Serialize v => EntryRow v -> Entry v
rowTypeToEntry EntryRow{..} = Entry
{ entryIndex = entryRowIndex
, entryTerm = entryRowTerm
, entryValue = entryRowValue
, entryIssuer = entryRowIssuer
, entryPrevHash = entryRowPrevHash
}
data PGError
= PGSqlError SqlError
| PGFormatError FormatError
| PGQueryError QueryError
| PGResultError ResultError
| PGUnexpectedError Text
deriving (Show)
raftDatabasePrefix :: [Char]
raftDatabasePrefix = "libraft_"
raftDatabaseName :: [Char] -> [Char]
raftDatabaseName suffix = raftDatabasePrefix ++ suffix
raftDatabaseConnInfo :: [Char] -> [Char] -> [Char] -> ConnectInfo
raftDatabaseConnInfo usrnm pswd dbSuffix =
defaultConnectInfo {
connectUser = usrnm
, connectPassword = pswd
, connectDatabase = raftDatabaseName dbSuffix
}
initConnInfo :: ConnectInfo
initConnInfo = defaultConnectInfo
{ connectDatabase = "postgres"
, connectUser = "libraft_test"
, connectPassword = "libraft_test"
}
tryPG :: IO a -> IO (Either PGError a)
tryPG action =
catches (Right <$> action)
[ catchSqlError, catchFmtError, catchQueryError, catchResultError, catchAllError ]
where
handler :: Exception e => (e -> IO a) -> Handler IO a
handler = Handler
catchSqlError = handler (pure . Left . PGSqlError)
catchFmtError = handler (pure . Left . PGFormatError)
catchQueryError = handler (pure . Left . PGQueryError)
catchResultError = handler (pure . Left . PGResultError)
catchAllError = handler (pure . Left . PGUnexpectedError . (show :: SomeException -> Text))
createDB :: [Char] -> Connection -> IO Int64
createDB dbName conn =
execute conn "CREATE DATABASE ?" (Only $ Identifier (toS dbName))
deleteDB :: [Char] -> Connection -> IO Int64
deleteDB dbName conn =
execute conn "DROP DATABASE IF EXISTS ?" (Only $ Identifier (toS dbName))
raftSchema :: IsString a => a
raftSchema = $(makeRelativeToProject "postgres/entries.sql" >>= embedStringFile)
createEntriesTable :: Connection -> IO Int64
createEntriesTable conn = execute_ conn raftSchema
setupDB :: ConnectInfo -> IO (Either PGError Connection)
setupDB connInfo = tryPG $ do
bracket (connect initConnInfo) close $ \conn ->
createDB dbName conn
conn <- connect connInfo
createEntriesTable conn
pure conn
where
dbName = connectDatabase connInfo