{-# Language OverloadedStrings, GADTs, DataKinds #-}
module Network.MQTT.Parser where
import Control.Monad
import Control.Monad.Loops
import Control.Monad.State.Strict
import Control.Applicative
import Data.Attoparsec.ByteString
import Data.Bits
import qualified Data.ByteString as BS
import Data.Text.Encoding (decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Data.Word
import Prelude hiding (takeWhile, take)
import Network.MQTT.Types hiding (body)
type MessageParser a = StateT Word32 Parser a
message :: Parser SomeMessage
message = do
    (msgType, header) <- mqttHeader
    remaining <- parseRemaining
    msg <- withSomeSingI msgType $ \sMsgType ->
      SomeMessage . Message header <$> mqttBody header sMsgType remaining
    return msg
mqttHeader :: Parser (MsgType, MqttHeader)
mqttHeader = ctxt "mqttHeader" $ do
    byte1 <- anyWord8
    qos <- toQoS $ 3 .&. shiftR byte1 1
    let retain = testBit byte1 0
        dup = testBit byte1 3
        msgType = shiftR byte1 4
    msgType' <- case msgType of
                  1  -> return CONNECT
                  2  -> return CONNACK
                  3  -> return PUBLISH
                  4  -> return PUBACK
                  5  -> return PUBREC
                  6  -> return PUBREL
                  7  -> return PUBCOMP
                  8  -> return SUBSCRIBE
                  9  -> return SUBACK
                  10 -> return UNSUBSCRIBE
                  11 -> return UNSUBACK
                  12 -> return PINGREQ
                  13 -> return PINGRESP
                  14 -> return DISCONNECT
                  x  -> fail $ "Invalid message type: " ++ show x
    return (msgType', Header dup qos retain)
parseRemaining :: Parser Word32
parseRemaining = ctxt "parseRemaining" $ do
    bytes <- takeWhile (> 0x7f) 
    when (BS.length bytes > 3) $
      fail "'Remaining length' field must not be longer than 4 bytes"
    stopByte <- anyWord8
    let (factor, acc) = BS.foldl' f (1, 0) bytes
    return $ acc + factor * fromIntegral stopByte
  where
    f (factor, acc) byte =
      (factor*128, acc + factor * fromIntegral (0x7f .&. byte))
mqttBody :: MqttHeader -> SMsgType t -> Word32 -> Parser (MessageBody t)
mqttBody header msgType remaining = ctxt "mqttBody" $
    let parser =
          case msgType of
            SCONNECT     -> connect
            SCONNACK     -> connAck
            SPUBLISH     -> publish header
            SPUBACK      -> PubAck  <$> parseMsgID
            SPUBREC      -> PubRec  <$> parseMsgID
            SPUBREL      -> PubRel  <$> parseMsgID
            SPUBCOMP     -> PubComp <$> parseMsgID
            SSUBSCRIBE   -> subscribe
            SSUBACK      -> subAck
            SUNSUBSCRIBE -> unsubscribe
            SUNSUBACK    -> UnsubAck <$> parseMsgID
            SPINGREQ     -> pure PingReq
            SPINGRESP    -> pure PingResp
            SDISCONNECT  -> pure Disconnect
    in evalStateT parser remaining
connect :: MessageParser (MessageBody 'CONNECT)
connect = ctxt' "connect" $ do
    protocol
    version
    flags <- anyWord8'
    let clean = testBit flags 1
        willFlag = testBit flags 2
        usernameFlag = testBit flags 7
        passwordFlag = testBit flags 6
    keepAlive <- anyWord16BE
    clientID <- getClientID
    mWill <- parseIf willFlag $
               Will (testBit flags 5)
                  <$> toQoS (3 .&. shiftR flags 3)
                  <*> (ctxt' "Will Topic" $ fmap toTopic mqttText)
                  <*> (ctxt' "Will Message" mqttText)
    username <- ctxt' "Username" $ parseIf usernameFlag mqttText
    password <- ctxt' "Password" $ parseIf passwordFlag mqttText
    return $ Connect clean mWill clientID username password keepAlive
  where
    protocol = ctxt' "protocol" $ do
      prot <- mqttText
      when (prot /= "MQIsdp") $
        fail $ "Invalid protocol: " ++ show prot
    version = ctxt' "version" $ do
      version <- anyWord8'
      when (version /= 3) $
        fail $ "Invalid version: " ++ show version
    getClientID = ctxt' "getClientID" $ do
      before <- get
      clientID <- mqttText
      after <- get
      let len = before - after - 2 
      when (len > 23) $
        fail $ "Client ID must not be longer than 23 chars: "
                  ++ show (text clientID) ++ " (" ++ show len ++ ")"
      return clientID
    parseIf :: Applicative f => Bool -> f a -> f (Maybe a)
    parseIf flag parser = if flag then Just <$> parser else pure Nothing
connAck :: MessageParser (MessageBody 'CONNACK)
connAck = ctxt' "connAck" $ anyWord8'  *> (ConnAck <$> anyWord8')
publish :: MqttHeader -> MessageParser (MessageBody 'PUBLISH)
publish header = ctxt' "publish" $ Publish
                  <$> getTopic
                  <*> (if qos header > NoConfirm
                         then Just <$> parseMsgID
                         else return Nothing)
                  <*> (get >>= take')
subscribe :: MessageParser (MessageBody 'SUBSCRIBE)
subscribe = ctxt' "subscribe" $ Subscribe
              <$> parseMsgID
              <*> whileM ((0 <) <$> get)
                    ((,) <$> getTopic <*> (anyWord8' >>= toQoS))
subAck :: MessageParser (MessageBody 'SUBACK)
subAck = ctxt' "subAck" $ SubAck
          <$> parseMsgID
          <*> whileM ((0 <) <$> get) (anyWord8' >>= toQoS)
unsubscribe :: MessageParser (MessageBody 'UNSUBSCRIBE)
unsubscribe = ctxt' "unsubscribe" $ Unsubscribe
                <$> parseMsgID
                <*> whileM ((0 <) <$> get) getTopic
getTopic :: MessageParser Topic
getTopic = ctxt' "getTopic" $ toTopic <$> mqttText
mqttText :: MessageParser MqttText
mqttText = ctxt' "mqttText" $
    MqttText . decodeUtf8With lenientDecode <$> (anyWord16BE >>= take')
parseMsgID :: MessageParser Word16
parseMsgID = ctxt' "parseMsgID" anyWord16BE
anyWord16BE :: (Num a, Bits a) => MessageParser a
anyWord16BE = do
    msb <- anyWord8'
    lsb <- anyWord8'
    return $ shiftL (fromIntegral msb) 8 .|. fromIntegral lsb
anyWord8' :: MessageParser Word8
anyWord8' = parseLength 1 >> lift anyWord8
ctxt :: String -> Parser a -> Parser a
ctxt = flip (<?>)
ctxt' :: String -> MessageParser a -> MessageParser a
ctxt' = mapStateT . ctxt
take' :: Word32 -> MessageParser BS.ByteString
take' n = parseLength n >> lift (take (fromIntegral n))
parseLength :: Word32 -> MessageParser ()
parseLength n = do
    rem <- get
    if rem < n
      then fail "Reached remaining = 0 before end of message."
      else put $ rem - n
toQoS :: (Num a, Eq a, Show a, Monad m) => a -> m QoS
toQoS 0 = return NoConfirm
toQoS 1 = return Confirm
toQoS 2 = return Handshake
toQoS x = fail $ "Invalid QoS value: " ++ show x