{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}
module Raft
(
RaftStateMachinePure(..)
, RaftStateMachine(..)
, RaftSendRPC(..)
, RaftRecvRPC(..)
, RaftSendClient(..)
, RaftRecvClient(..)
, RaftPersist(..)
, RaftEnv(..)
, runRaftNode
, runRaftT
, handleEventLoop
, ClientRequest(..)
, ClientReq(..)
, ClientResponse(..)
, ClientReadResp(..)
, ClientWriteResp(..)
, ClientRedirResp(..)
, RaftNodeConfig(..)
, Event(..)
, Timeout(..)
, MessageEvent(..)
, Entry(..)
, Entries
, RaftWriteLog(..)
, DeleteSuccess(..)
, RaftDeleteLog(..)
, RaftReadLog (..)
, RaftLog
, RaftLogError(..)
, RaftLogExceptions(..)
, LogCtx(..)
, LogDest(..)
, Severity(..)
, Mode(..)
, RaftNodeState(..)
, NodeState(..)
, CurrentLeader(..)
, FollowerState(..)
, CandidateState(..)
, LeaderState(..)
, initRaftNodeState
, isFollower
, isCandidate
, isLeader
, setLastLogEntry
, getLastLogEntry
, getLastAppliedAndCommitIndex
, PersistentState(..)
, initPersistentState
, NodeId
, NodeIds
, ClientId(..)
, LeaderId(..)
, Term(..)
, Index(..)
, term0
, index0
, RPC(..)
, RPCType(..)
, RPCMessage(..)
, AppendEntries(..)
, AppendEntriesResponse(..)
, RequestVote(..)
, RequestVoteResponse(..)
, AppendEntriesData(..)
) where
import Protolude hiding (STM, TChan, newRaftChan, readBoundedChan, writeBoundedChan, atomically)
import Control.Concurrent.STM.Timer
import Control.Monad.Fail
import Control.Monad.Catch
import qualified Data.Map as Map
import Data.Serialize (Serialize)
import Data.Sequence (Seq(..), singleton)
import Data.Time.Clock.System (getSystemTime)
import Raft.Action
import Raft.Client
import Raft.Config
import Raft.Event
import Raft.Handle
import Raft.Monad
import Raft.Log
import Raft.Logging hiding (logInfo, logDebug, logCritical, logAndPanic)
import Raft.Transition hiding (logInfo, logDebug)
import Raft.NodeState
import Raft.Persistent
import Raft.RPC
import Raft.StateMachine
import Raft.Types
runRaftNode
:: forall m sm v.
( Typeable m, Show v, Show sm, Serialize v, Show (Action sm v), Show (RaftLogError m), Show (RaftStateMachinePureError sm v)
, MonadIO m, MonadCatch m, MonadFail m, MonadRaft v m
, RaftStateMachine m sm v
, RaftSendRPC m v
, RaftRecvRPC m v
, RaftSendClient m sm v
, RaftRecvClient m v
, RaftLog m v
, RaftLogExceptions m
, RaftPersist m
, Exception (RaftPersistError m)
)
=> RaftNodeConfig
-> LogCtx (RaftT v m)
-> Int
-> sm
-> m ()
runRaftNode nodeConfig@RaftNodeConfig{..} logCtx timerSeed initRaftStateMachine = do
initializeStorage
eventChan <- newRaftChan @v
electionTimer <- liftIO $ newTimerRange timerSeed configElectionTimeout
heartbeatTimer <- liftIO $ newTimer configHeartbeatTimeout
let resetElectionTimer = liftIO $ void $ resetTimer electionTimer
resetHeartbeatTimer = liftIO $ void $ resetTimer heartbeatTimer
let raftEnv = RaftEnv eventChan resetElectionTimer resetHeartbeatTimer nodeConfig logCtx
runRaftT initRaftNodeState raftEnv $ do
raftFork (CustomThreadRole "Election Timeout Timer") . lift $
electionTimeoutTimer @m @v eventChan electionTimer
raftFork (CustomThreadRole "Heartbeat Timeout Timer") . lift $
heartbeatTimeoutTimer @m @v eventChan heartbeatTimer
raftFork RPCHandler (rpcHandler @m @v eventChan)
raftFork ClientRequestHandler (clientReqHandler @m @v eventChan)
handleEventLoop initRaftStateMachine
where
initializeStorage =
case configStorageState of
New -> do
ipsRes <- initializePersistentState
case ipsRes of
Left err -> throwM err
Right _ -> do
ilRes <- initializeLog (Proxy :: Proxy v)
case ilRes of
Left err -> throwM err
Right _ -> pure ()
Existing -> pure ()
handleEventLoop
:: forall sm v m.
( Show v, Serialize v, Show sm, Show (Action sm v), Show (RaftLogError m), Typeable m
, MonadIO m, MonadRaft v m, MonadFail m, MonadThrow m
, RaftStateMachine m sm v
, Show (RaftStateMachinePureError sm v)
, RaftPersist m
, RaftSendRPC m v
, RaftSendClient m sm v
, RaftLog m v
, RaftLogExceptions m
, RaftPersist m
, Exception (RaftPersistError m)
)
=> sm
-> RaftT v m ()
handleEventLoop initRaftStateMachine = do
setInitLastLogEntry
ePersistentState <- lift readPersistentState
case ePersistentState of
Left err -> throwM err
Right pstate -> handleEventLoop' initRaftStateMachine pstate
where
handleEventLoop' :: sm -> PersistentState -> RaftT v m ()
handleEventLoop' stateMachine persistentState = do
event <- lift . readRaftChan =<< asks eventChan
loadLogEntryTermAtAePrevLogIndex event
raftNodeState <- get
logDebug $ "[Event]: " <> show event
logDebug $ "[NodeState]: " <> show raftNodeState
logDebug $ "[State Machine]: " <> show stateMachine
logDebug $ "[Persistent State]: " <> show persistentState
nodeConfig <- asks raftNodeConfig
let transitionEnv = TransitionEnv nodeConfig stateMachine raftNodeState
(resRaftNodeState, resPersistentState, actions, logMsgs) =
Raft.Handle.handleEvent raftNodeState transitionEnv persistentState event
when (resPersistentState /= persistentState) $ do
eRes <- lift $ writePersistentState resPersistentState
case eRes of
Left err -> throwM err
Right _ -> pure ()
put resRaftNodeState
handleLogs logMsgs
handleActions actions
resRaftStateMachine <- applyLogEntries stateMachine
handleEventLoop' resRaftStateMachine resPersistentState
loadLogEntryTermAtAePrevLogIndex :: Event v -> RaftT v m ()
loadLogEntryTermAtAePrevLogIndex event =
case event of
MessageEvent (RPCMessageEvent (RPCMessage _ (AppendEntriesRPC ae))) -> do
RaftNodeState rns <- get
case rns of
NodeFollowerState fs -> do
eEntry <- lift $ readLogEntry (aePrevLogIndex ae)
case eEntry of
Left err -> throwM err
Right (mEntry :: Maybe (Entry v)) -> put $
RaftNodeState $ NodeFollowerState fs
{ fsTermAtAEPrevIndex = entryTerm <$> mEntry }
_ -> pure ()
_ -> pure ()
setInitLastLogEntry :: RaftT v m ()
setInitLastLogEntry = do
RaftNodeState rns <- get
eLogEntry <- lift readLastLogEntry
case eLogEntry of
Left err -> throwM err
Right Nothing -> pure ()
Right (Just e) ->
put (RaftNodeState (setLastLogEntry rns (singleton e)))
handleActions
:: ( Show v, Show sm, Show (Action sm v), Show (RaftLogError m), Typeable m
, MonadIO m, MonadRaft v m, MonadThrow m
, RaftStateMachine m sm v
, RaftSendRPC m v
, RaftSendClient m sm v
, RaftLog m v
, RaftLogExceptions m
)
=> [Action sm v]
-> RaftT v m ()
handleActions = mapM_ handleAction
handleAction
:: forall sm v m.
( Show v, Show sm, Show (Action sm v), Show (RaftLogError m), Typeable m
, MonadIO m, MonadRaft v m, MonadThrow m
, RaftStateMachine m sm v
, RaftSendRPC m v
, RaftSendClient m sm v
, RaftLog m v
, RaftLogExceptions m
)
=> Action sm v
-> RaftT v m ()
handleAction action = do
logDebug $ "[Action]: " <> show action
case action of
SendRPC nid sendRpcAction -> do
rpcMsg <- mkRPCfromSendRPCAction sendRpcAction
lift (sendRPC nid rpcMsg)
SendRPCs rpcMap ->
flip mapM_ (Map.toList rpcMap) $ \(nid, sendRpcAction) ->
raftFork (CustomThreadRole "Send RPC") $ do
rpcMsg <- mkRPCfromSendRPCAction sendRpcAction
lift (sendRPC nid rpcMsg)
BroadcastRPC nids sendRpcAction -> do
rpcMsg <- mkRPCfromSendRPCAction sendRpcAction
mapM_ (raftFork (CustomThreadRole "RPC Broadcast Thread") . lift . flip sendRPC rpcMsg) nids
RespondToClient cid cr -> do
clientResp <- mkClientResp cr
void $ raftFork (CustomThreadRole "Respond to Client") $ lift $ sendClient cid clientResp
ResetTimeoutTimer tout -> do
case tout of
ElectionTimeout -> lift . resetElectionTimer =<< ask
HeartbeatTimeout -> lift . resetHeartbeatTimer =<< ask
AppendLogEntries entries -> do
eRes <- lift (updateLog entries)
case eRes of
Left err -> logAndPanic (show err)
Right _ -> do
modify $ \(RaftNodeState ns) ->
RaftNodeState (setLastLogEntry ns entries)
UpdateClientReqCacheFrom idx -> do
RaftNodeState ns <- get
case ns of
NodeLeaderState ls@LeaderState{..} -> do
eentries <- lift (readLogEntriesFrom idx)
case eentries of
Left err -> throwM err
Right (entries :: Entries v) -> do
let committedClientReqs = clientReqData entries
when (Map.size committedClientReqs > 0) $ do
mapM_ respondClientWrite (Map.toList committedClientReqs)
let creqMap = Map.map (second Just) committedClientReqs
put $ RaftNodeState $ NodeLeaderState
ls { lsClientReqCache = creqMap `Map.union` lsClientReqCache }
_ -> logAndPanic "Only the leader should update the client request cache..."
where
respondClientWrite :: (ClientId, (SerialNum, Index)) -> RaftT v m ()
respondClientWrite (cid, (sn,idx)) =
handleAction $ RespondToClient cid (ClientWriteRespSpec idx sn :: ClientRespSpec sm)
mkClientResp :: ClientRespSpec sm -> RaftT v m (ClientResponse sm v)
mkClientResp crs =
case crs of
ClientReadRespSpec crrs ->
ClientReadResponse <$>
case crrs of
ClientReadRespSpecEntries res -> do
eRes <- lift (readEntries res)
case eRes of
Left err -> throwM err
Right res ->
case res of
OneEntry e -> pure (ClientReadRespEntry e)
ManyEntries es -> pure (ClientReadRespEntries es)
ClientReadRespSpecStateMachine sm ->
pure (ClientReadRespStateMachine sm)
ClientWriteRespSpec idx sn ->
pure (ClientWriteResponse (ClientWriteResp idx sn))
ClientRedirRespSpec cl ->
pure (ClientRedirectResponse (ClientRedirResp cl))
mkRPCfromSendRPCAction :: SendRPCAction v -> RaftT v m (RPCMessage v)
mkRPCfromSendRPCAction sendRPCAction = do
RaftNodeState ns <- get
nodeConfig <- asks raftNodeConfig
RPCMessage (configNodeId nodeConfig) <$>
case sendRPCAction of
SendAppendEntriesRPC aeData -> do
(entries, prevLogIndex, prevLogTerm, aeReadReq) <-
case aedEntriesSpec aeData of
FromIndex idx -> do
eLogEntries <- lift (readLogEntriesFrom (decrIndexWithDefault0 idx))
case eLogEntries of
Left err -> throwM err
Right log ->
case log of
pe :<| entries@(e :<| _)
| idx == 1 -> pure (log, index0, term0, Nothing)
| otherwise -> pure (entries, entryIndex pe, entryTerm pe, Nothing)
_ -> pure (log, index0, term0, Nothing)
FromClientWriteReq e -> prevEntryData e
FromNewLeader e -> prevEntryData e
NoEntries spec -> do
let readReq =
case spec of
FromClientReadReq n -> Just n
_ -> Nothing
(lastLogIndex, lastLogTerm) = lastLogEntryIndexAndTerm (getLastLogEntry ns)
pure (Empty, lastLogIndex, lastLogTerm, readReq)
let leaderId = LeaderId (configNodeId nodeConfig)
pure . toRPC $
AppendEntries
{ aeTerm = aedTerm aeData
, aeLeaderId = leaderId
, aePrevLogIndex = prevLogIndex
, aePrevLogTerm = prevLogTerm
, aeEntries = entries
, aeLeaderCommit = aedLeaderCommit aeData
, aeReadRequest = aeReadReq
}
SendAppendEntriesResponseRPC aer -> pure (toRPC aer)
SendRequestVoteRPC rv -> pure (toRPC rv)
SendRequestVoteResponseRPC rvr -> pure (toRPC rvr)
prevEntryData e = do
(x,y,z) <- prevEntryData' e
pure (x,y,z,Nothing)
prevEntryData' e
| entryIndex e == Index 1 = pure (singleton e, index0, term0)
| otherwise = do
let prevLogEntryIdx = decrIndexWithDefault0 (entryIndex e)
eLogEntry <- lift $ readLogEntry prevLogEntryIdx
case eLogEntry of
Left err -> throwM err
Right Nothing -> pure (singleton e, index0, term0)
Right (Just (prevEntry :: Entry v)) ->
pure (singleton e, entryIndex prevEntry, entryTerm prevEntry)
applyLogEntries
:: forall sm m v.
( Show sm, Show (RaftStateMachinePureError sm v)
, MonadIO m, MonadThrow m, MonadRaft v m
, RaftReadLog m v, Exception (RaftReadLogError m)
, RaftStateMachine m sm v
)
=> sm
-> RaftT v m sm
applyLogEntries stateMachine = do
raftNodeState@(RaftNodeState nodeState) <- get
if commitIndex nodeState > lastApplied nodeState
then do
let resNodeState = incrLastApplied nodeState
put $ RaftNodeState resNodeState
let newLastAppliedIndex = lastApplied resNodeState
eLogEntry <- lift $ readLogEntry newLastAppliedIndex
case eLogEntry of
Left err -> throwM err
Right Nothing -> logAndPanic $ "No log entry at 'newLastAppliedIndex := " <> show newLastAppliedIndex <> "'"
Right (Just logEntry) -> do
eRes <- lift (applyLogEntry stateMachine logEntry)
case eRes of
Left err -> logAndPanic $ "Failed to apply committed log entry: " <> show err
Right nsm -> applyLogEntries nsm
else pure stateMachine
where
incrLastApplied :: NodeState ns v -> NodeState ns v
incrLastApplied nodeState =
case nodeState of
NodeFollowerState fs ->
let lastApplied' = incrIndex (fsLastApplied fs)
in NodeFollowerState $ fs { fsLastApplied = lastApplied' }
NodeCandidateState cs ->
let lastApplied' = incrIndex (csLastApplied cs)
in NodeCandidateState $ cs { csLastApplied = lastApplied' }
NodeLeaderState ls ->
let lastApplied' = incrIndex (lsLastApplied ls)
in NodeLeaderState $ ls { lsLastApplied = lastApplied' }
lastApplied :: NodeState ns v -> Index
lastApplied = fst . getLastAppliedAndCommitIndex
commitIndex :: NodeState ns v -> Index
commitIndex = snd . getLastAppliedAndCommitIndex
handleLogs
:: (MonadIO m, MonadRaft v m)
=> [LogMsg]
-> RaftT v m ()
handleLogs logs = do
logCtx <- asks raftNodeLogCtx
mapM_ (logToDest logCtx) logs
rpcHandler
:: forall m v. (MonadIO m, MonadRaft v m, MonadCatch m, Show v, RaftRecvRPC m v)
=> RaftEventChan v m
-> RaftT v m ()
rpcHandler eventChan =
forever $ do
eRpcMsg <- lift $ Control.Monad.Catch.try receiveRPC
case eRpcMsg of
Left (err :: SomeException) -> logCritical (show err)
Right (Left err) -> logCritical (show err)
Right (Right rpcMsg) -> do
let rpcMsgEvent = MessageEvent (RPCMessageEvent rpcMsg)
lift $ writeRaftChan @v @m eventChan rpcMsgEvent
clientReqHandler
:: forall m v. (MonadIO m, MonadRaft v m, MonadCatch m, RaftRecvClient m v)
=> RaftEventChan v m
-> RaftT v m ()
clientReqHandler eventChan =
forever $ do
eClientReq <- lift $ Control.Monad.Catch.try receiveClient
case eClientReq of
Left (err :: SomeException) -> logCritical (show err)
Right (Left err) -> logCritical (show err)
Right (Right clientReq) -> do
let clientReqEvent = MessageEvent (ClientRequestEvent clientReq)
lift $ writeRaftChan @v @m eventChan clientReqEvent
electionTimeoutTimer :: forall m v. (MonadIO m, MonadRaft v m) => RaftEventChan v m -> Timer -> m ()
electionTimeoutTimer eventChan timer =
forever $ do
success <- liftIO $ startTimer timer
when (not success) $
panic "[Failed invariant]: Election timeout timer failed to start."
liftIO $ waitTimer timer
writeTimeoutEvent @m @v eventChan ElectionTimeout
heartbeatTimeoutTimer :: forall m v. (MonadIO m, MonadRaft v m) => RaftEventChan v m -> Timer -> m ()
heartbeatTimeoutTimer eventChan timer =
forever $ do
success <- liftIO $ startTimer timer
when (not success) $
panic "[Failed invariant]: Heartbeat timeout timer failed to start."
liftIO $ waitTimer timer
writeTimeoutEvent @m @v eventChan HeartbeatTimeout
writeTimeoutEvent :: forall m v. (MonadIO m , MonadRaft v m) => RaftEventChan v m -> Timeout -> m ()
writeTimeoutEvent eventChan timeout = do
now <- liftIO getSystemTime
writeRaftChan @v @m eventChan (TimeoutEvent now timeout)