{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
module Raft.Client
(
  
  RaftSendClient(..)
, RaftRecvClient(..)
, SerialNum(..)
, ClientRequest(..)
, ClientReq(..)
, ClientReqType(..)
, ClientReadReq(..)
, ReadEntriesSpec(..)
, ClientWriteReq(..)
, ClientMetricsReq(..)
, ClientResponse(..)
, ClientRespSpec(..)
, ClientReadRespSpec(..)
, ClientWriteRespSpec(..)
, ClientReadResp(..)
, ClientWriteResp(..)
, ClientRedirResp(..)
, ClientMetricsResp(..)
, RaftClientSend(..)
, RaftClientRecv(..)
, RaftClientState(..)
, RaftClientEnv(..)
, initRaftClientState
, RaftClientT
, runRaftClientT
, RaftClientError(..)
, clientRead
, clientReadFrom
, clientReadTimeout
, clientWrite
, clientWriteTo
, clientWriteTimeout
, clientQueryNodeMetrics
, clientQueryNodeMetricsTimeout
, clientSendRead
, clientSendWrite
, clientSendMetricsReqTo
, clientSend
, clientRecv
, clientTimeout
, retryOnRedirect
, clientAddNode
, clientGetNodes
) where
import Protolude hiding (threadDelay, STM, )
import Control.Concurrent.Lifted (threadDelay)
import Control.Monad.Base
import Control.Monad.Catch
import Control.Monad.Fail
import Control.Monad.Trans.Class
import Control.Monad.Trans.Control
import qualified Data.Set as Set
import qualified Data.Serialize as S
import System.Random
import System.Console.Haskeline.MonadException (MonadException(..), RunIO(..))
import System.Timeout.Lifted (timeout)
import Raft.Log (Entry, Entries, ReadEntriesSpec)
import Raft.StateMachine
import Raft.Metrics (RaftNodeMetrics)
import Raft.Types
class RaftStateMachinePure sm v => RaftSendClient m sm v where
  sendClient :: ClientId -> ClientResponse sm v -> m ()
class Show (RaftRecvClientError m v) => RaftRecvClient m v where
  type RaftRecvClientError m v
  receiveClient :: m (Either (RaftRecvClientError m v) (ClientRequest v))
data ClientRequest v
  = ClientRequest ClientId (ClientReq v)
  deriving (Show, Generic)
instance S.Serialize v => S.Serialize (ClientRequest v)
data ClientReq v
  = ClientReadReq ClientReadReq 
  | ClientWriteReq (ClientWriteReq v) 
  | ClientMetricsReq ClientMetricsReq 
  deriving (Show, Generic)
instance S.Serialize v => S.Serialize (ClientReq v)
data ClientReadReq
  = ClientReadEntries ReadEntriesSpec
  | ClientReadStateMachine
  deriving (Show, Generic, S.Serialize)
data ClientWriteReq v
  = ClientCmdReq SerialNum v 
  deriving (Show, Generic, S.Serialize)
data ClientMetricsReq
  = ClientAllMetricsReq
  deriving (Show, Generic, S.Serialize)
class ClientReqType a v
instance ClientReqType ClientReadReq v
instance ClientReqType (ClientWriteReq v) v
instance ClientReqType ClientMetricsReq v
data ClientRespSpec sm v
  = ClientReadRespSpec (ClientReadRespSpec sm)
  | ClientWriteRespSpec (ClientWriteRespSpec sm v)
  | ClientRedirRespSpec CurrentLeader
  | ClientMetricsRespSpec RaftNodeMetrics
  deriving (Generic)
deriving instance (Show sm, Show v, Show (RaftStateMachinePureError sm v)) => Show (ClientRespSpec sm v)
deriving instance (S.Serialize sm, S.Serialize v, S.Serialize (RaftStateMachinePureError sm v)) => S.Serialize (ClientRespSpec sm v)
data ClientReadRespSpec sm
  = ClientReadRespSpecEntries ReadEntriesSpec
  | ClientReadRespSpecStateMachine sm
  deriving (Show, Generic, S.Serialize)
data ClientWriteRespSpec sm v
  = ClientWriteRespSpecSuccess Index SerialNum
  | ClientWriteRespSpecFail SerialNum (RaftStateMachinePureError sm v)
  deriving (Generic)
deriving instance (Show sm, Show v, Show (RaftStateMachinePureError sm v)) => Show (ClientWriteRespSpec sm v)
deriving instance (S.Serialize sm, S.Serialize v, S.Serialize (RaftStateMachinePureError sm v)) => S.Serialize (ClientWriteRespSpec sm v)
data ClientResponse sm v
  = ClientReadResponse (ClientReadResp sm v)
    
  | ClientWriteResponse (ClientWriteResp sm v)
    
  | ClientRedirectResponse ClientRedirResp
    
  | ClientMetricsResponse ClientMetricsResp
    
  deriving (Generic)
deriving instance (Show sm, Show v, Show (ClientWriteResp sm v)) => Show (ClientResponse sm v)
deriving instance (S.Serialize sm, S.Serialize v, S.Serialize (ClientWriteResp sm v)) => S.Serialize (ClientResponse sm v)
data ClientReadResp sm v
  = ClientReadRespStateMachine sm
  | ClientReadRespEntry (Entry v)
  | ClientReadRespEntries (Entries v)
  deriving (Show, Generic, S.Serialize)
data ClientWriteResp sm v
  = ClientWriteRespSuccess Index SerialNum
  
  | ClientWriteRespFail SerialNum (RaftStateMachinePureError sm v)
  deriving (Generic)
deriving instance (Show sm, Show v, Show (RaftStateMachinePureError sm v)) => Show (ClientWriteResp sm v)
deriving instance (S.Serialize sm, S.Serialize v, S.Serialize (RaftStateMachinePureError sm v)) => S.Serialize (ClientWriteResp sm v)
data ClientRedirResp
  = ClientRedirResp CurrentLeader
  deriving (Show, Generic, S.Serialize)
data ClientMetricsResp
  = ClientMetricsResp RaftNodeMetrics
  deriving (Show, Generic, S.Serialize)
class Monad m => RaftClientSend m v where
  type RaftClientSendError m v
  raftClientSend :: NodeId -> ClientRequest v -> m (Either (RaftClientSendError m v) ())
class Monad m => RaftClientRecv m sm v | m sm -> v where
  type RaftClientRecvError m sm
  raftClientRecv :: m (Either (RaftClientRecvError m sm) (ClientResponse sm v))
data RaftClientState = RaftClientState
  { raftClientCurrentLeader :: CurrentLeader
  , raftClientSerialNum :: SerialNum
  , raftClientRaftNodes :: Set NodeId
  , raftClientRandomGen :: StdGen
  }
data RaftClientEnv = RaftClientEnv
  { raftClientId :: ClientId
  }
initRaftClientState :: Set NodeId -> StdGen -> RaftClientState
initRaftClientState = RaftClientState NoLeader 0
newtype RaftClientT s v m a = RaftClientT
  { unRaftClientT :: ReaderT RaftClientEnv (StateT RaftClientState m) a
  } deriving newtype (Functor, Applicative, Monad, MonadIO, MonadState RaftClientState, MonadReader RaftClientEnv, MonadFail, Alternative, MonadPlus)
deriving newtype instance MonadThrow m => MonadThrow (RaftClientT s v m)
deriving newtype instance MonadCatch m => MonadCatch (RaftClientT s v m)
deriving newtype instance MonadMask m => MonadMask (RaftClientT s v m)
instance MonadTrans (RaftClientT s v) where
  lift = RaftClientT . lift . lift
deriving newtype instance MonadBase IO m => MonadBase IO (RaftClientT s v m)
instance MonadTransControl (RaftClientT s v) where
    type StT (RaftClientT s v) a = StT (ReaderT RaftClientEnv) (StT (StateT RaftClientState) a)
    liftWith = defaultLiftWith2 RaftClientT unRaftClientT
    restoreT = defaultRestoreT2 RaftClientT
instance (MonadBaseControl IO m) => MonadBaseControl IO (RaftClientT s v m) where
    type StM (RaftClientT s v m) a = ComposeSt (RaftClientT s v) m a
    liftBaseWith    = defaultLiftBaseWith
    restoreM        = defaultRestoreM
instance MonadException m => MonadException (RaftClientT s v m) where
  controlIO f =
    RaftClientT $ ReaderT $ \r -> StateT $ \s ->
      controlIO $ \(RunIO run) ->
        let run' = RunIO (fmap (RaftClientT . ReaderT . const . StateT . const) . run . flip runStateT s . flip runReaderT r . unRaftClientT)
         in fmap (flip runStateT s . flip runReaderT r . unRaftClientT) $ f run'
instance RaftClientSend m v => RaftClientSend (RaftClientT s v m) v where
  type RaftClientSendError (RaftClientT s v m) v = RaftClientSendError m v
  raftClientSend nid creq = lift (raftClientSend nid creq)
instance RaftClientRecv m s v => RaftClientRecv (RaftClientT s v m) s v where
  type RaftClientRecvError (RaftClientT s v m) s = RaftClientRecvError m s
  raftClientRecv = lift raftClientRecv
runRaftClientT :: Monad m => RaftClientEnv -> RaftClientState -> RaftClientT s v m a -> m a
runRaftClientT raftClientEnv raftClientState =
  flip evalStateT raftClientState . flip runReaderT raftClientEnv . unRaftClientT
data RaftClientError s v m where
  RaftClientSendError  :: RaftClientSendError m v -> RaftClientError s v m
  RaftClientRecvError  :: RaftClientRecvError m s -> RaftClientError s v m
  RaftClientTimeout    :: Text -> RaftClientError s v m
  RaftClientUnexpectedReadResp  :: ClientReadResp s v -> RaftClientError s v m
  RaftClientUnexpectedWriteResp :: ClientWriteResp s v -> RaftClientError s v m
  RaftClientUnexpectedMetricsResp :: ClientMetricsResp -> RaftClientError s v m
  RaftClientUnexpectedRedirect  :: ClientRedirResp -> RaftClientError s v m
deriving instance (Show s, Show v, Show (RaftClientSendError m v), Show (RaftClientRecvError m s), Show (RaftStateMachinePureError s v)) => Show (RaftClientError s v m)
clientRead
  :: (RaftClientSend m v, RaftClientRecv m s v)
  => ClientReadReq
  -> RaftClientT s v m (Either (RaftClientError s v m) (ClientReadResp s v))
clientRead crr = do
  eSend <- clientSendRead crr
  case eSend of
    Left err -> pure (Left (RaftClientSendError err))
    Right _ -> clientRecvRead
clientReadFrom
  :: (RaftClientSend m v, RaftClientRecv m s v)
  => NodeId
  -> ClientReadReq
  -> RaftClientT s v m (Either (RaftClientError s v m) (ClientReadResp s v))
clientReadFrom nid crr = do
  eSend <- clientSendReadTo nid crr
  case eSend of
    Left err -> pure (Left (RaftClientSendError err))
    Right _ -> clientRecvRead
clientReadTimeout
  :: (MonadBaseControl IO m, RaftClientSend m v, RaftClientRecv m s v)
  => Int
  -> ClientReadReq
  -> RaftClientT s v m (Either (RaftClientError s v m) (ClientReadResp s v))
clientReadTimeout t = clientTimeout "clientRead" t . clientRead
clientWrite
  :: (RaftClientSend m v, RaftClientRecv m s v)
  => v
  -> RaftClientT s v m (Either (RaftClientError s v m) (ClientWriteResp s v))
clientWrite cmd = do
  eSend <- clientSendWrite cmd
  case eSend of
    Left err -> pure (Left (RaftClientSendError err))
    Right _ -> clientRecvWrite
clientWriteTo
  :: (RaftClientSend m v, RaftClientRecv m s v)
  => NodeId
  -> v
  -> RaftClientT s v m (Either (RaftClientError s v m) (ClientWriteResp s v))
clientWriteTo nid cmd = do
  eSend <- clientSendWriteTo nid cmd
  case eSend of
    Left err -> pure (Left (RaftClientSendError err))
    Right _ -> clientRecvWrite
clientWriteTimeout
  :: (MonadBaseControl IO m, RaftClientSend m v, RaftClientRecv m s v)
  => Int
  -> v
  -> RaftClientT s v m (Either (RaftClientError s v m) (ClientWriteResp s v))
clientWriteTimeout t cmd = clientTimeout "clientWrite" t (clientWrite cmd)
clientQueryNodeMetrics
  :: (MonadBaseControl IO m, RaftClientSend m v, RaftClientRecv m s v)
  => NodeId
  -> RaftClientT s v m (Either (RaftClientError s v m) ClientMetricsResp)
clientQueryNodeMetrics nid = do
  eQueryMetrics <- clientSendMetricsReqTo nid
  case eQueryMetrics of
    Left err -> pure (Left (RaftClientSendError err))
    Right () -> clientRecvMetrics
clientQueryNodeMetricsTimeout
  :: (MonadBaseControl IO m, RaftClientSend m v, RaftClientRecv m s v)
  => Int
  -> NodeId
  -> RaftClientT s v m (Either (RaftClientError s v m) ClientMetricsResp)
clientQueryNodeMetricsTimeout t nid =
  clientTimeout "clientQueryNodeMetrics" t (clientQueryNodeMetrics nid)
clientTimeout
  :: (MonadBaseControl IO m, RaftClientSend m v, RaftClientRecv m s v)
  => Text
  -> Int
  -> RaftClientT s v m (Either (RaftClientError s v m) r)
  -> RaftClientT s v m (Either (RaftClientError s v m) r)
clientTimeout fnm t r = do
  mRes <- timeout t r
  case mRes of
    Nothing -> pure (Left (RaftClientTimeout fnm))
    Just (Left err) -> pure (Left err)
    Just (Right cresp) -> pure (Right cresp)
retryOnRedirect
  :: MonadBaseControl IO m
  => RaftClientT s v m (Either (RaftClientError s v m) r)
  -> RaftClientT s v m (Either (RaftClientError s v m) r)
retryOnRedirect action = do
  eRes <- action
  case eRes of
    Left (RaftClientUnexpectedRedirect _) -> do
      threadDelay 10000
      retryOnRedirect action
    Left err -> pure (Left err)
    Right resp -> pure (Right resp)
clientSendRead
  :: RaftClientSend m v
  => ClientReadReq
  -> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendRead crr =
  clientSend (ClientReadReq crr)
clientSendReadTo
  :: RaftClientSend m v
  => NodeId
  -> ClientReadReq
  -> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendReadTo nid crr =
  clientSendTo nid (ClientReadReq crr)
clientSendWrite
  :: RaftClientSend m v
  => v
  -> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendWrite v = do
  gets raftClientSerialNum >>= \sn ->
    clientSend (ClientWriteReq (ClientCmdReq sn v))
clientSendWriteTo
  :: RaftClientSend m v
  => NodeId
  -> v
  -> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendWriteTo nid v =
  gets raftClientSerialNum >>= \sn ->
    clientSendTo nid (ClientWriteReq (ClientCmdReq sn v))
clientSendMetricsReqTo
  :: RaftClientSend m v
  => NodeId
  -> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendMetricsReqTo nid =
  clientSendTo nid (ClientMetricsReq ClientAllMetricsReq)
clientSend
  :: (RaftClientSend m v)
  => ClientReq v
  -> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSend creq = do
  currLeader <- gets raftClientCurrentLeader
  case currLeader of
    NoLeader -> clientSendRandom creq
    CurrentLeader (LeaderId nid) -> do
      cid <- asks raftClientId
      eRes <- raftClientSend nid (ClientRequest cid creq)
      case eRes of
        Left err -> clientSendRandom creq
        Right resp -> pure (Right resp)
clientSendTo
  :: RaftClientSend m v
  => NodeId
  -> ClientReq v
  -> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendTo nid creq = do
  cid <- asks raftClientId
  raftClientSend nid (ClientRequest cid creq)
clientSendRandom
  :: RaftClientSend m v
  => ClientReq v
  -> RaftClientT s v m (Either (RaftClientSendError m v) ())
clientSendRandom creq = do
  cid <- asks raftClientId
  raftNodes <- gets raftClientRaftNodes
  randomGen <- gets raftClientRandomGen
  let (idx, newRandomGen) = randomR (0, length raftNodes - 1) randomGen
  case atMay (toList raftNodes) idx of
    Nothing -> panic "No raft nodes known by client"
    Just nid -> do
      modify $ \s -> s { raftClientRandomGen = newRandomGen }
      raftClientSend nid (ClientRequest cid creq)
clientRecvWrite
  :: (RaftClientSend m v, RaftClientRecv m s v)
  => RaftClientT s v m (Either (RaftClientError s v m) (ClientWriteResp s v))
clientRecvWrite = do
  eRes <- clientRecv
  case eRes of
    Left err -> pure (Left (RaftClientRecvError err))
    Right cresp ->
      case cresp of
        ClientRedirectResponse crr -> pure (Left (RaftClientUnexpectedRedirect crr))
        ClientReadResponse crr -> pure (Left (RaftClientUnexpectedReadResp crr))
        ClientMetricsResponse cmr -> pure (Left (RaftClientUnexpectedMetricsResp cmr))
        ClientWriteResponse cwr -> pure (Right cwr)
clientRecvRead
  :: (RaftClientSend m v, RaftClientRecv m s v)
  => RaftClientT s v m (Either (RaftClientError s v m) (ClientReadResp s v))
clientRecvRead = do
  eRes <- clientRecv
  case eRes of
    Left err -> pure (Left (RaftClientRecvError err))
    Right cresp -> do
      case cresp of
        ClientRedirectResponse crr -> pure (Left (RaftClientUnexpectedRedirect crr))
        ClientWriteResponse cwr -> pure (Left (RaftClientUnexpectedWriteResp cwr))
        ClientMetricsResponse cmr -> pure (Left (RaftClientUnexpectedMetricsResp cmr))
        ClientReadResponse crr -> pure (Right crr)
clientRecvMetrics
  :: (RaftClientSend m v, RaftClientRecv m s v)
  => RaftClientT s v m (Either (RaftClientError s v m) ClientMetricsResp)
clientRecvMetrics = do
  eRes <- clientRecv
  case eRes of
    Left err -> pure (Left (RaftClientRecvError err))
    Right cresp ->
      case cresp of
        ClientRedirectResponse crr -> pure (Left (RaftClientUnexpectedRedirect crr))
        ClientWriteResponse cwr -> pure (Left (RaftClientUnexpectedWriteResp cwr))
        ClientReadResponse crr -> pure (Left (RaftClientUnexpectedReadResp crr))
        ClientMetricsResponse cmr -> pure (Right cmr)
clientRecv
  :: RaftClientRecv m s v
  => RaftClientT s v m (Either (RaftClientRecvError m s) (ClientResponse s v))
clientRecv = do
  ecresp <- raftClientRecv
  case ecresp of
    Left err -> pure (Left err)
    Right cresp ->
      case cresp of
        ClientWriteResponse cwr ->
          case cwr of
            ClientWriteRespSuccess _ serial@(SerialNum n) -> do
              SerialNum m <- gets raftClientSerialNum
              if m == n
              then do
                modify $ \s -> s
                  { raftClientSerialNum = SerialNum (succ m) }
                pure (Right cresp)
              else handleLowerSerialNum serial
            ClientWriteRespFail serial@(SerialNum n) err -> do
              SerialNum m <- gets raftClientSerialNum
              if m == n
              then pure (Right cresp)
              else handleLowerSerialNum serial
        ClientRedirectResponse (ClientRedirResp currLdr) -> do
          modify $ \s -> s
            { raftClientCurrentLeader = currLdr }
          pure (Right cresp)
        _ -> pure (Right cresp)
  where
    
    
    
    
    
    handleLowerSerialNum (SerialNum n) = do
      SerialNum m <- gets raftClientSerialNum
      if n < m
      then clientRecv
      else do
        let errMsg = "Received invalid serial number response: Expected " <> show m <> " but got " <> show n
        panic $ errMsg
clientAddNode :: Monad m => NodeId -> RaftClientT s v m ()
clientAddNode nid = modify $ \s ->
  s { raftClientRaftNodes = Set.insert nid (raftClientRaftNodes s) }
clientGetNodes :: Monad m => RaftClientT s v m (Set NodeId)
clientGetNodes = gets raftClientRaftNodes