{-# Language OverloadedStrings, GADTs #-} {-| Module: MQTT.Parsers Copyright: Lukas Braun 2014 License: GPL-3 Maintainer: koomi+mqtt@hackerspace-bamberg.de Parsers for MQTT messages. -} module Network.MQTT.Parser where import Control.Monad import Control.Monad.Loops import Control.Monad.State.Strict import Control.Applicative import Data.Attoparsec.ByteString import Data.Bits import qualified Data.ByteString as BS import Data.Singletons (withSomeSing) import Data.Text.Encoding (decodeUtf8') import Data.Word import Prelude hiding (takeWhile, take) import Network.MQTT.Types hiding (body) -- | Type of a parser that also keeps track of the remaining length. type MessageParser a = StateT Word32 Parser a -- | Parse any MQTT message. message :: Parser SomeMessage message = do (msgType, header) <- mqttHeader remaining <- parseRemaining withSomeSing msgType $ \sMsgType -> SomeMessage . Message header <$> mqttBody header sMsgType remaining --------------------------------- -- * Fixed Header --------------------------------- -- | Parser for the fixed header part of a MQTT message. mqttHeader :: Parser (MsgType, MqttHeader) mqttHeader = do byte1 <- anyWord8 qos <- toQoS $ 3 .&. shiftR byte1 1 let retain = testBit byte1 0 dup = testBit byte1 3 msgType = shiftR byte1 4 msgType' <- case msgType of 1 -> return CONNECT 2 -> return CONNACK 3 -> return PUBLISH 4 -> return PUBACK 5 -> return PUBREC 6 -> return PUBREL 7 -> return PUBCOMP 8 -> return SUBSCRIBE 9 -> return SUBACK 10 -> return UNSUBSCRIBE 11 -> return UNSUBACK 12 -> return PINGREQ 13 -> return PINGRESP 14 -> return DISCONNECT x -> fail $ "Invalid message type: " ++ show x return (msgType', Header dup qos retain) -- | Parse the 'remaining length' field that indicates how long the rest of -- the message is. parseRemaining :: Parser Word32 parseRemaining = do bytes <- takeWhile (> 0x7f) -- bytes with first bit set when (BS.length bytes > 3) $ fail "'Remaining length' field must not be longer than 4 bytes" stopByte <- anyWord8 return $ snd $ BS.foldr' f (128, fromIntegral stopByte) bytes where f byte (factor, acc) = (factor*128, acc + factor * fromIntegral (0x7f .&. byte)) --------------------------------- -- * Message Body --------------------------------- -- | «@mqttBody header msgtype remaining@» parses a 'Message' of type -- @msgtype@ that is @remaining@ bytes long. mqttBody :: MqttHeader -> SMsgType t -> Word32 -> Parser (MessageBody t) mqttBody header msgType remaining = let parser = case msgType of SCONNECT -> MConnect <$> connect SCONNACK -> MConnAck <$> connAck SPUBLISH -> MPublish <$> publish header SPUBACK -> MPubAck <$> simpleMsg SPUBREC -> MPubRec <$> simpleMsg SPUBREL -> MPubRel <$> simpleMsg SPUBCOMP -> MPubComp <$> simpleMsg SSUBSCRIBE -> MSubscribe <$> subscribe SSUBACK -> MSubAck <$> subAck SUNSUBSCRIBE -> MUnsubscribe <$> unsubscribe SUNSUBACK -> MUnsubAck <$> simpleMsg SPINGREQ -> pure MPingReq SPINGRESP -> pure MPingResp SDISCONNECT -> pure MDisconnect in evalStateT parser remaining connect :: MessageParser Connect connect = do protocol version flags <- anyWord8' let clean = testBit flags 1 willFlag = testBit flags 2 usernameFlag = testBit flags 7 passwordFlag = testBit flags 6 keepAlive <- anyWord16BE clientID <- getClientID mWill <- parseIf willFlag $ Will (testBit flags 5) <$> toQoS (3 .&. shiftR flags 3) <*> fmap toTopic mqttText <*> mqttText username <- parseIf usernameFlag mqttText password <- parseIf passwordFlag mqttText return $ Connect clean mWill clientID username password keepAlive where protocol = do prot <- mqttText when (prot /= "MQIsdp") $ fail $ "Invalid protocol: " ++ show prot version = do version <- anyWord8' when (version /= 3) $ fail $ "Invalid version: " ++ show version getClientID = do before <- get clientID <- mqttText after <- get let len = before - after - 2 -- 2 for length prefix when (len > 23) $ fail $ "Client ID must not be longer than 23 chars: " ++ show (text clientID) ++ " (" ++ show len ++ ")" return clientID parseIf :: Applicative f => Bool -> f a -> f (Maybe a) parseIf flag parser = if flag then Just <$> parser else pure Nothing connAck :: MessageParser ConnAck connAck = ConnAck <$> anyWord8' publish :: MqttHeader -> MessageParser Publish publish header = Publish <$> getTopic <*> (if qos header > NoConfirm then Just <$> parseMsgID else return Nothing) <*> (get >>= take') subscribe :: MessageParser Subscribe subscribe = Subscribe <$> parseMsgID <*> whileM ((0 <) <$> get) ((,) <$> getTopic <*> (anyWord8' >>= toQoS)) subAck :: MessageParser SubAck subAck = SubAck <$> parseMsgID <*> whileM ((0 <) <$> get) (anyWord8' >>= toQoS) unsubscribe :: MessageParser Unsubscribe unsubscribe = Unsubscribe <$> parseMsgID <*> whileM ((0 <) <$> get) getTopic simpleMsg :: MessageParser SimpleMsg simpleMsg = SimpleMsg <$> parseMsgID --------------------------------- -- * Utility functions --------------------------------- -- | Parse a topic name. getTopic :: MessageParser Topic getTopic = toTopic <$> mqttText -- | Parse a length-prefixed UTF-8 string. mqttText :: MessageParser MqttText mqttText = do n <- anyWord16BE rslt <- decodeUtf8' <$> take' n case rslt of Left err -> fail $ "Invalid UTF-8: " ++ show err Right txt -> return $ MqttText txt -- | Synonym for 'anyWord16BE'. parseMsgID :: MessageParser Word16 parseMsgID = anyWord16BE -- | Parse a big-endian 16bit integer. anyWord16BE :: (Num a, Bits a) => MessageParser a anyWord16BE = do msb <- anyWord8' lsb <- anyWord8' return $ shiftL (fromIntegral msb) 8 .|. fromIntegral lsb -- | A lifted version of attoparsec's 'anyWord8' that also subtracts 1 from -- the remaining length. anyWord8' :: MessageParser Word8 anyWord8' = parseLength 1 >> lift anyWord8 -- | A lifted version of attoparsec's 'take' that also subtracts the -- length. take' :: Word32 -> MessageParser BS.ByteString take' n = parseLength n >> lift (take (fromIntegral n)) -- | Subtract 'n' from the remaining length or 'fail' if there is not -- enough input left. parseLength :: Word32 -> MessageParser () parseLength n = do rem <- get if rem < n then fail "Reached remaining = 0 before end of message." else put $ rem - n -- | Convert a number to a 'QoS'. 'fail' if the number can't be converted. toQoS :: (Num a, Eq a, Show a, Monad m) => a -> m QoS toQoS 0 = return NoConfirm toQoS 1 = return Confirm toQoS 2 = return Handshake toQoS x = fail $ "Invalid QoS value: " ++ show x