module Network.MQTT.Parser where
import Control.Monad
import Control.Monad.Loops
import Control.Monad.State.Strict
import Control.Applicative
import Data.Attoparsec.ByteString
import Data.Bits
import qualified Data.ByteString as BS
import Data.Singletons (withSomeSing)
import Data.Text.Encoding (decodeUtf8')
import Data.Word
import Prelude hiding (takeWhile, take)
import Network.MQTT.Types hiding (body)
type MessageParser a = StateT Word32 Parser a
message :: Parser SomeMessage
message = do
(msgType, header) <- mqttHeader
remaining <- parseRemaining
withSomeSing msgType $ \sMsgType ->
SomeMessage . Message header <$> mqttBody header sMsgType remaining
mqttHeader :: Parser (MsgType, MqttHeader)
mqttHeader = do
byte1 <- anyWord8
qos <- toQoS $ 3 .&. shiftR byte1 1
let retain = testBit byte1 0
dup = testBit byte1 3
msgType = shiftR byte1 4
msgType' <- case msgType of
1 -> return CONNECT
2 -> return CONNACK
3 -> return PUBLISH
4 -> return PUBACK
5 -> return PUBREC
6 -> return PUBREL
7 -> return PUBCOMP
8 -> return SUBSCRIBE
9 -> return SUBACK
10 -> return UNSUBSCRIBE
11 -> return UNSUBACK
12 -> return PINGREQ
13 -> return PINGRESP
14 -> return DISCONNECT
x -> fail $ "Invalid message type: " ++ show x
return (msgType', Header dup qos retain)
parseRemaining :: Parser Word32
parseRemaining = do
bytes <- takeWhile (> 0x7f)
when (BS.length bytes > 3) $
fail "'Remaining length' field must not be longer than 4 bytes"
stopByte <- anyWord8
return $ snd $ BS.foldr' f (128, fromIntegral stopByte) bytes
where
f byte (factor, acc) =
(factor*128, acc + factor * fromIntegral (0x7f .&. byte))
mqttBody :: MqttHeader -> SMsgType t -> Word32 -> Parser (MessageBody t)
mqttBody header msgType remaining =
let parser =
case msgType of
SCONNECT -> MConnect <$> connect
SCONNACK -> MConnAck <$> connAck
SPUBLISH -> MPublish <$> publish header
SPUBACK -> MPubAck <$> simpleMsg
SPUBREC -> MPubRec <$> simpleMsg
SPUBREL -> MPubRel <$> simpleMsg
SPUBCOMP -> MPubComp <$> simpleMsg
SSUBSCRIBE -> MSubscribe <$> subscribe
SSUBACK -> MSubAck <$> subAck
SUNSUBSCRIBE -> MUnsubscribe <$> unsubscribe
SUNSUBACK -> MUnsubAck <$> simpleMsg
SPINGREQ -> pure MPingReq
SPINGRESP -> pure MPingResp
SDISCONNECT -> pure MDisconnect
in evalStateT parser remaining
connect :: MessageParser Connect
connect = do
protocol
version
flags <- anyWord8'
let clean = testBit flags 1
willFlag = testBit flags 2
usernameFlag = testBit flags 7
passwordFlag = testBit flags 6
keepAlive <- anyWord16BE
clientID <- getClientID
mWill <- parseIf willFlag $
Will (testBit flags 5)
<$> toQoS (3 .&. shiftR flags 3)
<*> fmap toTopic mqttText
<*> mqttText
username <- parseIf usernameFlag mqttText
password <- parseIf passwordFlag mqttText
return $ Connect clean mWill clientID username password keepAlive
where
protocol = do
prot <- mqttText
when (prot /= "MQIsdp") $
fail $ "Invalid protocol: " ++ show prot
version = do
version <- anyWord8'
when (version /= 3) $
fail $ "Invalid version: " ++ show version
getClientID = do
before <- get
clientID <- mqttText
after <- get
let len = before after 2
when (len > 23) $
fail $ "Client ID must not be longer than 23 chars: "
++ show (text clientID) ++ " (" ++ show len ++ ")"
return clientID
parseIf :: Applicative f => Bool -> f a -> f (Maybe a)
parseIf flag parser = if flag then Just <$> parser else pure Nothing
connAck :: MessageParser ConnAck
connAck = ConnAck <$> anyWord8'
publish :: MqttHeader -> MessageParser Publish
publish header = Publish
<$> getTopic
<*> (if qos header > NoConfirm
then Just <$> parseMsgID
else return Nothing)
<*> (get >>= take')
subscribe :: MessageParser Subscribe
subscribe = Subscribe
<$> parseMsgID
<*> whileM ((0 <) <$> get)
((,) <$> getTopic <*> (anyWord8' >>= toQoS))
subAck :: MessageParser SubAck
subAck = SubAck
<$> parseMsgID
<*> whileM ((0 <) <$> get) (anyWord8' >>= toQoS)
unsubscribe :: MessageParser Unsubscribe
unsubscribe = Unsubscribe
<$> parseMsgID
<*> whileM ((0 <) <$> get) getTopic
simpleMsg :: MessageParser SimpleMsg
simpleMsg = SimpleMsg <$> parseMsgID
getTopic :: MessageParser Topic
getTopic = toTopic <$> mqttText
mqttText :: MessageParser MqttText
mqttText = do
n <- anyWord16BE
rslt <- decodeUtf8' <$> take' n
case rslt of
Left err -> fail $ "Invalid UTF-8: " ++ show err
Right txt -> return $ MqttText txt
parseMsgID :: MessageParser Word16
parseMsgID = anyWord16BE
anyWord16BE :: (Num a, Bits a) => MessageParser a
anyWord16BE = do
msb <- anyWord8'
lsb <- anyWord8'
return $ shiftL (fromIntegral msb) 8 .|. fromIntegral lsb
anyWord8' :: MessageParser Word8
anyWord8' = parseLength 1 >> lift anyWord8
take' :: Word32 -> MessageParser BS.ByteString
take' n = parseLength n >> lift (take (fromIntegral n))
parseLength :: Word32 -> MessageParser ()
parseLength n = do
rem <- get
if rem < n
then fail "Reached remaining = 0 before end of message."
else put $ rem n
toQoS :: (Num a, Eq a, Show a, Monad m) => a -> m QoS
toQoS 0 = return NoConfirm
toQoS 1 = return Confirm
toQoS 2 = return Handshake
toQoS x = fail $ "Invalid QoS value: " ++ show x