{-# Language ScopedTypeVariables, MultiWayIf, TypeFamilies, FlexibleContexts #-} module Data.Bond.Internal.CompactBinaryProto ( CompactBinaryProto(..), CompactBinaryV1Proto(..) ) where import Data.Bond.Proto import Data.Bond.Types import Data.Bond.Internal.BinaryUtils import Data.Bond.Internal.BondedUtils import Data.Bond.Internal.Cast import Data.Bond.Internal.ProtoUtils import Data.Bond.Internal.Protocol import Data.Bond.Internal.SchemaUtils import Data.Bond.Internal.TaggedProtocol import Data.Bond.Internal.ZigZag import Data.Bond.Schema.BondDataType import Data.Bond.Schema.ProtocolType import Control.Applicative hiding (optional) import Control.Monad import Control.Monad.Error import Data.Bits import Data.List import Data.Maybe import Data.Proxy import Prelude -- ghc 7.10 workaround for Control.Applicative import qualified Data.Binary.Get as B import qualified Data.Binary.Put as B import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import qualified Data.HashSet as H import qualified Data.Map as M import qualified Data.Set as S import qualified Data.Vector as V {-| A binary, tagged protocol using variable integer encoding and compact field header. Version 2 of Compact Binary adds length prefix for structs. This enables deserialization of bonded\ and skipping of unknown struct fields in constant time. -} data CompactBinaryProto = CompactBinaryProto -- |A binary, tagged protocol using variable integer encoding and compact field header. data CompactBinaryV1Proto = CompactBinaryV1Proto instance TaggedProtocol CompactBinaryProto where getFieldHeader = getCompactFieldHeader getListHeader = do v <- getWord8 if v `shiftR` 5 /= 0 then return (BondDataType $ fromIntegral (v .&. 0x1f), fromIntegral (v `shiftR` 5) - 1) else do n <- getVarInt return (BondDataType (fromIntegral v), n) getTaggedStruct = do size <- getVarInt isolate size getTaggedData putFieldHeader = putCompactFieldHeader putListHeader t n = do let tag = fromIntegral $ fromEnum t if n < 7 then putWord8 $ tag .|. fromIntegral ((1 + n) `shiftL` 5) else do putWord8 tag putVarInt n putTaggedStruct v = do let BondPut g = putTaggedData v >> putTag bT_STOP :: BondPut CompactBinaryProto case tryPut g of Left msg -> throwError msg Right bs -> do putVarInt $ BL.length bs putLazyByteString bs skipStruct = getVarInt >>= skip skipRestOfStruct = let loop = do (wiretype, _) <- getFieldHeader if | wiretype == bT_STOP -> return () | wiretype == bT_STOP_BASE -> loop | otherwise -> skipType wiretype >> loop in loop skipType = compactSkipType instance BondProto CompactBinaryProto where bondRead = binaryDecode bondWrite = binaryEncode bondReadWithSchema = readTaggedWithSchema bondWriteWithSchema = writeTaggedWithSchema protoSig _ = protoHeader cOMPACT_PROTOCOL 2 instance BondTaggedProto CompactBinaryProto where bondReadTagged = readTagged bondWriteTagged = writeTagged instance Protocol CompactBinaryProto where type ReaderM CompactBinaryProto = B.Get type WriterM CompactBinaryProto = ErrorT String B.PutM bondGetStruct = do size <- getVarInt isolate size $ getStruct TopLevelStruct bondGetBaseStruct = getStruct BaseStruct bondGetBool = do v <- getWord8 return $ v /= 0 bondGetUInt8 = getWord8 bondGetUInt16 = getVarInt bondGetUInt32 = getVarInt bondGetUInt64 = getVarInt bondGetInt8 = fromIntegral <$> getWord8 bondGetInt16 = fromZigZag <$> getVarInt bondGetInt32 = fromZigZag <$> getVarInt bondGetInt64 = fromZigZag <$> getVarInt bondGetFloat = wordToFloat <$> getWord32le bondGetDouble = wordToDouble <$> getWord64le bondGetString = do n <- getVarInt Utf8 <$> getByteString n bondGetWString = do n <- getVarInt Utf16 <$> getByteString (n * 2) bondGetBlob = getBlob bondGetDefNothing = Just <$> bondGet bondGetList = getList bondGetHashSet = H.fromList <$> bondGetList bondGetSet = S.fromList <$> bondGetList bondGetMap = getMap bondGetVector = getVector bondGetNullable = do v <- bondGetList case v of [] -> return Nothing [x] -> return (Just x) _ -> fail $ "list of length " ++ show (length v) ++ " where nullable expected" bondGetBonded = getBonded cOMPACT_PROTOCOL 2 bondPutStruct v = do let BondPut g = putStruct TopLevelStruct v :: BondPut CompactBinaryProto case tryPut g of Left msg -> throwError msg Right bs -> do putVarInt $ BL.length bs putLazyByteString bs bondPutBaseStruct = putBaseStruct bondPutField = putField bondPutDefNothingField p n Nothing = unless (isOptionalField p n) $ fail "can't write nothing to non-optional field" bondPutDefNothingField p n (Just v) = putField p n v bondPutBool True = putWord8 1 bondPutBool False = putWord8 0 bondPutUInt8 = putWord8 bondPutUInt16 = putVarInt bondPutUInt32 = putVarInt bondPutUInt64 = putVarInt bondPutInt8 = putWord8 . fromIntegral bondPutInt16 = putVarInt . toZigZag bondPutInt32 = putVarInt . toZigZag bondPutInt64 = putVarInt . toZigZag bondPutFloat = putWord32le . floatToWord bondPutDouble = putWord64le . doubleToWord bondPutString (Utf8 s) = do putVarInt $ BS.length s putByteString s bondPutWString (Utf16 s) = do putVarInt $ BS.length s `div` 2 putByteString s bondPutList = putList bondPutNullable = bondPutList . maybeToList bondPutHashSet = putHashSet bondPutSet = putSet bondPutMap = putMap bondPutVector = putVector bondPutBlob (Blob b) = do putListHeader bT_INT8 (BS.length b) putByteString b bondPutBonded (BondedObject a) = bondPut a bondPutBonded s = do BondedStream stream <- case bondRecode CompactBinaryProto s of Left msg -> throwError $ "Bonded recode error: " ++ msg Right v -> return v putLazyByteString (BL.drop 4 stream) instance TaggedProtocol CompactBinaryV1Proto where getFieldHeader = getCompactFieldHeader getListHeader = do t <- BondDataType . fromIntegral <$> getWord8 n <- getVarInt return (t, n) getTaggedStruct = getTaggedData putFieldHeader = putCompactFieldHeader putListHeader t n = do putTag t putVarInt n putTaggedStruct s = putTaggedData s >> putTag bT_STOP skipStruct = let loop = do (td, _) <- getFieldHeader if | td == bT_STOP -> return () | td == bT_STOP_BASE -> loop | otherwise -> skipType td >> loop in loop skipRestOfStruct = skipType bT_STRUCT skipType = compactSkipType instance BondProto CompactBinaryV1Proto where bondRead = binaryDecode bondWrite = binaryEncode bondReadWithSchema = readTaggedWithSchema bondWriteWithSchema = writeTaggedWithSchema protoSig _ = protoHeader cOMPACT_PROTOCOL 1 instance BondTaggedProto CompactBinaryV1Proto where bondReadTagged = readTagged bondWriteTagged = writeTagged instance Protocol CompactBinaryV1Proto where type ReaderM CompactBinaryV1Proto = B.Get type WriterM CompactBinaryV1Proto = ErrorT String B.PutM bondGetStruct = getStruct TopLevelStruct bondGetBaseStruct = getStruct BaseStruct bondGetBool = do v <- getWord8 return $ v /= 0 bondGetUInt8 = getWord8 bondGetUInt16 = getVarInt bondGetUInt32 = getVarInt bondGetUInt64 = getVarInt bondGetInt8 = fromIntegral <$> getWord8 bondGetInt16 = fromZigZag <$> getVarInt bondGetInt32 = fromZigZag <$> getVarInt bondGetInt64 = fromZigZag <$> getVarInt bondGetFloat = wordToFloat <$> getWord32le bondGetDouble = wordToDouble <$> getWord64le bondGetString = do n <- getVarInt Utf8 <$> getByteString n bondGetWString = do n <- getVarInt Utf16 <$> getByteString (n * 2) bondGetBlob = getBlob bondGetDefNothing = Just <$> bondGet bondGetList = getList bondGetHashSet = H.fromList <$> bondGetList bondGetSet = S.fromList <$> bondGetList bondGetMap = getMap bondGetVector = getVector bondGetNullable = do v <- bondGetList case v of [] -> return Nothing [x] -> return (Just x) _ -> fail $ "list of length " ++ show (length v) ++ " where nullable expected" bondGetBonded = getBonded cOMPACT_PROTOCOL 1 bondPutStruct = putStruct TopLevelStruct bondPutBaseStruct = putBaseStruct bondPutField = putField bondPutDefNothingField p n Nothing = unless (isOptionalField p n) $ fail "can't write nothing to non-optional field" bondPutDefNothingField p n (Just v) = putField p n v bondPutBool True = putWord8 1 bondPutBool False = putWord8 0 bondPutUInt8 = putWord8 bondPutUInt16 = putVarInt bondPutUInt32 = putVarInt bondPutUInt64 = putVarInt bondPutInt8 = putWord8 . fromIntegral bondPutInt16 = putVarInt . toZigZag bondPutInt32 = putVarInt . toZigZag bondPutInt64 = putVarInt . toZigZag bondPutFloat = putWord32le . floatToWord bondPutDouble = putWord64le . doubleToWord bondPutString (Utf8 s) = do putVarInt $ BS.length s putByteString s bondPutWString (Utf16 s) = do putVarInt $ BS.length s `div` 2 putByteString s bondPutList = putList bondPutNullable = bondPutList . maybeToList bondPutHashSet = putHashSet bondPutSet = putSet bondPutMap = putMap bondPutVector = putVector bondPutBlob (Blob b) = do putListHeader bT_INT8 (BS.length b) putByteString b bondPutBonded (BondedObject a) = bondPut a bondPutBonded s = do BondedStream stream <- case bondRecode CompactBinaryV1Proto s of Left msg -> throwError $ "Bonded recode error: " ++ msg Right v -> return v putLazyByteString (BL.drop 4 stream) getCompactFieldHeader :: (ReaderM t ~ B.Get) => BondGet t (BondDataType, Ordinal) getCompactFieldHeader = do tag <- getWord8 case tag `shiftR` 5 of 6 -> do n <- getWord8 return (BondDataType $ fromIntegral $ tag .&. 31, Ordinal (fromIntegral n)) 7 -> do n <- getWord16le return (BondDataType $ fromIntegral $ tag .&. 31, Ordinal n) n -> return (BondDataType $ fromIntegral $ tag .&. 31, Ordinal (fromIntegral n)) putCompactFieldHeader :: (WriterM t ~ ErrorT String B.PutM) => BondDataType -> Ordinal -> BondPut t putCompactFieldHeader t (Ordinal n) = let tbits = fromIntegral $ fromEnum t nbits = fromIntegral n in if | n <= 5 -> putWord8 $ tbits .|. (nbits `shiftL` 5) | n <= 255 -> do putWord8 $ tbits .|. 0xC0 putWord8 nbits | otherwise -> do putWord8 $ tbits .|. 0xE0 putWord16le n getBlob :: (TaggedProtocol t, ReaderM t ~ B.Get) => BondGet t Blob getBlob = do (t, n) <- getListHeader unless (t == bT_INT8) $ fail $ "invalid element tag " ++ bondTypeName t ++ " in blob field" Blob <$> getByteString n getList :: forall a t. (TaggedProtocol t, ReaderM t ~ B.Get, BondType a) => BondGet t [a] getList = do let et = getWireType (Proxy :: Proxy a) (t, n) <- getListHeader unless (t == et) $ fail $ "invalid element tag " ++ bondTypeName t ++ " in list field, " ++ bondTypeName et ++ " expected" replicateM n bondGet getVector :: forall a t. (TaggedProtocol t, ReaderM t ~ B.Get, BondType a) => BondGet t (Vector a) getVector = do let et = getWireType (Proxy :: Proxy a) (t, n) <- getListHeader unless (t == et) $ fail $ "invalid element tag " ++ bondTypeName t ++ " in list field, " ++ bondTypeName et ++ " expected" V.replicateM n bondGet getMap :: forall k v t. (TaggedProtocol t, ReaderM t ~ B.Get, Ord k, BondType k, BondType v) => BondGet t (Map k v) getMap = do let etk = getWireType (Proxy :: Proxy k) let etv = getWireType (Proxy :: Proxy v) tk <- BondDataType . fromIntegral <$> getWord8 tv <- BondDataType . fromIntegral <$> getWord8 unless (tk == etk) $ fail $ "invalid key tag " ++ bondTypeName tk ++ " in map field, " ++ bondTypeName etk ++ " expected" unless (tv == etv) $ fail $ "invalid value tag " ++ bondTypeName tv ++ " in map field, " ++ bondTypeName etv ++ " expected" n <- getVarInt fmap M.fromList $ replicateM n $ do k <- bondGet v <- bondGet return (k, v) getBonded :: (TaggedProtocol t, ReaderM t ~ B.Get) => ProtocolType -> Word16 -> BondGet t (Bonded a) getBonded sig ver = do size <- lookAhead $ do start <- bytesRead skipType bT_STRUCT stop <- bytesRead return (stop - start) bs <- getLazyByteString (fromIntegral size) return $ BondedStream $ BL.append (protoHeader sig ver) bs skipVarInt :: forall t. (ReaderM t ~ B.Get) => BondGet t () skipVarInt = void (getVarInt :: BondGet t Word64) compactSkipType :: (TaggedProtocol t, ReaderM t ~ B.Get) => BondDataType -> BondGet t () compactSkipType t = if | t == bT_BOOL -> skip 1 | t == bT_UINT8 -> skip 1 | t == bT_UINT16 -> skipVarInt | t == bT_UINT32 -> skipVarInt | t == bT_UINT64 -> skipVarInt | t == bT_FLOAT -> skip 4 | t == bT_DOUBLE -> skip 8 | t == bT_STRING -> getVarInt >>= skip | t == bT_STRUCT -> skipStruct | t == bT_LIST -> do (td, n) <- getListHeader replicateM_ n (skipType td) | t == bT_SET -> skipType bT_LIST | t == bT_MAP -> do tk <- BondDataType . fromIntegral <$> getWord8 tv <- BondDataType . fromIntegral <$> getWord8 n <- getVarInt replicateM_ n $ skipType tk >> skipType tv | t == bT_INT8 -> skip 1 | t == bT_INT16 -> skipVarInt | t == bT_INT32 -> skipVarInt | t == bT_INT64 -> skipVarInt | t == bT_WSTRING -> do n <- getVarInt skip $ n * 2 | otherwise -> fail $ "Invalid type to skip " ++ bondTypeName t putList :: forall a t. (TaggedProtocol t, Monad (WriterM t), BondType a) => [a] -> BondPut t putList xs = do putListHeader (getWireType (Proxy :: Proxy a)) (length xs) mapM_ bondPut xs putHashSet :: forall a t. (TaggedProtocol t, Monad (WriterM t), BondType a) => HashSet a -> BondPut t putHashSet xs = do putListHeader (getWireType (Proxy :: Proxy a)) (H.size xs) mapM_ bondPut $ H.toList xs putSet :: forall a t. (TaggedProtocol t, Monad (WriterM t), BondType a) => Set a -> BondPut t putSet xs = do putListHeader (getWireType (Proxy :: Proxy a)) (S.size xs) mapM_ bondPut $ S.toList xs putMap :: forall k v t. (Protocol t, WriterM t ~ ErrorT String B.PutM, BondType k, BondType v) => Map k v -> BondPut t putMap m = do putTag $ getWireType (Proxy :: Proxy k) putTag $ getWireType (Proxy :: Proxy v) putVarInt $ M.size m forM_ (M.toList m) $ \(k, v) -> do bondPut k bondPut v putVector :: forall a t. (TaggedProtocol t, Monad (WriterM t), BondType a) => Vector a -> BondPut t putVector xs = do putListHeader (getWireType (Proxy :: Proxy a)) (V.length xs) V.mapM_ bondPut xs