-- | Length-prefixed messages
--
-- These are used both for inputs and outputs.
module Network.GRPC.Spec.Serialization.LengthPrefixed (
    -- * Message prefix
    MessagePrefix(..)
    -- * Length-prefixex messages
    -- ** Construction
  , OutboundMeta(..)
  , buildInput
  , buildOutput
    -- ** Parsing
  , InboundMeta(..)
  , parseInput
  , parseOutput
  ) where

import Data.Binary.Get (Get)
import Data.Binary.Get qualified as Binary
import Data.ByteString.Builder (Builder)
import Data.ByteString.Builder qualified as Builder
import Data.ByteString.Lazy qualified as BS.Lazy
import Data.ByteString.Lazy qualified as Lazy (ByteString)
import Data.Proxy
import Data.Word

import Network.GRPC.Spec
import Network.GRPC.Spec.Util.Parser (Parser)
import Network.GRPC.Spec.Util.Parser qualified as Parser

{-------------------------------------------------------------------------------
  Message prefix
-------------------------------------------------------------------------------}

data MessagePrefix = MessagePrefix {
      MessagePrefix -> Bool
msgIsCompressed :: Bool
    , MessagePrefix -> Word32
msgLength       :: Word32
    }
  deriving (Int -> MessagePrefix -> ShowS
[MessagePrefix] -> ShowS
MessagePrefix -> String
(Int -> MessagePrefix -> ShowS)
-> (MessagePrefix -> String)
-> ([MessagePrefix] -> ShowS)
-> Show MessagePrefix
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MessagePrefix -> ShowS
showsPrec :: Int -> MessagePrefix -> ShowS
$cshow :: MessagePrefix -> String
show :: MessagePrefix -> String
$cshowList :: [MessagePrefix] -> ShowS
showList :: [MessagePrefix] -> ShowS
Show)

buildMessagePrefix :: MessagePrefix -> Builder
buildMessagePrefix :: MessagePrefix -> Builder
buildMessagePrefix MessagePrefix{Word32
msgLength :: MessagePrefix -> Word32
msgLength :: Word32
msgLength, Bool
msgIsCompressed :: MessagePrefix -> Bool
msgIsCompressed :: Bool
msgIsCompressed} = [Builder] -> Builder
forall a. Monoid a => [a] -> a
mconcat [
      Word8 -> Builder
Builder.word8    (Word8 -> Builder) -> Word8 -> Builder
forall a b. (a -> b) -> a -> b
$ if Bool
msgIsCompressed then Word8
1 else Word8
0
    , Word32 -> Builder
Builder.word32BE (Word32 -> Builder) -> Word32 -> Builder
forall a b. (a -> b) -> a -> b
$ Word32
msgLength
    ]

getMessagePrefix :: Get MessagePrefix
getMessagePrefix :: Get MessagePrefix
getMessagePrefix = do
    msgIsCompressed <- Get Word8
Binary.getWord8 Get Word8 -> (Word8 -> Get Bool) -> Get Bool
forall a b. Get a -> (a -> Get b) -> Get b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                         Word8
0 -> Bool -> Get Bool
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
                         Word8
1 -> Bool -> Get Bool
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
                         Word8
n -> String -> Get Bool
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get Bool) -> String -> Get Bool
forall a b. (a -> b) -> a -> b
$ String
"parseMessagePrefix: unxpected " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word8 -> String
forall a. Show a => a -> String
show Word8
n
    msgLength       <- Binary.getWord32be
    return MessagePrefix{msgIsCompressed, msgLength}

{-------------------------------------------------------------------------------
  Construction
-------------------------------------------------------------------------------}

-- | Serialize RPC input
--
-- > Length-Prefixed-Message → Compressed-Flag Message-Length Message
-- >
-- > Compressed-Flag → 0 / 1
-- >                     # encoded as 1 byte unsigned integer
-- > Message-Length  → {length of Message}
-- >                     # encoded as 4 byte unsigned integer (big endian)
-- > Message         → *{binary octet}
buildInput ::
     SupportsClientRpc rpc
  => Proxy rpc
  -> Compression
  -> (OutboundMeta, Input rpc)
  -> Builder
buildInput :: forall {k} (rpc :: k).
SupportsClientRpc rpc =>
Proxy rpc -> Compression -> (OutboundMeta, Input rpc) -> Builder
buildInput = (Input rpc -> ByteString)
-> Compression -> (OutboundMeta, Input rpc) -> Builder
forall x.
(x -> ByteString) -> Compression -> (OutboundMeta, x) -> Builder
buildMsg ((Input rpc -> ByteString)
 -> Compression -> (OutboundMeta, Input rpc) -> Builder)
