{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
module Data.ProtocolBuffers.Wire
  ( Enumeration(..)
  , WireField(..)
  , Tag
  , EncodeWire(..)
  , DecodeWire(..)
  , wireFieldTag
  , getWireField
  , getVarInt
  , getVarintPrefixedBS
  , putVarSInt
  , putVarUInt
  , putVarintPrefixedBS
  , zzEncode32
  , zzEncode64
  , zzDecode32
  , zzDecode64
  ) where
import Control.Applicative
import Data.Bits
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Foldable
import Data.Int
import Data.Monoid
import Data.Serialize.Get
import Data.Serialize.IEEE754
import Data.Serialize.Put
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Data.Typeable
import Data.Word
import Data.Binary.IEEE754 (wordToDouble, wordToFloat)
import Data.ProtocolBuffers.Types
type Tag = Word32
data WireField
  = VarintField    {-# UNPACK #-} !Tag {-# UNPACK #-} !Word64 
  | Fixed64Field   {-# UNPACK #-} !Tag {-# UNPACK #-} !Word64 
  | DelimitedField {-# UNPACK #-} !Tag !ByteString 
  | StartField     {-# UNPACK #-} !Tag 
  | EndField       {-# UNPACK #-} !Tag 
  | Fixed32Field   {-# UNPACK #-} !Tag {-# UNPACK #-} !Word32 
    deriving (Eq, Ord, Show, Typeable)
getVarintPrefixedBS :: Get ByteString
getVarintPrefixedBS = getBytes =<< getVarInt
putVarintPrefixedBS :: ByteString -> Put
putVarintPrefixedBS bs = putVarUInt (B.length bs) >> putByteString bs
getWireField :: Get WireField
getWireField = do
  wireTag <- getVarInt
  let tag = wireTag `shiftR` 3
  case wireTag .&. 7 of
    0 -> VarintField    tag <$> getVarInt
    1 -> Fixed64Field   tag <$> getWord64le
    2 -> DelimitedField tag <$> getVarintPrefixedBS
    3 -> return $! StartField tag
    4 -> return $! EndField   tag
    5 -> Fixed32Field   tag <$> getWord32le
    x -> fail $ "Wire type out of range: " ++ show x
putWireField :: WireField -> Put
putWireField (VarintField    t val) = putWireTag t 0 >> putVarUInt val
putWireField (Fixed64Field   t val) = putWireTag t 1 >> putWord64le val
putWireField (DelimitedField t val) = putWireTag t 2 >> putVarintPrefixedBS val
putWireField (StartField     t    ) = putWireTag t 3
putWireField (EndField       t    ) = putWireTag t 4
putWireField (Fixed32Field   t val) = putWireTag t 5 >> putWord32le val
putWireTag :: Tag -> Word32 -> Put
putWireTag tag typ
  | tag <= 0x1FFFFFFF, typ <= 7 = putVarUInt $ tag `shiftL` 3 .|. (typ .&. 7)
  | tag  > 0x1FFFFFFF = legacyFail $ "Wire tag out of range: "  ++ show tag
  | otherwise         = legacyFail $ "Wire type out of range: " ++ show typ
getVarInt :: (Integral a, Bits a) => Get a
getVarInt = go 0 0 where
  go n !val = do
    b <- getWord8
    if testBit b 7
      then go (n+7) (val .|. (fromIntegral (b .&. 0x7F) `shiftL` n))
      else return $! val .|. (fromIntegral b `shiftL` n)
{-# INLINE putVarSInt #-}
putVarSInt :: (Integral a, Bits a) => a -> Put
putVarSInt bIn =
  case compare bIn 0 of
    LT -> let 
              b = fromIntegral bIn
              len = 10                                
              last'Mask = 1                           
              go :: Int64 -> Int -> Put
              go !i 1 = putWord8 (fromIntegral (i .&. last'Mask))
              go !i n = putWord8 (fromIntegral (i .&. 0x7F) .|. 0x80) >> go (i `shiftR` 7) (pred n)
          in go b len
    EQ -> putWord8 0
    GT -> putVarUInt bIn
{-# INLINE putVarUInt #-}
putVarUInt :: (Integral a, Bits a) => a -> Put
putVarUInt i
  | i < 0x80  = putWord8 (fromIntegral i)
  | otherwise = putWord8 (fromIntegral (i .&. 0x7F) .|. 0x80) >> putVarUInt (i `shiftR` 7)
wireFieldTag :: WireField -> Tag
wireFieldTag f = case f of
  VarintField    t _ -> t
  Fixed64Field   t _ -> t
  DelimitedField t _ -> t
  StartField     t   -> t
  EndField       t   -> t
  Fixed32Field   t _ -> t
class EncodeWire a where
  encodeWire :: Tag -> a -> Put
class DecodeWire a where
  decodeWire :: WireField -> Get a
deriving instance EncodeWire a => EncodeWire (Always (Value a))
deriving instance EncodeWire a => EncodeWire (Last (Value a))
deriving instance DecodeWire a => DecodeWire (Always (Value a))
deriving instance DecodeWire a => DecodeWire (Last (Value a))
instance EncodeWire a => EncodeWire [Value a] where
  encodeWire t = traverse_ (encodeWire t)
instance EncodeWire WireField where
  encodeWire t f
    | t == wireFieldTag f = putWireField f
    | otherwise           = legacyFail "Specified tag and field tag do not match"
instance DecodeWire WireField where
  decodeWire = pure
instance EncodeWire a => EncodeWire (Value a) where
  encodeWire t = traverse_ (encodeWire t)
instance DecodeWire a => DecodeWire (Value a) where
  decodeWire = fmap Value . decodeWire
instance EncodeWire a => EncodeWire (Maybe (Value a)) where
  encodeWire t = traverse_ (encodeWire t)
instance DecodeWire a => DecodeWire (Maybe (Value a)) where
  decodeWire = fmap (Just . Value) . decodeWire
instance EncodeWire Int32 where
  encodeWire t val = putWireTag t 0 >> putVarSInt val
instance DecodeWire Int32 where
  decodeWire (VarintField  _ val) = pure $ fromIntegral val
  decodeWire _ = empty
instance EncodeWire Int64 where
  encodeWire t val = putWireTag t 0 >> putVarSInt val
instance DecodeWire Int64 where
  decodeWire (VarintField  _ val) = pure $ fromIntegral val
  decodeWire _ = empty
instance EncodeWire Word32 where
  encodeWire t val = putWireTag t 0 >> putVarUInt val
instance DecodeWire Word32 where
  decodeWire (VarintField  _ val) = pure $ fromIntegral val
  decodeWire _ = empty
instance EncodeWire Word64 where
  encodeWire t val = putWireTag t 0 >> putVarUInt val
instance DecodeWire Word64 where
  decodeWire (VarintField  _ val) = pure val
  decodeWire _ = empty
instance EncodeWire (Signed Int32) where
  encodeWire t (Signed val) = putWireTag t 0 >> putVarSInt (zzEncode32 val)
instance DecodeWire (Signed Int32) where
  decodeWire (VarintField  _ val) = pure . Signed . zzDecode32 $ fromIntegral val
  decodeWire _ = empty
instance EncodeWire (Signed Int64) where
  encodeWire t (Signed val) = putWireTag t 0 >> putVarSInt (zzEncode64 val)
instance DecodeWire (Signed Int64) where
  decodeWire (VarintField  _ val) = pure . Signed . zzDecode64 $ fromIntegral val
  decodeWire _ = empty
instance EncodeWire (Fixed Int32) where
  encodeWire t (Fixed val) = putWireTag t 5 >> putWord32le (fromIntegral val)
instance DecodeWire (Fixed Int32) where
  decodeWire (Fixed32Field _ val) = pure . Fixed $ fromIntegral val
  decodeWire _ = empty
instance EncodeWire (Fixed Int64) where
  encodeWire t (Fixed val) = putWireTag t 1 >> putWord64le (fromIntegral val)
instance DecodeWire (Fixed Int64) where
  decodeWire (Fixed64Field _ val) = pure . Fixed $ fromIntegral val
  decodeWire _ = empty
instance EncodeWire (Fixed Word32) where
  encodeWire t (Fixed val) = putWireTag t 5 >> putWord32le val
instance DecodeWire (Fixed Word32) where
  decodeWire (Fixed32Field _ val) = pure $ Fixed val
  decodeWire _ = empty
instance EncodeWire (Fixed Word64) where
  encodeWire t (Fixed val) = putWireTag t 1 >> putWord64le val
instance DecodeWire (Fixed Word64) where
  decodeWire (Fixed64Field _ val) = pure $ Fixed val
  decodeWire _ = empty
instance EncodeWire Bool where
  encodeWire t val = putWireTag t 0 >> putVarUInt (if val then 1 else (0 :: Int32))
instance DecodeWire Bool where
  decodeWire (VarintField _ val) = pure $ val /= 0
  decodeWire _ = empty
instance EncodeWire Float where
  encodeWire t val = putWireTag t 5 >> putFloat32le val
instance DecodeWire Float where
  decodeWire (Fixed32Field _ val) = pure $ wordToFloat val
  decodeWire _ = empty
instance EncodeWire Double where
  encodeWire t val = putWireTag t 1 >> putFloat64le val
instance DecodeWire Double where
  decodeWire (Fixed64Field _ val) = pure $ wordToDouble val
  decodeWire _ = empty
instance EncodeWire ByteString where
  encodeWire t val = putWireTag t 2 >> putVarUInt (B.length val) >> putByteString val
instance DecodeWire ByteString where
  decodeWire (DelimitedField _ bs) = pure bs
  decodeWire _ = empty
instance EncodeWire String where
  encodeWire t = encodeWire t . T.pack
instance DecodeWire String where
  decodeWire = fmap T.unpack . decodeWire
instance EncodeWire T.Text where
  encodeWire t = encodeWire t . T.encodeUtf8
instance DecodeWire T.Text where
  decodeWire (DelimitedField _ bs) =
    case T.decodeUtf8' bs of
      Right val -> pure val
      Left err  -> fail $ "Decoding failed: " ++ show err
  decodeWire _ = empty
decodePackedList :: Get a -> WireField -> Get [a]
{-# INLINE decodePackedList #-}
decodePackedList g (DelimitedField _ bs) =
  case runGet (many g) bs of
    Right val -> return val
    Left err  -> fail err
decodePackedList _ _ = empty
encodePackedList :: Tag -> Put -> Put
{-# INLINE encodePackedList #-}
encodePackedList t p
  | bs <- runPut p
  , not (B.null bs) = encodeWire t bs
  | otherwise = pure ()
instance EncodeWire (PackedList (Value Int32)) where
  encodeWire t (PackedList xs) =
    encodePackedList t $ traverse_ (putVarSInt . runValue) xs
instance DecodeWire (PackedList (Value Int32)) where
  decodeWire x = do
    xs <- decodePackedList getVarInt x
    return . PackedList $ Value <$> xs
instance EncodeWire (PackedList (Value Int64)) where
  encodeWire t (PackedList xs) =
    encodePackedList t $ traverse_ (putVarSInt . runValue) xs
instance DecodeWire (PackedList (Value Int64)) where
  decodeWire x = do
    xs <- decodePackedList getVarInt x
    return . PackedList $ Value <$> xs
instance EncodeWire (PackedList (Value Word32)) where
  encodeWire t (PackedList xs) =
    encodePackedList t $ traverse_ (putVarUInt . runValue) xs
instance DecodeWire (PackedList (Value Word32)) where
  decodeWire x = do
    xs <- decodePackedList getVarInt x
    return . PackedList $ Value <$> xs
instance EncodeWire (PackedList (Value Word64)) where
  encodeWire t (PackedList xs) =
    encodePackedList t $ traverse_ (putVarUInt . runValue) xs
instance DecodeWire (PackedList (Value Word64)) where
  decodeWire x = do
    xs <- decodePackedList getVarInt x
    return . PackedList $ Value <$> xs
instance EncodeWire (PackedList (Value (Signed Int32))) where
  encodeWire t (PackedList xs) = do
    let c (Signed x) = putVarSInt $ zzEncode32 x
    encodePackedList t $ traverse_ (c . runValue) xs
instance DecodeWire (PackedList (Value (Signed Int32))) where
  decodeWire x = do
    xs <- decodePackedList getVarInt x
    return . PackedList $ Value . Signed . zzDecode32 <$> xs
instance EncodeWire (PackedList (Value (Signed Int64))) where
  encodeWire t (PackedList xs) = do
    let c (Signed x) = putVarSInt $ zzEncode64 x
    encodePackedList t $ traverse_ (c . runValue) xs
instance DecodeWire (PackedList (Value (Signed Int64))) where
  decodeWire x = do
    xs <- decodePackedList getVarInt x
    return . PackedList $ Value . Signed . zzDecode64 <$> xs
instance EncodeWire (PackedList (Value (Fixed Word32))) where
  encodeWire t (PackedList xs) = do
    let c (Fixed x) = putWord32le x
    encodePackedList t $ traverse_ (c . runValue) xs
instance DecodeWire (PackedList (Value (Fixed Word32))) where
  decodeWire x = do
    xs <- decodePackedList getWord32le x
    return . PackedList $ Value . Fixed <$> xs
instance EncodeWire (PackedList (Value (Fixed Word64))) where
  encodeWire t (PackedList xs) = do
    let c (Fixed x) = putWord64le x
    encodePackedList t $ traverse_ (c . runValue) xs
instance DecodeWire (PackedList (Value (Fixed Word64))) where
  decodeWire x = do
    xs <- decodePackedList getWord64le x
    return . PackedList $ Value . Fixed <$> xs
instance EncodeWire (PackedList (Value (Fixed Int32))) where
  encodeWire t (PackedList xs) = do
    let c (Fixed x) = putWord32le $ fromIntegral x
    encodePackedList t $ traverse_ (c . runValue) xs
instance DecodeWire (PackedList (Value (Fixed Int32))) where
  decodeWire x = do
    xs <- decodePackedList getWord32le x
    return . PackedList $ Value . Fixed . fromIntegral <$> xs
instance EncodeWire (PackedList (Value (Fixed Int64))) where
  encodeWire t (PackedList xs) = do
    let c (Fixed x) = putWord64le $ fromIntegral x
    encodePackedList t $ traverse_ (c . runValue) xs
instance DecodeWire (PackedList (Value (Fixed Int64))) where
  decodeWire x = do
    xs <- decodePackedList getWord64le x
    return . PackedList $ Value . Fixed . fromIntegral <$> xs
instance EncodeWire (PackedList (Value Float)) where
  encodeWire t (PackedList xs) =
    encodePackedList t $ traverse_ (putFloat32le . runValue) xs
instance DecodeWire (PackedList (Value Float)) where
  decodeWire x = do
    xs <- decodePackedList getFloat32le x
    return . PackedList $ Value <$> xs
instance EncodeWire (PackedList (Value Double)) where
  encodeWire t (PackedList xs) =
    encodePackedList t $ traverse_ (putFloat64le . runValue) xs
instance DecodeWire (PackedList (Value Double)) where
  decodeWire x = do
    xs <- decodePackedList getFloat64le x
    return . PackedList $ Value <$> xs
instance EncodeWire (PackedList (Value Bool)) where
  encodeWire t (PackedList xs) =
    encodePackedList t $ traverse_ (putVarUInt . fromEnum) xs
instance DecodeWire (PackedList (Value Bool)) where
  decodeWire x = do
    xs <- decodePackedList getVarInt x
    return . PackedList $ toEnum <$> xs
instance Enum a => EncodeWire (PackedList (Enumeration a)) where
  encodeWire t (PackedList xs) =
    encodePackedList t $ traverse_ (putVarUInt . fromEnum) xs
instance Enum a => DecodeWire (PackedList (Enumeration a)) where
  decodeWire x = do
    xs <- decodePackedList getVarInt x
    return . PackedList $ toEnum <$> xs
instance (Foldable f, Enum a) => EncodeWire (f (Enumeration a)) where
  encodeWire t = traverse_ (encodeWire t . c . runEnumeration) where
    c :: a -> Int32
    c = fromIntegral . fromEnum
instance Enum a => DecodeWire (Enumeration a) where
  decodeWire f = c <$> decodeWire f where
    c :: Int32 -> Enumeration a
    c = Enumeration . toEnum . fromIntegral
instance Enum a => DecodeWire (Maybe (Enumeration a)) where
  decodeWire f = c <$> decodeWire f where
    c :: Int32 -> Maybe (Enumeration a)
    c = Just . Enumeration . toEnum . fromIntegral
instance Enum a => DecodeWire (Always (Enumeration a)) where
  decodeWire f = c <$> decodeWire f where
    c :: Int32 -> Always (Enumeration a)
    c = Always . Enumeration . toEnum . fromIntegral
zzEncode32 :: Int32 -> Word32
zzEncode32 x = fromIntegral ((x `shiftL` 1) `xor` x `shiftR` 31)
zzEncode64 :: Int64 -> Word64
zzEncode64 x = fromIntegral ((x `shiftL` 1) `xor` x `shiftR` 63)
zzDecode32 :: Word32 -> Int32
zzDecode32 w = fromIntegral (w `shiftR` 1) `xor` negate (fromIntegral (w .&. 1))
zzDecode64 :: Word64 -> Int64
zzDecode64 w = fromIntegral (w `shiftR` 1) `xor` negate (fromIntegral (w .&. 1))
legacyFail :: Monad m => String -> m a
#if __GLASGOW_HASKELL__ <= 710
legacyFail = fail
#else
legacyFail = errorWithoutStackTrace
#endif