module Text.ProtocolBuffers.Extensions
(
getKeyFieldId,getKeyFieldType,getKeyDefaultValue
, Key(..),ExtKey(..),MessageAPI(..)
, wireSizeExtField,wirePutExtField,loadExtension,notExtension
, GPB,ExtField(..),ExtendMessage(..),ExtFieldValue(..)
) where
import Control.Monad.Error.Class(throwError)
import qualified Data.ByteString.Lazy as L
import qualified Data.Foldable as F
import Data.Generics
import Data.Map(Map)
import qualified Data.Map as M
import Data.Maybe(fromMaybe,isJust)
import Data.Monoid(mappend)
import Data.Sequence(Seq,(|>))
import qualified Data.Sequence as Seq
import Data.Typeable
import Text.ProtocolBuffers.Basic
import Text.ProtocolBuffers.Default()
import Text.ProtocolBuffers.WireMessage
import Text.ProtocolBuffers.Reflections
import Text.ProtocolBuffers.Get as Get (Get,runGet,Result(..),bytesRead)
err :: String -> b
err msg = error $ "Text.ProtocolBuffers.Extensions error\n"++msg
data Key c msg v where
Key :: (ExtKey c,ExtendMessage msg,GPB v) => FieldId -> FieldType -> (Maybe v) -> Key c msg v
getKeyFieldId :: Key c msg v -> FieldId
getKeyFieldId (Key fi _ _) = fi
getKeyFieldType :: Key c msg v -> FieldType
getKeyFieldType (Key _ ft _) = ft
getKeyDefaultValue :: Key c msg v -> v
getKeyDefaultValue (Key _ _ md) = fromMaybe defaultValue md
instance Typeable1 c => Typeable2 (Key c) where
typeOf2 _ = mkTyConApp (mkTyCon "Text.ProtocolBuffers.Extensions.Key") [typeOf1 (undefined :: c ())]
instance (Typeable1 c, Typeable msg, Typeable v) => Show (Key c msg v) where
show key@(Key fieldId fieldType maybeDefaultValue) =
concat ["(Key (",show fieldId
,") (",show fieldType
,") (",show maybeDefaultValue
,") :: ",show (typeOf key)
,")"]
data GPWitness a where GPWitness :: (GPB a) => GPWitness a
deriving (Typeable)
data GPDyn = forall a . GPDyn !(GPWitness a) a
deriving (Typeable)
data GPDynSeq = forall a . GPDynSeq !(GPWitness a) !(Seq a)
deriving (Typeable)
data ExtFieldValue = ExtFromWire !WireType !(Seq ByteString)
| ExtOptional !FieldType !GPDyn
| ExtRepeated !FieldType !GPDynSeq
deriving (Typeable,Ord,Show)
data DummyMessageType deriving (Typeable)
instance ExtendMessage DummyMessageType where
getExtField = undefined
putExtField = undefined
validExtRanges = undefined
instance Eq ExtFieldValue where
(==) (ExtFromWire a b) (ExtFromWire a' b') = a==a' && b==b'
(==) (ExtOptional a b) (ExtOptional a' b') = a==a' && b==b'
(==) (ExtRepeated a b) (ExtRepeated a' b') = a==a' && b==b'
(==) x@(ExtOptional ft (GPDyn w@GPWitness _)) (ExtFromWire wt' s') =
let wt = toWireType ft
in wt==wt' && (let makeKeyType :: GPWitness a -> Key Maybe DummyMessageType a
makeKeyType = undefined
key = Key 0 ft Nothing `asTypeOf` makeKeyType w
in case parseWireExtMaybe key wt s' of
Right (_,y) -> x==y
_ -> False)
(==) y@(ExtFromWire {}) x@(ExtOptional {}) = x == y
(==) x@(ExtRepeated ft (GPDynSeq w@GPWitness _)) (ExtFromWire wt' s') =
let wt = toWireType ft
in wt==wt' && (let makeKeyType :: GPWitness a -> Key Seq DummyMessageType a
makeKeyType = undefined
key = Key 0 ft Nothing `asTypeOf` makeKeyType w
in case parseWireExtSeq key wt s' of
Right (_,y) -> x==y
_ -> False)
(==) y@(ExtFromWire {}) x@(ExtRepeated {}) = x == y
(==) _ _ = False
newtype ExtField = ExtField (Map FieldId ExtFieldValue)
deriving (Typeable,Eq,Ord,Show)
class Typeable msg => ExtendMessage msg where
getExtField :: msg -> ExtField
putExtField :: ExtField -> msg -> msg
validExtRanges :: msg -> [(FieldId,FieldId)]
class ExtKey c where
putExt :: Key c msg v -> c v -> msg -> msg
getExt :: Key c msg v -> msg -> Either String (c v)
clearExt :: Key c msg v -> msg -> msg
wireGetKey :: Key c msg v -> msg -> Get msg
class (Mergeable a,Default a,Wire a,Show a,Typeable a,Eq a,Ord a) => GPB a
instance GPB Bool
instance GPB ByteString
instance GPB Utf8
instance GPB Double
instance GPB Float
instance GPB Int32
instance GPB Int64
instance GPB Word32
instance GPB Word64
instance Mergeable ExtField where
mergeEmpty = ExtField M.empty
mergeAppend (ExtField m1) (ExtField m2) = ExtField (M.unionWith mergeExtFieldValue m1 m2)
mergeExtFieldValue :: ExtFieldValue -> ExtFieldValue -> ExtFieldValue
mergeExtFieldValue (ExtFromWire wt1 s1) (ExtFromWire wt2 s2) =
if wt1 /= wt2 then err $ "mergeExtFieldValue : ExtFromWire WireType mismatch " ++ show (wt1,wt2)
else ExtFromWire wt2 (mappend s1 s2)
mergeExtFieldValue (ExtOptional ft1 (GPDyn GPWitness d1))
(ExtOptional ft2 (GPDyn GPWitness d2)) =
if ft1 /= ft2 then err $ "mergeExtFieldValue : ExtOptional FieldType mismatch "++show (ft1,ft2)
else case cast d2 of
Nothing -> err $ "mergeExtFieldValue : ExtOptional cast failed, FieldType "++show (ft2,typeOf d1,typeOf d2)
Just d2' -> ExtOptional ft2 (GPDyn GPWitness (mergeAppend d1 d2'))
mergeExtFieldValue (ExtRepeated ft1 (GPDynSeq GPWitness s1))
(ExtRepeated ft2 (GPDynSeq GPWitness s2)) =
if ft1 /= ft2 then err $ "mergeExtFieldValue : ExtRepeated FieldType mismatch "++show (ft1,ft2)
else case cast s2 of
Nothing -> err $ "mergeExtFieldValue : ExtRepeated cast failed, FieldType "++show (ft2,typeOf s1,typeOf s2)
Just s2' -> ExtRepeated ft2 (GPDynSeq GPWitness (mappend s1 s2'))
mergeExtFieldValue a b = err $ "mergeExtFieldValue : mismatch of constructors "++show (a,b)
instance Default ExtField where
defaultValue = ExtField M.empty
instance Show (GPWitness a) where
showsPrec _n GPWitness = ("(GPWitness :: GPWitness ("++) . shows (typeOf (undefined :: a)) . (')':) . (')':)
instance Eq (GPWitness a) where
(==) GPWitness GPWitness = True
(/=) GPWitness GPWitness = False
instance Ord (GPWitness a) where
compare GPWitness GPWitness = EQ
instance (GPB a) => Data (GPWitness a) where
gunfold _k z c = case constrIndex c of
1 -> z GPWitness
_ -> err "gunfold of GPWitness error"
toConstr GPWitness = gpWitnessC
dataTypeOf _ = gpWitnessDT
gpWitnessC :: Constr
gpWitnessC = mkConstr gpWitnessDT "GPWitness" [] Prefix
gpWitnessDT :: DataType
gpWitnessDT = mkDataType "GPWitness" [gpWitnessC]
instance Eq GPDyn where
(==) a b = fromMaybe False (eqGPDyn a b)
instance Ord GPDyn where
compare a b = fromMaybe (compare (show a) (show b)) (ordGPDyn a b)
instance Show GPDyn where
showsPrec _n (GPDyn x@GPWitness a) = ("(GPDyn "++) . shows x . (" ("++) . shows a . ("))"++)
instance Eq GPDynSeq where
(==) a b = fromMaybe False (eqGPDynSeq a b)
instance Ord GPDynSeq where
compare a b = fromMaybe (compare (show a) (show b)) (ordGPDynSeq a b)
instance Show GPDynSeq where
showsPrec _n (GPDynSeq x@GPWitness s) = ("(GPDynSeq "++) . shows x . (" ("++) . shows s . ("))"++)
ordGPDyn :: GPDyn -> GPDyn -> Maybe Ordering
ordGPDyn (GPDyn GPWitness a1) (GPDyn GPWitness a2) = fmap (compare a1) (cast a2)
eqGPDyn :: GPDyn -> GPDyn -> Maybe Bool
eqGPDyn (GPDyn GPWitness a1) (GPDyn GPWitness a2) = fmap (a1==) (cast a2)
ordGPDynSeq :: GPDynSeq -> GPDynSeq -> Maybe Ordering
ordGPDynSeq (GPDynSeq GPWitness a1) (GPDynSeq GPWitness a2) = fmap (compare a1) (cast a2)
eqGPDynSeq :: GPDynSeq -> GPDynSeq -> Maybe Bool
eqGPDynSeq (GPDynSeq GPWitness a1) (GPDynSeq GPWitness a2) = fmap (a1==) (cast a2)
instance ExtKey Maybe where
putExt key Nothing msg = clearExt key msg
putExt (Key i t _) (Just v) msg =
let (ExtField ef) = getExtField msg
v' = ExtOptional t (GPDyn GPWitness v)
ef' = M.insert i v' ef
in seq v' $ seq ef' (putExtField (ExtField ef') msg)
clearExt (Key i _ _ ) msg =
let (ExtField ef) = getExtField msg
ef' = M.delete i ef
in seq ef' (putExtField (ExtField ef') msg)
getExt k@(Key i t _) msg =
let (ExtField ef) = getExtField msg
in case M.lookup i ef of
Nothing -> Right Nothing
Just (ExtFromWire wt raw) -> either Left (getExt' . snd) (parseWireExtMaybe k wt raw)
Just x -> getExt' x
where getExt' (ExtRepeated t' _) = Left $ "getKey Maybe: ExtField has repeated type: "++show (k,t')
getExt' (ExtOptional t' (GPDyn GPWitness d)) | t/=t' =
Left $ "getExt Maybe: Key's FieldType does not match ExtField's: "++show (k,t')
| otherwise =
case cast d of
Nothing -> Left $ "getExt Maybe: Key's value cast failed: "++show (k,typeOf d)
Just d' -> Right (Just d')
getExt' _ = err $ "Impossible? getExt.getExt' Maybe should not have this case (after parseWireExt)!"
wireGetKey k@(Key i t mv) msg = do
let myCast :: Maybe a -> Get a
myCast = undefined
v <- wireGet t `asTypeOf` (myCast mv)
let (ExtField ef) = getExtField msg
v' <- case M.lookup i ef of
Nothing -> return $ ExtOptional t (GPDyn GPWitness v)
Just (ExtOptional t' (GPDyn GPWitness vOld)) | t /= t' ->
fail $ "wireGetKey Maybe: Key mismatch! found wrong field type: "++show (k,t,t')
| otherwise ->
case cast vOld of
Nothing -> fail $ "wireGetKey Maybe: previous Maybe value case failed: "++show (k,typeOf vOld)
Just vOld' -> return $ ExtOptional t (GPDyn GPWitness (mergeAppend vOld' v))
Just (ExtFromWire wt raw) ->
case parseWireExtMaybe k wt raw of
Left errMsg -> fail $ "wireGetKey Maybe: Could not parseWireExtMaybe: "++show k++"\n"++errMsg
Right (_,ExtOptional t' (GPDyn GPWitness vOld)) | t/=t' ->
fail $ "wireGetKey Maybe: Key mismatch! found wrong field type: "++show (k,t,t')
| otherwise ->
case cast vOld of
Nothing -> fail $ "wireGetKey Maybe: previous Maybe value case failed: "++show (k,typeOf vOld)
Just vOld' -> return $ ExtOptional t (GPDyn GPWitness (mergeAppend vOld' v))
wtf -> fail $ "wireGetKey Maybe: Weird parseGetWireMaybe return value: "++show (k,wtf)
wtf -> fail $ "wireGetKey Maybe: ExtRepeated found with ExtOptional expected: "++show (k,wtf)
let ef' = M.insert i v' ef
seq v' $ seq ef' $ return (putExtField (ExtField ef') msg)
parseWireExtMaybe :: Key Maybe msg v -> WireType -> Seq ByteString -> Either String (FieldId,ExtFieldValue)
parseWireExtMaybe k@(Key fi ft mv) wt raw | wt /= toWireType ft =
Left $ "parseWireExt Maybe: Key's FieldType does not match ExtField's wire type: "++show (k,toWireType ft,wt)
| otherwise = do
let mkWitType :: Maybe a -> GPWitness a
mkWitType = undefined
witness = GPWitness `asTypeOf` (mkWitType mv)
parsed = map (applyGet (wireGet ft)) . F.toList $ raw
errs = [ m | Left m <- parsed ]
if null errs then Right (fi,(ExtOptional ft (GPDyn witness (mergeConcat [ a | Right a <- parsed ]))))
else Left (unlines errs)
applyGet :: Get r -> ByteString -> Either String r
applyGet g bsIn = resolveEOF (runGet g bsIn) where
resolveEOF :: Result r -> Either String r
resolveEOF (Failed i s) = Left ("Failed at "++show i++" : "++s)
resolveEOF (Finished bs _i r) | L.null bs = Right r
| otherwise = Left "Not all input consumed"
resolveEOF (Partial {}) = Left "Not enough input"
instance ExtKey Seq where
putExt key@(Key i t _) s msg | Seq.null s = clearExt key msg
| otherwise =
let (ExtField ef) = getExtField msg
v' = ExtRepeated t (GPDynSeq GPWitness s)
ef' = M.insert i v' ef
in seq v' $ seq ef' (putExtField (ExtField ef') msg)
clearExt (Key i _ _ ) msg =
let (ExtField ef) = getExtField msg
ef' = M.delete i ef
in seq ef' (putExtField (ExtField ef') msg)
getExt k@(Key i t _) msg =
let (ExtField ef) = getExtField msg
in case M.lookup i ef of
Nothing -> Right Seq.empty
Just (ExtFromWire wt raw) -> either Left (getExt' . snd) (parseWireExtSeq k wt raw)
Just x -> getExt' x
where getExt' (ExtOptional t' _) = Left $ "getKey Seq: ExtField has optional type: "++show (k,t')
getExt' (ExtRepeated t' (GPDynSeq GPWitness s)) | t'/=t =
Left $ "getExt Seq: Key's FieldType does not match ExtField's: "++show (k,t')
| otherwise =
case cast s of
Nothing -> Left $ "getExt Seq: Key's Seq value cast failed: "++show (k,typeOf s)
Just s' -> Right s'
getExt' _ = err $ "Impossible? getExt.getExt' Maybe should not have this case (after parseWireExtSeq)!"
wireGetKey k@(Key i t mv) msg = do
let myCast :: Maybe a -> Get a
myCast = undefined
v <- wireGet t `asTypeOf` (myCast mv)
let (ExtField ef) = getExtField msg
v' <- case M.lookup i ef of
Nothing -> return $ ExtRepeated t (GPDynSeq GPWitness (Seq.singleton v))
Just (ExtRepeated t' (GPDynSeq GPWitness s)) | t/=t' ->
fail $ "wireGetKey Seq: Key mismatch! found wrong field type: "++show (k,t,t')
| otherwise ->
case cast s of
Nothing -> fail $ "wireGetKey Seq: previous Seq value cast failed: "++show (k,typeOf s)
Just s' -> return $ ExtRepeated t (GPDynSeq GPWitness (s' |> v))
Just (ExtFromWire wt raw) ->
case parseWireExtSeq k wt raw of
Left errMsg -> fail $ "wireGetKey Seq: Could not parseWireExtSeq: "++show k++"\n"++errMsg
Right (_,ExtRepeated t' (GPDynSeq GPWitness s)) | t/=t' ->
fail $ "wireGetKey Seq: Key mismatch! parseWireExtSeq returned wrong field type: "++show (k,t,t')
| otherwise ->
case cast s of
Nothing -> fail $ "wireGetKey Seq: previous Seq value cast failed: "++show (k,typeOf s)
Just s' -> return $ ExtRepeated t (GPDynSeq GPWitness (s' |> v))
wtf -> fail $ "wireGetKey Seq: Weird parseWireExtSeq return value: "++show (k,wtf)
wtf -> fail $ "wireGetKey Seq: ExtOptional found when ExtRepeated expected: "++show (k,wtf)
let ef' = M.insert i v' ef
seq v' $ seq ef' $ return (putExtField (ExtField ef') msg)
parseWireExtSeq :: Key Seq msg v -> WireType -> Seq ByteString -> Either String (FieldId,ExtFieldValue)
parseWireExtSeq k@(Key i t mv) wt raw | wt /= toWireType t =
Left $ "parseWireExt Maybe: Key mismatch! Key's FieldType does not match ExtField's wire type: "++show (k,toWireType t,wt)
| otherwise = do
let mkWitType :: Maybe a -> GPWitness a
mkWitType = undefined
witness = GPWitness `asTypeOf` (mkWitType mv)
parsed = map (applyGet (wireGet t)) . F.toList $ raw
errs = [ m | Left m <- parsed ]
if null errs then Right (i,(ExtRepeated t (GPDynSeq witness (Seq.fromList [ a | Right a <- parsed ]))))
else Left (unlines errs)
wireSizeExtField :: ExtField -> WireSize
wireSizeExtField (ExtField m) = F.foldl' aSize 0 (M.assocs m) where
aSize old (fi,(ExtFromWire wt raw)) = old +
let tagSize = size'Varint (getWireTag (mkWireTag fi wt))
in F.foldl' (\oldVal new -> oldVal + L.length new) (fromIntegral (Seq.length raw) * tagSize) raw
aSize old (fi,(ExtOptional ft (GPDyn GPWitness d))) = old +
let tagSize = size'Varint (getWireTag (toWireTag fi ft))
in wireSizeReq tagSize ft d
aSize old (fi,(ExtRepeated ft (GPDynSeq GPWitness s))) = old +
let tagSize = size'Varint (getWireTag (toWireTag fi ft))
in wireSizeRep tagSize ft s
wirePutExtField :: ExtField -> Put
wirePutExtField (ExtField m) = mapM_ aPut (M.assocs m) where
aPut (fi,(ExtFromWire wt raw)) = F.mapM_ (\bs -> putVarUInt (getWireTag $ mkWireTag fi wt) >> putLazyByteString bs) raw
aPut (fi,(ExtOptional ft (GPDyn GPWitness d))) = wirePutOpt (toWireTag fi ft) ft (Just d)
aPut (fi,(ExtRepeated ft (GPDynSeq GPWitness s))) = wirePutRep (toWireTag fi ft) ft s
notExtension :: (ReflectDescriptor a, ExtendMessage a,Typeable a) => FieldId -> WireType -> a -> Get a
notExtension fieldId _wireType msg = throwError ("Field id "++show fieldId++" is not a valid extension field id for "++show (typeOf (undefined `asTypeOf` msg)))
loadExtension :: (ReflectDescriptor a, ExtendMessage a) => FieldId -> WireType -> a -> Get a
loadExtension fieldId wireType msg = do
let (ExtField ef) = getExtField msg
badwt wt = do here <- bytesRead
fail $ "Conflicting wire types at byte position "++show here ++ " for extension to message: "++show (typeOf msg,fieldId,wireType,wt)
case M.lookup fieldId ef of
Nothing -> do
bs <- wireGetFromWire fieldId wireType
let v' = ExtFromWire wireType (Seq.singleton bs)
ef' = M.insert fieldId v' ef
seq v' $ seq ef' $ return $ putExtField (ExtField ef') msg
Just (ExtFromWire wt raw) | wt /= wireType -> badwt wt
| otherwise -> do
bs <- wireGetFromWire fieldId wireType
let v' = ExtFromWire wt (raw |> bs)
ef' = M.insert fieldId v' ef
seq v' $ seq ef' $ return (putExtField (ExtField ef') msg)
Just (ExtOptional ft (GPDyn x@GPWitness a)) | toWireType ft /= wireType -> badwt (toWireType ft)
| otherwise -> do
b <- wireGet ft
let v' = ExtOptional ft (GPDyn x (mergeAppend a b))
ef' = M.insert fieldId v' ef
seq v' $ seq ef' $ return (putExtField (ExtField ef') msg)
Just (ExtRepeated ft (GPDynSeq x@GPWitness s)) | toWireType ft /= wireType -> badwt (toWireType ft)
| otherwise -> do
a <- wireGet ft
let v' = ExtRepeated ft (GPDynSeq x (s |> a))
ef' = M.insert fieldId v' ef
seq v' $ seq ef' $ return (putExtField (ExtField ef') msg)
class MessageAPI msg a b | msg a -> b where
getVal :: msg -> a -> b
isSet :: msg -> a -> Bool
isSet _ _ = True
instance (Default msg,Default a) => MessageAPI msg (msg -> Maybe a) a where
getVal m f = fromMaybe (fromMaybe defaultValue (f defaultValue)) (f m)
isSet m f = isJust (f m)
instance MessageAPI msg (msg -> (Seq a)) (Seq a) where
getVal m f = f m
isSet m f = not (Seq.null (f m))
instance (Default v) => MessageAPI msg (Key Maybe msg v) v where
getVal m k@(Key _ _ md) = case getExt k m of
Right (Just v) -> v
_ -> fromMaybe defaultValue md
isSet m (Key fid _ _) = let (ExtField x) = getExtField m
in M.member fid x
instance (Default v) => MessageAPI msg (Key Seq msg v) (Seq v) where
getVal m k@(Key _ _ _) = case getExt k m of
Right s -> s
_ -> Seq.empty
isSet m (Key fid _ _) = let (ExtField x) = getExtField m
in M.member fid x
instance MessageAPI msg (msg -> ByteString) ByteString where getVal m f = f m
instance MessageAPI msg (msg -> Utf8) Utf8 where getVal m f = f m
instance MessageAPI msg (msg -> Double) Double where getVal m f = f m
instance MessageAPI msg (msg -> Float) Float where getVal m f = f m
instance MessageAPI msg (msg -> Int32) Int32 where getVal m f = f m
instance MessageAPI msg (msg -> Int64) Int64 where getVal m f = f m
instance MessageAPI msg (msg -> Word32) Word32 where getVal m f = f m
instance MessageAPI msg (msg -> Word64) Word64 where getVal m f = f m