{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Examples.Raft.Socket.Node where
import Protolude
import Control.Monad.Fail
import Control.Monad.Catch
import Control.Monad.Trans.Class
import Control.Concurrent.STM.TChan
import Control.Concurrent.STM.TVar
import Control.Concurrent.STM.TMVar
import qualified Data.Map as Map
import qualified Data.Serialize as S
import qualified Network.Simple.TCP as NS
import Network.Simple.TCP (HostName, ServiceName)
import Examples.Raft.Socket.Common
import Raft.Client
import Raft.Event
import Raft.Log
import Raft.Monad
import Raft.Persistent
import Raft.RPC
import Raft.StateMachine
import Raft.Types
data ResponseSignal sm v
= OkResponse (ClientResponse sm v)
| DeadResponse
data NodeSocketEnv sm v = NodeSocketEnv
{ nsMsgQueue :: TChan (RPCMessage v)
, nsClientReqQueue :: TChan (ClientRequest v)
, nsClientReqResps :: TVar (Map ClientId (TMVar (ResponseSignal sm v)))
}
newtype RaftSocketT sm v m a = RaftSocketT { unRaftSocketT :: ReaderT (NodeSocketEnv sm v) m a }
deriving (Functor, Applicative, Monad, MonadIO, MonadFail, MonadReader (NodeSocketEnv sm v), Alternative, MonadPlus, MonadTrans)
deriving instance MonadThrow m => MonadThrow (RaftSocketT sm v m)
deriving instance MonadCatch m => MonadCatch (RaftSocketT sm v m)
deriving instance MonadMask m => MonadMask (RaftSocketT sm v m)
instance (RaftStateMachinePure sm v, MonadMask m, MonadCatch m, MonadIO m, S.Serialize sm, S.Serialize v) => RaftSendClient (RaftSocketT sm v m) sm v where
sendClient clientId msg = do
NodeSocketEnv{..} <- ask
mRespVar <- liftIO . atomically . fmap (Map.lookup clientId) . readTVar $ nsClientReqResps
case mRespVar of
Nothing -> liftIO $ putText "sendClient: response lookup failed"
Just respVar -> liftIO . atomically . putTMVar respVar . OkResponse $ msg
instance (MonadIO m, S.Serialize v) => RaftRecvClient (RaftSocketT sm v m) v where
type RaftRecvClientError (RaftSocketT sm v m) v = Text
receiveClient = do
cReq <- asks nsClientReqQueue
fmap Right . liftIO . atomically $ readTChan cReq
instance (MonadCatch m, MonadMask m, MonadIO m, S.Serialize v, Show v) => RaftSendRPC (RaftSocketT sm v m) v where
sendRPC nid msg = do
eRes <- Control.Monad.Catch.try $
NS.connect host port $ \(sock,_) -> do
NS.send sock (S.encode $ RPCMessageEvent msg)
case eRes of
Left (err :: SomeException) -> putText ("Failed to send RPC: " <> show err)
Right _ -> pure ()
where
(host, port) = nidToHostPort nid
instance (MonadIO m, Show v) => RaftRecvRPC (RaftSocketT sm v m) v where
type RaftRecvRPCError (RaftSocketT sm v m) v = Text
receiveRPC = do
msgQueue <- asks nsMsgQueue
fmap Right . liftIO . atomically $ readTChan msgQueue
runRaftSocketT :: MonadIO m => NodeSocketEnv sm v -> RaftSocketT sm v m a -> m a
runRaftSocketT nodeSocketEnv = flip runReaderT nodeSocketEnv . unRaftSocketT
acceptConnections
:: forall sm v m.
( S.Serialize sm, S.Serialize v, S.Serialize (RaftStateMachinePureError sm v)
, Show (RaftStateMachinePureError sm v)
, MonadIO m
)
=> HostName
-> ServiceName
-> RaftSocketT sm v m ()
acceptConnections host port = do
NodeSocketEnv{..} <- ask
NS.serve (NS.Host host) port $ \(sock, _) -> do
mVal <- recvSerialized sock
case mVal of
Nothing -> putText "Socket was closed on the other end"
Just (ClientRequestEvent req@(ClientRequest clientId _)) -> do
respVar <- atomically $ do
newRespVar <- newEmptyTMVar
clientReqResps <- readTVar nsClientReqResps
when (Map.member clientId clientReqResps) $
case Map.lookup clientId clientReqResps of
Nothing -> pure ()
Just reqVar -> putTMVar reqVar DeadResponse
writeTVar nsClientReqResps . Map.insert clientId newRespVar $ clientReqResps
pure newRespVar
atomically $ writeTChan nsClientReqQueue req
respMsg <- atomically $ takeTMVar respVar
case respMsg of
OkResponse okResp -> NS.send sock (S.encode okResp)
DeadResponse -> pure ()
atomically $ do
modifyTVar nsClientReqResps $ \_nsClientReqResps ->
Map.delete clientId _nsClientReqResps
Just (RPCMessageEvent msg) ->
atomically $ writeTChan nsMsgQueue msg
instance (MonadIO m, RaftPersist m) => RaftPersist (RaftSocketT sm v m) where
type RaftPersistError (RaftSocketT sm v m) = RaftPersistError m
initializePersistentState = lift initializePersistentState
writePersistentState ps = lift $ writePersistentState ps
readPersistentState = lift readPersistentState
instance (MonadIO m, RaftInitLog m v) => RaftInitLog (RaftSocketT sm v m) v where
type RaftInitLogError (RaftSocketT sm v m) = RaftInitLogError m
initializeLog p = lift $ initializeLog p
instance RaftWriteLog m v => RaftWriteLog (RaftSocketT sm v m) v where
type RaftWriteLogError (RaftSocketT sm v m) = RaftWriteLogError m
writeLogEntries entries = lift $ writeLogEntries entries
instance RaftReadLog m v => RaftReadLog (RaftSocketT sm v m) v where
type RaftReadLogError (RaftSocketT sm v m) = RaftReadLogError m
readLogEntry idx = lift $ readLogEntry idx
readLastLogEntry = lift readLastLogEntry
instance RaftDeleteLog m v => RaftDeleteLog (RaftSocketT sm v m) v where
type RaftDeleteLogError (RaftSocketT sm v m) = RaftDeleteLogError m
deleteLogEntriesFrom idx = lift $ deleteLogEntriesFrom idx
instance RaftStateMachine m sm v => RaftStateMachine (RaftSocketT sm v m) sm v where
validateCmd = lift . validateCmd
askRaftStateMachinePureCtx = lift askRaftStateMachinePureCtx
instance MonadRaftChan v m => MonadRaftChan v (RaftSocketT sm v m) where
type RaftEventChan v (RaftSocketT sm v m) = RaftEventChan v m
readRaftChan = lift . readRaftChan
writeRaftChan chan = lift . writeRaftChan chan
newRaftChan = lift (newRaftChan @v @m)
instance (MonadIO m, MonadRaftFork m) => MonadRaftFork (RaftSocketT sm v m) where
type RaftThreadId (RaftSocketT sm v m) = RaftThreadId m
raftFork r m = do
persistFile <- ask
lift $ raftFork r (runRaftSocketT persistFile m)