{-# Language ScopedTypeVariables, MultiWayIf, TypeFamilies, FlexibleContexts #-} module Data.Bond.Internal.TaggedProtocol where import Data.Bond.Schema.BondDataType import Data.Bond.Struct import Data.Bond.TypedSchema import Data.Bond.Types import Data.Bond.Internal.BinaryUtils import Data.Bond.Internal.Default import Data.Bond.Internal.OrdinalSet import Data.Bond.Internal.Protocol import Data.Bond.Internal.SchemaOps import Data.Bond.Internal.SchemaUtils import Control.Applicative import Control.Monad import Control.Monad.Error import Data.Bits import Data.Proxy import Prelude -- ghc 7.10 workaround for Control.Applicative import qualified Data.Binary.Get as B import qualified Data.Binary.Put as B import qualified Data.ByteString.Lazy as BL import qualified Data.Map as M import qualified Data.Map.Strict as MS data StructLevel = TopLevelStruct | BaseStruct deriving (Show, Eq) class Protocol t => TaggedProtocol t where getFieldHeader :: BondGet t (BondDataType, Ordinal) getListHeader :: BondGet t (BondDataType, Int) getTaggedStruct :: BondGet t Struct putFieldHeader :: BondDataType -> Ordinal -> BondPut t putListHeader :: (Integral a, FiniteBits a) => BondDataType -> a -> BondPut t putTaggedStruct :: MonadError String (BondPutM t) => Struct -> BondPut t skipStruct :: BondGet t () skipRestOfStruct :: BondGet t () skipType :: TaggedProtocol t => BondDataType -> BondGet t () getStruct :: forall a t. (Functor (ReaderM t), Monad (ReaderM t), TaggedProtocol t, BondStruct a) => StructLevel -> BondGet t a getStruct level = do let schema = getSchema (Proxy :: Proxy a) let fieldsMap = structFields schema b <- bondStructGetBase defaultValue -- iterate over stream, update fields let readField wiretype ordinal s = if M.member ordinal fieldsMap then bondStructGetField ordinal s else do skipType wiretype -- unknown field, ignore it return s let loop (s, ords) = do (wiretype, ordinal) <- getFieldHeader if | wiretype == bT_STOP && level == BaseStruct -> fail "BT_STOP found where BT_STOP_BASE expected" | wiretype == bT_STOP && level == TopLevelStruct -> return (s, ords) | wiretype == bT_STOP_BASE && level == BaseStruct -> return (s, ords) | wiretype == bT_STOP_BASE && level == TopLevelStruct -> skipRestOfStruct >> return (s, ords) | otherwise -> do s' <- readField wiretype ordinal s loop (s', deleteOrdinal ordinal ords) (value, notRead) <- loop (b, structRequiredOrdinals schema) unless (isEmptySet notRead) $ fail $ "required fields not read: " ++ show (map (getFieldName schema) $ toOrdinalList notRead) return value putStruct :: (WriterM t ~ ErrorT String B.PutM, TaggedProtocol t, BondStruct a) => StructLevel -> a -> BondPut t putStruct level a = do bondStructPut a case level of TopLevelStruct -> putTag bT_STOP BaseStruct -> putTag bT_STOP_BASE putBaseStruct :: (WriterM t ~ ErrorT String B.PutM, TaggedProtocol t, BondStruct a) => a -> BondPut t putBaseStruct = putStruct BaseStruct putField :: forall a b t. (Monad (BondPutM t), TaggedProtocol t, BondType a, BondStruct b) => Proxy b -> Ordinal -> a -> BondPut t putField p ordinal value = do let tag = getWireType (Proxy :: Proxy a) let info = M.findWithDefault (error "internal error: unknown field ordinal") ordinal (structFields $ getSchema p) let needToSave = not (equalToDefault (fieldType info) value) || fieldModifier info /= FieldOptional when needToSave $ do putFieldHeader tag ordinal bondPut value putTag :: WriterM t ~ ErrorT String B.PutM => BondDataType -> BondPut t putTag = putWord8 . fromIntegral . fromEnum binaryDecode :: forall a t. (ReaderM t ~ B.Get, BondStruct a, Protocol t) => t -> BL.ByteString -> Either String a binaryDecode _ s = let BondGet g = bondGetStruct :: BondGet t a in case B.runGetOrFail g s of Left (_, used, msg) -> Left $ "parse error at " ++ show used ++ ": " ++ msg Right (rest, used, _) | not (BL.null rest) -> Left $ "incomplete parse, used " ++ show used ++ ", left " ++ show (BL.length rest) Right (_, _, a) -> Right a binaryEncode :: forall a t. (WriterM t ~ ErrorT String B.PutM, BondStruct a, Protocol t) => t -> a -> Either String BL.ByteString binaryEncode _ a = let BondPut g = bondPutStruct a :: BondPut t in tryPut g getTaggedData :: forall t. (ReaderM t ~ B.Get, TaggedProtocol t) => BondGet t Struct getTaggedData = fieldLoop $ Struct Nothing M.empty where getValue :: BondDataType -> BondGet t Value getValue t = if | t == bT_BOOL -> BOOL <$> bondGetBool | t == bT_UINT8 -> UINT8 <$> bondGetUInt8 | t == bT_UINT16 -> UINT16 <$> bondGetUInt16 | t == bT_UINT32 -> UINT32 <$> bondGetUInt32 | t == bT_UINT64 -> UINT64 <$> bondGetUInt64 | t == bT_FLOAT -> FLOAT <$> bondGetFloat | t == bT_DOUBLE -> DOUBLE <$> bondGetDouble | t == bT_STRING -> STRING <$> bondGetString | t == bT_STRUCT -> STRUCT <$> getTaggedStruct | t == bT_LIST -> do (td, n) <- getListHeader LIST td <$> replicateM n (getValue td) | t == bT_SET -> do (td, n) <- getListHeader SET td <$> replicateM n (getValue td) | t == bT_MAP -> do tk <- BondDataType . fromIntegral <$> getWord8 tv <- BondDataType . fromIntegral <$> getWord8 n <- getVarInt MAP tk tv <$> replicateM n (do k <- getValue tk v <- getValue tv return (k, v)) | t == bT_INT8 -> INT8 <$> bondGetInt8 | t == bT_INT16 -> INT16 <$> bondGetInt16 | t == bT_INT32 -> INT32 <$> bondGetInt32 | t == bT_INT64 -> INT64 <$> bondGetInt64 | t == bT_WSTRING -> WSTRING <$> bondGetWString | otherwise -> fail $ "invalid field type " ++ bondTypeName t setField s o v = return $ s { fields = MS.insert o v (fields s) } fieldLoop s = do (t, o) <- getFieldHeader if | t == bT_STOP -> return s | t == bT_STOP_BASE -> fieldLoop $ Struct (Just s) M.empty | otherwise -> getValue t >>= setField s o >>= fieldLoop readTagged :: forall t. (ReaderM t ~ B.Get, TaggedProtocol t) => t -> BL.ByteString -> Either String Struct readTagged _ s = let BondGet g = getTaggedStruct :: BondGet t Struct in case B.runGetOrFail g s of Left (_, used, msg) -> Left $ "parse error at " ++ show used ++ ": " ++ msg Right (rest, used, _) | not (BL.null rest) -> Left $ "incomplete parse, used " ++ show used ++ ", left " ++ show (BL.length rest) Right (_, _, a) -> Right a readTaggedWithSchema :: forall t. (ReaderM t ~ B.Get, TaggedProtocol t) => t -> StructSchema -> BL.ByteString -> Either String Struct readTaggedWithSchema _ schema s = let BondGet g = getTaggedStruct :: BondGet t Struct in case B.runGetOrFail g s of Left (_, used, msg) -> Left $ "parse error at " ++ show used ++ ": " ++ msg Right (rest, used, _) | not (BL.null rest) -> Left $ "incomplete parse, used " ++ show used ++ ", left " ++ show (BL.length rest) Right (_, _, a) -> checkStructSchema schema a putTaggedData :: forall t. (MonadError String (BondPutM t), WriterM t ~ ErrorT String B.PutM, TaggedProtocol t) => Struct -> BondPut t putTaggedData s = do case base s of Just b -> putTaggedData b >> putTag bT_STOP_BASE Nothing -> return () forM_ (M.toList $ fields s) $ \ (o, v) -> do let (typ, writer) = saveValue v putFieldHeader typ o writer where saveValue :: Value -> (BondDataType, BondPut t) saveValue (BOOL v) = (bT_BOOL, bondPutBool v) saveValue (INT8 v) = (bT_INT8, bondPutInt8 v) saveValue (INT16 v) = (bT_INT16, bondPutInt16 v) saveValue (INT32 v) = (bT_INT32, bondPutInt32 v) saveValue (INT64 v) = (bT_INT64, bondPutInt64 v) saveValue (UINT8 v) = (bT_UINT8, bondPutUInt8 v) saveValue (UINT16 v) = (bT_UINT16, bondPutUInt16 v) saveValue (UINT32 v) = (bT_UINT32, bondPutUInt32 v) saveValue (UINT64 v) = (bT_UINT64, bondPutUInt64 v) saveValue (FLOAT v) = (bT_FLOAT, bondPutFloat v) saveValue (DOUBLE v) = (bT_DOUBLE, bondPutDouble v) saveValue (STRING v) = (bT_STRING, bondPutString v) saveValue (WSTRING v) = (bT_WSTRING, bondPutWString v) saveValue (STRUCT v) = (bT_STRUCT, putTaggedStruct v) saveValue (LIST td xs) = (bT_LIST, putListHeader td (length xs) >> mapM_ (saveTypedValue td) xs) saveValue (SET td xs) = (bT_SET, putListHeader td (length xs) >> mapM_ (saveTypedValue td) xs) saveValue (MAP tk tv xs) = (bT_MAP, do putTag tk putTag tv putVarInt $ length xs forM_ xs $ \ (k, v) -> do saveTypedValue tk k saveTypedValue tv v ) saveValue (BONDED (BondedObject v)) = (bT_STRUCT, putTaggedStruct v) saveValue (BONDED _) = (bT_STRUCT, throwError "not implemented: should decode bonded values before tagged writes") -- FIXME be smart here -- same sig - copy stream -- tagged sig - unmarshal struct blindly, then marshal -- untagged sig - return error, but do untagged decoding while matching struct with schema saveTypedValue td v = let (realtd, writer) = saveValue v in if td == realtd then writer else throwError $ "element type do not match container type: " ++ bondTypeName td ++ " expected, " ++ bondTypeName realtd ++ " found" writeTagged :: forall t. (MonadError String (BondPutM t), WriterM t ~ ErrorT String B.PutM, TaggedProtocol t) => t -> Struct -> Either String BL.ByteString writeTagged _ a = let BondPut g = putTaggedStruct a :: BondPut t in tryPut g writeTaggedWithSchema :: (MonadError String (BondPutM t), WriterM t ~ ErrorT String B.PutM, TaggedProtocol t) => t -> StructSchema -> Struct -> Either String BL.ByteString writeTaggedWithSchema t schema struct = checkStructSchema schema struct >>= writeTagged t