{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE GADTs #-}

module Raft.Transition where

import Protolude hiding (pass)

import Control.Arrow ((&&&))
import Control.Monad.RWS

import qualified Data.Set as Set

import Raft.Action
import Raft.Client
import Raft.Config
import Raft.Event
import Raft.Log
import Raft.Persistent
import Raft.NodeState
import Raft.RPC
import Raft.Types
import Raft.Logging (RaftLogger, runRaftLoggerT, RaftLoggerT(..), LogMsg)
import qualified Raft.Logging as Logging


--------------------------------------------------------------------------------
-- Raft Transition Monad
--------------------------------------------------------------------------------

tellAction :: Action sm v -> TransitionM sm v ()
tellAction a = tell [a]

tellActions :: [Action sm v] -> TransitionM sm v ()
tellActions as = tell as

data TransitionEnv sm v = TransitionEnv
  { nodeConfig :: RaftNodeConfig
  , stateMachine :: sm
  , nodeState :: RaftNodeState v
  }

newtype TransitionM sm v a = TransitionM
  { unTransitionM :: RaftLoggerT v (RWS (TransitionEnv sm v) [Action sm v] PersistentState) a
  } deriving (Functor, Applicative, Monad)

instance MonadWriter [Action sm v] (TransitionM sm v) where
  tell = TransitionM . RaftLoggerT . tell
  listen = TransitionM . RaftLoggerT . listen . unRaftLoggerT . unTransitionM
  pass = TransitionM . RaftLoggerT . pass . unRaftLoggerT . unTransitionM

instance MonadReader (TransitionEnv sm v) (TransitionM sm v) where
  ask = TransitionM . RaftLoggerT $ ask
  local f = TransitionM . RaftLoggerT . local f . unRaftLoggerT . unTransitionM

instance MonadState PersistentState (TransitionM sm v) where
  get = TransitionM . RaftLoggerT $ lift get
  put = TransitionM . RaftLoggerT . lift . put

instance RaftLogger v (RWS (TransitionEnv sm v) [Action sm v] PersistentState) where
  loggerCtx = asks ((configNodeId . nodeConfig) &&& nodeState)

runTransitionM
  :: TransitionEnv sm v
  -> PersistentState
  -> TransitionM sm v a
  -> ((a, [LogMsg]), PersistentState, [Action sm v])
runTransitionM transEnv persistentState transitionM =
  runRWS (runRaftLoggerT (unTransitionM transitionM)) transEnv persistentState

askNodeId :: TransitionM sm v NodeId
askNodeId = asks (configNodeId . nodeConfig)

--------------------------------------------------------------------------------
-- Handlers
--------------------------------------------------------------------------------

type RPCHandler ns sm r v = (RPCType r v, Show v) => NodeState ns v -> NodeId -> r -> TransitionM sm v (ResultState ns v)
type TimeoutHandler ns sm v = Show v => NodeState ns v -> Timeout -> TransitionM sm v (ResultState ns v)
type ClientReqHandler ns sm v = Show v => NodeState ns v -> ClientRequest v -> TransitionM sm v (ResultState ns v)

--------------------------------------------------------------------------------
-- RWS Helpers
--------------------------------------------------------------------------------

broadcast :: SendRPCAction v -> TransitionM sm v ()
broadcast sendRPC = do
  selfNodeId <- askNodeId
  tellAction =<<
    flip BroadcastRPC sendRPC
      <$> asks (Set.filter (selfNodeId /=) . configNodeIds . nodeConfig)

send :: NodeId -> SendRPCAction v -> TransitionM sm v ()
send nodeId sendRPC = tellAction (SendRPC nodeId sendRPC)

-- | Resets the election timeout.
resetElectionTimeout :: TransitionM sm v ()
resetElectionTimeout = tellAction (ResetTimeoutTimer ElectionTimeout)

resetHeartbeatTimeout :: TransitionM sm v ()
resetHeartbeatTimeout = tellAction (ResetTimeoutTimer HeartbeatTimeout)

redirectClientToLeader :: ClientId -> CurrentLeader -> TransitionM sm v ()
redirectClientToLeader clientId currentLeader = do
  let clientRedirRespSpec = ClientRedirRespSpec currentLeader
  tellAction (RespondToClient clientId clientRedirRespSpec)

respondClientRead :: ClientId -> ClientReadReq -> TransitionM sm v ()
respondClientRead clientId readReq = do
  readReqData <-
    case readReq of
      ClientReadEntries res -> pure (ClientReadRespSpecEntries res)
      ClientReadStateMachine -> do
        sm <- asks stateMachine
        pure (ClientReadRespSpecStateMachine sm)
  tellAction . RespondToClient clientId . ClientReadRespSpec $ readReqData


respondClientWrite :: ClientId -> Index -> SerialNum -> TransitionM sm v ()
respondClientWrite cid entryIdx sn =
  tellAction (RespondToClient cid (ClientWriteRespSpec entryIdx sn))

respondClientRedir :: ClientId -> CurrentLeader -> TransitionM sm v ()
respondClientRedir cid cl =
  tellAction (RespondToClient cid (ClientRedirRespSpec cl))

appendLogEntries :: Show v => Seq (Entry v) -> TransitionM sm v ()
appendLogEntries = tellAction . AppendLogEntries

updateClientReqCacheFromIdx :: Index -> TransitionM sm v ()
updateClientReqCacheFromIdx = tellAction . UpdateClientReqCacheFrom

--------------------------------------------------------------------------------

startElection
  :: Index
  -> Index
  -> LastLogEntry v
  -> ClientWriteReqCache
  -> TransitionM sm v (CandidateState v)
startElection commitIndex lastApplied lastLogEntry clientReqCache  = do
    incrementTerm
    voteForSelf
    resetElectionTimeout
    broadcast =<< requestVoteMessage
    selfNodeId <- askNodeId
    -- Return new candidate state
    pure CandidateState
      { csCommitIndex = commitIndex
      , csLastApplied = lastApplied
      , csVotes = Set.singleton selfNodeId
      , csLastLogEntry = lastLogEntry
      , csClientReqCache = clientReqCache
      }
  where
    requestVoteMessage = do
      term <- currentTerm <$> get
      selfNodeId <- askNodeId
      pure $ SendRequestVoteRPC
        RequestVote
          { rvTerm = term
          , rvCandidateId = selfNodeId
          , rvLastLogIndex = lastLogEntryIndex lastLogEntry
          , rvLastLogTerm = lastLogEntryTerm lastLogEntry
          }

    incrementTerm = do
      psNextTerm <- incrTerm . currentTerm <$> get
      modify $ \pstate ->
        pstate { currentTerm = psNextTerm
               , votedFor = Nothing
               }

    voteForSelf = do
      selfNodeId <- askNodeId
      modify $ \pstate ->
        pstate { votedFor = Just selfNodeId }

--------------------------------------------------------------------------------
-- Logging
--------------------------------------------------------------------------------

logInfo = TransitionM . Logging.logInfo
logDebug = TransitionM . Logging.logDebug