-- Copyright 2016 Google Inc. All Rights Reserved. -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RankNTypes #-} -- | Module defining the individual base wire types (e.g. VarInt, Fixed64) and -- how to encode/decode them. module Data.ProtoLens.Encoding.Wire( WireType(..), SomeWireType(..), WireValue(..), Tag(..), TaggedValue(..), getTaggedValue, putTaggedValue, getWireValue, putWireValue, Equal(..), equalWireTypes, decodeFieldSet, ) where import Control.DeepSeq (NFData(..)) import Data.Attoparsec.ByteString as Parse import Data.Bits import qualified Data.ByteString as B import Data.ByteString.Lazy.Builder as Builder import Data.Monoid ((<>)) import Data.Word import Data.ProtoLens.Encoding.Bytes data WireType a where -- Note: all of these types are fully strict (vs, say, -- Data.ByteString.Lazy.ByteString). If that changes, we'll -- need to update the NFData instance. VarInt :: WireType Word64 Fixed64 :: WireType Word64 Fixed32 :: WireType Word32 Lengthy :: WireType B.ByteString StartGroup :: WireType () EndGroup :: WireType () instance Show (WireType a) where show = show . wireTypeToInt -- A value read from the wire data WireValue = forall a . WireValue !(WireType a) !a instance Show WireValue where show (WireValue VarInt x) = show x show (WireValue Fixed64 x) = show x show (WireValue Fixed32 x) = show x show (WireValue Lengthy x) = show x show (WireValue StartGroup x) = show x show (WireValue EndGroup x) = show x -- The wire contents of a single key-value pair in a Message. data TaggedValue = TaggedValue !Tag !WireValue deriving (Show, Eq, Ord) -- TaggedValue, WireValue and Tag are strict, so their NFData instances are -- trivial: instance NFData TaggedValue where rnf = (`seq` ()) instance NFData WireValue where rnf = (`seq` ()) -- | A tag that identifies a particular field of the message when converting -- to/from the wire format. newtype Tag = Tag { unTag :: Int} deriving (Show, Eq, Ord, Num, NFData) data Equal a b where -- TODO: move Eq/Ord instance somewhere else? Equal :: (Eq a, Ord a) => Equal a a -- Assert that two wire types are the same, or fail with a message about this -- field. {-# INLINE equalWireTypes #-} equalWireTypes :: Monad m => WireType a -> WireType b -> m (Equal a b) equalWireTypes VarInt VarInt = return Equal equalWireTypes Fixed64 Fixed64 = return Equal equalWireTypes Fixed32 Fixed32 = return Equal equalWireTypes Lengthy Lengthy = return Equal equalWireTypes StartGroup StartGroup = return Equal equalWireTypes EndGroup EndGroup = return Equal equalWireTypes expected actual = fail $ "Expected wire type " ++ show expected ++ " but found " ++ show actual instance Eq WireValue where WireValue t v == WireValue t' v' | Just Equal <- equalWireTypes t t' = v == v' | otherwise = False instance Ord WireValue where WireValue t v `compare` WireValue t' v' | Just Equal <- equalWireTypes t t' = v `compare` v' | otherwise = wireTypeToInt t `compare` wireTypeToInt t' getWireValue :: WireType a -> Parser a getWireValue VarInt = getVarInt getWireValue Fixed64 = anyBits getWireValue Fixed32 = anyBits getWireValue Lengthy = getVarInt >>= Parse.take . fromIntegral getWireValue StartGroup = return () getWireValue EndGroup = return () putWireValue :: WireType a -> a -> Builder putWireValue VarInt n = putVarInt n putWireValue Fixed64 n = word64LE n putWireValue Fixed32 n = word32LE n putWireValue Lengthy b = putVarInt (fromIntegral $ B.length b) <> byteString b putWireValue StartGroup _ = mempty putWireValue EndGroup _ = mempty data SomeWireType where SomeWireType :: WireType a -> SomeWireType wireTypeToInt :: WireType a -> Word64 wireTypeToInt VarInt = 0 wireTypeToInt Fixed64 = 1 wireTypeToInt Lengthy = 2 wireTypeToInt StartGroup = 3 wireTypeToInt EndGroup = 4 wireTypeToInt Fixed32 = 5 intToWireType :: Word64 -> Either String SomeWireType intToWireType 0 = Right $ SomeWireType VarInt intToWireType 1 = Right $ SomeWireType Fixed64 intToWireType 2 = Right $ SomeWireType Lengthy intToWireType 3 = Right $ SomeWireType StartGroup intToWireType 4 = Right $ SomeWireType EndGroup intToWireType 5 = Right $ SomeWireType Fixed32 intToWireType n = Left $ "Unrecognized wire type " ++ show n putTypeAndTag :: WireType a -> Tag -> Builder putTypeAndTag wt (Tag tag) = putVarInt $ wireTypeToInt wt .|. fromIntegral tag `shiftL` 3 getTypeAndTag :: Parser (SomeWireType, Tag) getTypeAndTag = do n <- getVarInt case intToWireType (n .&. 7) of Left err -> fail err Right wt -> return (wt, fromIntegral $ n `shiftR` 3) getTaggedValue :: Parser TaggedValue getTaggedValue = do (SomeWireType wt, tag) <- getTypeAndTag val <- getWireValue wt return $ TaggedValue tag (WireValue wt val) putTaggedValue :: TaggedValue -> Builder putTaggedValue (TaggedValue tag (WireValue wt val)) = putTypeAndTag wt tag <> putWireValue wt val decodeFieldSet :: B.ByteString -> Either String [TaggedValue] decodeFieldSet = parseOnly (manyTill getTaggedValue endOfInput)