module Text.ProtocolBuffers.WireMessage
(
messageSize,messagePut,messageGet,messagePutM,messageGetM
, messageWithLengthSize,messageWithLengthPut,messageWithLengthGet,messageWithLengthPutM,messageWithLengthGetM
, messageAsFieldSize,messageAsFieldPutM,messageAsFieldGetM
, Put,Get,runPut,runGet,runGetOnLazy,getFromBS
, Wire(..)
, size'Varint,toWireType,toWireTag,mkWireTag
, prependMessageSize,putSize,putVarUInt,getVarInt,putLazyByteString,splitWireTag
, wireSizeReq,wireSizeOpt,wireSizeRep
, wirePutReq,wirePutOpt,wirePutRep
, wireSizeErr,wirePutErr,wireGetErr
, getMessage,getBareMessage,getMessageWith,getBareMessageWith
, unknownField,unknown,wireGetFromWire
, castWord64ToDouble,castWord32ToFloat,castDoubleToWord64,castFloatToWord32
, zzEncode64,zzEncode32,zzDecode64,zzDecode32
) where
import Control.Monad(when)
import Control.Monad.ST
import Data.Array.ST
import Data.Bits (Bits(..))
import qualified Data.ByteString.Lazy as BS (length)
import qualified Data.Foldable as F(foldl',forM_)
import Data.List (genericLength)
import qualified Data.Set as Set(notMember,delete,null)
import Data.Typeable (Typeable(..))
import Data.Binary.Put (Put,runPut,putWord8,putWord32le,putWord64le,putLazyByteString)
import Text.ProtocolBuffers.Basic
import Text.ProtocolBuffers.Get as Get (Result(..),Get,runGet,bytesRead,isReallyEmpty
,spanOf,skip,lookAhead
,getWord8,getWord32le,getWord64le,getLazyByteString)
import Text.ProtocolBuffers.Mergeable()
import Text.ProtocolBuffers.Reflections(ReflectDescriptor(reflectDescriptorInfo,getMessageInfo)
,DescriptorInfo(..),GetMessageInfo(..))
messageSize :: (ReflectDescriptor msg,Wire msg) => msg -> WireSize
messageSize msg = wireSize 10 msg
messageWithLengthSize :: (ReflectDescriptor msg,Wire msg) => msg -> WireSize
messageWithLengthSize msg = wireSize 11 msg
messageAsFieldSize :: (ReflectDescriptor msg,Wire msg) => FieldId -> msg -> WireSize
messageAsFieldSize fi msg = let headerSize = size'Varint (getWireTag (toWireTag fi 11))
in headerSize + messageWithLengthSize msg
messagePut :: (ReflectDescriptor msg, Wire msg) => msg -> ByteString
messagePut msg = runPut (messagePutM msg)
messageWithLengthPut :: (ReflectDescriptor msg, Wire msg) => msg -> ByteString
messageWithLengthPut msg = runPut (messageWithLengthPutM msg)
messagePutM :: (ReflectDescriptor msg, Wire msg) => msg -> Put
messagePutM msg = wirePut 10 msg
messageWithLengthPutM :: (ReflectDescriptor msg, Wire msg) => msg -> Put
messageWithLengthPutM msg = wirePut 11 msg
messageAsFieldPutM :: (ReflectDescriptor msg, Wire msg) => FieldId -> msg -> Put
messageAsFieldPutM fi msg = let wireTag = toWireTag fi 11
in wirePutReq wireTag 11 msg
messageGet :: (ReflectDescriptor msg, Wire msg) => ByteString -> Either String (msg,ByteString)
messageGet bs = runGetOnLazy (messageGetM) bs
messageWithLengthGet :: (ReflectDescriptor msg, Wire msg) => ByteString -> Either String (msg,ByteString)
messageWithLengthGet bs = runGetOnLazy (messageWithLengthGetM) bs
messageGetM :: (ReflectDescriptor msg, Wire msg) => Get msg
messageGetM = wireGet 10
messageWithLengthGetM :: (ReflectDescriptor msg, Wire msg) => Get msg
messageWithLengthGetM = wireGet 11
messageAsFieldGetM :: (ReflectDescriptor msg, Wire msg) => Get (FieldId,msg)
messageAsFieldGetM = do
wireTag <- fmap WireTag getVarInt
let (fieldId,wireType) = splitWireTag wireTag
when (wireType /= 2) (fail $ "messageAsFieldGetM: wireType was not 2 "++show (fieldId,wireType))
msg <- wireGet 11
return (fieldId,msg)
getFromBS :: Get r -> ByteString -> r
getFromBS parser bs = case runGetOnLazy parser bs of
Left msg -> error msg
Right (r,_) -> r
runGetOnLazy :: Get r -> ByteString -> Either String (r,ByteString)
runGetOnLazy parser bs = resolve (runGet parser bs)
where resolve :: Result r -> Either String (r,ByteString)
resolve (Failed i s) = Left ("Failed at "++show i++" : "++s)
resolve (Finished bsOut _i r) = Right (r,bsOut)
resolve (Partial op) = resolve (op Nothing)
prependMessageSize :: WireSize -> WireSize
prependMessageSize n = n + size'Varint n
wirePutReq :: Wire b => WireTag -> FieldType -> b -> Put
wirePutReq wireTag 10 b = let startTag = getWireTag wireTag
endTag = succ startTag
in putVarUInt startTag >> wirePut 10 b >> putVarUInt endTag
wirePutReq wireTag fieldType b = putVarUInt (getWireTag wireTag) >> wirePut fieldType b
wirePutOpt :: Wire b => WireTag -> FieldType -> Maybe b -> Put
wirePutOpt _wireTag _fieldType Nothing = return ()
wirePutOpt wireTag fieldType (Just b) = wirePutReq wireTag fieldType b
wirePutRep :: Wire b => WireTag -> FieldType -> Seq b -> Put
wirePutRep wireTag fieldType bs = F.forM_ bs (\b -> wirePutReq wireTag fieldType b)
wireSizeReq :: Wire b => Int64 -> FieldType -> b -> Int64
wireSizeReq tagSize 10 v = tagSize + wireSize 10 v + tagSize
wireSizeReq tagSize i v = tagSize + wireSize i v
wireSizeOpt :: Wire b => Int64 -> FieldType -> Maybe b -> Int64
wireSizeOpt _tagSize _i Nothing = 0
wireSizeOpt tagSize i (Just v) = wireSizeReq tagSize i v
wireSizeRep :: Wire b => Int64 -> FieldType -> Seq b -> Int64
wireSizeRep tagSize i s = F.foldl' (\n v -> n + wireSizeReq tagSize i v) 0 s
putSize :: WireSize -> Put
putSize = putVarUInt
toWireTag :: FieldId -> FieldType -> WireTag
toWireTag fieldId fieldType
= ((fromIntegral . getFieldId $ fieldId) `shiftL` 3) .|. (fromIntegral . getWireType . toWireType $ fieldType)
mkWireTag :: FieldId -> WireType -> WireTag
mkWireTag fieldId fieldType
= ((fromIntegral . getFieldId $ fieldId) `shiftL` 3) .|. (fromIntegral . getWireType $ fieldType)
splitWireTag :: WireTag -> (FieldId,WireType)
splitWireTag (WireTag wireTag) = ( FieldId . fromIntegral $ wireTag `shiftR` 3
, WireType . fromIntegral $ wireTag .&. 7 )
getMessage :: (Mergeable message, ReflectDescriptor message,Typeable message)
=> (FieldId -> message -> Get message)
-> Get message
getMessage = getMessageWith unknown
getMessageWith :: (Mergeable message, ReflectDescriptor message)
=> (FieldId -> WireType -> message -> Get message)
-> (FieldId -> message -> Get message)
-> Get message
getMessageWith punt updater = do
messageLength <- getVarInt
start <- bytesRead
let stop = messageLength+start
go reqs message | Set.null reqs = go' message
| otherwise = do
here <- bytesRead
case compare stop here of
EQ -> notEnoughData messageLength start
LT -> tooMuchData messageLength start here
GT -> do
wireTag <- fmap WireTag getVarInt
let (fieldId,wireType) = splitWireTag wireTag
if Set.notMember wireTag allowed
then punt fieldId wireType message >>= go reqs
else let reqs' = Set.delete wireTag reqs
in updater fieldId message >>= go reqs'
go' message = do
here <- bytesRead
case compare stop here of
EQ -> return message
LT -> tooMuchData messageLength start here
GT -> do
wireTag <- fmap WireTag getVarInt
let (fieldId,wireType) = splitWireTag wireTag
if Set.notMember wireTag allowed
then punt fieldId wireType message >>= go'
else updater fieldId message >>= go'
go required initialMessage
where
initialMessage = mergeEmpty
(GetMessageInfo {requiredTags=required,allowedTags=allowed}) = getMessageInfo initialMessage
notEnoughData messageLength start =
fail ("Text.ProtocolBuffers.WireMessage.getMessage: Required fields missing when processing "
++ (show . descName . reflectDescriptorInfo $ initialMessage)
++ " at (messageLength,start) == " ++ show (messageLength,start))
tooMuchData messageLength start here =
fail ("Text.ProtocolBuffers.WireMessage.getMessage : overran expected length when processing"
++ (show . descName . reflectDescriptorInfo $ initialMessage)
++ " at (messageLength,start,here) == " ++ show (messageLength,start,here))
unknown :: (Typeable a,ReflectDescriptor a) => FieldId -> WireType -> a -> Get a
unknown fieldId wireType initialMessage = do
here <- bytesRead
fail ("Text.ProtocolBuffers.WireMessage.unkown: Unknown wire tag read (type,fieldId,wireType,here) == "
++ show (typeOf initialMessage,fieldId,wireType,here) ++ " when processing "
++ (show . descName . reflectDescriptorInfo $ initialMessage))
getBareMessage :: (Typeable message, Mergeable message, ReflectDescriptor message)
=> (FieldId -> message -> Get message)
-> Get message
getBareMessage = getBareMessageWith unknown
getBareMessageWith :: (Mergeable message, ReflectDescriptor message)
=> (FieldId -> WireType -> message -> Get message)
-> (FieldId -> message -> Get message)
-> Get message
getBareMessageWith punt updater = go required initialMessage
where
go reqs message | Set.null reqs = go' message
| otherwise = do
done <- isReallyEmpty
if done then notEnoughData
else do
wireTag <- fmap WireTag getVarInt
let (fieldId,wireType) = splitWireTag wireTag
if wireType == 4 then notEnoughData
else if Set.notMember wireTag allowed
then punt fieldId wireType message >>= go reqs
else let reqs' = Set.delete wireTag reqs
in updater fieldId message >>= go reqs'
go' message = do
done <- isReallyEmpty
if done then return message
else do
wireTag <- fmap WireTag getVarInt
let (fieldId,wireType) = splitWireTag wireTag
if wireType == 4 then return message
else if Set.notMember wireTag allowed
then punt fieldId wireType message >>= go'
else updater fieldId message >>= go'
initialMessage = mergeEmpty
(GetMessageInfo {requiredTags=required,allowedTags=allowed}) = getMessageInfo initialMessage
notEnoughData = fail ("Text.ProtocolBuffers.WireMessage.getBareMessage: Required fields missing when processing "
++ (show . descName . reflectDescriptorInfo $ initialMessage))
unknownField :: FieldId -> Get a
unknownField fieldId = do
here <- bytesRead
fail ("Impossible? Text.ProtocolBuffers.WireMessage.unknownField "
++" The Message's updater claims there is an unknown field id on wire: "++show fieldId
++" at a position just before here == "++show here)
castWord32ToFloat :: Word32 -> Float
castWord32ToFloat x = runST (newArray (0::Int,0) x >>= castSTUArray >>= flip readArray 0)
castFloatToWord32 :: Float -> Word32
castFloatToWord32 x = runST (newArray (0::Int,0) x >>= castSTUArray >>= flip readArray 0)
castWord64ToDouble :: Word64 -> Double
castWord64ToDouble x = runST (newArray (0::Int,0) x >>= castSTUArray >>= flip readArray 0)
castDoubleToWord64 :: Double -> Word64
castDoubleToWord64 x = runST (newArray (0::Int,0) x >>= castSTUArray >>= flip readArray 0)
wireSizeErr :: Typeable a => FieldType -> a -> WireSize
wireSizeErr ft x = error $ concat [ "Impossible? wireSize field type mismatch error: Field type number ", show ft
, " does not match internal type ", show (typeOf x) ]
wirePutErr :: Typeable a => FieldType -> a -> Put
wirePutErr ft x = fail $ concat [ "Impossible? wirePut field type mismatch error: Field type number ", show ft
, " does not match internal type ", show (typeOf x) ]
wireGetErr :: Typeable a => FieldType -> Get a
wireGetErr ft = answer where
answer = fail $ concat [ "Impossible? wireGet field type mismatch error: Field type number ", show ft
, " does not match internal type ", show (typeOf (undefined `asTypeOf` typeHack answer)) ]
typeHack :: Get a -> a
typeHack = undefined
instance Wire Double where
wireSize 1 _ = 8
wireSize ft x = wireSizeErr ft x
wirePut 1 x = putWord64le (castDoubleToWord64 x)
wirePut ft x = wirePutErr ft x
wireGet 1 = fmap castWord64ToDouble getWord64le
wireGet ft = wireGetErr ft
instance Wire Float where
wireSize 2 _ = 4
wireSize ft x = wireSizeErr ft x
wirePut 2 x = putWord32le (castFloatToWord32 x)
wirePut ft x = wirePutErr ft x
wireGet 2 = fmap castWord32ToFloat getWord32le
wireGet ft = wireGetErr ft
instance Wire Int64 where
wireSize 3 x = size'Varint x
wireSize 18 x = size'Varint (zzEncode64 x)
wireSize 16 _ = 8
wireSize ft x = wireSizeErr ft x
wirePut 3 x = putVarSInt x
wirePut 18 x = putVarUInt (zzEncode64 x)
wirePut 16 x = putWord64le (fromIntegral x)
wirePut ft x = wirePutErr ft x
wireGet 3 = getVarInt
wireGet 18 = fmap zzDecode64 getVarInt
wireGet 16 = fmap fromIntegral getWord64le
wireGet ft = wireGetErr ft
instance Wire Int32 where
wireSize 5 x = size'Varint x
wireSize 17 x = size'Varint (zzEncode32 x)
wireSize 15 _ = 4
wireSize ft x = wireSizeErr ft x
wirePut 5 x = putVarSInt x
wirePut 17 x = putVarUInt (zzEncode32 x)
wirePut 15 x = putWord32le (fromIntegral x)
wirePut ft x = wirePutErr ft x
wireGet 5 = getVarInt
wireGet 17 = fmap zzDecode32 getVarInt
wireGet 15 = fmap fromIntegral getWord32le
wireGet ft = wireGetErr ft
instance Wire Word64 where
wireSize 4 x = size'Varint x
wireSize 6 _ = 8
wireSize ft x = wireSizeErr ft x
wirePut 4 x = putVarUInt x
wirePut 6 x = putWord64le x
wirePut ft x = wirePutErr ft x
wireGet 6 = getWord64le
wireGet 4 = getVarInt
wireGet ft = wireGetErr ft
instance Wire Word32 where
wireSize 13 x = size'Varint x
wireSize 7 _ = 4
wireSize ft x = wireSizeErr ft x
wirePut 13 x = putVarUInt x
wirePut 7 x = putWord32le x
wirePut ft x = wirePutErr ft x
wireGet 13 = getVarInt
wireGet 7 = getWord32le
wireGet ft = wireGetErr ft
instance Wire Bool where
wireSize 8 _ = 1
wireSize ft x = wireSizeErr ft x
wirePut 8 False = putWord8 0
wirePut 8 True = putWord8 1
wirePut ft x = wirePutErr ft x
wireGet 8 = do
x <- getVarInt :: Get Int32
case x of
0 -> return False
x' | x' < 128 -> return True
_ -> fail ("TYPE_BOOL read failure : " ++ show x)
wireGet ft = wireGetErr ft
instance Wire Utf8 where
wireSize 9 x = prependMessageSize $ BS.length (utf8 x)
wireSize ft x = wireSizeErr ft x
wirePut 9 x = putVarUInt (BS.length (utf8 x)) >> putLazyByteString (utf8 x)
wirePut ft x = wirePutErr ft x
wireGet 9 = getVarInt >>= getLazyByteString >>= return . Utf8
wireGet ft = wireGetErr ft
instance Wire ByteString where
wireSize 12 x = prependMessageSize $ BS.length x
wireSize ft x = wireSizeErr ft x
wirePut 12 x = putVarUInt (BS.length x) >> putLazyByteString x
wirePut ft x = wirePutErr ft x
wireGet 12 = getVarInt >>= getLazyByteString >>= return
wireGet ft = wireGetErr ft
instance Wire Int where
wireSize 14 x = size'Varint x
wireSize ft x = wireSizeErr ft x
wirePut 14 x = putVarUInt x
wirePut ft x = wirePutErr ft x
wireGet 14 = getVarInt
wireGet ft = wireGetErr ft
size'Varint :: (Bits a,Integral a) => a -> Int64
size'Varint b = case compare b 0 of
LT -> fromIntegral (divBy (bitSize b) 7)
EQ -> 1
GT -> genericLength . takeWhile (0<) . iterate (`shiftR` 7) $ b
divBy :: (Ord a, Integral a) => a -> a -> a
divBy a b = let (q,r) = quotRem (abs a) b
in if r==0 then q else succ q
zzEncode32 :: Int32 -> Word32
zzEncode32 x = fromIntegral ((x `shiftL` 1) `xor` (x `shiftR` 31))
zzEncode64 :: Int64 -> Word64
zzEncode64 x = fromIntegral ((x `shiftL` 1) `xor` (x `shiftR` 63))
zzDecode32 :: Word32 -> Int32
zzDecode32 w = (fromIntegral (w `shiftR` 1)) `xor` (negate (fromIntegral (w .&. 1)))
zzDecode64 :: Word64 -> Int64
zzDecode64 w = (fromIntegral (w `shiftR` 1)) `xor` (negate (fromIntegral (w .&. 1)))
getVarInt :: (Integral a, Bits a) => Get a
getVarInt = do
b <- getWord8
if testBit b 7 then go 7 (fromIntegral (b .&. 0x7F))
else return (fromIntegral b)
where
go n val = do
b <- getWord8
if testBit b 7 then go (n+7) (val .|. ((fromIntegral (b .&. 0x7F)) `shiftL` n))
else return (val .|. ((fromIntegral b) `shiftL` n))
putVarSInt :: (Integral a, Bits a) => a -> Put
putVarSInt b =
case compare b 0 of
LT -> let len = divBy (bitSize b) 7
last'Size = (bitSize b)((pred len)*7)
last'Mask = pred (1 `shiftL` last'Size)
go i 1 = putWord8 (fromIntegral i .&. last'Mask)
go i n = putWord8 (fromIntegral (i .&. 0x7F) .|. 0x80) >> go (i `shiftR` 7) (pred n)
in go b len
EQ -> putWord8 0
GT -> putVarUInt b
putVarUInt :: (Integral a, Bits a) => a -> Put
putVarUInt b = let go i | i < 0x80 = putWord8 (fromIntegral i)
| otherwise = putWord8 (fromIntegral (i .&. 0x7F) .|. 0x80) >> go (i `shiftR` 7)
in go b
wireGetFromWire :: FieldId -> WireType -> Get ByteString
wireGetFromWire fi wt = getLazyByteString =<< calcLen where
calcLen = case wt of
0 -> lenOf (spanOf (>=128) >> skip 1)
1 -> return 8
2 -> lookAhead $ do
here <- bytesRead
len <- getVarInt
there <- bytesRead
return ((therehere)+len)
3 -> lenOf (skipGroup fi)
4 -> fail $ "Cannot wireGetFromWire with wireType of STOP_GROUP: "++show (fi,wt)
5 -> return 4
wtf -> fail $ "Invalid wire type (expected 0,1,2,3,or 5) found: "++show (fi,wtf)
lenOf g = do here <- bytesRead
there <- lookAhead (g >> bytesRead)
return (therehere)
skipGroup :: FieldId -> Get ()
skipGroup start_fi = go where
go = do
(fieldId,wireType) <- fmap (splitWireTag . WireTag) getVarInt
case wireType of
0 -> spanOf (>=128) >> skip 1 >> go
1 -> skip 8 >> go
2 -> getVarInt >>= skip >> go
3 -> skipGroup fieldId >> go
4 | start_fi /= fieldId -> fail $ "skipGroup failed, fieldId mismatch bewteen START_GROUP and STOP_GROUP: "++show (start_fi,(fieldId,wireType))
| otherwise -> return ()
5 -> skip 4 >> go
wtf -> fail $ "Invalid wire type (expected 0,1,2,3,4,or 5) found: "++show (fieldId,wtf)
toWireType :: FieldType -> WireType
toWireType 1 = 1
toWireType 2 = 5
toWireType 3 = 0
toWireType 4 = 0
toWireType 5 = 0
toWireType 6 = 1
toWireType 7 = 5
toWireType 8 = 0
toWireType 9 = 2
toWireType 10 = 3
toWireType 11 = 2
toWireType 12 = 2
toWireType 13 = 0
toWireType 14 = 0
toWireType 15 = 5
toWireType 16 = 1
toWireType 17 = 5
toWireType 18 = 1
toWireType x = error $ "Text.ProcolBuffers.Basic.toWireType: Bad FieldType: "++show x