-- Copyright 2016 Google Inc. All Rights Reserved. -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -- | Functions for encoding and decoding protocol buffer Messages. -- -- TODO: Currently all operations are on strict ByteStrings; -- we should try to generalize to lazy Bytestrings as well. {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE ScopedTypeVariables #-} module Data.ProtoLens.Encoding( encodeMessage, buildMessage, decodeMessage, decodeMessageOrDie, ) where import Data.ProtoLens.Message import Data.ProtoLens.Encoding.Bytes import Data.ProtoLens.Encoding.Wire import Control.Applicative ((<|>), (<$>)) import Control.Monad (foldM) import Data.Attoparsec.ByteString as Parse import Data.Bool (bool) import Data.Text.Encoding (encodeUtf8, decodeUtf8') import Data.Text.Encoding.Error (UnicodeException(..)) import qualified Data.ByteString as B import qualified Data.Map.Strict as Map import Data.ByteString.Lazy.Builder as Builder import qualified Data.ByteString.Lazy as L import Data.Monoid (mconcat, mempty) import Data.Foldable (foldMap, toList, foldl') import Lens.Family2 (set, over, (^.), (&)) -- TODO: We could be more incremental when parsing/encoding length-based fields, -- rather than forcing the whole thing. E.g., for encoding we're doing extra -- allocation by building an intermediate bytestring. -- | Decode a message from its wire format. Returns 'Left' if the decoding -- fails. decodeMessage :: Message msg => B.ByteString -> Either String msg decodeMessage input = parseOnly (Parse.manyTill getTaggedValue endOfInput) input >>= taggedValuesToMessage -- | Decode a message from its wire format. Throws an error if the decoding -- fails. decodeMessageOrDie :: Message msg => B.ByteString -> msg decodeMessageOrDie bs = case decodeMessage bs of Left e -> error $ "decodeMessageOrDie: " ++ e Right x -> x -- | Convert a sequence of parsed key-value pairs into a Message via its -- descriptor. Will fail if any of the key-value pairs do not match those -- expected by the field descriptors. taggedValuesToMessage :: Message msg => [TaggedValue] -> Either String msg taggedValuesToMessage tvs | missing <- missingFields fields tvs, not $ null missing = Left $ "Missing required fields " ++ show missing | otherwise = reverseRepeatedFields fields <$> result where addTaggedValue msg tv@(TaggedValue tag _) = case Map.lookup (Tag tag) fields of Nothing -> return msg Just field -> parseAndAddField msg field tv fields = fieldsByTag descriptor result = foldM addTaggedValue def tvs missingFields :: Map.Map Tag (FieldDescriptor msg) -> [TaggedValue] -> [String] missingFields fields = map fieldDescriptorName . Map.elems . foldl' (\m (TaggedValue t _) -> Map.delete (Tag t) m) requiredFields where requiredFields = Map.filter isRequired fields runEither :: Either String a -> Parser a runEither (Left x) = fail x runEither (Right x) = return x parseAndAddField :: msg -> FieldDescriptor msg -> TaggedValue -> Either String msg parseAndAddField msg (FieldDescriptor name typeDescriptor accessor) (TaggedValue tag (WireValue wt val)) = case fieldWireType typeDescriptor of FieldWireType fieldWt _ get -> let getSimpleVal = do Equal <- equalWireTypes name fieldWt wt get val -- Get a block of packed values, reversed. getPackedVals = do Equal <- equalWireTypes name Lengthy wt let getElt = getWireValue fieldWt tag >>= runEither . get parseOnly (manyReversedTill getElt endOfInput) val in case accessor of PlainField _ f -> do x <- getSimpleVal return $ set f x msg OptionalField f -> do x <- getSimpleVal return $ set f (Just x) msg RepeatedField Unpacked f -> do x <- getSimpleVal return $ over f (x:) msg RepeatedField Packed f -> do xs <- getPackedVals return $ over f (xs++) msg MapField keyLens valueLens f -> do entry <- getSimpleVal return $ over f (Map.insert (entry ^. keyLens) (entry ^. valueLens)) msg -- | Run the parser zero or more times, until the "end" parser succeeds. -- Returns a list of the parsed elements, in reverse order. manyReversedTill :: Parser a -> Parser b -> Parser [a] manyReversedTill p end = loop [] where loop xs = (end >> return xs) <|> (p >>= \x -> loop (x:xs)) -- | Encode a message to the wire format. encodeMessage :: Message msg => msg -> B.ByteString encodeMessage = L.toStrict . toLazyByteString . buildMessage -- | Encode a message to the wire format, as part of a 'Builder'. buildMessage :: Message msg => msg -> Builder buildMessage msg = foldMap putTaggedValue (messageToTaggedValues msg) -- | Encode a message as a sequence of key-value pairs. messageToTaggedValues :: Message msg => msg -> [TaggedValue] messageToTaggedValues msg = mconcat [ map (TaggedValue t) (messageFieldToVals fieldDescr msg) | (Tag t, fieldDescr) <- Map.toList (fieldsByTag descriptor) ] messageFieldToVals :: FieldDescriptor msg -> msg -> [WireValue] messageFieldToVals (FieldDescriptor _ typeDescriptor accessor) msg = case fieldWireType typeDescriptor of FieldWireType wt convert _ -> case accessor of PlainField d f | Optional <- d, src == fieldDefault -> [] | otherwise -> [WireValue wt (convert src)] where src = msg ^. f OptionalField f -> case msg ^. f of Just src -> [WireValue wt (convert src)] _ -> mempty RepeatedField Unpacked f -> [ WireValue wt (convert src) | src <- toList (msg ^. f) ] RepeatedField Packed f -> [WireValue Lengthy v] where v = L.toStrict $ toLazyByteString $ mconcat [ putWireValue wt (convert src) | src <- toList (msg ^. f) ] MapField keyLens valueLens f -> [ WireValue wt v | (key, value) <- Map.toList (msg ^. f) , let entry = def & set keyLens key & set valueLens value , let v = convert entry ] data FieldWireType value where FieldWireType :: WireType w -> (value -> w) -> (w -> Either String value) -> FieldWireType value fieldWireType :: FieldTypeDescriptor value -> FieldWireType value -- TODO: Don't let toEnum crash on unknown enum values. fieldWireType EnumField = simpleFieldWireType VarInt (fromIntegral . fromEnum) (toEnum . fromIntegral) fieldWireType BoolField = simpleFieldWireType VarInt (bool 0 1) (/= 0) -- Note: int{32,64} and sfixed{32,64} are stored using the signed -> unsigned -- conversion of fromIntegral. fieldWireType Int32Field = integralFieldWireType VarInt fieldWireType Int64Field = integralFieldWireType VarInt fieldWireType UInt32Field = integralFieldWireType VarInt fieldWireType UInt64Field = identityFieldWireType VarInt fieldWireType SInt32Field = simpleFieldWireType VarInt (fromIntegral . signedInt32ToWord) (wordToSignedInt32 . fromIntegral) fieldWireType SInt64Field = simpleFieldWireType VarInt signedInt64ToWord wordToSignedInt64 fieldWireType Fixed32Field = identityFieldWireType Fixed32 fieldWireType Fixed64Field = identityFieldWireType Fixed64 fieldWireType SFixed32Field = integralFieldWireType Fixed32 fieldWireType SFixed64Field = integralFieldWireType Fixed64 fieldWireType FloatField = simpleFieldWireType Fixed32 floatToWord wordToFloat fieldWireType DoubleField = simpleFieldWireType Fixed64 doubleToWord wordToDouble fieldWireType StringField = FieldWireType Lengthy encodeUtf8 (stringizeError . decodeUtf8') fieldWireType BytesField = identityFieldWireType Lengthy fieldWireType MessageField = FieldWireType Lengthy encodeMessage decodeMessage fieldWireType GroupField = FieldWireType StartGroup messageToTaggedValues taggedValuesToMessage -- | Helper function to define a field type whose decoding operation can't fail. simpleFieldWireType :: WireType w -> (value -> w) -> (w -> value) -> FieldWireType value simpleFieldWireType w f g = FieldWireType w f (return . g) -- | A simple field type which is the same as its wire type. identityFieldWireType :: WireType w -> FieldWireType w identityFieldWireType w = simpleFieldWireType w id id -- | A simple field type which converts to/from its wire type via -- "fromIntegral". integralFieldWireType :: (Integral w, Integral value) => WireType w -> FieldWireType value integralFieldWireType w = simpleFieldWireType w fromIntegral fromIntegral stringizeError :: Either UnicodeException a -> Either String a stringizeError (Left e) = Left (show e) stringizeError (Right a) = Right a