{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}
module Raft
(
RaftStateMachinePure(..)
, RaftStateMachine(..)
, RaftSendRPC(..)
, RaftRecvRPC(..)
, RaftSendClient(..)
, RaftRecvClient(..)
, RaftPersist(..)
, RaftEnv(..)
, runRaftNode
, runRaftT
, handleEventLoop
, ClientRequest(..)
, ClientReq(..)
, ClientResponse(..)
, ClientReadResp(..)
, ClientWriteResp(..)
, ClientRedirResp(..)
, RaftNodeConfig(..)
, Event(..)
, Timeout(..)
, MessageEvent(..)
, Entry(..)
, Entries
, RaftWriteLog(..)
, DeleteSuccess(..)
, RaftDeleteLog(..)
, RaftReadLog (..)
, RaftLog
, RaftLogError(..)
, RaftLogExceptions(..)
, LogCtx(..)
, LogDest(..)
, Severity(..)
, Mode(..)
, RaftNodeState(..)
, NodeState(..)
, CurrentLeader(..)
, FollowerState(..)
, CandidateState(..)
, LeaderState(..)
, initRaftNodeState
, isFollower
, isCandidate
, isLeader
, setLastLogEntry
, getLastLogEntry
, getLastAppliedAndCommitIndex
, PersistentState(..)
, initPersistentState
, NodeId
, NodeIds
, ClientId(..)
, LeaderId(..)
, Term(..)
, Index(..)
, term0
, index0
, RPC(..)
, RPCType(..)
, RPCMessage(..)
, AppendEntries(..)
, AppendEntriesResponse(..)
, RequestVote(..)
, RequestVoteResponse(..)
, AppendEntriesData(..)
) where
import Protolude hiding (STM, TChan, newRaftChan, readBoundedChan, writeBoundedChan, atomically)
import Control.Concurrent.STM.Timer
import Control.Monad.Fail
import Control.Monad.Catch
import qualified Data.Map as Map
import Data.Serialize (Serialize)
import Data.Sequence (Seq(..), singleton)
import Data.Time.Clock.System (getSystemTime)
import Raft.Action
import Raft.Client
import Raft.Config
import Raft.Event
import Raft.Handle
import Raft.Monad
import Raft.Log
import Raft.Logging hiding (logInfo, logDebug, logCritical, logAndPanic)
import Raft.Transition hiding (logInfo, logDebug)
import Raft.NodeState
import Raft.Persistent
import Raft.RPC
import Raft.StateMachine
import Raft.Types
import qualified Raft.Metrics as Metrics
import qualified System.Remote.Monitoring as EKG
runRaftNode
:: forall m sm v.
( Typeable m, Show v, Show sm, Serialize v, Show (Action sm v), Show (RaftLogError m), Show (RaftStateMachinePureError sm v)
, MonadIO m, MonadCatch m, MonadFail m, MonadMask m, MonadRaft v m
, RaftStateMachine m sm v
, RaftSendRPC m v
, RaftRecvRPC m v
, RaftSendClient m sm v
, RaftRecvClient m v
, RaftLog m v
, RaftLogExceptions m
, RaftPersist m
, Exception (RaftPersistError m)
)
=> RaftNodeConfig
-> OptionalRaftNodeConfig
-> LogCtx (RaftT v m)
-> sm
-> m ()
runRaftNode nodeConfig@RaftNodeConfig{..} optConfig logCtx initStateMachine = do
metricsPort <- liftIO (resolveMetricsPort (raftConfigMetricsPort optConfig))
timerSeed <- liftIO (resolveTimerSeed (raftConfigTimerSeed optConfig))
initializeStorage raftConfigStorageState
eventChan <- newRaftChan @v
electionTimer <- liftIO $ newTimerRange timerSeed raftConfigElectionTimeout
heartbeatTimer <- liftIO $ newTimer raftConfigHeartbeatTimeout
let resetElectionTimer = liftIO $ void $ resetTimer electionTimer
resetHeartbeatTimer = liftIO $ void $ resetTimer heartbeatTimer
raftEnv <- initializeRaftEnv eventChan resetElectionTimer resetHeartbeatTimer nodeConfig logCtx
runRaftT initRaftNodeState raftEnv $ do
case metricsPort of
Nothing -> pure ()
Just port -> do
logInfo ("Forking metrics server on port " <> show port <> "...")
metricsStore <- Metrics.getMetricsStore
void $ liftIO (EKG.forkServerWith metricsStore "localhost" (fromIntegral port))
logInfo ("Initialized election timer with seed " <> show timerSeed <> "...")
raftFork (CustomThreadRole "Election Timeout Timer") . lift $
electionTimeoutTimer @m @v eventChan electionTimer
raftFork (CustomThreadRole "Heartbeat Timeout Timer") . lift $
heartbeatTimeoutTimer @m @v eventChan heartbeatTimer
raftFork RPCHandler (rpcHandler @m @v eventChan)
raftFork ClientRequestHandler (clientReqHandler @m @v eventChan)
handleEventLoop initStateMachine
where
initializeStorage storageState =
case storageState of
New -> do
ipsRes <- initializePersistentState
case ipsRes of
Left err -> throwM err
Right _ -> do
ilRes <- initializeLog (Proxy :: Proxy v)
case ilRes of
Left err -> throwM err
Right _ -> pure ()
Existing -> pure ()
handleEventLoop
:: forall sm v m.
( Show v, Serialize v, Show sm, Show (Action sm v), Show (RaftLogError m), Typeable m
, MonadIO m, MonadRaft v m, MonadFail m, MonadThrow m, MonadMask m
, RaftStateMachine m sm v
, Show (RaftStateMachinePureError sm v)
, RaftPersist m
, RaftSendRPC m v
, RaftSendClient m sm v
, RaftLog m v
, RaftLogExceptions m
, RaftPersist m
, Exception (RaftPersistError m)
)
=> sm
-> RaftT v m ()
handleEventLoop initStateMachine = do
setInitLastLogEntry
ePersistentState <- lift readPersistentState
case ePersistentState of
Left err -> throwM err
Right pstate -> handleEventLoop' initStateMachine pstate
where
withValidatedEvent :: sm -> (Event v -> RaftT v m a) -> RaftT v m (Maybe a)
withValidatedEvent stateMachine f = do
event <- lift . readRaftChan =<< asks eventChan
RaftNodeState raftNodeState <- get
case raftNodeState of
NodeLeaderState _ -> do
case event of
MessageEvent (ClientRequestEvent (ClientRequest cid creq)) ->
case creq of
ClientWriteReq (ClientCmdReq serial cmd) -> do
eRes <- lift (applyLogCmd MonadicValidation stateMachine cmd)
case eRes of
Left err -> do
Metrics.incrInvalidCmdCounter
let clientWriteRespSpec = ClientWriteRespSpec (ClientWriteRespSpecFail serial err)
clientFailRespAction = RespondToClient cid clientWriteRespSpec
handleAction clientFailRespAction
pure Nothing
Right _ -> Just <$> f event
_ -> Just <$> f event
_ -> Just <$> f event
_ -> Just <$> f event
handleEventLoop' :: sm -> PersistentState -> RaftT v m ()
handleEventLoop' stateMachine persistentState = do
Metrics.incrEventsHandledCounter
mRes <-
withValidatedEvent stateMachine $ \event -> do
loadLogEntryTermAtAePrevLogIndex event
raftNodeState@(RaftNodeState nodeState) <- get
Metrics.setNodeStateLabel (nodeMode raftNodeState)
Metrics.setCommitIndexGauge (getCommitIndex nodeState)
logDebug $ "[Event]: " <> show event
logDebug $ "[NodeState]: " <> show raftNodeState
logDebug $ "[State Machine]: " <> show stateMachine
logDebug $ "[Persistent State]: " <> show persistentState
nodeConfig <- asks raftNodeConfig
raftNodeMetrics <- Metrics.getRaftNodeMetrics
let transitionEnv = TransitionEnv nodeConfig stateMachine raftNodeState raftNodeMetrics
pure (Raft.Handle.handleEvent raftNodeState transitionEnv persistentState event)
case mRes of
Nothing -> handleEventLoop' stateMachine persistentState
Just (resRaftNodeState, resPersistentState, actions, logMsgs) -> do
logDebug "Writing PersistentState to disk..."
when (resPersistentState /= persistentState) $ do
eRes <- lift $ writePersistentState resPersistentState
case eRes of
Left err -> throwM err
Right _ -> pure ()
put resRaftNodeState
handleLogs logMsgs
handleActions actions
resRaftStateMachine <- applyLogEntries stateMachine
handleEventLoop' resRaftStateMachine resPersistentState
loadLogEntryTermAtAePrevLogIndex :: Event v -> RaftT v m ()
loadLogEntryTermAtAePrevLogIndex event =
case event of
MessageEvent (RPCMessageEvent (RPCMessage _ (AppendEntriesRPC ae))) -> do
RaftNodeState rns <- get
case rns of
NodeFollowerState fs -> do
eEntry <- lift $ readLogEntry (aePrevLogIndex ae)
case eEntry of
Left err -> throwM err
Right (mEntry :: Maybe (Entry v)) -> put $
RaftNodeState $ NodeFollowerState fs
{ fsTermAtAEPrevIndex = entryTerm <$> mEntry }
_ -> pure ()
_ -> pure ()
setInitLastLogEntry :: RaftT v m ()
setInitLastLogEntry = do
RaftNodeState rns <- get
eLogEntry <- lift readLastLogEntry
case eLogEntry of
Left err -> throwM err
Right Nothing -> pure ()
Right (Just e) ->
put (RaftNodeState (setLastLogEntry rns (singleton e)))
handleActions
:: ( Show v, Show sm, Show (Action sm v), Show (RaftLogError m), Typeable m
, MonadIO m, MonadRaft v m, MonadThrow m, MonadMask m
, RaftStateMachine m sm v
, RaftSendRPC m v
, RaftSendClient m sm v
, RaftLog m v
, RaftLogExceptions m
)
=> [Action sm v]
-> RaftT v m ()
handleActions actions = do
mapM_ handleAction actions
handleAction
:: forall sm v m.
( Show v, Show sm, Show (Action sm v), Show (RaftLogError m), Typeable m
, MonadIO m, MonadRaft v m, MonadThrow m
, RaftStateMachine m sm v
, RaftSendRPC m v
, RaftSendClient m sm v
, RaftLog m v
, RaftLogExceptions m
)
=> Action sm v
-> RaftT v m ()
handleAction action = do
logDebug $ "Handling [Action]: " <> show action
case action of
SendRPC nid sendRpcAction -> do
rpcMsg <- mkRPCfromSendRPCAction sendRpcAction
sendRPCThread nid rpcMsg
SendRPCs rpcMap ->
forM_ (Map.toList rpcMap) $ \(nid, sendRpcAction) -> do
rpcMsg <- mkRPCfromSendRPCAction sendRpcAction
sendRPCThread nid rpcMsg
BroadcastRPC nids sendRpcAction -> do
rpcMsg <- mkRPCfromSendRPCAction sendRpcAction
let sendRPC' = lift . flip sendRPC rpcMsg
forM_ nids $ \nid ->
raftFork (CustomThreadRole "RPC Broadcast") (sendRPC' nid)
RespondToClient cid cr -> respondToClient cid cr
ResetTimeoutTimer tout -> do
case tout of
ElectionTimeout -> lift . resetElectionTimer =<< ask
HeartbeatTimeout -> lift . resetHeartbeatTimer =<< ask
AppendLogEntries entries -> do
eRes <- lift (updateLog entries)
case eRes of
Left err -> logAndPanic (show err)
Right midx -> do
case midx of
Nothing -> pure ()
Just idx -> Metrics.setLastLogEntryIndexGauge idx
modify $ \(RaftNodeState ns) ->
RaftNodeState (setLastLogEntry ns entries)
UpdateClientReqCacheFrom idx -> do
RaftNodeState ns <- get
case ns of
NodeLeaderState ls@LeaderState{..} -> do
eentries <- lift (readLogEntriesFrom idx)
case eentries of
Left err -> throwM err
Right (entries :: Entries v) -> do
let committedClientReqs = clientReqData entries
when (Map.size committedClientReqs > 0) $ do
mapM_ respondToClientWrite (Map.toList committedClientReqs)
let creqMap = Map.map (second Just) committedClientReqs
put $ RaftNodeState $ NodeLeaderState
ls { lsClientReqCache = creqMap `Map.union` lsClientReqCache }
_ -> logAndPanic "Only the leader should update the client request cache..."
where
sendRPCThread :: NodeId -> RPCMessage v -> RaftT v m ()
sendRPCThread nid rpcMsg =
void (raftFork (CustomThreadRole "Send RPC") (lift (sendRPC nid rpcMsg)))
respondToClientWrite :: (ClientId, (SerialNum, Index)) -> RaftT v m ()
respondToClientWrite (cid, (sn,idx)) = do
let clientWriteRespSpec =
ClientWriteRespSpec @sm (ClientWriteRespSpecSuccess idx sn)
respondToClient cid clientWriteRespSpec
respondToClient :: ClientId -> ClientRespSpec sm v -> RaftT v m ()
respondToClient cid crs = do
void $ raftFork (CustomThreadRole "Respond to Client") $ do
clientResp <- mkClientResp crs
lift (sendClient cid clientResp)
mkClientResp :: ClientRespSpec sm v -> RaftT v m (ClientResponse sm v)
mkClientResp crs =
case crs of
ClientReadRespSpec crrs ->
ClientReadResponse <$>
case crrs of
ClientReadRespSpecEntries res -> do
eRes <- lift (readEntries res)
case eRes of
Left err -> throwM err
Right res ->
case res of
OneEntry e -> pure (ClientReadRespEntry e)
ManyEntries es -> pure (ClientReadRespEntries es)
ClientReadRespSpecStateMachine sm ->
pure (ClientReadRespStateMachine sm)
ClientWriteRespSpec cwrs ->
ClientWriteResponse <$>
case cwrs of
ClientWriteRespSpecSuccess idx sn ->
pure (ClientWriteRespSuccess idx sn)
ClientWriteRespSpecFail sn err ->
pure (ClientWriteRespFail sn err)
ClientRedirRespSpec cl ->
pure (ClientRedirectResponse (ClientRedirResp cl))
ClientMetricsRespSpec rnm ->
pure (ClientMetricsResponse (ClientMetricsResp rnm))
mkRPCfromSendRPCAction :: SendRPCAction v -> RaftT v m (RPCMessage v)
mkRPCfromSendRPCAction sendRPCAction = do
RaftNodeState ns <- get
nodeConfig <- asks raftNodeConfig
RPCMessage (raftConfigNodeId nodeConfig) <$>
case sendRPCAction of
SendAppendEntriesRPC aeData -> do
(entries, prevLogIndex, prevLogTerm, aeReadReq) <-
case aedEntriesSpec aeData of
FromIndex idx -> attachNothing <$> mkPrevLogEntryDataByIdx idx
FromClientWriteReq e -> attachNothing <$> mkPrevEntryDataByEntry e
FromNewLeader e -> attachNothing <$> mkPrevEntryDataByEntry e
NoEntries spec -> do
let readReq =
case spec of
FromClientReadReq n -> Just n
FromHeartbeat -> Nothing
(prevLogIdx, prevLogTerm) <-
case getLastLogEntry ns of
NoLogEntries -> pure (index0, term0)
LastLogEntry e -> do
(_,prevLogIdx, prevLogTerm) <- mkPrevEntryDataByEntry e
pure (prevLogIdx, prevLogTerm)
pure (Empty, prevLogIdx, prevLogTerm, readReq)
let leaderId = LeaderId (raftConfigNodeId nodeConfig)
pure . toRPC $
AppendEntries
{ aeTerm = aedTerm aeData
, aeLeaderId = leaderId
, aePrevLogIndex = prevLogIndex
, aePrevLogTerm = prevLogTerm
, aeEntries = entries
, aeLeaderCommit = aedLeaderCommit aeData
, aeReadRequest = aeReadReq
}
SendAppendEntriesResponseRPC aer -> pure (toRPC aer)
SendRequestVoteRPC rv -> pure (toRPC rv)
SendRequestVoteResponseRPC rvr -> pure (toRPC rvr)
attachNothing (x,y,z) = (x,y,z,Nothing)
mkPrevLogEntryDataByIdx idx = do
eLogEntries <- lift (readLogEntriesFrom (decrIndexWithDefault0 idx))
case eLogEntries of
Left err -> throwM err
Right log ->
case log of
pe :<| entries@(e :<| _)
| idx == 1 -> pure (log, index0, term0)
| otherwise -> pure (entries, entryIndex pe, entryTerm pe)
_ -> pure (log, index0, term0)
mkPrevEntryDataByEntry e
| entryIndex e == Index 1 = pure (singleton e, index0, term0)
| otherwise = do
let prevLogEntryIdx = decrIndexWithDefault0 (entryIndex e)
eLogEntry <- lift $ readLogEntry prevLogEntryIdx
case eLogEntry of
Left err -> throwM err
Right Nothing -> pure (singleton e, index0, term0)
Right (Just (prevEntry :: Entry v)) ->
pure (singleton e, entryIndex prevEntry, entryTerm prevEntry)
applyLogEntries
:: forall sm m v.
( Show v, Show sm, Show (RaftStateMachinePureError sm v)
, MonadIO m, MonadThrow m, MonadRaft v m
, RaftReadLog m v, Exception (RaftReadLogError m)
, RaftStateMachine m sm v
)
=> sm
-> RaftT v m sm
applyLogEntries stateMachine = do
raftNodeState@(RaftNodeState nodeState) <- get
if commitIndex nodeState > lastApplied nodeState
then do
let resNodeState = incrLastApplied nodeState
put $ RaftNodeState resNodeState
let newLastAppliedIndex = lastApplied resNodeState
eLogEntry <- lift $ readLogEntry newLastAppliedIndex
case eLogEntry of
Left err -> throwM err
Right Nothing ->
logAndPanic . mconcat $
[ "No log entry at 'newLastAppliedIndex := "
, show newLastAppliedIndex <> "'"
]
Right (Just logEntry) -> do
logDebug $ "[Applying Log Entry]: " <> show logEntry
eRes <- lift (applyLogEntry NoMonadicValidation stateMachine logEntry)
case eRes of
Left err -> logAndPanic $ "Failed to apply committed log entry: " <> show err
Right nsm -> applyLogEntries nsm
else pure stateMachine
where
incrLastApplied :: NodeState ns v -> NodeState ns v
incrLastApplied nodeState =
case nodeState of
NodeFollowerState fs ->
let lastApplied' = incrIndex (fsLastApplied fs)
in NodeFollowerState $ fs { fsLastApplied = lastApplied' }
NodeCandidateState cs ->
let lastApplied' = incrIndex (csLastApplied cs)
in NodeCandidateState $ cs { csLastApplied = lastApplied' }
NodeLeaderState ls ->
let lastApplied' = incrIndex (lsLastApplied ls)
in NodeLeaderState $ ls { lsLastApplied = lastApplied' }
lastApplied :: NodeState ns v -> Index
lastApplied = fst . getLastAppliedAndCommitIndex
commitIndex :: NodeState ns v -> Index
commitIndex = snd . getLastAppliedAndCommitIndex
handleLogs
:: (MonadIO m, MonadRaft v m)
=> [LogMsg]
-> RaftT v m ()
handleLogs logs = do
logCtx <- asks raftNodeLogCtx
mapM_ (logToDest logCtx) logs
rpcHandler
:: forall m v. (MonadIO m, MonadRaft v m, MonadCatch m, Show v, RaftRecvRPC m v)
=> RaftEventChan v m
-> RaftT v m ()
rpcHandler eventChan =
forever $ do
eRpcMsg <- lift $ Control.Monad.Catch.try receiveRPC
case eRpcMsg of
Left (err :: SomeException) -> logCritical (show err)
Right (Left err) -> logCritical (show err)
Right (Right rpcMsg) -> do
let rpcMsgEvent = MessageEvent (RPCMessageEvent rpcMsg)
lift $ writeRaftChan @v @m eventChan rpcMsgEvent
clientReqHandler
:: forall m v. (MonadIO m, MonadRaft v m, MonadCatch m, RaftRecvClient m v)
=> RaftEventChan v m
-> RaftT v m ()
clientReqHandler eventChan =
forever $ do
eClientReq <- lift $ Control.Monad.Catch.try receiveClient
case eClientReq of
Left (err :: SomeException) -> logCritical (show err)
Right (Left err) -> logCritical (show err)
Right (Right clientReq) -> do
let clientReqEvent = MessageEvent (ClientRequestEvent clientReq)
lift $ writeRaftChan @v @m eventChan clientReqEvent
electionTimeoutTimer :: forall m v. (MonadIO m, MonadRaft v m) => RaftEventChan v m -> Timer -> m ()
electionTimeoutTimer eventChan timer =
forever $ do
success <- liftIO $ startTimer timer
when (not success) $
panic "[Failed invariant]: Election timeout timer failed to start."
liftIO $ waitTimer timer
writeTimeoutEvent @m @v eventChan ElectionTimeout
heartbeatTimeoutTimer :: forall m v. (MonadIO m, MonadRaft v m) => RaftEventChan v m -> Timer -> m ()
heartbeatTimeoutTimer eventChan timer =
forever $ do
success <- liftIO $ startTimer timer
when (not success) $
panic "[Failed invariant]: Heartbeat timeout timer failed to start."
liftIO $ waitTimer timer
writeTimeoutEvent @m @v eventChan HeartbeatTimeout
writeTimeoutEvent :: forall m v. (MonadIO m , MonadRaft v m) => RaftEventChan v m -> Timeout -> m ()
writeTimeoutEvent eventChan timeout = do
now <- liftIO getSystemTime
writeRaftChan @v @m eventChan (TimeoutEvent now timeout)