{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
module Raft.Client
(
RaftSendClient(..)
, RaftRecvClient(..)
, SerialNum(..)
, ClientRequest(..)
, ClientReq(..)
, ClientReqType(..)
, ClientReadReq(..)
, ReadEntriesSpec(..)
, ClientWriteReq(..)
, ClientMetricsReq(..)
, ClientResponse(..)
, ClientRespSpec(..)
, ClientReadRespSpec(..)
, ClientWriteRespSpec(..)
, ClientReadResp(..)
, ClientWriteResp(..)
, ClientRedirResp(..)
, ClientMetricsResp(..)
, RaftClientSend(..)
, RaftClientRecv(..)
, RaftClientState(..)
, RaftClientEnv(..)
, initRaftClientState
, RaftClientT
, runRaftClientT
, RaftClientError(..)
, clientRead
, clientReadFrom
, clientReadTimeout
, clientWrite
, clientWriteTo
, clientWriteTimeout
, clientQueryNodeMetrics
, clientQueryNodeMetricsTimeout
, clientSendRead
, clientSendWrite
, clientSendMetricsReqTo
, clientSend
, clientRecv
, clientTimeout
, retryOnRedirect
, clientAddNode
, clientGetNodes
) where
import Protolude hiding (threadDelay, STM, )
import Control.Concurrent.Lifted (threadDelay)
import Control.Monad.Base
import Control.Monad.Catch
import Control.Monad.Fail
import Control.Monad.Trans.Class
import Control.Monad.Trans.Control
import qualified Data.Set as Set
import qualified Data.Serialize as S
import System.Random
import System.Console.Haskeline.MonadException (MonadException(..), RunIO(..))
import System.Timeout.Lifted (timeout)
import Raft.Log (Entry, Entries, ReadEntriesSpec)
import Raft.StateMachine
import Raft.Metrics (RaftNodeMetrics)
import Raft.Types
class RaftStateMachinePure sm v => RaftSendClient m sm v where
sendClient :: ClientId -> ClientResponse sm v -> m ()
class Show (RaftRecvClientError m v) => RaftRecvClient m v where
type RaftRecvClientError m v
receiveClient :: m (Either (RaftRecvClientError m v) (ClientRequest v))
data ClientRequest v
= ClientRequest ClientId (ClientReq v)
deriving (Show, Generic)
instance S.Serialize v => S.Serialize (ClientRequest v)
data ClientReq v
= ClientReadReq ClientReadReq
| ClientWriteReq (ClientWriteReq v)
| ClientMetricsReq ClientMetricsReq
deriving (Show, Generic)
instance S.Serialize v => S.Serialize (ClientReq v)
data ClientReadReq
= ClientReadEntries ReadEntriesSpec
| ClientReadStateMachine
deriving (Show, Generic, S.Serialize)
data ClientWriteReq v
= ClientCmdReq SerialNum v
deriving (Show, Generic, S.Serialize)
data ClientMetricsReq
= ClientAllMetricsReq
deriving (Show, Generic, S.Serialize)
class ClientReqType a v
instance ClientReqType ClientReadReq v
instance ClientReqType (ClientWriteReq v) v
instance ClientReqType ClientMetricsReq v
data ClientRespSpec sm v
= ClientReadRespSpec (ClientReadRespSpec sm)
| ClientWriteRespSpec (ClientWriteRespSpec sm v)
| ClientRedirRespSpec CurrentLeader
| ClientMetricsRespSpec RaftNodeMetrics
deriving (Generic)
deriving instance (Show sm, Show v, Show (RaftStateMachinePureError sm v)) => Show (ClientRespSpec sm v)
deriving instance (S.Serialize sm, S.Serialize v, S.Serialize (RaftStateMachinePureError sm v)) => S.Serialize (ClientRespSpec sm v)
data ClientReadRespSpec sm
= ClientReadRespSpecEntries ReadEntriesSpec
| ClientReadRespSpecStateMachine sm
deriving (Show, Generic, S.Serialize)
data ClientWriteRespSpec sm v
= ClientWriteRespSpecSuccess Index SerialNum
| ClientWriteRespSpecFail SerialNum (RaftStateMachinePureError sm v)
deriving (Generic)
deriving instance (Show sm, Show v, Show (RaftStateMachinePureError sm v)) => Show (ClientWriteRespSpec sm v)
deriving instance (S.Serialize sm, S.Serialize v, S.Serialize (RaftStateMachinePureError sm v)) => S.Serialize (ClientWriteRespSpec sm v)
data ClientResponse sm v
= ClientReadResponse (ClientReadResp sm v)
| ClientWriteResponse (ClientWriteResp sm v)
| ClientRedirectResponse ClientRedirResp
| ClientMetricsResponse ClientMetricsResp
deriving (Generic)
deriving instance (Show sm, Show v, Show (ClientWriteResp sm v)) => Show (ClientResponse sm v)
deriving instance (S.Serialize sm, S.Serialize v, S.Serialize (ClientWriteResp sm v)) => S.Serialize (ClientResponse sm v)
data ClientReadResp sm v
= ClientReadRespStateMachine sm
| ClientReadRespEntry (Entry v)
| ClientReadRespEntries (Entries v)
deriving (Show, Generic, S.Serialize)
data ClientWriteResp sm v
= ClientWriteRespSuccess Index SerialNum
| ClientWriteRespFail SerialNum (RaftStateMachinePureError sm v)
deriving (Generic)
deriving instance (Show sm, Show v, Show (RaftStateMachinePureError sm v)) => Show (ClientWriteResp sm v)
deriving instance (S.Serialize sm, S.Serialize v, S.Serialize (RaftStateMachinePureError sm v)) => S.Serialize (ClientWriteResp sm v)
data ClientRedirResp
= ClientRedirResp CurrentLeader
deriving (Show, Generic, S.Serialize)
data ClientMetricsResp
= ClientMetricsResp RaftNodeMetrics
deriving (Show, Generic, S.Serialize)
class Monad m => RaftClientSend m v where
type RaftClientSendError m v
raftClientSend :: NodeId -> ClientRequest v -> m (Either (RaftClientSendError m v) ())
class Monad m => RaftClientRecv m sm v | m sm -> v where
type RaftClientRecvError m sm
raftClientRecv :: m (Either (RaftClientRecvError m sm) (ClientResponse sm v))
data RaftClientState = RaftClientState
{ raftClientCurrentLeader :: CurrentLeader
, raftClientSerialNum :: SerialNum
, raftClientRaftNodes :: Set NodeId
, raftClientRandomGen :: StdGen
}
data RaftClientEnv = RaftClientEnv
{ raftClientId :: ClientId
}
initRaftClientState :: Set NodeId -> StdGen -> RaftClientState
initRaftClientState = RaftClientState NoLeader 0
newtype RaftClientT s v m a = RaftClientT
{ unRaftClientT :: ReaderT RaftClientEnv (StateT RaftClientState m) a
} deriving newtype (Functor, Applicative, Monad, MonadIO, MonadState RaftClientState, MonadReader RaftClientEnv, MonadFail, Alternative, MonadPlus)
deriving newtype instance MonadThrow m => MonadThrow (RaftClientT s v m)
deriving newtype instance MonadCatch m => MonadCatch (RaftClientT s v m)
deriving newtype instance MonadMask m => MonadMask (RaftClientT s v m)
instance MonadTrans (RaftClientT s v) where
lift = RaftClientT . lift . lift
deriving newtype instance MonadBase IO m => MonadBase IO (RaftClientT s v m)
instance MonadTransControl (RaftClientT s v) where
type StT (RaftClientT s v) a = StT (ReaderT RaftClientEnv) (StT (StateT RaftClientState) a)
liftWith = defaultLiftWith2 RaftClientT unRaftClientT
restoreT = defaultRestoreT2 RaftClientT
instance (MonadBaseControl IO m) => MonadBaseControl IO (RaftClientT s v m) where
type StM (RaftClientT s v m) a = ComposeSt (RaftClientT s v) m a
liftBaseWith = defaultLiftBaseWith
restoreM = defaultRestoreM
instance MonadException m => MonadException (RaftClientT s v m) where
controlIO f =
RaftClientT $ ReaderT $ \r -> StateT $ \s ->
controlIO $ \(RunIO run) ->
let run' = RunIO (fmap (RaftClientT . ReaderT . const . StateT . const) . run . flip runStateT s . flip runReaderT r . unRaftClientT)
in fmap (flip runStateT s . flip runReaderT r . unRaftClientT) $ f run'
instance RaftClientSend m v => RaftClientSend (RaftClientT s v m) v where
type RaftClientSendError (RaftClientT s v m) v = RaftClientSendError m v
raftClientSend nid creq = lift (raftClientSend nid creq)
instance RaftClientRecv m s v => RaftClientRecv (RaftClientT s v m) s v where
type RaftClientRecvError (RaftClientT s v m) s = RaftClientRecvError m s
raftClientRecv = lift raftClientRecv
runRaftClientT :: Monad m => RaftClientEnv -> RaftClientState -> RaftClientT s v m a -> m a
runRaftClientT raftClientEnv raftClientState =
flip evalStateT raftClientState . flip runReaderT raftClientEnv . unRaftClientT
data RaftClientError s v m where
RaftClientSendError :: RaftClientSendError m v -> RaftClientError s v m
RaftClientRecvError :: RaftClientRecvError m s -> RaftClientError s v m
RaftClientTimeout :: Text -> RaftClientError s v m
RaftClientUnexpectedReadResp :: ClientReadResp s v -> RaftClientError s v m
RaftClientUnexpectedWriteResp :: ClientWriteResp s v -> RaftClientError s v m
RaftClientUnexpectedMetricsResp :: ClientMetricsResp -> RaftClientError s v m
RaftClientUnexpectedRedirect :: ClientRedirResp -> RaftClientError s v m
deriving instance (Show s, Show v, Show (RaftClientSendError m v), Show (RaftClientRecvError m s), Show (RaftStateMachinePureError s v)) => Show (RaftClientError s v m)
clientRead
:: (RaftClientSend m v, RaftClientRecv m s v)
=> ClientReadReq
-> RaftClientT s v m (Either (RaftClientError s v m) (ClientReadResp s v))
clientRead crr = do
eSend <- clientSendRead crr
case eSend of
Left err -> pure (Left (RaftClientSendError err))
Right _ -> clientRecvRead
clientReadFrom
:: (RaftClientSend m v, RaftClientRecv m s v)
=> NodeId
-> ClientReadReq
-> RaftClientT s v m (Either (RaftClientError s v m) (ClientReadResp s v))
clientReadFrom nid crr = do
eSend <- clientSendReadTo nid crr
case eSend of
Left err -> pure (Left (RaftClientSendError err))
Right _ -> clientRecvRead
clientReadTimeout
:: (MonadBaseControl IO m, RaftClientSend m v, RaftClientRecv m s v)
=> Int
-> ClientReadReq
-> RaftClientT s v m (Either (RaftClientError s v m) (ClientReadResp s v))
clientReadTimeout t = clientTimeout "clientRead" t . clientRead
clientWrite
:: (RaftClientSend m v, RaftClientRecv m s v)
=> v
-> RaftClientT s v m (Either (RaftClientError s v m) (ClientWriteResp s v))
clientWrite cmd = do
eSend <- clientSendWrite cmd
case eSend of
Left err -> pure (Left (RaftClientSendError err))
Right _ -> clientRecvWrite
clientWriteTo
:: (RaftClientSend m v, RaftClientRecv m s v)
=> NodeId
-> v
-> RaftClientT s v m (Either (RaftClientError s v m) (ClientWriteResp s v))
clientWriteTo nid cmd = do
eSend <- clientSendWriteTo nid cmd
case eSend of
Left err -> pure (Left (RaftClientSendError err))
Right _ -> clientRecvWrite
clientWriteTimeout
:: (MonadBaseControl IO m, RaftClientSend m v, RaftClientRecv m s v)
=> Int
-> v
-> RaftClientT s v m (Either (RaftClientError s v m) (ClientWriteResp s v))
clientWriteTimeout t cmd = clientTimeout "clientWrite" t (clientWrite cmd)
clientQueryNodeMetrics
:: (MonadBaseControl IO m, RaftClientSend m v, RaftClientRecv m s v)
=> NodeId
-> RaftClientT s v m (Either (RaftClientError s v m) ClientMetricsResp)
clientQueryNodeMetrics nid = do
eQueryMetrics <- clientSendMetricsReqTo nid
case eQueryMetrics of
Left err -> pure (Left (RaftClientSendError err))
Right () -> clientRecvMetrics
clientQueryNodeMetricsTimeout
:: (MonadBaseControl IO m, RaftClientSend m v, RaftClientRecv m s v)
=> Int
-> NodeId
-> RaftClientT s v m (Either (RaftClientError s v m) ClientMetricsResp)
clientQueryNodeMetricsTimeout t nid =
clientTimeout "clientQueryNodeMetrics" t (clientQueryNodeMetrics nid)
clientTimeout
:: (MonadBaseControl IO m, RaftClientSend m v, RaftClientRecv m s v)
=> Text
-> Int
-> RaftClientT s v m (Either (RaftClientError s v m) r)
-> RaftClientT s v m (Either (RaftClientError s v m) r)
clientTimeout fnm t r = do
mRes <- timeout t r
case mRes of
Nothing -> pure (Left (RaftClientTimeout fnm))
Just (Left err) -> pure (Left err)
Just (Right cresp) -> pure (Right cresp)
retryOnRedirect
:: MonadBaseControl IO m
=> RaftClientT s v m (Either (RaftClientError s v m) r)
-> RaftClientT s v m (Either (RaftClientError s v m) r)
retryOnRedirect action = do
eRes <- action
case eRes of
Left (RaftClientUnexpectedRedirect _) -> do
threadDelay 10000
retryOnRedirect action
Left err -> pure (Left err)
Right resp -> pure (Right resp)
clientSendRead
:: RaftClientSend m v
=> ClientReadReq
-> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendRead crr =
clientSend (ClientReadReq crr)
clientSendReadTo
:: RaftClientSend m v
=> NodeId
-> ClientReadReq
-> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendReadTo nid crr =
clientSendTo nid (ClientReadReq crr)
clientSendWrite
:: RaftClientSend m v
=> v
-> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendWrite v = do
gets raftClientSerialNum >>= \sn ->
clientSend (ClientWriteReq (ClientCmdReq sn v))
clientSendWriteTo
:: RaftClientSend m v
=> NodeId
-> v
-> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendWriteTo nid v =
gets raftClientSerialNum >>= \sn ->
clientSendTo nid (ClientWriteReq (ClientCmdReq sn v))
clientSendMetricsReqTo
:: RaftClientSend m v
=> NodeId
-> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendMetricsReqTo nid =
clientSendTo nid (ClientMetricsReq ClientAllMetricsReq)
clientSend
:: (RaftClientSend m v)
=> ClientReq v
-> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSend creq = do
currLeader <- gets raftClientCurrentLeader
case currLeader of
NoLeader -> clientSendRandom creq
CurrentLeader (LeaderId nid) -> do
cid <- asks raftClientId
eRes <- raftClientSend nid (ClientRequest cid creq)
case eRes of
Left err -> clientSendRandom creq
Right resp -> pure (Right resp)
clientSendTo
:: RaftClientSend m v
=> NodeId
-> ClientReq v
-> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendTo nid creq = do
cid <- asks raftClientId
raftClientSend nid (ClientRequest cid creq)
clientSendRandom
:: RaftClientSend m v
=> ClientReq v
-> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendRandom creq = do
cid <- asks raftClientId
raftNodes <- gets raftClientRaftNodes
randomGen <- gets raftClientRandomGen
let (idx, newRandomGen) = randomR (0, length raftNodes - 1) randomGen
case atMay (toList raftNodes) idx of
Nothing -> panic "No raft nodes known by client"
Just nid -> do
modify $ \s -> s { raftClientRandomGen = newRandomGen }
raftClientSend nid (ClientRequest cid creq)
clientRecvWrite
:: (RaftClientSend m v, RaftClientRecv m s v)
=> RaftClientT s v m (Either (RaftClientError s v m) (ClientWriteResp s v))
clientRecvWrite = do
eRes <- clientRecv
case eRes of
Left err -> pure (Left (RaftClientRecvError err))
Right cresp ->
case cresp of
ClientRedirectResponse crr -> pure (Left (RaftClientUnexpectedRedirect crr))
ClientReadResponse crr -> pure (Left (RaftClientUnexpectedReadResp crr))
ClientMetricsResponse cmr -> pure (Left (RaftClientUnexpectedMetricsResp cmr))
ClientWriteResponse cwr -> pure (Right cwr)
clientRecvRead
:: (RaftClientSend m v, RaftClientRecv m s v)
=> RaftClientT s v m (Either (RaftClientError s v m) (ClientReadResp s v))
clientRecvRead = do
eRes <- clientRecv
case eRes of
Left err -> pure (Left (RaftClientRecvError err))
Right cresp -> do
case cresp of
ClientRedirectResponse crr -> pure (Left (RaftClientUnexpectedRedirect crr))
ClientWriteResponse cwr -> pure (Left (RaftClientUnexpectedWriteResp cwr))
ClientMetricsResponse cmr -> pure (Left (RaftClientUnexpectedMetricsResp cmr))
ClientReadResponse crr -> pure (Right crr)
clientRecvMetrics
:: (RaftClientSend m v, RaftClientRecv m s v)
=> RaftClientT s v m (Either (RaftClientError s v m) ClientMetricsResp)
clientRecvMetrics = do
eRes <- clientRecv
case eRes of
Left err -> pure (Left (RaftClientRecvError err))
Right cresp ->
case cresp of
ClientRedirectResponse crr -> pure (Left (RaftClientUnexpectedRedirect crr))
ClientWriteResponse cwr -> pure (Left (RaftClientUnexpectedWriteResp cwr))
ClientReadResponse crr -> pure (Left (RaftClientUnexpectedReadResp crr))
ClientMetricsResponse cmr -> pure (Right cmr)
clientRecv
:: RaftClientRecv m s v
=> RaftClientT s v m (Either (RaftClientRecvError m s) (ClientResponse s v))
clientRecv = do
ecresp <- raftClientRecv
case ecresp of
Left err -> pure (Left err)
Right cresp ->
case cresp of
ClientWriteResponse cwr ->
case cwr of
ClientWriteRespSuccess _ serial@(SerialNum n) -> do
SerialNum m <- gets raftClientSerialNum
if m == n
then do
modify $ \s -> s
{ raftClientSerialNum = SerialNum (succ m) }
pure (Right cresp)
else handleLowerSerialNum serial
ClientWriteRespFail serial@(SerialNum n) err -> do
SerialNum m <- gets raftClientSerialNum
if m == n
then pure (Right cresp)
else handleLowerSerialNum serial
ClientRedirectResponse (ClientRedirResp currLdr) -> do
modify $ \s -> s
{ raftClientCurrentLeader = currLdr }
pure (Right cresp)
_ -> pure (Right cresp)
where
handleLowerSerialNum (SerialNum n) = do
SerialNum m <- gets raftClientSerialNum
if n < m
then clientRecv
else do
let errMsg = "Received invalid serial number response: Expected " <> show m <> " but got " <> show n
panic $ errMsg
clientAddNode :: Monad m => NodeId -> RaftClientT s v m ()
clientAddNode nid = modify $ \s ->
s { raftClientRaftNodes = Set.insert nid (raftClientRaftNodes s) }
clientGetNodes :: Monad m => RaftClientT s v m (Set NodeId)
clientGetNodes = gets raftClientRaftNodes