{-# Language OverloadedStrings, GADTs, DataKinds #-} {-| Module: MQTT.Parsers Copyright: Lukas Braun 2014-2016 License: GPL-3 Maintainer: koomi+mqtt@hackerspace-bamberg.de Parsers for MQTT messages. -} module Network.MQTT.Parser where import Control.Monad import Control.Monad.Loops import Control.Monad.State.Strict import Control.Applicative import Data.Attoparsec.ByteString import Data.Bits import qualified Data.ByteString as BS import Data.Text.Encoding (decodeUtf8With) import Data.Text.Encoding.Error (lenientDecode) import Data.Word import Prelude hiding (takeWhile, take) import Network.MQTT.Types hiding (body) -- | Type of a parser that also keeps track of the remaining length. type MessageParser a = StateT Word32 Parser a -- | Parse any MQTT message. message :: Parser SomeMessage message = do (msgType, header) <- mqttHeader remaining <- parseRemaining msg <- withSomeSingI msgType $ \sMsgType -> SomeMessage . Message header <$> mqttBody header sMsgType remaining return msg --------------------------------- -- * Fixed Header --------------------------------- -- | Parser for the fixed header part of a MQTT message. mqttHeader :: Parser (MsgType, MqttHeader) mqttHeader = ctxt "mqttHeader" $ do byte1 <- anyWord8 qos <- toQoS $ 3 .&. shiftR byte1 1 let retain = testBit byte1 0 dup = testBit byte1 3 msgType = shiftR byte1 4 msgType' <- case msgType of 1 -> return CONNECT 2 -> return CONNACK 3 -> return PUBLISH 4 -> return PUBACK 5 -> return PUBREC 6 -> return PUBREL 7 -> return PUBCOMP 8 -> return SUBSCRIBE 9 -> return SUBACK 10 -> return UNSUBSCRIBE 11 -> return UNSUBACK 12 -> return PINGREQ 13 -> return PINGRESP 14 -> return DISCONNECT x -> fail $ "Invalid message type: " ++ show x return (msgType', Header dup qos retain) -- | Parse the 'remaining length' field that indicates how long the rest of -- the message is. parseRemaining :: Parser Word32 parseRemaining = ctxt "parseRemaining" $ do bytes <- takeWhile (> 0x7f) -- bytes with first bit set when (BS.length bytes > 3) $ fail "'Remaining length' field must not be longer than 4 bytes" stopByte <- anyWord8 let (factor, acc) = BS.foldl' f (1, 0) bytes return $ acc + factor * fromIntegral stopByte where f (factor, acc) byte = (factor*128, acc + factor * fromIntegral (0x7f .&. byte)) --------------------------------- -- * Message Body --------------------------------- -- | «@mqttBody header msgtype remaining@» parses a 'Message' of type -- @msgtype@ that is @remaining@ bytes long. mqttBody :: MqttHeader -> SMsgType t -> Word32 -> Parser (MessageBody t) mqttBody header msgType remaining = ctxt "mqttBody" $ let parser = case msgType of SCONNECT -> connect SCONNACK -> connAck SPUBLISH -> publish header SPUBACK -> PubAck <$> parseMsgID SPUBREC -> PubRec <$> parseMsgID SPUBREL -> PubRel <$> parseMsgID SPUBCOMP -> PubComp <$> parseMsgID SSUBSCRIBE -> subscribe SSUBACK -> subAck SUNSUBSCRIBE -> unsubscribe SUNSUBACK -> UnsubAck <$> parseMsgID SPINGREQ -> pure PingReq SPINGRESP -> pure PingResp SDISCONNECT -> pure Disconnect in evalStateT parser remaining connect :: MessageParser (MessageBody 'CONNECT) connect = ctxt' "connect" $ do protocol version flags <- anyWord8' let clean = testBit flags 1 willFlag = testBit flags 2 usernameFlag = testBit flags 7 passwordFlag = testBit flags 6 keepAlive <- anyWord16BE clientID <- getClientID mWill <- parseIf willFlag $ Will (testBit flags 5) <$> toQoS (3 .&. shiftR flags 3) <*> (ctxt' "Will Topic" $ fmap toTopic mqttText) <*> (ctxt' "Will Message" mqttText) username <- ctxt' "Username" $ parseIf usernameFlag mqttText password <- ctxt' "Password" $ parseIf passwordFlag mqttText return $ Connect clean mWill clientID username password keepAlive where protocol = ctxt' "protocol" $ do prot <- mqttText when (prot /= "MQIsdp") $ fail $ "Invalid protocol: " ++ show prot version = ctxt' "version" $ do version <- anyWord8' when (version /= 3) $ fail $ "Invalid version: " ++ show version getClientID = ctxt' "getClientID" $ do before <- get clientID <- mqttText after <- get let len = before - after - 2 -- 2 for length prefix when (len > 23) $ fail $ "Client ID must not be longer than 23 chars: " ++ show (text clientID) ++ " (" ++ show len ++ ")" return clientID parseIf :: Applicative f => Bool -> f a -> f (Maybe a) parseIf flag parser = if flag then Just <$> parser else pure Nothing connAck :: MessageParser (MessageBody 'CONNACK) connAck = ctxt' "connAck" $ anyWord8' {- reserved -} *> (ConnAck <$> anyWord8') publish :: MqttHeader -> MessageParser (MessageBody 'PUBLISH) publish header = ctxt' "publish" $ Publish <$> getTopic <*> (if qos header > NoConfirm then Just <$> parseMsgID else return Nothing) <*> (get >>= take') subscribe :: MessageParser (MessageBody 'SUBSCRIBE) subscribe = ctxt' "subscribe" $ Subscribe <$> parseMsgID <*> whileM ((0 <) <$> get) ((,) <$> getTopic <*> (anyWord8' >>= toQoS)) subAck :: MessageParser (MessageBody 'SUBACK) subAck = ctxt' "subAck" $ SubAck <$> parseMsgID <*> whileM ((0 <) <$> get) (anyWord8' >>= toQoS) unsubscribe :: MessageParser (MessageBody 'UNSUBSCRIBE) unsubscribe = ctxt' "unsubscribe" $ Unsubscribe <$> parseMsgID <*> whileM ((0 <) <$> get) getTopic --------------------------------- -- * Utility functions --------------------------------- -- | Parse a topic name. getTopic :: MessageParser Topic getTopic = ctxt' "getTopic" $ toTopic <$> mqttText -- | Parse a length-prefixed UTF-8 string. mqttText :: MessageParser MqttText mqttText = ctxt' "mqttText" $ MqttText . decodeUtf8With lenientDecode <$> (anyWord16BE >>= take') -- | Synonym for 'anyWord16BE'. parseMsgID :: MessageParser Word16 parseMsgID = ctxt' "parseMsgID" anyWord16BE -- | Parse a big-endian 16bit integer. anyWord16BE :: (Num a, Bits a) => MessageParser a anyWord16BE = do msb <- anyWord8' lsb <- anyWord8' return $ shiftL (fromIntegral msb) 8 .|. fromIntegral lsb -- | A lifted version of attoparsec's 'anyWord8' that also subtracts 1 from -- the remaining length. anyWord8' :: MessageParser Word8 anyWord8' = parseLength 1 >> lift anyWord8 ctxt :: String -> Parser a -> Parser a ctxt = flip () ctxt' :: String -> MessageParser a -> MessageParser a ctxt' = mapStateT . ctxt -- | A lifted version of attoparsec's 'take' that also subtracts the -- length. take' :: Word32 -> MessageParser BS.ByteString take' n = parseLength n >> lift (take (fromIntegral n)) -- | Subtract 'n' from the remaining length or 'fail' if there is not -- enough input left. parseLength :: Word32 -> MessageParser () parseLength n = do rem <- get if rem < n then fail "Reached remaining = 0 before end of message." else put $ rem - n -- | Convert a number to a 'QoS'. 'fail' if the number can't be converted. toQoS :: (Num a, Eq a, Show a, Monad m) => a -> m QoS toQoS 0 = return NoConfirm toQoS 1 = return Confirm toQoS 2 = return Handshake toQoS x = fail $ "Invalid QoS value: " ++ show x