module Data.ProtoLens.Encoding(
encodeMessage,
buildMessage,
decodeMessage,
decodeMessageOrDie,
) where
import Data.ProtoLens.Message
import Data.ProtoLens.Encoding.Bytes
import Data.ProtoLens.Encoding.Wire
import Control.Applicative ((<|>), (<$>))
import Control.Monad (foldM)
import Data.Attoparsec.ByteString as Parse
import Data.Bool (bool)
import Data.Text.Encoding (encodeUtf8, decodeUtf8')
import Data.Text.Encoding.Error (UnicodeException(..))
import qualified Data.ByteString as B
import qualified Data.Map.Strict as Map
import Data.ByteString.Lazy.Builder as Builder
import qualified Data.ByteString.Lazy as L
import Data.Monoid (mconcat, mempty)
import Data.Foldable (foldMap, toList, foldl')
import Lens.Family2 (set, over, (^.), (&))
decodeMessage :: Message msg => B.ByteString -> Either String msg
decodeMessage input =
parseOnly (Parse.manyTill getTaggedValue endOfInput) input
>>= taggedValuesToMessage
decodeMessageOrDie :: Message msg => B.ByteString -> msg
decodeMessageOrDie bs = case decodeMessage bs of
Left e -> error $ "decodeMessageOrDie: " ++ e
Right x -> x
taggedValuesToMessage :: Message msg => [TaggedValue] -> Either String msg
taggedValuesToMessage tvs
| missing <- missingFields fields tvs, not $ null missing
= Left $ "Missing required fields " ++ show missing
| otherwise = reverseRepeatedFields fields <$> result
where
addTaggedValue msg tv@(TaggedValue tag _) =
case Map.lookup (Tag tag) fields of
Nothing -> return msg
Just field -> parseAndAddField msg field tv
fields = fieldsByTag descriptor
result = foldM addTaggedValue def tvs
missingFields :: Map.Map Tag (FieldDescriptor msg) -> [TaggedValue] -> [String]
missingFields fields
= map fieldDescriptorName
. Map.elems
. foldl' (\m (TaggedValue t _) -> Map.delete (Tag t) m) requiredFields
where
requiredFields = Map.filter isRequired fields
runEither :: Either String a -> Parser a
runEither (Left x) = fail x
runEither (Right x) = return x
parseAndAddField :: msg
-> FieldDescriptor msg
-> TaggedValue
-> Either String msg
parseAndAddField
msg
(FieldDescriptor name typeDescriptor accessor)
(TaggedValue tag (WireValue wt val))
= case fieldWireType typeDescriptor of
FieldWireType fieldWt _ get -> let
getSimpleVal = do
Equal <- equalWireTypes name fieldWt wt
get val
getPackedVals = do
Equal <- equalWireTypes name Lengthy wt
let getElt = getWireValue fieldWt tag >>= runEither . get
parseOnly (manyReversedTill getElt endOfInput) val
in case accessor of
PlainField _ f -> do
x <- getSimpleVal
return $ set f x msg
OptionalField f -> do
x <- getSimpleVal
return $ set f (Just x) msg
RepeatedField Unpacked f -> do
x <- getSimpleVal
return $ over f (x:) msg
RepeatedField Packed f -> do
xs <- getPackedVals
return $ over f (xs++) msg
MapField keyLens valueLens f -> do
entry <- getSimpleVal
return $ over f
(Map.insert (entry ^. keyLens) (entry ^. valueLens))
msg
manyReversedTill :: Parser a -> Parser b -> Parser [a]
manyReversedTill p end = loop []
where
loop xs = (end >> return xs) <|> (p >>= \x -> loop (x:xs))
encodeMessage :: Message msg => msg -> B.ByteString
encodeMessage = L.toStrict . toLazyByteString . buildMessage
buildMessage :: Message msg => msg -> Builder
buildMessage msg = foldMap putTaggedValue (messageToTaggedValues msg)
messageToTaggedValues :: Message msg => msg -> [TaggedValue]
messageToTaggedValues msg = mconcat
[ map (TaggedValue t) (messageFieldToVals fieldDescr msg)
| (Tag t, fieldDescr) <- Map.toList (fieldsByTag descriptor)
]
messageFieldToVals :: FieldDescriptor msg -> msg -> [WireValue]
messageFieldToVals (FieldDescriptor _ typeDescriptor accessor) msg =
case fieldWireType typeDescriptor of
FieldWireType wt convert _ -> case accessor of
PlainField d f
| Optional <- d, src == fieldDefault -> []
| otherwise -> [WireValue wt (convert src)]
where src = msg ^. f
OptionalField f -> case msg ^. f of
Just src -> [WireValue wt (convert src)]
_ -> mempty
RepeatedField Unpacked f
-> [ WireValue wt (convert src)
| src <- toList (msg ^. f)
]
RepeatedField Packed f
-> [WireValue Lengthy v]
where v = L.toStrict $ toLazyByteString
$ mconcat
[ putWireValue wt (convert src)
| src <- toList (msg ^. f)
]
MapField keyLens valueLens f ->
[ WireValue wt v
| (key, value) <- Map.toList (msg ^. f)
, let entry = def & set keyLens key & set valueLens value
, let v = convert entry
]
data FieldWireType value where
FieldWireType :: WireType w -> (value -> w) -> (w -> Either String value)
-> FieldWireType value
fieldWireType :: FieldTypeDescriptor value -> FieldWireType value
fieldWireType EnumField = simpleFieldWireType VarInt
(fromIntegral . fromEnum)
(toEnum . fromIntegral)
fieldWireType BoolField = simpleFieldWireType VarInt (bool 0 1) (/= 0)
fieldWireType Int32Field = integralFieldWireType VarInt
fieldWireType Int64Field = integralFieldWireType VarInt
fieldWireType UInt32Field = integralFieldWireType VarInt
fieldWireType UInt64Field = identityFieldWireType VarInt
fieldWireType SInt32Field = simpleFieldWireType VarInt
(fromIntegral . signedInt32ToWord)
(wordToSignedInt32 . fromIntegral)
fieldWireType SInt64Field = simpleFieldWireType VarInt
signedInt64ToWord wordToSignedInt64
fieldWireType Fixed32Field = identityFieldWireType Fixed32
fieldWireType Fixed64Field = identityFieldWireType Fixed64
fieldWireType SFixed32Field = integralFieldWireType Fixed32
fieldWireType SFixed64Field = integralFieldWireType Fixed64
fieldWireType FloatField = simpleFieldWireType Fixed32 floatToWord wordToFloat
fieldWireType DoubleField = simpleFieldWireType Fixed64
doubleToWord wordToDouble
fieldWireType StringField = FieldWireType Lengthy encodeUtf8
(stringizeError . decodeUtf8')
fieldWireType BytesField = identityFieldWireType Lengthy
fieldWireType MessageField = FieldWireType Lengthy encodeMessage decodeMessage
fieldWireType GroupField =
FieldWireType StartGroup messageToTaggedValues taggedValuesToMessage
simpleFieldWireType :: WireType w -> (value -> w) -> (w -> value)
-> FieldWireType value
simpleFieldWireType w f g = FieldWireType w f (return . g)
identityFieldWireType :: WireType w -> FieldWireType w
identityFieldWireType w = simpleFieldWireType w id id
integralFieldWireType
:: (Integral w, Integral value) => WireType w -> FieldWireType value
integralFieldWireType w = simpleFieldWireType w fromIntegral fromIntegral
stringizeError :: Either UnicodeException a -> Either String a
stringizeError (Left e) = Left (show e)
stringizeError (Right a) = Right a