{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} -- | -- Module : Pinch.Protocol.Binary -- Copyright : (c) Abhinav Gupta 2015 -- License : BSD3 -- -- Maintainer : Abhinav Gupta -- Stability : experimental -- -- Implements the Thrift Binary Protocol as a 'Protocol'. module Pinch.Protocol.Binary (binaryProtocol) where #if __GLASGOW_HASKELL__ < 709 import Control.Applicative #endif import Control.Monad import Data.Bits (shiftR, (.&.)) import Data.ByteString (ByteString) import Data.HashMap.Strict (HashMap) import Data.Int (Int16, Int32, Int8) import Data.Monoid import qualified Data.ByteString as B import qualified Data.HashMap.Strict as M import qualified Data.Text.Encoding as TE import Pinch.Internal.Builder (Builder) import Pinch.Internal.Message import Pinch.Internal.Parser (Parser, runParser, runParser') import Pinch.Internal.TType import Pinch.Internal.Value import Pinch.Protocol (Protocol (..)) import qualified Pinch.Internal.Builder as BB import qualified Pinch.Internal.FoldList as FL import qualified Pinch.Internal.Parser as P -- | Provides an implementation of the Thrift Binary Protocol. binaryProtocol :: Protocol binaryProtocol = Protocol { serializeValue = binarySerialize , deserializeValue' = binaryDeserialize ttype , serializeMessage = binarySerializeMessage , deserializeMessage = binaryDeserializeMessage } ------------------------------------------------------------------------------ binarySerializeMessage :: Message -> Builder binarySerializeMessage msg = string (TE.encodeUtf8 $ messageName msg) <> BB.int8 (messageCode (messageType msg)) <> BB.int32BE (messageId msg) <> binarySerialize (messagePayload msg) binaryDeserializeMessage :: ByteString -> Either String Message binaryDeserializeMessage = runParser binaryMessageParser binaryMessageParser :: Parser Message binaryMessageParser = do size <- P.int32 if size < 0 then parseStrict size else parseNonStrict size where -- versionAndType:4 name~4 seqid:4 payload -- versionAndType = version:2 0x00 type:1 parseStrict versionAndType = do unless (version == 1) $ fail $ "Unsupported version: " ++ show version Message <$> TE.decodeUtf8 <$> (P.int32 >>= P.take . fromIntegral) <*> typ <*> P.int32 <*> binaryParser ttype where version = (0x7fff0000 .&. versionAndType) `shiftR` 16 code = fromIntegral $ 0x00ff .&. versionAndType typ = case fromMessageCode code of Nothing -> fail $ "Unknown message type: " ++ show code Just t -> return t -- name~4 type:1 seqid:4 payload parseNonStrict nameLength = Message <$> TE.decodeUtf8 <$> P.take (fromIntegral nameLength) <*> parseMessageType <*> P.int32 <*> binaryParser ttype parseMessageType :: Parser MessageType parseMessageType = P.int8 >>= \code -> case fromMessageCode code of Nothing -> fail $ "Unknown message type: " ++ show code Just t -> return t ------------------------------------------------------------------------------ binaryDeserialize :: TType a -> ByteString -> Either String (ByteString, Value a) binaryDeserialize t = runParser' (binaryParser t) binaryParser :: TType a -> Parser (Value a) binaryParser typ = case typ of TBool -> parseBool TByte -> parseByte TDouble -> parseDouble TInt16 -> parseInt16 TInt32 -> parseInt32 TInt64 -> parseInt64 TBinary -> parseBinary TStruct -> parseStruct TMap -> parseMap TSet -> parseSet TList -> parseList getTType :: Int8 -> Parser SomeTType getTType code = maybe (fail $ "Unknown TType: " ++ show code) return $ fromTypeCode code parseTType :: Parser SomeTType parseTType = P.int8 >>= getTType parseBool :: Parser (Value TBool) parseBool = VBool . (== 1) <$> P.int8 parseByte :: Parser (Value TByte) parseByte = VByte <$> P.int8 parseDouble :: Parser (Value TDouble) parseDouble = VDouble <$> P.double parseInt16 :: Parser (Value TInt16) parseInt16 = VInt16 <$> P.int16 parseInt32 :: Parser (Value TInt32) parseInt32 = VInt32 <$> P.int32 parseInt64 :: Parser (Value TInt64) parseInt64 = VInt64 <$> P.int64 parseBinary :: Parser (Value TBinary) parseBinary = VBinary <$> (P.int32 >>= P.take . fromIntegral) parseMap :: Parser (Value TMap) parseMap = do ktype' <- parseTType vtype' <- parseTType count <- P.int32 case (ktype', vtype') of (SomeTType ktype, SomeTType vtype) -> do items <- FL.replicateM (fromIntegral count) $ MapItem <$> binaryParser ktype <*> binaryParser vtype return $ VMap items parseSet :: Parser (Value TSet) parseSet = do vtype' <- parseTType count <- P.int32 case vtype' of SomeTType vtype -> VSet <$> FL.replicateM (fromIntegral count) (binaryParser vtype) parseList :: Parser (Value TList) parseList = do vtype' <- parseTType count <- P.int32 case vtype' of SomeTType vtype -> VList <$> FL.replicateM (fromIntegral count) (binaryParser vtype) parseStruct :: Parser (Value TStruct) parseStruct = P.int8 >>= loop M.empty where loop :: HashMap Int16 SomeValue -> Int8 -> Parser (Value TStruct) loop fields 0 = return $ VStruct fields loop fields code = do vtype' <- getTType code fieldId <- P.int16 case vtype' of SomeTType vtype -> do value <- SomeValue <$> binaryParser vtype loop (M.insert fieldId value fields) =<< P.int8 ------------------------------------------------------------------------------ binarySerialize :: forall a. IsTType a => Value a -> Builder binarySerialize = case (ttype :: TType a) of TBinary -> serializeBinary TBool -> serializeBool TByte -> serializeByte TDouble -> serializeDouble TInt16 -> serializeInt16 TInt32 -> serializeInt32 TInt64 -> serializeInt64 TStruct -> serializeStruct TList -> serializeList TMap -> serializeMap TSet -> serializeSet {-# INLINE binarySerialize #-} serializeBinary :: Value TBinary -> Builder serializeBinary (VBinary x) = string x {-# INLINE serializeBinary #-} serializeBool :: Value TBool -> Builder serializeBool (VBool x) = BB.int8 $ if x then 1 else 0 {-# INLINE serializeBool #-} serializeByte :: Value TByte -> Builder serializeByte (VByte x) = BB.int8 x {-# INLINE serializeByte #-} serializeDouble :: Value TDouble -> Builder serializeDouble (VDouble x) = BB.doubleBE x {-# INLINE serializeDouble #-} serializeInt16 :: Value TInt16 -> Builder serializeInt16 (VInt16 x) = BB.int16BE x {-# INLINE serializeInt16 #-} serializeInt32 :: Value TInt32 -> Builder serializeInt32 (VInt32 x) = BB.int32BE x {-# INLINE serializeInt32 #-} serializeInt64 :: Value TInt64 -> Builder serializeInt64 (VInt64 x) = BB.int64BE x {-# INLINE serializeInt64 #-} serializeList :: Value TList -> Builder serializeList (VList xs) = serializeCollection ttype xs {-# INLINE serializeList #-} serializeSet :: Value TSet -> Builder serializeSet (VSet xs) = serializeCollection ttype xs {-# INLINE serializeSet #-} serializeStruct :: Value TStruct -> Builder serializeStruct (VStruct fields) = M.foldlWithKey' (\rest fid (SomeValue val) -> rest <> writeField fid ttype val) mempty fields <> BB.int8 0 where writeField :: IsTType a => Int16 -> TType a -> Value a -> Builder writeField fieldId fieldType fieldValue = typeCode fieldType <> BB.int16BE fieldId <> binarySerialize fieldValue {-# INLINE writeField #-} {-# INLINE serializeStruct #-} serializeMap :: Value TMap -> Builder serializeMap VNullMap = error "serializeMap: VNullMap" serializeMap (VMap items) = serialize ttype ttype items where serialize :: (IsTType k, IsTType v) => TType k -> TType v -> FL.FoldList (MapItem k v) -> Builder serialize kt vt xs = typeCode kt <> typeCode vt <> BB.int32BE size <> body where (body, size) = FL.foldl' go (mempty, 0 :: Int32) xs go (prev, !c) (MapItem k v) = ( prev <> binarySerialize k <> binarySerialize v , c + 1 ) {-# INLINE serializeMap #-} serializeCollection :: IsTType a => TType a -> FL.FoldList (Value a) -> Builder serializeCollection vtype xs = let go (prev, !c) item = (prev <> binarySerialize item, c + 1) (body, size) = FL.foldl' go (mempty, 0 :: Int32) xs in typeCode vtype <> BB.int32BE size <> body {-# INLINE serializeCollection #-} ------------------------------------------------------------------------------ messageCode :: MessageType -> Int8 messageCode Call = 1 messageCode Reply = 2 messageCode Exception = 3 messageCode Oneway = 4 {-# INLINE messageCode #-} fromMessageCode :: Int8 -> Maybe MessageType fromMessageCode 1 = Just Call fromMessageCode 2 = Just Reply fromMessageCode 3 = Just Exception fromMessageCode 4 = Just Oneway fromMessageCode _ = Nothing {-# INLINE fromMessageCode #-} -- | Map a TType to its type code. toTypeCode :: TType a -> Int8 toTypeCode TBool = 2 toTypeCode TByte = 3 toTypeCode TDouble = 4 toTypeCode TInt16 = 6 toTypeCode TInt32 = 8 toTypeCode TInt64 = 10 toTypeCode TBinary = 11 toTypeCode TStruct = 12 toTypeCode TMap = 13 toTypeCode TSet = 14 toTypeCode TList = 15 {-# INLINE toTypeCode #-} -- | Map a type code to the corresponding TType. fromTypeCode :: Int8 -> Maybe SomeTType fromTypeCode 2 = Just $ SomeTType TBool fromTypeCode 3 = Just $ SomeTType TByte fromTypeCode 4 = Just $ SomeTType TDouble fromTypeCode 6 = Just $ SomeTType TInt16 fromTypeCode 8 = Just $ SomeTType TInt32 fromTypeCode 10 = Just $ SomeTType TInt64 fromTypeCode 11 = Just $ SomeTType TBinary fromTypeCode 12 = Just $ SomeTType TStruct fromTypeCode 13 = Just $ SomeTType TMap fromTypeCode 14 = Just $ SomeTType TSet fromTypeCode 15 = Just $ SomeTType TList fromTypeCode _ = Nothing {-# INLINE fromTypeCode #-} ------------------------------------------------------------------------------ string :: ByteString -> Builder string b = BB.int32BE (fromIntegral $ B.length b) <> BB.byteString b {-# INLINE string #-} typeCode :: TType a -> Builder typeCode = BB.int8 . toTypeCode {-# INLINE typeCode #-}