-> (Proxy rpc -> Input rpc -> ByteString)
-> Proxy rpc
-> Compression
-> (OutboundMeta, Input rpc)
-> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy rpc -> Input rpc -> ByteString
forall {k} (rpc :: k).
SupportsClientRpc rpc =>
Proxy rpc -> Input rpc -> ByteString
rpcSerializeInput

-- | Serialize RPC output
buildOutput ::
     SupportsServerRpc rpc
  => Proxy rpc
  -> Compression
  -> (OutboundMeta, Output rpc)
  -> Builder
buildOutput :: forall {k} (rpc :: k).
SupportsServerRpc rpc =>
Proxy rpc -> Compression -> (OutboundMeta, Output rpc) -> Builder
buildOutput = (Output rpc -> ByteString)
-> Compression -> (OutboundMeta, Output rpc) -> Builder
forall x.
(x -> ByteString) -> Compression -> (OutboundMeta, x) -> Builder
buildMsg ((Output rpc -> ByteString)
 -> Compression -> (OutboundMeta, Output rpc) -> Builder)
-> (Proxy rpc -> Output rpc -> ByteString)
-> Proxy rpc
-> Compression
-> (OutboundMeta, Output rpc)
-> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy rpc -> Output rpc -> ByteString
forall {k} (rpc :: k).
SupportsServerRpc rpc =>
Proxy rpc -> Output rpc -> ByteString
rpcSerializeOutput

-- | Generalization of 'buildInput' and 'buildOutput'
buildMsg ::
     (x -> Lazy.ByteString)
  -> Compression
  -> (OutboundMeta, x)
  -> Builder
buildMsg :: forall x.
(x -> ByteString) -> Compression -> (OutboundMeta, x) -> Builder
buildMsg x -> ByteString
build Compression
compr (OutboundMeta
meta, x
x) = [Builder] -> Builder
forall a. Monoid a => [a] -> a
mconcat [
      MessagePrefix -> Builder
buildMessagePrefix MessagePrefix
prefix
    , ByteString -> Builder
Builder.lazyByteString (ByteString -> Builder) -> ByteString -> Builder
forall a b. (a -> b) -> a -> b
$
        if Bool
shouldCompress
          then ByteString
compressed
          else ByteString
uncompressed
    ]
  where
    uncompressed, compressed :: Lazy.ByteString
    uncompressed :: ByteString
uncompressed = x -> ByteString
build x
x
    compressed :: ByteString
compressed   = Compression -> ByteString -> ByteString
compress Compression
compr ByteString
uncompressed

    shouldCompress :: Bool
    shouldCompress :: Bool
shouldCompress = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [
          Compression -> Int64 -> Bool
uncompressedSizeThreshold Compression
compr Int64
uncompressedLength
        , OutboundMeta -> Bool
outboundEnableCompression OutboundMeta
meta
        , Int64
compressedLength Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
uncompressedLength
        ]
      where
        uncompressedLength :: Int64
uncompressedLength = ByteString -> Int64
BS.Lazy.length ByteString
uncompressed
        compressedLength :: Int64
compressedLength = ByteString -> Int64
BS.Lazy.length ByteString
compressed

    prefix :: MessagePrefix
    prefix :: MessagePrefix
prefix
      | Bool
shouldCompress
      = MessagePrefix {
            msgIsCompressed :: Bool
msgIsCompressed = Bool
True
          , msgLength :: Word32
msgLength       = Int64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Word32) -> Int64 -> Word32
forall a b. (a -> b) -> a -> b
$ ByteString -> Int64
BS.Lazy.length ByteString
compressed
          }

      | Bool
otherwise
      = MessagePrefix {
            msgIsCompressed :: Bool
msgIsCompressed = Bool
False
          , msgLength :: Word32
msgLength       = Int64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Word32) -> Int64 -> Word32
forall a b. (a -> b) -> a -> b
$ ByteString -> Int64
BS.Lazy.length ByteString
uncompressed
          }

{-------------------------------------------------------------------------------
  Parsing
-------------------------------------------------------------------------------}

-- | Parse input
parseInput ::
     SupportsServerRpc rpc
  => Proxy rpc
  -> Compression
  -> Parser String (InboundMeta, Input rpc)
