{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Network.MQTT.Client (
MQTTConfig(..), MQTTClient, QoS(..), Topic, mqttConfig, mkLWT, LastWill(..),
ProtocolLevel(..), Property(..), SubOptions(..), subOptions, MessageCallback(..),
waitForClient,
connectURI, isConnected,
disconnect, normalDisconnect,
subscribe, unsubscribe, publish, publishq, pubAliased,
svrProps, connACK, MQTTException(..),
runMQTTConduit, MQTTConduit, isConnectedSTM, connACKSTM,
registerCorrelated, unregisterCorrelated
) where
import Control.Concurrent (myThreadId, threadDelay)
import Control.Concurrent.Async (Async, async, asyncThreadId, cancelWith, link, race_, wait, waitAnyCancel)
import Control.Concurrent.STM (STM, TChan, TVar, atomically, check, modifyTVar', newTChan, newTChanIO,
newTVarIO, orElse, readTChan, readTVar, readTVarIO, registerDelay, retry,
writeTChan, writeTVar)
import Control.DeepSeq (force)
import qualified Control.Exception as E
import Control.Monad (forever, guard, unless, void, when)
import Control.Monad.IO.Class (liftIO)
import Data.Bifunctor (first)
import qualified Data.ByteString.Char8 as BCS
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Lazy.Char8 as BC
import Data.Conduit (ConduitT, Void, await, runConduit, yield, (.|))
import Data.Conduit.Attoparsec (conduitParser)
import qualified Data.Conduit.Combinators as C
import Data.Conduit.Network (AppData, appSink, appSource, clientSettings, runTCPClient)
import Data.Conduit.Network.TLS (runTLSClient, tlsClientConfig, tlsClientTLSSettings)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import qualified Data.Text.Encoding as TE
import Data.Word (Word16)
import GHC.Conc (labelThread)
import Network.Connection (ConnectionParams (..), TLSSettings (..), connectTo, connectionClose,
connectionGetChunk, connectionPut, initConnectionContext)
import Network.URI (URI (..), unEscapeString, uriPort, uriRegName, uriUserInfo)
import qualified Network.WebSockets as WS
import Network.WebSockets.Stream (makeStream)
import System.IO.Error (catchIOError, isEOFError)
import System.Timeout (timeout)
import Network.MQTT.Topic (Filter, Topic)
import Network.MQTT.Types as T
data ConnState = Starting
| Connected
| Stopped
| Disconnected
| DiscoErr DisconnectRequest
| ConnErr ConnACKFlags deriving (Eq, Show)
data DispatchType = DSubACK | DUnsubACK | DPubACK | DPubREC | DPubREL | DPubCOMP
deriving (Eq, Show, Ord, Enum, Bounded)
data MessageCallback = NoCallback
| SimpleCallback (MQTTClient -> Topic -> BL.ByteString -> [Property] -> IO ())
| LowLevelCallback (MQTTClient -> PublishRequest -> IO ())
data MQTTClient = MQTTClient {
_ch :: TChan MQTTPkt
, _pktID :: TVar Word16
, _cb :: MessageCallback
, _acks :: TVar (Map (DispatchType,Word16) (TChan MQTTPkt))
, _inflight :: TVar (Map Word16 PublishRequest)
, _st :: TVar ConnState
, _ct :: TVar (Async ())
, _outA :: TVar (Map Topic Word16)
, _inA :: TVar (Map Word16 Topic)
, _connACKFlags :: TVar ConnACKFlags
, _corr :: TVar (Map BL.ByteString MessageCallback)
}
data MQTTConfig = MQTTConfig{
_cleanSession :: Bool
, _lwt :: Maybe LastWill
, _msgCB :: MessageCallback
, _protocol :: ProtocolLevel
, _connProps :: [Property]
, _hostname :: String
, _port :: Int
, _connID :: String
, _username :: Maybe String
, _password :: Maybe String
, _connectTimeout :: Int
, _tlsSettings :: TLSSettings
}
mqttConfig :: MQTTConfig
mqttConfig = MQTTConfig{_hostname="", _port=1883, _connID="",
_username=Nothing, _password=Nothing,
_cleanSession=True, _lwt=Nothing,
_msgCB=NoCallback,
_protocol=Protocol311, _connProps=mempty,
_connectTimeout=180000000,
_tlsSettings=TLSSettingsSimple False False False}
connectURI :: MQTTConfig -> URI -> IO MQTTClient
connectURI cfg@MQTTConfig{..} uri = do
let cf = case uriScheme uri of
"mqtt:" -> runClient
"mqtts:" -> runClientTLS
"ws:" -> runWS uri False
"wss:" -> runWS uri True
us -> mqttFail $ "invalid URI scheme: " <> us
(Just a) = uriAuthority uri
(u,p) = up (uriUserInfo a)
v <- namedTimeout "MQTT connect" _connectTimeout $
cf cfg{Network.MQTT.Client._connID=cid _protocol (uriFragment uri),
_hostname=uriRegName a, _port=port (uriPort a) (uriScheme uri),
Network.MQTT.Client._username=u, Network.MQTT.Client._password=p}
case v of
Nothing -> mqttFail $ "connection to " <> show uri <> " timed out"
Just x -> pure x
where
port "" "mqtt:" = 1883
port "" "mqtts:" = 8883
port "" "ws:" = 80
port "" "wss:" = 443
port x _ = (read . tail) x
cid _ ['#'] = ""
cid _ ('#':xs) = xs
cid _ _ = ""
up "" = (Nothing, Nothing)
up x = let (u,r) = break (== ':') (init x) in
(Just (unEscapeString u), if r == "" then Nothing else Just (unEscapeString $ tail r))
runClient :: MQTTConfig -> IO MQTTClient
runClient cfg@MQTTConfig{..} = tcpCompat (runTCPClient (clientSettings _port (BCS.pack _hostname))) cfg
runClientTLS :: MQTTConfig -> IO MQTTClient
runClientTLS cfg@MQTTConfig{..} = tcpCompat (runTLSClient tlsConf) cfg
where tlsConf = (tlsClientConfig _port (BCS.pack _hostname)) {tlsClientTLSSettings=_tlsSettings}
tcpCompat :: ((AppData -> IO ()) -> IO ()) -> MQTTConfig -> IO MQTTClient
tcpCompat mkconn = runMQTTConduit (adapt mkconn)
where adapt mk f = mk (f . adaptor)
adaptor ad = (appSource ad, appSink ad)
runWS :: URI -> Bool -> MQTTConfig -> IO MQTTClient
runWS URI{uriPath, uriQuery} secure cfg@MQTTConfig{..} =
runMQTTConduit (adapt $ cf secure _hostname _port endpoint WS.defaultConnectionOptions hdrs) cfg
where
hdrs = [("Sec-WebSocket-Protocol", "mqtt")]
adapt mk f = mk (f . adaptor)
adaptor s = (wsSource s, wsSink s)
endpoint = uriPath <> uriQuery
cf :: Bool -> String -> Int -> String -> WS.ConnectionOptions -> WS.Headers -> WS.ClientApp () -> IO ()
cf False = WS.runClientWith
cf True = runWSS
wsSource :: WS.Connection -> ConduitT () BCS.ByteString IO ()
wsSource ws = forever $ do
bs <- liftIO $ WS.receiveData ws
unless (BCS.null bs) $ yield bs
wsSink :: WS.Connection -> ConduitT BCS.ByteString Void IO ()
wsSink ws = maybe (pure ()) (\bs -> liftIO (WS.sendBinaryData ws bs) >> wsSink ws) =<< await
runWSS :: String -> Int -> String -> WS.ConnectionOptions -> WS.Headers -> WS.ClientApp () -> IO ()
runWSS host port path options hdrs' app = do
let connectionParams = ConnectionParams
{ connectionHostname = host
, connectionPort = toEnum port
, connectionUseSecure = Just _tlsSettings
, connectionUseSocks = Nothing
}
context <- initConnectionContext
E.bracket (connectTo context connectionParams) connectionClose
(\conn -> do
stream <- makeStream (reader conn) (writer conn)
WS.runClientWithStream stream host path options hdrs' app)
where
reader conn =
catchIOError (Just <$> connectionGetChunk conn)
(\e -> if isEOFError e then pure Nothing else E.throwIO e)
writer conn = maybe (pure ()) (connectionPut conn . BC.toStrict)
pingPeriod :: Int
pingPeriod = 30000000
mqttFail :: String -> a
mqttFail = E.throw . MQTTException
namedAsync :: String -> IO a -> IO (Async a)
namedAsync s a = async a >>= \p -> labelThread (asyncThreadId p) s >> pure p
namedTimeout :: String -> Int -> IO a -> IO (Maybe a)
namedTimeout n to a = timeout to (myThreadId >>= \tid -> labelThread tid n >> a)
type MQTTConduit = (ConduitT () BCS.ByteString IO (), ConduitT BCS.ByteString Void IO ())
runMQTTConduit :: ((MQTTConduit -> IO ()) -> IO ())
-> MQTTConfig
-> IO MQTTClient
runMQTTConduit mkconn MQTTConfig{..} = do
_ch <- newTChanIO
_pktID <- newTVarIO 1
_acks <- newTVarIO mempty
_inflight <- newTVarIO mempty
_st <- newTVarIO Starting
_ct <- newTVarIO undefined
_outA <- newTVarIO mempty
_inA <- newTVarIO mempty
_connACKFlags <- newTVarIO (ConnACKFlags NewSession ConnUnspecifiedError mempty)
_corr <- newTVarIO mempty
let _cb = _msgCB
cli = MQTTClient{..}
t <- namedAsync "MQTT clientThread" $ clientThread cli
s <- atomically (waitForLaunch cli t)
when (s == Disconnected) $ wait t
atomically $ checkConnected cli
pure cli
where
clientThread cli = E.finally connectAndRun markDisco
where
connectAndRun = mkconn $ \ad -> start cli ad >>= run ad
markDisco = atomically $ do
st <- readTVar (_st cli)
guard $ st == Starting || st == Connected
writeTVar (_st cli) Disconnected
start c@MQTTClient{..} (_,sink) = do
void . runConduit $ do
let req = connectRequest{T._connID=BC.pack _connID,
T._lastWill=_lwt,
T._username=BC.pack <$> _username,
T._password=BC.pack <$> _password,
T._cleanSession=_cleanSession,
T._connProperties=_connProps}
yield (BL.toStrict $ toByteString _protocol req) .| sink
pure c
run (src,sink) c@MQTTClient{..} = do
pch <- newTChanIO
o <- namedAsync "MQTT out" $ onceConnected >> processOut
p <- namedAsync "MQTT ping" $ onceConnected >> doPing
w <- namedAsync "MQTT watchdog" $ watchdog pch
s <- namedAsync "MQTT in" $ doSrc pch
void $ waitAnyCancel [o, p, w, s]
where
doSrc pch = runConduit $ src
.| conduitParser (parsePacket _protocol)
.| C.mapM_ (\(_,x) -> liftIO (dispatch c pch x))
onceConnected = atomically $ check . (== Connected) =<< readTVar _st
processOut = runConduit $
C.repeatM (liftIO (atomically $ checkConnected c >> readTChan _ch))
.| C.map (BL.toStrict . toByteString _protocol)
.| sink
doPing = forever $ threadDelay pingPeriod >> sendPacketIO c PingPkt
watchdog ch = forever $ do
toch <- registerDelay (pingPeriod * 3)
timedOut <- atomically $ ((check =<< readTVar toch) >> pure True) `orElse` (readTChan ch >> pure False)
when timedOut $ killConn c Timeout
waitForLaunch MQTTClient{..} t = do
writeTVar _ct t
c <- readTVar _st
if c == Starting then retry else pure c
waitForClient :: MQTTClient -> IO ()
waitForClient c@MQTTClient{..} = do
wait =<< readTVarIO _ct
e <- atomically $ stateX c Stopped
case e of
Nothing -> pure ()
Just x -> E.throwIO x
stateX :: MQTTClient -> ConnState -> STM (Maybe E.SomeException)
stateX MQTTClient{..} want = f <$> readTVar _st
where
je = Just . E.toException . MQTTException
f :: ConnState -> Maybe E.SomeException
f Connected = if want == Connected then Nothing else je "unexpectedly connected"
f Stopped = if want == Stopped then Nothing else je "unexpectedly stopped"
f Disconnected = je "disconnected"
f Starting = je "died while starting"
f (DiscoErr x) = Just . E.toException . Discod $ x
f (ConnErr e) = je (show e)
data MQTTException = Timeout | BadData | Discod DisconnectRequest | MQTTException String deriving(Eq, Show)
instance E.Exception MQTTException
dispatch :: MQTTClient -> TChan Bool -> MQTTPkt -> IO ()
dispatch c@MQTTClient{..} pch pkt =
case pkt of
(ConnACKPkt p) -> connACKd p
(PublishPkt p) -> pub p
(SubACKPkt (SubscribeResponse i _ _)) -> delegate DSubACK i
(UnsubACKPkt (UnsubscribeResponse i _ _)) -> delegate DUnsubACK i
(PubACKPkt (PubACK i _ _)) -> delegate DPubACK i
(PubRELPkt (PubREL i _ _)) -> pubd i
(PubRECPkt (PubREC i _ _)) -> delegate DPubREC i
(PubCOMPPkt (PubCOMP i _ _)) -> delegate DPubCOMP i
(DisconnectPkt req) -> disco req
PongPkt -> atomically . writeTChan pch $ True
x -> print x
where connACKd connr@(ConnACKFlags _ val _) = case val of
ConnAccepted -> atomically $ do
writeTVar _connACKFlags connr
writeTVar _st Connected
_ -> do
t <- readTVarIO _ct
atomically $ writeTVar _st (ConnErr connr)
cancelWith t (MQTTException $ show connr)
pub p@PublishRequest{_pubQoS=QoS0} = atomically (resolve p) >>= notify Nothing
pub p@PublishRequest{_pubQoS=QoS1, _pubPktID} = do
notify (Just (PubACKPkt (PubACK _pubPktID 0 mempty))) =<< atomically (resolve p)
pub p@PublishRequest{_pubQoS=QoS2} = atomically $ do
p'@PublishRequest{..} <- resolve p
modifyTVar' _inflight (Map.insert _pubPktID p')
sendPacket c (PubRECPkt (PubREC _pubPktID 0 mempty))
pubd i = do
mp <- atomically $ do
r <- Map.lookup i <$> readTVar _inflight
modifyTVar' _inflight (Map.delete i)
pure r
case mp of
Nothing -> sendPacketIO c (PubCOMPPkt (PubCOMP i 0x92 mempty))
Just p -> notify (Just (PubCOMPPkt (PubCOMP i 0 mempty))) p
notify rpkt p@PublishRequest{..} = do
atomically $ modifyTVar' _inflight (Map.delete _pubPktID)
corrs <- readTVarIO _corr
E.evaluate . force =<< case maybe _cb (\cd -> Map.findWithDefault _cb cd corrs) cdata of
NoCallback -> pure ()
SimpleCallback f -> call (f c (blToText _pubTopic) _pubBody _pubProps)
LowLevelCallback f -> call (f c p)
where
call a = link =<< namedAsync "notifier" (a >> respond)
respond = void $ traverse (sendPacketIO c) rpkt
cdata = foldr f Nothing _pubProps
where f (PropCorrelationData x) _ = Just x
f _ o = o
resolve p@PublishRequest{..} = do
topic <- resolveTopic (foldr aliasID Nothing _pubProps)
pure p{_pubTopic=textToBL topic}
where
aliasID (PropTopicAlias x) _ = Just x
aliasID _ o = o
resolveTopic Nothing = pure (blToText _pubTopic)
resolveTopic (Just x) = do
when (_pubTopic /= "") $ modifyTVar' _inA (Map.insert x (blToText _pubTopic))
m <- readTVar _inA
case Map.lookup x m of
Nothing -> mqttFail ("failed to lookup topic alias " <> show x)
Just t -> pure t
delegate dt pid = atomically $ do
m <- readTVar _acks
case Map.lookup (dt, pid) m of
Nothing -> nak dt
Just ch -> writeTChan ch pkt
where
nak DPubREC = sendPacket c (PubRELPkt (PubREL pid 0x92 mempty))
nak _ = pure ()
disco req = do
t <- readTVarIO _ct
atomically $ writeTVar _st (DiscoErr req)
cancelWith t (Discod req)
killConn :: E.Exception e => MQTTClient -> e -> IO ()
killConn MQTTClient{..} e = readTVarIO _ct >>= \t -> cancelWith t e
checkConnected :: MQTTClient -> STM ()
checkConnected mc = maybe (pure ()) E.throw =<< stateX mc Connected
isConnected :: MQTTClient -> IO Bool
isConnected = atomically . isConnectedSTM
isConnectedSTM :: MQTTClient -> STM Bool
isConnectedSTM MQTTClient{..} = (Connected ==) <$> readTVar _st
sendPacket :: MQTTClient -> MQTTPkt -> STM ()
sendPacket c@MQTTClient{..} p = checkConnected c >> writeTChan _ch p
sendPacketIO :: MQTTClient -> MQTTPkt -> IO ()
sendPacketIO c = atomically . sendPacket c
textToBL :: Text -> BL.ByteString
textToBL = BL.fromStrict . TE.encodeUtf8
blToText :: BL.ByteString -> Text
blToText = TE.decodeUtf8 . BL.toStrict
reservePktID :: MQTTClient -> [DispatchType] -> STM (TChan MQTTPkt, Word16)
reservePktID c@MQTTClient{..} dts = do
checkConnected c
ch <- newTChan
pid <- readTVar _pktID
modifyTVar' _pktID $ if pid == maxBound then const 1 else succ
modifyTVar' _acks (Map.union (Map.fromList [((t, pid), ch) | t <- dts]))
pure (ch,pid)
releasePktID :: MQTTClient -> (DispatchType,Word16) -> STM ()
releasePktID c@MQTTClient{..} k = checkConnected c >> modifyTVar' _acks (Map.delete k)
releasePktIDs :: MQTTClient -> [(DispatchType,Word16)] -> STM ()
releasePktIDs c@MQTTClient{..} ks = checkConnected c >> modifyTVar' _acks deleteMany
where deleteMany m = foldr Map.delete m ks
sendAndWait :: MQTTClient -> DispatchType -> (Word16 -> MQTTPkt) -> IO MQTTPkt
sendAndWait c@MQTTClient{..} dt f = do
(ch,pid) <- atomically $ do
(ch,pid) <- reservePktID c [dt]
sendPacket c (f pid)
pure (ch,pid)
atomically $ do
st <- readTVar _st
when (st /= Connected) $ mqttFail "disconnected waiting for response"
releasePktID c (dt,pid)
readTChan ch
subscribe :: MQTTClient -> [(Filter, SubOptions)] -> [Property] -> IO ([Either SubErr QoS], [Property])
subscribe c@MQTTClient{..} ls props = do
r <- sendAndWait c DSubACK (\pid -> SubscribePkt $ SubscribeRequest pid ls' props)
let (SubACKPkt (SubscribeResponse _ rs aprops)) = r
pure (rs, aprops)
where ls' = map (first textToBL) ls
unsubscribe :: MQTTClient -> [Filter] -> [Property] -> IO ([UnsubStatus], [Property])
unsubscribe c@MQTTClient{..} ls props = do
(UnsubACKPkt (UnsubscribeResponse _ rsn rprop)) <- sendAndWait c DUnsubACK (\pid -> UnsubscribePkt $ UnsubscribeRequest pid (map textToBL ls) props)
pure (rprop, rsn)
publish :: MQTTClient
-> Topic
-> BL.ByteString
-> Bool
-> IO ()
publish c t m r = void $ publishq c t m r QoS0 mempty
publishq :: MQTTClient
-> Topic
-> BL.ByteString
-> Bool
-> QoS
-> [Property]
-> IO ()
publishq c t m r q props = do
(ch,pid) <- atomically $ reservePktID c types
E.finally (publishAndWait ch pid) (atomically $ releasePktIDs c [(t',pid) | t' <- types])
where
types = [DPubACK, DPubREC, DPubCOMP]
publishAndWait ch pid = do
sendPacketIO c (pkt pid)
when (q > QoS0) $ satisfyQoS ch pid
pkt pid = PublishPkt $ PublishRequest {_pubDup = False,
_pubQoS = q,
_pubPktID = pid,
_pubRetain = r,
_pubTopic = textToBL t,
_pubBody = m,
_pubProps = props}
satisfyQoS ch pid
| q == QoS0 = pure ()
| q == QoS1 = void $ do
(PubACKPkt (PubACK _ st pprops)) <- atomically $ checkConnected c >> readTChan ch
when (st /= 0) $ mqttFail ("qos 1 publish error: " <> show st <> " " <> show pprops)
| q == QoS2 = waitRec
| otherwise = error "invalid QoS"
where
waitRec = do
rpkt <- atomically $ checkConnected c >> readTChan ch
case rpkt of
PubRECPkt (PubREC _ st recprops) -> do
when (st /= 0) $ mqttFail ("qos 2 REC publish error: " <> show st <> " " <> show recprops)
sendPacketIO c (PubRELPkt $ PubREL pid 0 mempty)
PubCOMPPkt (PubCOMP _ st' compprops) ->
when (st' /= 0) $ mqttFail ("qos 2 COMP publish error: " <> show st' <> " " <> show compprops)
wtf -> mqttFail ("unexpected packet received in QoS2 publish: " <> show wtf)
disconnect :: MQTTClient -> DiscoReason -> [Property] -> IO ()
disconnect c@MQTTClient{..} reason props = race_ getDisconnected orDieTrying
where
getDisconnected = do
sendPacketIO c (DisconnectPkt $ DisconnectRequest reason props)
wait =<< readTVarIO _ct
atomically $ writeTVar _st Stopped
orDieTrying = threadDelay 10000000 >> killConn c Timeout
normalDisconnect :: MQTTClient -> IO ()
normalDisconnect c = disconnect c DiscoNormalDisconnection mempty
mkLWT :: Topic -> BL.ByteString -> Bool -> T.LastWill
mkLWT t m r = T.LastWill{
T._willRetain=r,
T._willQoS=QoS0,
T._willTopic = textToBL t,
T._willMsg=m,
T._willProps=mempty
}
svrProps :: MQTTClient -> IO [Property]
svrProps mc = p <$> atomically (connACKSTM mc)
where p (ConnACKFlags _ _ props) = props
connACKSTM :: MQTTClient -> STM ConnACKFlags
connACKSTM MQTTClient{_connACKFlags} = readTVar _connACKFlags
connACK :: MQTTClient -> IO ConnACKFlags
connACK = atomically . connACKSTM
maxAliases :: MQTTClient -> IO Word16
maxAliases mc = foldr f 0 <$> svrProps mc
where
f (PropTopicAliasMaximum n) _ = n
f _ o = o
pubAliased :: MQTTClient
-> Topic
-> BL.ByteString
-> Bool
-> QoS
-> [Property]
-> IO ()
pubAliased c@MQTTClient{..} t m r q props = do
x <- maxAliases c
(t', n) <- alias x
let np = props <> case n of
0 -> mempty
_ -> [PropTopicAlias n]
publishq c t' m r q np
where
alias mv = atomically $ do
as <- readTVar _outA
let n = toEnum (length as + 1)
cur = Map.lookup t as
v = fromMaybe (if n > mv then 0 else n) cur
when (v > 0) $ writeTVar _outA (Map.insert t v as)
pure (maybe t (const "") cur, v)
registerCorrelated :: MQTTClient -> BL.ByteString -> MessageCallback -> STM ()
registerCorrelated MQTTClient{_corr} bs cb = modifyTVar' _corr (Map.insert bs cb)
unregisterCorrelated :: MQTTClient -> BL.ByteString -> STM ()
unregisterCorrelated MQTTClient{_corr} bs = modifyTVar' _corr (Map.delete bs)