module Network.MQTT.Broker.Server
( serveConnection
, MQTT ()
, MqttServerTransportStack (..)
, SS.Server ( .. )
, SS.ServerConfig ( .. )
, SS.ServerConnection ( .. )
, SS.ServerException ( .. )
) where
import Control.Concurrent
import Control.Concurrent.Async
import qualified Control.Exception as E
import Control.Monad
import qualified Data.Binary.Get as SG
import qualified Data.ByteString as BS
import Data.Int
import Data.IORef
import Data.Typeable
import qualified Network.Stack.Server as SS
import qualified Network.WebSockets as WS
import qualified System.Log.Logger as Log
import qualified System.Socket as S
import Network.MQTT.Message
import Network.MQTT.Broker.Authentication
import qualified Network.MQTT.Broker.Internal as Session
import qualified Network.MQTT.Broker as Broker
import qualified Network.MQTT.Broker.Session as Session
instance (Typeable transport) => E.Exception (SS.ServerException (MQTT transport))
data MQTT transport
class SS.ServerStack a => MqttServerTransportStack a where
getConnectionRequest :: SS.ServerConnectionInfo a -> IO ConnectionRequest
instance (Typeable f, Typeable t, Typeable p, S.Family f, S.Protocol p, S.Type t, S.HasNameInfo f) => MqttServerTransportStack (S.Socket f t p) where
getConnectionRequest (SS.SocketServerConnectionInfo addr) = do
remoteAddr <- S.hostName <$> S.getNameInfo addr (S.niNumericHost `mappend` S.niNumericService)
pure ConnectionRequest {
requestClientIdentifier = ClientIdentifier mempty
, requestSecure = False
, requestCleanSession = True
, requestCredentials = Nothing
, requestHttp = Nothing
, requestCertificateChain = Nothing
, requestRemoteAddress = Just remoteAddr
}
instance (SS.StreamServerStack a, MqttServerTransportStack a) => MqttServerTransportStack (SS.WebSocket a) where
getConnectionRequest (SS.WebSocketServerConnectionInfo tci rh) = do
req <- getConnectionRequest tci
pure req {
requestHttp = Just (WS.requestPath rh, WS.requestHeaders rh)
}
instance (SS.StreamServerStack a, MqttServerTransportStack a) => MqttServerTransportStack (SS.TLS a) where
getConnectionRequest (SS.TlsServerConnectionInfo tci mcc) = do
req <- getConnectionRequest tci
pure req {
requestSecure = True
, requestCertificateChain = mcc
}
instance (SS.StreamServerStack transport) => SS.ServerStack (MQTT transport) where
data Server (MQTT transport) = MqttServer
{ mqttTransportServer :: SS.Server transport
}
data ServerConfig (MQTT transport) = MqttServerConfig
{ mqttTransportConfig :: SS.ServerConfig transport
}
data ServerConnection (MQTT transport) = MqttServerConnection
{ mqttTransportConnection :: SS.ServerConnection transport
, mqttTransportLeftover :: MVar BS.ByteString
}
data ServerConnectionInfo (MQTT transport) = MqttServerConnectionInfo
{ mqttTransportServerConnectionInfo :: SS.ServerConnectionInfo transport
}
data ServerException (MQTT transport)
= ProtocolViolation String
| MessageTooLong
| ConnectionRejected RejectReason
| KeepAliveTimeoutException
deriving (Eq, Ord, Show, Typeable)
withServer config handle =
SS.withServer (mqttTransportConfig config) $ \server->
handle (MqttServer server)
withConnection server handler =
SS.withConnection (mqttTransportServer server) $ \connection info->
flip handler (MqttServerConnectionInfo info) =<< MqttServerConnection
<$> pure connection
<*> newMVar mempty
instance (SS.StreamServerStack transport) => SS.MessageServerStack (MQTT transport) where
type ClientMessage (MQTT transport) = ClientPacket
type ServerMessage (MQTT transport) = ServerPacket
sendMessage connection =
SS.sendStreamBuilder (mqttTransportConnection connection) 8192 . serverPacketBuilder
sendMessages connection msgs =
SS.sendStreamBuilder (mqttTransportConnection connection) 8192 $ foldl (\b m-> b `mappend` serverPacketBuilder m) mempty msgs
receiveMessage connection maxMsgSize =
modifyMVar (mqttTransportLeftover connection) (execute 0 . SG.pushChunk decode)
where
fetch = SS.receiveStream (mqttTransportConnection connection) 4096
decode = SG.runGetIncremental clientPacketParser
execute received result
| received > maxMsgSize = E.throwIO (MessageTooLong :: SS.ServerException (MQTT transport))
| otherwise = case result of
SG.Partial continuation -> do
bs <- fetch
if BS.null bs
then execute received (continuation Nothing)
else execute (received + fromIntegral (BS.length bs)) (continuation $ Just bs)
SG.Fail _ _ failure ->
E.throwIO (ProtocolViolation failure :: SS.ServerException (MQTT transport))
SG.Done leftover' _ msg ->
pure (leftover', msg)
consumeMessages connection maxMsgSize consume =
modifyMVar_ (mqttTransportLeftover connection) (execute 0 . SG.pushChunk decode)
where
fetch = SS.receiveStream (mqttTransportConnection connection) 4096
decode = SG.runGetIncremental clientPacketParser
execute received result
| received > maxMsgSize = E.throwIO (MessageTooLong :: SS.ServerException (MQTT transport))
| otherwise = case result of
SG.Partial continuation -> do
bs <- fetch
if BS.null bs
then execute received (continuation Nothing)
else execute (received + fromIntegral (BS.length bs)) (continuation $ Just bs)
SG.Fail _ _ failure ->
E.throwIO (ProtocolViolation failure :: SS.ServerException (MQTT transport))
SG.Done leftover' _ msg -> do
done <- consume msg
if done
then pure leftover'
else execute 0 (SG.pushChunk decode leftover')
deriving instance Show (SS.ServerConnectionInfo transport) => Show (SS.ServerConnectionInfo (MQTT transport))
serveConnection :: forall transport auth. (SS.StreamServerStack transport, MqttServerTransportStack transport, Authenticator auth) => Broker.Broker auth -> SS.ServerConnection (MQTT transport) -> SS.ServerConnectionInfo (MQTT transport) -> IO ()
serveConnection broker conn connInfo = do
recentActivity <- newIORef True
req <- getConnectionRequest (mqttTransportServerConnectionInfo connInfo)
msg <- SS.receiveMessage conn maxInitialPacketSize
case msg of
ClientConnectUnsupported -> do
Log.warningM "Server.connection" $ "Connection from "
++ show (requestRemoteAddress req) ++ " rejected: UnacceptableProtocolVersion"
void $ SS.sendMessage conn (ServerConnectionRejected UnacceptableProtocolVersion)
ClientConnect {} -> do
let
sessionRejectHandler reason = do
Log.warningM "Server.connection" $ "Connection rejected: " ++ show reason
void $ SS.sendMessage conn (ServerConnectionRejected reason)
sessionAcceptHandler session sessionPresent@(SessionPresent sp) = do
principal <- Session.getPrincipal session
Log.infoM "Server.connection" $ "Connection accepted: Associated "
++ show principal ++ (if sp then " with existing session "
++ show (Session.sessionIdentifier session) ++ "." else " with new session.")
void $ SS.sendMessage conn (ServerConnectionAccepted sessionPresent)
foldl1 race_
[ handleInput recentActivity session
, handleOutput session
, keepAlive recentActivity (connectKeepAlive msg) session
] `E.catch` (\e-> do
Log.warningM "Server.connection" $"Session " ++ show (Session.sessionIdentifier session)
++ ": Connection terminated with exception: " ++ show (e :: E.SomeException)
E.throwIO e
)
Log.infoM "Server.connection" $
"Session " ++ show (Session.sessionIdentifier session) ++ ": Graceful disconnect."
request = req {
requestClientIdentifier = connectClientIdentifier msg
, requestCleanSession = cleanSession
, requestCredentials = connectCredentials msg
}
where
CleanSession cleanSession = connectCleanSession msg
Log.infoM "Server.connection" $ "Connection request: " ++ show request
Broker.withSession broker request sessionRejectHandler sessionAcceptHandler
_ -> pure ()
where
maxInitialPacketSize :: Int64
maxInitialPacketSize = 65535
keepAlive :: IORef Bool -> KeepAliveInterval -> Session.Session auth -> IO ()
keepAlive recentActivity (KeepAliveInterval interval) session = forever $ do
writeIORef recentActivity False
threadDelay regularInterval
activity <- readIORef recentActivity
unless activity $ do
threadDelay regularInterval
activity' <- readIORef recentActivity
unless activity' $ do
Log.warningM "Server.connection.keepAlive" $ "Session " ++ show (Session.sessionIdentifier session) ++ ": Client is overdue."
threadDelay regularInterval
activity'' <- readIORef recentActivity
unless activity'' $ E.throwIO (KeepAliveTimeoutException :: SS.ServerException (MQTT transport))
where
regularInterval = fromIntegral interval * 500000
handleInput :: IORef Bool -> Session.Session auth -> IO ()
handleInput recentActivity session = do
maxPacketSize <- fromIntegral . quotaMaxPacketSize . principalQuota <$> Session.getPrincipal session
SS.consumeMessages conn maxPacketSize $ \packet-> do
writeIORef recentActivity True
case packet of
ClientConnect {} ->
E.throwIO (ProtocolViolation "Unexpected CONN packet." :: SS.ServerException (MQTT transport))
ClientConnectUnsupported ->
E.throwIO (ProtocolViolation "Unexpected CONN packet (of unsupported protocol version)." :: SS.ServerException (MQTT transport))
ClientPublish pid dup msg -> do
Session.processPublish session pid dup msg
pure False
ClientPublishAcknowledged pid -> do
Session.processPublishAcknowledged session pid
pure False
ClientPublishReceived pid -> do
Session.processPublishReceived session pid
pure False
ClientPublishRelease pid -> do
Session.processPublishRelease session pid
pure False
ClientPublishComplete pid -> do
Session.processPublishComplete session pid
pure False
ClientSubscribe pid filters -> do
Session.subscribe session pid filters
pure False
ClientUnsubscribe pid filters -> do
Session.unsubscribe session pid filters
pure False
ClientPingRequest -> do
Session.enqueuePingResponse session
pure False
ClientDisconnect ->
pure True
handleOutput :: Session.Session auth -> IO ()
handleOutput session = forever $ do
Session.waitPending session
msgs <- Session.dequeue session
SS.sendMessages conn msgs