parseInput :: forall {k} (rpc :: k).
SupportsServerRpc rpc =>
Proxy rpc -> Compression -> Parser String (InboundMeta, Input rpc)
parseInput = (ByteString -> Either String (Input rpc))
-> Compression -> Parser String (InboundMeta, Input rpc)
forall x.
(ByteString -> Either String x)
-> Compression -> Parser String (InboundMeta, x)
parseMsg ((ByteString -> Either String (Input rpc))
 -> Compression -> Parser String (InboundMeta, Input rpc))
-> (Proxy rpc -> ByteString -> Either String (Input rpc))
-> Proxy rpc
-> Compression
-> Parser String (InboundMeta, Input rpc)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy rpc -> ByteString -> Either String (Input rpc)
forall {k} (rpc :: k).
SupportsServerRpc rpc =>
Proxy rpc -> ByteString -> Either String (Input rpc)
rpcDeserializeInput

-- | Parse output
parseOutput ::
     SupportsClientRpc rpc
  => Proxy rpc
  -> Compression
  -> Parser String (InboundMeta, Output rpc)
parseOutput :: forall {k} (rpc :: k).
SupportsClientRpc rpc =>
Proxy rpc -> Compression -> Parser String (InboundMeta, Output rpc)
parseOutput = (ByteString -> Either String (Output rpc))
-> Compression -> Parser String (InboundMeta, Output rpc)
forall x.
(ByteString -> Either String x)
-> Compression -> Parser String (InboundMeta, x)
parseMsg ((ByteString -> Either String (Output rpc))
 -> Compression -> Parser String (InboundMeta, Output rpc))
-> (Proxy rpc -> ByteString -> Either String (Output rpc))
-> Proxy rpc
-> Compression
-> Parser String (InboundMeta, Output rpc)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy rpc -> ByteString -> Either String (Output rpc)
forall {k} (rpc :: k).
SupportsClientRpc rpc =>
Proxy rpc -> ByteString -> Either String (Output rpc)
rpcDeserializeOutput

parseMsg :: forall x.
     (Lazy.ByteString -> Either String x)
  -> Compression
  -> Parser String (InboundMeta, x)
parseMsg :: forall x.
(ByteString -> Either String x)
-> Compression -> Parser String (InboundMeta, x)
parseMsg ByteString -> Either String x
parse Compression
compr = do
    prefix <- Int64 -> Get MessagePrefix -> Parser String MessagePrefix
forall a. Int64 -> Get a -> Parser String a
Parser.getExactly Int64
5 Get MessagePrefix
getMessagePrefix
    Parser.consumeExactly (fromIntegral $ msgLength prefix) $
      parseBody (msgIsCompressed prefix)
  where
    parseBody :: Bool -> Lazy.ByteString -> Either String (InboundMeta, x)
    parseBody :: Bool -> ByteString -> Either String (InboundMeta, x)
parseBody Bool
False ByteString
body =
        (InboundMeta
meta,) (x -> (InboundMeta, x))
-> Either String x -> Either String (InboundMeta, x)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Either String x
parse ByteString
body
      where
        meta :: InboundMeta
        meta :: InboundMeta
meta = InboundMeta {
              inboundCompressedSize :: Maybe Word32
inboundCompressedSize   = Maybe Word32
forall a. Maybe a
Nothing
            , inboundUncompressedSize :: Word32
inboundUncompressedSize = ByteString -> Word32
forall a. Num a => ByteString -> a
lengthOf ByteString
body
            }
    parseBody Bool
True ByteString
compressed =
        (InboundMeta
meta,) (x -> (InboundMeta, x))
-> Either String x -> Either String (InboundMeta, x)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Either String x
parse ByteString
uncompressed
      where
        uncompressed :: Lazy.ByteString
        uncompressed :: ByteString
uncompressed = Compression -> ByteString -> ByteString
decompress Compression
compr ByteString
compressed

        meta :: InboundMeta
        meta :: InboundMeta
meta = InboundMeta {
              inboundCompressedSize :: Maybe Word32
inboundCompressedSize   = Word32 -> Maybe Word32
forall a. a -> Maybe a
Just (ByteString -> Word32
forall a. Num a => ByteString -> a
lengthOf ByteString
compressed)
            , inboundUncompressedSize :: Word32
inboundUncompressedSize = ByteString -> Word32
forall a. Num a => ByteString -> a
lengthOf ByteString
uncompressed
            }

    lengthOf :: Num a => Lazy.ByteString -> a
    lengthOf :: forall a. Num a => ByteString -> a
lengthOf = Int64 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> a) -> (ByteString -> Int64) -> ByteString -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int64
BS.Lazy.length