{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Raft
  (
  -- * State machine type class
    RSMP(..)
  , RSM(..)

  -- * Networking type classes
  , RaftSendRPC(..)
  , RaftRecvRPC(..)
  , RaftSendClient(..)
  , RaftRecvClient(..)
  , RaftPersist(..)

  , EventChan

  , RaftEnv(..)
  , runRaftNode
  , runRaftT

  , handleEventLoop

  -- * Client data types
  , ClientRequest(..)
  , ClientReq(..)
  , ClientResponse(..)
  , ClientReadResp(..)
  , ClientWriteResp(..)
  , ClientRedirResp(..)

  -- * Configuration
  , NodeConfig(..)

  -- * Events
  , Event(..)
  , Timeout(..)
  , MessageEvent(..)

  -- * Log
  , Entry(..)
  , Entries
  , RaftWriteLog(..)
  , DeleteSuccess(..)
  , RaftDeleteLog(..)
  , RaftReadLog (..)
  , RaftLog
  , RaftLogError(..)
  , RaftLogExceptions(..)

  -- * Logging
  , LogDest(..)
  , Severity(..)

  -- * Raft node states
  , Mode(..)
  , RaftNodeState(..)
  , NodeState(..)
  , CurrentLeader(..)
  , FollowerState(..)
  , CandidateState(..)
  , LeaderState(..)
  , initRaftNodeState
  , isFollower
  , isCandidate
  , isLeader
  , setLastLogEntryData
  , getLastLogEntryData
  , getLastAppliedAndCommitIndex

  -- * Persistent state
  , PersistentState(..)
  , initPersistentState

  -- * Basic types
  , NodeId
  , NodeIds
  , ClientId(..)
  , LeaderId(..)
  , Term(..)
  , Index(..)
  , term0
  , index0

  -- * RPC
  , RPC(..)
  , RPCType(..)
  , RPCMessage(..)
  , AppendEntries(..)
  , AppendEntriesResponse(..)
  , RequestVote(..)
  , RequestVoteResponse(..)
  , AppendEntriesData(..)
  ) where

import Protolude hiding (STM, TChan, newTChan, readTChan, writeTChan, atomically)

import Control.Monad.Conc.Class
import Control.Concurrent.STM.Timer
import Control.Concurrent.Classy.STM.TChan
import Control.Concurrent.Classy.Async

import Control.Monad.Fail
import Control.Monad.Catch
import Control.Monad.Trans.Class

import qualified Data.Map as Map
import Data.Sequence (Seq(..), singleton)

import Raft.Action
import Raft.Client
import Raft.Config
import Raft.Event
import Raft.Handle
import Raft.Log
import Raft.Logging hiding (logInfo, logDebug, logCritical)
import Raft.Monad hiding (logInfo, logDebug)
import Raft.NodeState
import Raft.Persistent
import Raft.RPC
import Raft.Types


type EventChan m v = TChan (STM m) (Event v)

-- | The raft server environment composed of the concurrent variables used in
-- the effectful raft layer.
data RaftEnv v m = RaftEnv
  { eventChan :: EventChan m v
  , resetElectionTimer :: m ()
  , resetHeartbeatTimer :: m ()
  , raftNodeConfig :: NodeConfig
  , raftNodeLogDest :: LogDest
  }

newtype RaftT v m a = RaftT
  { unRaftT :: ReaderT (RaftEnv v m) (StateT RaftNodeState m) a
  } deriving (Functor, Applicative, Monad, MonadReader (RaftEnv v m), MonadState RaftNodeState, MonadFail, Alternative, MonadPlus)

instance MonadTrans (RaftT v) where
  lift = RaftT . lift . lift

deriving instance MonadIO m => MonadIO (RaftT v m)
deriving instance MonadThrow m => MonadThrow (RaftT v m)
deriving instance MonadCatch m => MonadCatch (RaftT v m)
deriving instance MonadMask m => MonadMask (RaftT v m)
deriving instance MonadConc m => MonadConc (RaftT v m)

instance Monad m => RaftLogger (RaftT v m) where
  loggerNodeId = asks (configNodeId . raftNodeConfig)
  loggerNodeState = get

runRaftT
  :: MonadConc m
  => RaftNodeState
  -> RaftEnv v m
  -> RaftT v m ()
  -> m ()
runRaftT raftNodeState raftEnv =
  flip evalStateT raftNodeState . flip runReaderT raftEnv . unRaftT

------------------------------------------------------------------------------

logDebug :: MonadIO m => Text -> RaftT v m ()
logDebug msg = flip logDebugIO msg =<< asks raftNodeLogDest

logCritical :: MonadIO m => Text -> RaftT v m ()
logCritical msg = flip logCriticalIO msg =<< asks raftNodeLogDest

------------------------------------------------------------------------------

-- | Run timers, RPC and client request handlers and start event loop.
-- It should run forever
runRaftNode
  :: ( Show v, Show sm, Show (Action sm v)
     , MonadIO m, MonadConc m, MonadFail m
     , RSM sm v m
     , Show (RSMPError sm v)
     , RaftSendRPC m v
     , RaftRecvRPC m v
     , RaftSendClient m sm
     , RaftRecvClient m v
     , RaftLog m v
     , RaftLogExceptions m
     , RaftPersist m
     , Exception (RaftPersistError m)
     )
   => NodeConfig           -- ^ Node configuration
   -> LogDest              -- ^ Logs destination
   -> Int                  -- ^ Timer seed
   -> sm                   -- ^ Initial state machine state
   -> m ()
runRaftNode nodeConfig@NodeConfig{..} logDest timerSeed initRSM = do
  eventChan <- atomically newTChan

  electionTimer <- newTimerRange timerSeed configElectionTimeout
  heartbeatTimer <- newTimer configHeartbeatTimeout

  let resetElectionTimer = resetTimer electionTimer
      resetHeartbeatTimer = resetTimer heartbeatTimer
      raftEnv = RaftEnv eventChan resetElectionTimer resetHeartbeatTimer nodeConfig logDest

  runRaftT initRaftNodeState raftEnv $ do

    -- Fork all event producers to run concurrently
    lift $ fork (electionTimeoutTimer electionTimer eventChan)
    lift $ fork (heartbeatTimeoutTimer heartbeatTimer eventChan)
    fork (rpcHandler eventChan)
    fork (clientReqHandler eventChan)

    -- Start the main event handling loop
    handleEventLoop initRSM

handleEventLoop
  :: forall sm v m.
     ( Show v, Show sm, Show (Action sm v)
     , MonadIO m, MonadConc m, MonadFail m
     , RSM sm v m
     , Show (RSMPError sm v)
     , RaftPersist m
     , RaftSendRPC m v
     , RaftSendClient m sm
     , RaftLog m v
     , RaftLogExceptions m
     , RaftPersist m
     , Exception (RaftPersistError m)
     )
  => sm
  -> RaftT v m ()
handleEventLoop initRSM = do
    ePersistentState <- lift readPersistentState
    case ePersistentState of
      Left err -> throw err
      Right pstate -> handleEventLoop' initRSM pstate
  where
    handleEventLoop' :: sm -> PersistentState -> RaftT v m ()
    handleEventLoop' stateMachine persistentState = do
      event <- atomically . readTChan =<< asks eventChan
      loadLogEntryTermAtAePrevLogIndex event
      raftNodeState <- get
      logDebug $ "[Event]: " <> show event
      logDebug $ "[NodeState]: " <> show raftNodeState
      Right log :: Either (RaftReadLogError m) (Entries v) <- lift $ readLogEntriesFrom index0
      logDebug $ "[Log]: " <> show log
      logDebug $ "[State Machine]: " <> show stateMachine
      logDebug $ "[Persistent State]: " <> show persistentState
      -- Perform core state machine transition, handling the current event
      nodeConfig <- asks raftNodeConfig
      let transitionEnv = TransitionEnv nodeConfig stateMachine raftNodeState
          (resRaftNodeState, resPersistentState, actions, logMsgs) =
            Raft.Handle.handleEvent raftNodeState transitionEnv persistentState event
      -- Write persistent state to disk
      eRes <- lift $ writePersistentState resPersistentState
      case eRes of
        Left err -> throw err
        Right _ -> pure ()
      -- Update raft node state with the resulting node state
      put resRaftNodeState
      -- Handle logs producek by core state machine
      handleLogs logMsgs
      -- Handle actions produced by core state machine
      handleActions nodeConfig actions
      -- Apply new log entries to the state machine
      resRSM <- applyLogEntries stateMachine
      handleEventLoop' resRSM resPersistentState

    -- In the case that a node is a follower receiving an AppendEntriesRPC
    -- Event, read the log at the aePrevLogIndex
    loadLogEntryTermAtAePrevLogIndex :: Event v -> RaftT v m ()
    loadLogEntryTermAtAePrevLogIndex event =
      case event of
        MessageEvent (RPCMessageEvent (RPCMessage _ (AppendEntriesRPC ae))) -> do
          RaftNodeState rns <- get
          case rns of
            NodeFollowerState fs -> do
              eEntry <- lift $ readLogEntry (aePrevLogIndex ae)
              case eEntry of
                Left err -> throw err
                Right (mEntry :: Maybe (Entry v)) -> put $
                  RaftNodeState $ NodeFollowerState fs
                    { fsTermAtAEPrevIndex = entryTerm <$> mEntry }
            _ -> pure ()
        _ -> pure ()

handleActions
  :: ( Show v, Show sm, Show (Action sm v)
     , MonadIO m, MonadConc m
     , RSM sm v m
     , RaftSendRPC m v
     , RaftSendClient m sm
     , RaftLog m v
     , RaftLogExceptions m
     )
  => NodeConfig
  -> [Action sm v]
  -> RaftT v m ()
handleActions = mapM_ . handleAction

handleAction
  :: forall sm v m.
     ( Show v, Show sm, Show (Action sm v)
     , MonadIO m, MonadConc m
     , RSM sm v m
     , RaftSendRPC m v
     , RaftSendClient m sm
     , RaftLog m v
     , RaftLogExceptions m
     )
  => NodeConfig
  -> Action sm v
  -> RaftT v m ()
handleAction nodeConfig action = do
  logDebug $ "[Action]: " <> show action
  case action of
    SendRPC nid sendRpcAction -> do
      rpcMsg <- mkRPCfromSendRPCAction sendRpcAction
      lift (sendRPC nid rpcMsg)
    SendRPCs rpcMap ->
      forConcurrently_ (Map.toList rpcMap) $ \(nid, sendRpcAction) -> do
        rpcMsg <- mkRPCfromSendRPCAction sendRpcAction
        lift (sendRPC nid rpcMsg)
    BroadcastRPC nids sendRpcAction -> do
      rpcMsg <- mkRPCfromSendRPCAction sendRpcAction
      mapConcurrently_ (lift . flip sendRPC rpcMsg) nids
    RespondToClient cid cr -> lift $ sendClient cid cr
    ResetTimeoutTimer tout ->
      case tout of
        ElectionTimeout -> lift . resetElectionTimer =<< ask
        HeartbeatTimeout -> lift . resetHeartbeatTimer =<< ask
    AppendLogEntries entries -> do
      lift (updateLog entries)
      -- Update the last log entry data
      modify $ \(RaftNodeState ns) ->
        RaftNodeState (setLastLogEntryData ns entries)

  where
    mkRPCfromSendRPCAction :: SendRPCAction v -> RaftT v m (RPCMessage v)
    mkRPCfromSendRPCAction sendRPCAction = do
      RaftNodeState ns <- get
      RPCMessage (configNodeId nodeConfig) <$>
        case sendRPCAction of
          SendAppendEntriesRPC aeData -> do
            (entries, prevLogIndex, prevLogTerm, aeReadReq) <-
              case aedEntriesSpec aeData of
                FromIndex idx -> do
                  eLogEntries <- lift (readLogEntriesFrom (decrIndexWithDefault0 idx))
                  case eLogEntries of
                    Left err -> throw err
                    Right log ->
                      case log of
                        pe :<| entries@(e :<| _)
                          | idx == 1 -> pure (log, index0, term0, Nothing)
                          | otherwise -> pure (entries, entryIndex pe, entryTerm pe, Nothing)
                        _ -> pure (log, index0, term0, Nothing)
                FromClientWriteReq e -> prevEntryData e
                FromNewLeader e -> prevEntryData e
                NoEntries spec -> do
                  let readReq =
                        case spec of
                          FromClientReadReq n -> Just n
                          _ -> Nothing
                      (lastLogIndex, lastLogTerm) = getLastLogEntryData ns
                  pure (Empty, lastLogIndex, lastLogTerm, readReq)
            let leaderId = LeaderId (configNodeId nodeConfig)
            pure . toRPC $
              AppendEntries
                { aeTerm = aedTerm aeData
                , aeLeaderId = leaderId
                , aePrevLogIndex = prevLogIndex
                , aePrevLogTerm = prevLogTerm
                , aeEntries = entries
                , aeLeaderCommit = aedLeaderCommit aeData
                , aeReadRequest = aeReadReq
                }
          SendAppendEntriesResponseRPC aer -> pure (toRPC aer)
          SendRequestVoteRPC rv -> pure (toRPC rv)
          SendRequestVoteResponseRPC rvr -> pure (toRPC rvr)

    prevEntryData e = do
      (x,y,z) <- prevEntryData' e
      pure (x,y,z,Nothing)

    prevEntryData' e
      | entryIndex e == Index 1 = pure (singleton e, index0, term0)
      | otherwise = do
          let prevLogEntryIdx = decrIndexWithDefault0 (entryIndex e)
          eLogEntry <- lift $ readLogEntry prevLogEntryIdx
          case eLogEntry of
            Left err -> throw err
            Right Nothing -> pure (singleton e, index0, term0)
            Right (Just (prevEntry :: Entry v)) ->
              pure (singleton e, entryIndex prevEntry, entryTerm prevEntry)

-- If commitIndex > lastApplied: increment lastApplied, apply
-- log[lastApplied] to state machine (Section 5.3) until the state machine
-- is up to date with all the committed log entries
applyLogEntries
  :: forall sm m v.
     ( Show sm
     , MonadConc m
     , RaftReadLog m v
     , Exception (RaftReadLogError m)
     , RSM sm v m
     , Show (RSMPError sm v)
     )
  => sm
  -> RaftT v m sm
applyLogEntries stateMachine = do
    raftNodeState@(RaftNodeState nodeState) <- get
    if commitIndex nodeState > lastApplied nodeState
      then do
        let resNodeState = incrLastApplied nodeState
        put $ RaftNodeState resNodeState
        let newLastAppliedIndex = lastApplied resNodeState
        eLogEntry <- lift $ readLogEntry newLastAppliedIndex
        case eLogEntry of
          Left err -> throw err
          Right Nothing -> panic "No log entry at 'newLastAppliedIndex'"
          Right (Just logEntry) -> do
            -- The command should be verified by the leader, thus all node
            -- attempting to apply the committed log entry should not fail when
            -- doing so; failure here means something has gone very wrong.
            eRes <- lift (applyEntryRSM stateMachine logEntry)
            case eRes of
              Left err -> panic $ "Failed to apply committed log entry: " <> show err
              Right nsm -> applyLogEntries nsm
      else pure stateMachine
  where
    incrLastApplied :: NodeState ns -> NodeState ns
    incrLastApplied nodeState =
      case nodeState of
        NodeFollowerState fs ->
          let lastApplied' = incrIndex (fsLastApplied fs)
           in NodeFollowerState $ fs { fsLastApplied = lastApplied' }
        NodeCandidateState cs ->
          let lastApplied' = incrIndex (csLastApplied cs)
           in  NodeCandidateState $ cs { csLastApplied = lastApplied' }
        NodeLeaderState ls ->
          let lastApplied' = incrIndex (lsLastApplied ls)
           in NodeLeaderState $ ls { lsLastApplied = lastApplied' }

    lastApplied :: NodeState ns -> Index
    lastApplied = fst . getLastAppliedAndCommitIndex

    commitIndex :: NodeState ns -> Index
    commitIndex = snd . getLastAppliedAndCommitIndex

handleLogs
  :: (MonadIO m, MonadConc m)
  => [LogMsg]
  -> RaftT v m ()
handleLogs logs = do
  logDest <- asks raftNodeLogDest
  mapM_ (logToDest logDest) logs

------------------------------------------------------------------------------
-- Event Producers
------------------------------------------------------------------------------

-- | Producer for rpc message events
rpcHandler
  :: (MonadIO m, MonadConc m, Show v, RaftRecvRPC m v)
  => TChan (STM m) (Event v)
  -> RaftT v m ()
rpcHandler eventChan =
  forever $ do
    eRpcMsg <- lift $ Control.Monad.Catch.try receiveRPC
    case eRpcMsg of
      Left (err :: SomeException) -> logCritical (show err)
      Right (Left err) -> logCritical (show err)
      Right (Right rpcMsg) -> do
        let rpcMsgEvent = MessageEvent (RPCMessageEvent rpcMsg)
        atomically $ writeTChan eventChan rpcMsgEvent

-- | Producer for rpc message events
clientReqHandler
  :: (MonadIO m, MonadConc m, RaftRecvClient m v)
  => TChan (STM m) (Event v)
  -> RaftT v m ()
clientReqHandler eventChan =
  forever $ do
    eClientReq <- lift $ Control.Monad.Catch.try receiveClient
    case eClientReq of
      Left (err :: SomeException) -> logCritical (show err)
      Right (Left err) -> logCritical (show err)
      Right (Right clientReq) -> do
        let clientReqEvent = MessageEvent (ClientRequestEvent clientReq)
        atomically $ writeTChan eventChan clientReqEvent

-- | Producer for the election timeout event
electionTimeoutTimer :: MonadConc m => Timer m -> TChan (STM m) (Event v) -> m ()
electionTimeoutTimer timer eventChan =
  forever $ do
    startTimer timer >> waitTimer timer
    atomically $ writeTChan eventChan (TimeoutEvent ElectionTimeout)

-- | Producer for the heartbeat timeout event
heartbeatTimeoutTimer :: MonadConc m => Timer m -> TChan (STM m) (Event v) -> m ()
heartbeatTimeoutTimer timer eventChan =
  forever $ do
    startTimer timer >> waitTimer timer
    atomically $ writeTChan eventChan (TimeoutEvent HeartbeatTimeout)