{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# Language ConstraintKinds #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module Raft.Monad (
  MonadRaft
, MonadRaftChan(..)
, RaftThreadRole(..)
, MonadRaftFork(..)
, RaftEnv(..)
, initializeRaftEnv
, RaftT
, runRaftT
, Raft.Monad.logInfo
, Raft.Monad.logDebug
, Raft.Monad.logCritical
, Raft.Monad.logAndPanic
) where
import Protolude hiding (STM, TChan, readTChan, writeTChan, newTChan, atomically)
import qualified Control.Monad.Metrics as Metrics
import Control.Monad.Catch
import Control.Monad.Fail
import Control.Monad.Trans.Class
import qualified Control.Monad.Conc.Class as Conc
import Control.Concurrent.Classy.STM.TChan
import Raft.Config
import Raft.Event
import Raft.Logging
import Raft.NodeState
import Test.DejaFu.Conc (ConcIO)
import qualified Test.DejaFu.Types as TDT
type MonadRaft v m = (MonadRaftChan v m, MonadRaftFork m)
class Monad m => MonadRaftChan v m where
  type RaftEventChan v m
  readRaftChan :: RaftEventChan v m -> m (Event v)
  writeRaftChan :: RaftEventChan v m -> Event v -> m ()
  newRaftChan :: m (RaftEventChan v m)
instance MonadRaftChan v IO where
  type RaftEventChan v IO = TChan (Conc.STM IO) (Event v)
  readRaftChan = Conc.atomically . readTChan
  writeRaftChan chan = Conc.atomically . writeTChan chan
  newRaftChan = Conc.atomically newTChan
instance MonadRaftChan v ConcIO where
  type RaftEventChan v ConcIO = TChan (Conc.STM ConcIO) (Event v)
  readRaftChan = Conc.atomically . readTChan
  writeRaftChan chan = Conc.atomically . writeTChan chan
  newRaftChan = Conc.atomically newTChan
data RaftThreadRole
  = RPCHandler
  | ClientRequestHandler
  | CustomThreadRole Text
  deriving Show
class Monad m => MonadRaftFork m where
  type RaftThreadId m
  raftFork
    :: RaftThreadRole 
    -> m ()   
    -> m (RaftThreadId m)
instance MonadRaftFork IO where
  type RaftThreadId IO = Protolude.ThreadId
  raftFork _ = forkIO
instance MonadRaftFork ConcIO where
  type RaftThreadId ConcIO = TDT.ThreadId
  raftFork r = Conc.forkN (show r)
data RaftEnv v m = RaftEnv
  { eventChan :: RaftEventChan v m
  , resetElectionTimer :: m ()
  , resetHeartbeatTimer :: m ()
  , raftNodeConfig :: RaftNodeConfig
  , raftNodeLogCtx :: LogCtx (RaftT v m)
  , raftNodeMetrics :: Metrics.Metrics
  }
newtype RaftT v m a = RaftT
  { unRaftT :: ReaderT (RaftEnv v m) (StateT (RaftNodeState v) m) a
  } deriving newtype (Functor, Applicative, Monad, MonadReader (RaftEnv v m), MonadState (RaftNodeState v), MonadFail, Alternative, MonadPlus)
instance MonadTrans (RaftT v) where
  lift = RaftT . lift . lift
deriving newtype instance MonadIO m => MonadIO (RaftT v m)
deriving newtype instance MonadThrow m => MonadThrow (RaftT v m)
deriving newtype instance MonadCatch m => MonadCatch (RaftT v m)
deriving newtype instance MonadMask m => MonadMask (RaftT v m)
instance MonadRaftFork m => MonadRaftFork (RaftT v m) where
  type RaftThreadId (RaftT v m) = RaftThreadId m
  raftFork s m = do
    raftEnv <- ask
    raftState <- get
    lift $ raftFork s (runRaftT raftState raftEnv m)
instance Monad m => RaftLogger v (RaftT v m) where
  loggerCtx = (,) <$> asks (raftConfigNodeId . raftNodeConfig) <*> get
instance Monad m => Metrics.MonadMetrics (RaftT v m) where
  getMetrics = asks raftNodeMetrics
initializeRaftEnv
  :: MonadIO m
  => RaftEventChan v m
  -> m ()
  -> m ()
  -> RaftNodeConfig
  -> LogCtx (RaftT v m)
  -> m (RaftEnv v m)
initializeRaftEnv eventChan resetElectionTimer resetHeartbeatTimer nodeConfig logCtx = do
  metrics <- liftIO Metrics.initialize
  pure RaftEnv
    { eventChan = eventChan
    , resetElectionTimer = resetElectionTimer
    , resetHeartbeatTimer = resetHeartbeatTimer
    , raftNodeConfig = nodeConfig
    , raftNodeLogCtx = logCtx
    , raftNodeMetrics = metrics
    }
runRaftT
  :: Monad m
  => RaftNodeState v
  -> RaftEnv v m
  -> RaftT v m a
  -> m a
runRaftT raftNodeState raftEnv =
  flip evalStateT raftNodeState . flip runReaderT raftEnv . unRaftT
logInfo :: MonadIO m => Text -> RaftT v m ()
logInfo msg = flip logInfoIO msg =<< asks raftNodeLogCtx
logDebug :: MonadIO m => Text -> RaftT v m ()
logDebug msg = flip logDebugIO msg =<< asks raftNodeLogCtx
logCritical :: MonadIO m => Text -> RaftT v m ()
logCritical msg = flip logCriticalIO msg =<< asks raftNodeLogCtx
logAndPanic :: MonadIO m => Text -> RaftT v m a
logAndPanic msg = flip logAndPanicIO msg =<< asks raftNodeLogCtx