module Network.Legion.Runtime (
forkLegionary,
runLegionary,
StartupMode(..),
) where
import Control.Concurrent (forkIO)
import Control.Concurrent.Chan (writeChan, newChan, Chan)
import Control.Concurrent.MVar (newEmptyMVar, takeMVar, putMVar)
import Control.Monad (void, forever, join)
import Control.Monad.Catch (catchAll, try, SomeException, throwM)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Logger (logWarn, logError, logInfo, LoggingT,
MonadLoggerIO, runLoggingT, askLoggerIO)
import Control.Monad.Trans.Class (lift)
import Data.Binary (encode)
import Data.Conduit (Source, ($$), (=$=), yield, await, awaitForever,
transPipe, ConduitM, runConduit)
import Data.Conduit.Network (sourceSocket)
import Data.Conduit.Serialization.Binary (conduitDecode)
import Data.Map (Map)
import Data.Text (pack)
import Network.Legion.Admin (runAdmin)
import Network.Legion.Application (LegionConstraints, Legionary,
RequestMsg)
import Network.Legion.BSockAddr (BSockAddr(BSockAddr))
import Network.Legion.ClusterState (ClusterPowerState)
import Network.Legion.Conduit (merge, chanToSink, chanToSource)
import Network.Legion.ConnectionManager (newConnectionManager, send,
newPeers)
import Network.Legion.Distribution (Peer, newPeer)
import Network.Legion.Fork (forkC)
import Network.Legion.LIO (LIO)
import Network.Legion.PartitionKey (PartitionKey)
import Network.Legion.Settings (LegionarySettings(LegionarySettings,
adminHost, adminPort, peerBindAddr, joinBindAddr))
import Network.Legion.StateMachine (stateMachine, LInput(J, P, R,
A), JoinRequest(JoinRequest), JoinResponse(JoinOk, JoinRejected),
LOutput(Send, NewPeers), AdminMessage, NodeState, PeerMessage,
newNodeState)
import Network.Legion.UUID (getUUID)
import Network.Socket (Family(AF_INET, AF_INET6, AF_UNIX, AF_CAN),
SocketOption(ReuseAddr), SocketType(Stream), accept, bindSocket,
defaultProtocol, listen, setSocketOption, socket, SockAddr(SockAddrInet,
SockAddrInet6, SockAddrUnix, SockAddrCan), connect, getPeerName, Socket)
import Network.Socket.ByteString.Lazy (sendAll)
import qualified Data.Conduit.List as CL
import qualified Network.Legion.ClusterState as C
runLegionary :: (LegionConstraints i o s)
=> Legionary i o s
-> LegionarySettings
-> StartupMode
-> Source IO (RequestMsg i o)
-> LoggingT IO ()
runLegionary
legionary
settings@LegionarySettings {adminHost, adminPort}
startupMode
requestSource
= do
peerS <- loggingC =<< startPeerListener settings
(nodeState, peers) <- makeNodeState settings startupMode
cm <- newConnectionManager peers
$(logInfo) . pack
$ "The initial node state is: " ++ show nodeState
adminS <- loggingC =<< runAdmin adminPort adminHost
joinS <- loggingC (joinMsgSource settings)
runConduit $
(joinS `merge` (peerS `merge` (requestSource `merge` adminS)))
=$= CL.map toMessage
=$= stateMachine legionary nodeState
=$= handleOutput cm
where
handleOutput cm = awaitForever (lift . \case
Send peer message -> send cm peer message
NewPeers peers -> newPeers cm peers
)
toMessage
:: Either
(JoinRequest, JoinResponse -> LIO ())
(Either
(PeerMessage i o s)
(Either (RequestMsg i o) (AdminMessage i o s)))
-> LInput i o s
toMessage (Left m) = J m
toMessage (Right (Left m)) = P m
toMessage (Right (Right (Left m))) = R m
toMessage (Right (Right (Right m))) = A m
loggingC :: ConduitM i o LIO r -> LIO (ConduitM i o IO r)
loggingC c = do
logging <- askLoggerIO
return (transPipe (`runLoggingT` logging) c)
data StartupMode
= NewCluster
| JoinCluster SockAddr
deriving (Show, Eq)
startPeerListener :: (LegionConstraints i o s)
=> LegionarySettings
-> LIO (Source LIO (PeerMessage i o s))
startPeerListener LegionarySettings {peerBindAddr} =
catchAll (do
(inputChan, so) <- lift $ do
inputChan <- newChan
so <- socket (fam peerBindAddr) Stream defaultProtocol
setSocketOption so ReuseAddr 1
bindSocket so peerBindAddr
listen so 5
return (inputChan, so)
forkC "peer socket acceptor" $ acceptLoop so inputChan
return (chanToSource inputChan)
) (\err -> do
$(logError) . pack
$ "Couldn't start incomming peer message service, because of: "
++ show (err :: SomeException)
throwM err
)
where
acceptLoop :: (LegionConstraints i o s)
=> Socket
-> Chan (PeerMessage i o s)
-> LIO ()
acceptLoop so inputChan =
catchAll (
forever $ do
(conn, _) <- lift $ accept so
remoteAddr <- lift $ getPeerName conn
logging <- askLoggerIO
let runSocket =
sourceSocket conn
=$= conduitDecode
$$ msgSink
void
. lift
. forkIO
. (`runLoggingT` logging)
. logErrors remoteAddr
$ runSocket
) (\err -> do
$(logError) . pack
$ "error in peer message accept loop: "
++ show (err :: SomeException)
throwM err
)
where
msgSink = chanToSink inputChan
logErrors :: SockAddr -> LIO () -> LIO ()
logErrors remoteAddr io = do
result <- try io
case result of
Left err ->
$(logWarn) . pack
$ "Incomming peer connection (" ++ show remoteAddr
++ ") crashed because of: " ++ show (err :: SomeException)
Right v -> return v
makeNodeState :: (LegionConstraints i o s)
=> LegionarySettings
-> StartupMode
-> LIO (NodeState i o s, Map Peer BSockAddr)
makeNodeState LegionarySettings {peerBindAddr} NewCluster = do
self <- newPeer
clusterId <- getUUID
let cluster = C.new clusterId self peerBindAddr
nodeState <- newNodeState self cluster
return (nodeState, C.getPeers cluster)
makeNodeState LegionarySettings {peerBindAddr} (JoinCluster addr) = do
$(logInfo) "Trying to join an existing cluster."
(self, clusterPS) <- joinCluster (JoinRequest (BSockAddr peerBindAddr))
let cluster = C.initProp self clusterPS
nodeState <- newNodeState self cluster
return (nodeState, C.getPeers cluster)
where
joinCluster :: JoinRequest -> LIO (Peer, ClusterPowerState)
joinCluster joinMsg = liftIO $ do
so <- socket (fam addr) Stream defaultProtocol
connect so addr
sendAll so (encode joinMsg)
sourceSocket so =$= conduitDecode $$ do
response <- await
case response of
Nothing -> fail
$ "Couldn't join a cluster because there was no response "
++ "to our join request!"
Just (JoinOk self cps) ->
return (self, cps)
Just (JoinRejected reason) -> fail
$ "The cluster would not allow us to re-join. "
++ "The reason given was: " ++ show reason
joinMsgSource
:: LegionarySettings
-> Source LIO (JoinRequest, JoinResponse -> LIO ())
joinMsgSource LegionarySettings {joinBindAddr} = join . lift $
catchAll (do
(chan, so) <- lift $ do
chan <- newChan
so <- socket (fam joinBindAddr) Stream defaultProtocol
setSocketOption so ReuseAddr 1
bindSocket so joinBindAddr
listen so 5
return (chan, so)
forkC "join socket acceptor" $ acceptLoop so chan
return (chanToSource chan)
) (\err -> do
$(logError) . pack
$ "Couldn't start join request service, because of: "
++ show (err :: SomeException)
throwM err
)
where
acceptLoop :: Socket -> Chan (JoinRequest, JoinResponse -> LIO ()) -> LIO ()
acceptLoop so chan =
catchAll (
forever $ do
(conn, _) <- lift $ accept so
logging <- askLoggerIO
(void . lift . forkIO . (`runLoggingT` logging) . logErrors) (
sourceSocket conn
=$= conduitDecode
=$= attachResponder conn
$$ chanToSink chan
)
) (\err -> do
$(logError) . pack
$ "error in join request accept loop: "
++ show (err :: SomeException)
throwM err
)
where
logErrors :: LIO () -> LIO ()
logErrors io = do
result <- try io
case result of
Left err ->
$(logWarn) . pack
$ "Incomming join connection crashed because of: "
++ show (err :: SomeException)
Right v -> return v
attachResponder
:: Socket
-> ConduitM JoinRequest (JoinRequest, JoinResponse -> LIO ()) LIO ()
attachResponder conn = awaitForever (\msg -> do
mvar <- liftIO newEmptyMVar
yield (msg, lift . putMVar mvar)
response <- liftIO $ takeMVar mvar
liftIO $ sendAll conn (encode response)
)
fam :: SockAddr -> Family
fam SockAddrInet {} = AF_INET
fam SockAddrInet6 {} = AF_INET6
fam SockAddrUnix {} = AF_UNIX
fam SockAddrCan {} = AF_CAN
forkLegionary :: (LegionConstraints i o s, MonadLoggerIO io)
=> Legionary i o s
-> LegionarySettings
-> StartupMode
-> io (PartitionKey -> i -> IO o)
forkLegionary legionary settings startupMode = do
logging <- askLoggerIO
liftIO . (`runLoggingT` logging) $ do
chan <- liftIO newChan
forkC "main legion thread" $
runLegionary legionary settings startupMode (chanToSource chan)
return (\ key request -> do
responseVar <- newEmptyMVar
writeChan chan ((key, request), putMVar responseVar)
takeMVar responseVar
)