{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# Language ConstraintKinds #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module Raft.Monad (
MonadRaft
, MonadRaftChan(..)
, RaftThreadRole(..)
, MonadRaftFork(..)
, RaftEnv(..)
, initializeRaftEnv
, RaftT
, runRaftT
, Raft.Monad.logInfo
, Raft.Monad.logDebug
, Raft.Monad.logCritical
, Raft.Monad.logAndPanic
) where
import Protolude hiding (STM, TChan, readTChan, writeTChan, newTChan, atomically)
import qualified Control.Monad.Metrics as Metrics
import Control.Monad.Catch
import Control.Monad.Fail
import Control.Monad.Trans.Class
import qualified Control.Monad.Conc.Class as Conc
import Control.Concurrent.Classy.STM.TChan
import Raft.Config
import Raft.Event
import Raft.Logging
import Raft.NodeState
import Test.DejaFu.Conc (ConcIO)
import qualified Test.DejaFu.Types as TDT
type MonadRaft v m = (MonadRaftChan v m, MonadRaftFork m)
class Monad m => MonadRaftChan v m where
type RaftEventChan v m
readRaftChan :: RaftEventChan v m -> m (Event v)
writeRaftChan :: RaftEventChan v m -> Event v -> m ()
newRaftChan :: m (RaftEventChan v m)
instance MonadRaftChan v IO where
type RaftEventChan v IO = TChan (Conc.STM IO) (Event v)
readRaftChan = Conc.atomically . readTChan
writeRaftChan chan = Conc.atomically . writeTChan chan
newRaftChan = Conc.atomically newTChan
instance MonadRaftChan v ConcIO where
type RaftEventChan v ConcIO = TChan (Conc.STM ConcIO) (Event v)
readRaftChan = Conc.atomically . readTChan
writeRaftChan chan = Conc.atomically . writeTChan chan
newRaftChan = Conc.atomically newTChan
data RaftThreadRole
= RPCHandler
| ClientRequestHandler
| CustomThreadRole Text
deriving Show
class Monad m => MonadRaftFork m where
type RaftThreadId m
raftFork
:: RaftThreadRole
-> m ()
-> m (RaftThreadId m)
instance MonadRaftFork IO where
type RaftThreadId IO = Protolude.ThreadId
raftFork _ = forkIO
instance MonadRaftFork ConcIO where
type RaftThreadId ConcIO = TDT.ThreadId
raftFork r = Conc.forkN (show r)
data RaftEnv v m = RaftEnv
{ eventChan :: RaftEventChan v m
, resetElectionTimer :: m ()
, resetHeartbeatTimer :: m ()
, raftNodeConfig :: RaftNodeConfig
, raftNodeLogCtx :: LogCtx (RaftT v m)
, raftNodeMetrics :: Metrics.Metrics
}
newtype RaftT v m a = RaftT
{ unRaftT :: ReaderT (RaftEnv v m) (StateT (RaftNodeState v) m) a
} deriving newtype (Functor, Applicative, Monad, MonadReader (RaftEnv v m), MonadState (RaftNodeState v), MonadFail, Alternative, MonadPlus)
instance MonadTrans (RaftT v) where
lift = RaftT . lift . lift
deriving newtype instance MonadIO m => MonadIO (RaftT v m)
deriving newtype instance MonadThrow m => MonadThrow (RaftT v m)
deriving newtype instance MonadCatch m => MonadCatch (RaftT v m)
deriving newtype instance MonadMask m => MonadMask (RaftT v m)
instance MonadRaftFork m => MonadRaftFork (RaftT v m) where
type RaftThreadId (RaftT v m) = RaftThreadId m
raftFork s m = do
raftEnv <- ask
raftState <- get
lift $ raftFork s (runRaftT raftState raftEnv m)
instance Monad m => RaftLogger v (RaftT v m) where
loggerCtx = (,) <$> asks (raftConfigNodeId . raftNodeConfig) <*> get
instance Monad m => Metrics.MonadMetrics (RaftT v m) where
getMetrics = asks raftNodeMetrics
initializeRaftEnv
:: MonadIO m
=> RaftEventChan v m
-> m ()
-> m ()
-> RaftNodeConfig
-> LogCtx (RaftT v m)
-> m (RaftEnv v m)
initializeRaftEnv eventChan resetElectionTimer resetHeartbeatTimer nodeConfig logCtx = do
metrics <- liftIO Metrics.initialize
pure RaftEnv
{ eventChan = eventChan
, resetElectionTimer = resetElectionTimer
, resetHeartbeatTimer = resetHeartbeatTimer
, raftNodeConfig = nodeConfig
, raftNodeLogCtx = logCtx
, raftNodeMetrics = metrics
}
runRaftT
:: Monad m
=> RaftNodeState v
-> RaftEnv v m
-> RaftT v m a
-> m a
runRaftT raftNodeState raftEnv =
flip evalStateT raftNodeState . flip runReaderT raftEnv . unRaftT
logInfo :: MonadIO m => Text -> RaftT v m ()
logInfo msg = flip logInfoIO msg =<< asks raftNodeLogCtx
logDebug :: MonadIO m => Text -> RaftT v m ()
logDebug msg = flip logDebugIO msg =<< asks raftNodeLogCtx
logCritical :: MonadIO m => Text -> RaftT v m ()
logCritical msg = flip logCriticalIO msg =<< asks raftNodeLogCtx
logAndPanic :: MonadIO m => Text -> RaftT v m a
logAndPanic msg = flip logAndPanicIO msg =<< asks raftNodeLogCtx