{-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} -- | -- Module : Pinch.Protocol.Binary -- Copyright : (c) Abhinav Gupta 2015 -- License : BSD3 -- -- Maintainer : Abhinav Gupta -- Stability : experimental -- -- Implements the Thrift Binary Protocol as a 'Protocol'. module Pinch.Protocol.Binary (binaryProtocol) where #if __GLASGOW_HASKELL__ < 709 import Control.Applicative #endif import Control.Monad import Data.Bits (shiftR, (.&.)) import Data.ByteString (ByteString) import Data.HashMap.Strict (HashMap) import Data.HashSet (HashSet) import Data.Int (Int16, Int8) import Data.Vector (Vector) import qualified Data.ByteString as B import qualified Data.HashMap.Strict as M import qualified Data.HashSet as S import qualified Data.Text.Encoding as TE import qualified Data.Vector as V import Pinch.Internal.Builder (Build) import Pinch.Internal.Message import Pinch.Internal.Parser (Parser, runParser) import Pinch.Internal.TType import Pinch.Internal.Value import Pinch.Protocol (Protocol (..)) import qualified Pinch.Internal.Builder as BB import qualified Pinch.Internal.Parser as P -- | Provides an implementation of the Thrift Binary Protocol. binaryProtocol :: Protocol binaryProtocol = Protocol { serializeValue = BB.run . binarySerialize , deserializeValue = binaryDeserialize ttype , serializeMessage = BB.run . binarySerializeMessage , deserializeMessage = binaryDeserializeMessage } ------------------------------------------------------------------------------ binarySerializeMessage :: Message -> Build binarySerializeMessage msg = do binarySerialize . VBinary . TE.encodeUtf8 $ messageName msg BB.int8 $ messageCode (messageType msg) BB.int32 $ messageId msg binarySerialize (messagePayload msg) binaryDeserializeMessage :: ByteString -> Either String Message binaryDeserializeMessage = runParser binaryMessageParser binaryMessageParser :: Parser Message binaryMessageParser = do size <- P.int32 if size < 0 then parseStrict size else parseNonStrict size where -- versionAndType:4 name~4 seqid:4 payload -- versionAndType = version:2 0x00 type:1 parseStrict versionAndType = do unless (version == 1) $ fail $ "Unsupported version: " ++ show version Message <$> TE.decodeUtf8 <$> (P.int32 >>= P.take . fromIntegral) <*> typ <*> P.int32 <*> binaryParser ttype where version = (0x7fff0000 .&. versionAndType) `shiftR` 16 code = fromIntegral $ 0x00ff .&. versionAndType typ = case fromMessageCode code of Nothing -> fail $ "Unknown message type: " ++ show code Just t -> return t -- name~4 type:1 seqid:4 payload parseNonStrict nameLength = Message <$> TE.decodeUtf8 <$> P.take (fromIntegral nameLength) <*> parseMessageType <*> P.int32 <*> binaryParser ttype parseMessageType :: Parser MessageType parseMessageType = P.int8 >>= \code -> case fromMessageCode code of Nothing -> fail $ "Unknown message type: " ++ show code Just t -> return t ------------------------------------------------------------------------------ binaryDeserialize :: TType a -> ByteString -> Either String (Value a) binaryDeserialize t = runParser (binaryParser t) binaryParser :: TType a -> Parser (Value a) binaryParser typ = case typ of TBool -> parseBool TByte -> parseByte TDouble -> parseDouble TInt16 -> parseInt16 TInt32 -> parseInt32 TInt64 -> parseInt64 TBinary -> parseBinary TStruct -> parseStruct TMap -> parseMap TSet -> parseSet TList -> parseList getTType :: Int8 -> Parser SomeTType getTType code = maybe (fail $ "Unknown TType: " ++ show code) return $ fromTypeCode code parseTType :: Parser SomeTType parseTType = P.int8 >>= getTType parseBool :: Parser (Value TBool) parseBool = VBool . (== 1) <$> P.int8 parseByte :: Parser (Value TByte) parseByte = VByte <$> P.int8 parseDouble :: Parser (Value TDouble) parseDouble = VDouble <$> P.double parseInt16 :: Parser (Value TInt16) parseInt16 = VInt16 <$> P.int16 parseInt32 :: Parser (Value TInt32) parseInt32 = VInt32 <$> P.int32 parseInt64 :: Parser (Value TInt64) parseInt64 = VInt64 <$> P.int64 parseBinary :: Parser (Value TBinary) parseBinary = VBinary <$> (P.int32 >>= P.take . fromIntegral) parseMap :: Parser (Value TMap) parseMap = do ktype' <- parseTType vtype' <- parseTType count <- P.int32 case (ktype', vtype') of (SomeTType ktype, SomeTType vtype) -> do pairs <- replicateM (fromIntegral count) $ (,) <$> binaryParser ktype <*> binaryParser vtype return $ VMap (M.fromList pairs) parseSet :: Parser (Value TSet) parseSet = do vtype' <- parseTType count <- P.int32 case vtype' of SomeTType vtype -> do items <- replicateM (fromIntegral count) (binaryParser vtype) return $ VSet (S.fromList items) parseList :: Parser (Value TList) parseList = do vtype' <- parseTType count <- P.int32 case vtype' of SomeTType vtype -> VList <$> V.replicateM (fromIntegral count) (binaryParser vtype) parseStruct :: Parser (Value TStruct) parseStruct = P.int8 >>= loop M.empty where loop :: HashMap Int16 SomeValue -> Int8 -> Parser (Value TStruct) loop fields 0 = return $ VStruct fields loop fields code = do vtype' <- getTType code fieldId <- P.int16 case vtype' of SomeTType vtype -> do value <- SomeValue <$> binaryParser vtype loop (M.insert fieldId value fields) =<< P.int8 ------------------------------------------------------------------------------ binarySerialize :: Value a -> Build binarySerialize v0 = case v0 of VBinary x -> do BB.int32 . fromIntegral . B.length $ x BB.byteString x VBool x -> BB.int8 $ if x then 1 else 0 VByte x -> BB.int8 x VDouble x -> BB.double x VInt16 x -> BB.int16 x VInt32 x -> BB.int32 x VInt64 x -> BB.int64 x VStruct xs -> serializeStruct xs VList xs -> serializeList ttype xs VMap xs -> serializeMap ttype ttype xs VSet xs -> serializeSet ttype xs serializeStruct :: HashMap Int16 SomeValue -> Build serializeStruct fields = do forM_ (M.toList fields) $ \(fieldId, SomeValue fieldValue) -> writeField fieldId ttype fieldValue BB.int8 0 where writeField :: Int16 -> TType a -> Value a -> Build writeField fieldId fieldType fieldValue = do BB.int8 (toTypeCode fieldType) BB.int16 fieldId binarySerialize fieldValue serializeList :: TType a -> Vector (Value a) -> Build serializeList vtype xs = do BB.int8 $ toTypeCode vtype BB.int32 $ fromIntegral (V.length xs) mapM_ binarySerialize (V.toList xs) serializeMap :: TType k -> TType v -> HashMap (Value k) (Value v) -> Build serializeMap kt vt xs = do BB.int8 $ toTypeCode kt BB.int8 $ toTypeCode vt BB.int32 $ fromIntegral (M.size xs) forM_ (M.toList xs) $ \(k, v) -> do binarySerialize k binarySerialize v serializeSet :: TType a -> HashSet (Value a) -> Build serializeSet vtype xs = do BB.int8 $ toTypeCode vtype BB.int32 $ fromIntegral (S.size xs) mapM_ binarySerialize (S.toList xs) ------------------------------------------------------------------------------ messageCode :: MessageType -> Int8 messageCode Call = 1 messageCode Reply = 2 messageCode Exception = 3 messageCode Oneway = 4 fromMessageCode :: Int8 -> Maybe MessageType fromMessageCode 1 = Just Call fromMessageCode 2 = Just Reply fromMessageCode 3 = Just Exception fromMessageCode 4 = Just Oneway fromMessageCode _ = Nothing -- | Map a TType to its type code. toTypeCode :: TType a -> Int8 toTypeCode TBool = 2 toTypeCode TByte = 3 toTypeCode TDouble = 4 toTypeCode TInt16 = 6 toTypeCode TInt32 = 8 toTypeCode TInt64 = 10 toTypeCode TBinary = 11 toTypeCode TStruct = 12 toTypeCode TMap = 13 toTypeCode TSet = 14 toTypeCode TList = 15 -- | Map a type code to the corresponding TType. fromTypeCode :: Int8 -> Maybe SomeTType fromTypeCode 2 = Just $ SomeTType TBool fromTypeCode 3 = Just $ SomeTType TByte fromTypeCode 4 = Just $ SomeTType TDouble fromTypeCode 6 = Just $ SomeTType TInt16 fromTypeCode 8 = Just $ SomeTType TInt32 fromTypeCode 10 = Just $ SomeTType TInt64 fromTypeCode 11 = Just $ SomeTType TBinary fromTypeCode 12 = Just $ SomeTType TStruct fromTypeCode 13 = Just $ SomeTType TMap fromTypeCode 14 = Just $ SomeTType TSet fromTypeCode 15 = Just $ SomeTType TList fromTypeCode _ = Nothing