{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE BinaryLiterals #-}

module Network.HTTP3.Frame (
    H3Frame(..)
  , H3FrameType(..)
  , fromH3FrameType
  , toH3FrameType
  , encodeH3Frame
  , encodeH3Frames
  , decodeH3Frame
  , IFrame(..)
  , parseH3Frame
  , QInt(..)
  , parseQInt
  , permittedInControlStream
  , permittedInRequestStream
  , permittedInPushStream
  ) where

import qualified Data.ByteString as BS
import Network.ByteOrder
import Network.QUIC.Internal

import Imports

data H3Frame = H3Frame H3FrameType ByteString

data H3FrameType = H3FrameData
                 | H3FrameHeaders
                 | H3FrameCancelPush
                 | H3FrameSettings
                 | H3FramePushPromise
                 | H3FrameGoaway
                 | H3FrameMaxPushId
                 | H3FrameUnknown Int64
                 deriving (H3FrameType -> H3FrameType -> Bool
(H3FrameType -> H3FrameType -> Bool)
-> (H3FrameType -> H3FrameType -> Bool) -> Eq H3FrameType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: H3FrameType -> H3FrameType -> Bool
$c/= :: H3FrameType -> H3FrameType -> Bool
== :: H3FrameType -> H3FrameType -> Bool
$c== :: H3FrameType -> H3FrameType -> Bool
Eq, Int -> H3FrameType -> ShowS
[H3FrameType] -> ShowS
H3FrameType -> String
(Int -> H3FrameType -> ShowS)
-> (H3FrameType -> String)
-> ([H3FrameType] -> ShowS)
-> Show H3FrameType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [H3FrameType] -> ShowS
$cshowList :: [H3FrameType] -> ShowS
show :: H3FrameType -> String
$cshow :: H3FrameType -> String
showsPrec :: Int -> H3FrameType -> ShowS
$cshowsPrec :: Int -> H3FrameType -> ShowS
Show)

fromH3FrameType :: H3FrameType -> Int64
fromH3FrameType :: H3FrameType -> Int64
fromH3FrameType H3FrameType
H3FrameData        = Int64
0x0
fromH3FrameType H3FrameType
H3FrameHeaders     = Int64
0x1
fromH3FrameType H3FrameType
H3FrameCancelPush  = Int64
0x3
fromH3FrameType H3FrameType
H3FrameSettings    = Int64
0x4
fromH3FrameType H3FrameType
H3FramePushPromise = Int64
0x5
fromH3FrameType H3FrameType
H3FrameGoaway      = Int64
0x7
fromH3FrameType H3FrameType
H3FrameMaxPushId   = Int64
0xD
fromH3FrameType (H3FrameUnknown Int64
i) =   Int64
i

toH3FrameType :: Int64 -> H3FrameType
toH3FrameType :: Int64 -> H3FrameType
toH3FrameType Int64
0x0 = H3FrameType
H3FrameData
toH3FrameType Int64
0x1 = H3FrameType
H3FrameHeaders
toH3FrameType Int64
0x3 = H3FrameType
H3FrameCancelPush
toH3FrameType Int64
0x4 = H3FrameType
H3FrameSettings
toH3FrameType Int64
0x5 = H3FrameType
H3FramePushPromise
toH3FrameType Int64
0x7 = H3FrameType
H3FrameGoaway
toH3FrameType Int64
0xD = H3FrameType
H3FrameMaxPushId
toH3FrameType   Int64
i = Int64 -> H3FrameType
H3FrameUnknown Int64
i

permittedInControlStream :: H3FrameType -> Bool
permittedInControlStream :: H3FrameType -> Bool
permittedInControlStream H3FrameType
H3FrameData        = Bool
False
permittedInControlStream H3FrameType
H3FrameHeaders     = Bool
False
permittedInControlStream H3FrameType
H3FrameCancelPush  = Bool
True
permittedInControlStream H3FrameType
H3FrameSettings    = Bool
True
permittedInControlStream H3FrameType
H3FramePushPromise = Bool
False
permittedInControlStream H3FrameType
H3FrameGoaway      = Bool
True
permittedInControlStream H3FrameType
H3FrameMaxPushId   = Bool
True
permittedInControlStream (H3FrameUnknown Int64
i)
  | Int64
i Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int64
0x9                                = Bool
False
  | Bool
otherwise                               = Bool
True

permittedInRequestStream :: H3FrameType -> Bool
permittedInRequestStream :: H3FrameType -> Bool
permittedInRequestStream H3FrameType
H3FrameData        = Bool
True
permittedInRequestStream H3FrameType
H3FrameHeaders     = Bool
True
permittedInRequestStream H3FrameType
H3FrameCancelPush  = Bool
False
permittedInRequestStream H3FrameType
H3FrameSettings    = Bool
False
permittedInRequestStream H3FrameType
H3FramePushPromise = Bool
True
permittedInRequestStream H3FrameType
H3FrameGoaway      = Bool
False
permittedInRequestStream H3FrameType
H3FrameMaxPushId   = Bool
False
permittedInRequestStream (H3FrameUnknown Int64
i)
  | Int64
i Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int64
0x9                                = Bool
False
  | Bool
otherwise                               = Bool
True

permittedInPushStream :: H3FrameType -> Bool
permittedInPushStream :: H3FrameType -> Bool
permittedInPushStream H3FrameType
H3FrameData        = Bool
True
permittedInPushStream H3FrameType
H3FrameHeaders     = Bool
True
permittedInPushStream H3FrameType
H3FrameCancelPush  = Bool
False
permittedInPushStream H3FrameType
H3FrameSettings    = Bool
False
permittedInPushStream H3FrameType
H3FramePushPromise = Bool
False
permittedInPushStream H3FrameType
H3FrameGoaway      = Bool
False
permittedInPushStream H3FrameType
H3FrameMaxPushId   = Bool
False
permittedInPushStream (H3FrameUnknown Int64
i)
  | Int64
i Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int64
0x9                                = Bool
False
  | Bool
otherwise                               = Bool
True

encodeH3Frame :: H3Frame -> IO ByteString
encodeH3Frame :: H3Frame -> IO ByteString
encodeH3Frame (H3Frame H3FrameType
typ ByteString
bs) = do
    ByteString
tl <- Int -> (WriteBuffer -> IO ()) -> IO ByteString
withWriteBuffer Int
16 ((WriteBuffer -> IO ()) -> IO ByteString)
-> (WriteBuffer -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \WriteBuffer
wbuf -> do
        WriteBuffer -> Int64 -> IO ()
encodeInt' WriteBuffer
wbuf (Int64 -> IO ()) -> Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int64) -> Int64 -> Int64
forall a b. (a -> b) -> a -> b
$ H3FrameType -> Int64
fromH3FrameType H3FrameType
typ
        WriteBuffer -> Int64 -> IO ()
encodeInt' WriteBuffer
wbuf (Int64 -> IO ()) -> Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> Int -> Int64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length ByteString
bs
    ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
tl ByteString -> ByteString -> ByteString
`BS.append` ByteString
bs

encodeH3Frames :: [H3Frame] -> [ByteString]
encodeH3Frames :: [H3Frame] -> [ByteString]
encodeH3Frames [H3Frame]
fs0 = [H3Frame] -> ([ByteString] -> [ByteString]) -> [ByteString]
forall c. [H3Frame] -> ([ByteString] -> c) -> c
loop [H3Frame]
fs0 [ByteString] -> [ByteString]
forall a. a -> a
id
  where
    loop :: [H3Frame] -> ([ByteString] -> c) -> c
loop []                  [ByteString] -> c
build = [ByteString] -> c
build []
    loop (H3Frame H3FrameType
ty ByteString
val:[H3Frame]
fs) [ByteString] -> c
build = [H3Frame] -> ([ByteString] -> c) -> c
loop [H3Frame]
fs ([ByteString] -> c
build ([ByteString] -> c)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString
typ ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:) ([ByteString] -> [ByteString])
-> ([ByteString] -> [ByteString]) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString
len ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:) ([ByteString] -> [ByteString])
-> ([ByteString] -> [ByteString]) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString
val ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:))
      where
        typ :: ByteString
typ = Int64 -> ByteString
encodeInt (Int64 -> ByteString) -> Int64 -> ByteString
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int64) -> Int64 -> Int64
forall a b. (a -> b) -> a -> b
$ H3FrameType -> Int64
fromH3FrameType H3FrameType
ty
        len :: ByteString
len = Int64 -> ByteString
encodeInt (Int64 -> ByteString) -> Int64 -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int64) -> Int -> Int64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length ByteString
val

decodeH3Frame :: ByteString -> IO H3Frame
decodeH3Frame :: ByteString -> IO H3Frame
decodeH3Frame ByteString
hf = ByteString -> (ReadBuffer -> IO H3Frame) -> IO H3Frame
forall a. ByteString -> (ReadBuffer -> IO a) -> IO a
withReadBuffer ByteString
hf ((ReadBuffer -> IO H3Frame) -> IO H3Frame)
-> (ReadBuffer -> IO H3Frame) -> IO H3Frame
forall a b. (a -> b) -> a -> b
$ \ReadBuffer
rbuf -> do
    H3FrameType
typ <- Int64 -> H3FrameType
toH3FrameType (Int64 -> H3FrameType) -> (Int64 -> Int64) -> Int64 -> H3FrameType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> H3FrameType) -> IO Int64 -> IO H3FrameType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Int64
decodeInt' ReadBuffer
rbuf
    Int
len <- Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> IO Int64 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Int64
decodeInt' ReadBuffer
rbuf
    ByteString
bs <- ReadBuffer -> Int -> IO ByteString
forall a. Readable a => a -> Int -> IO ByteString
extractByteString ReadBuffer
rbuf Int
len
    H3Frame -> IO H3Frame
forall (m :: * -> *) a. Monad m => a -> m a
return (H3Frame -> IO H3Frame) -> H3Frame -> IO H3Frame
forall a b. (a -> b) -> a -> b
$ H3FrameType -> ByteString -> H3Frame
H3Frame H3FrameType
typ ByteString
bs

data QInt = QInit
          | QMore Word8        -- Masked first byte
                  Int          -- Bytes required
                  Int          -- Bytes received so far. (sum . map length)
                  [ByteString] -- Reverse order
          | QDone Int64        -- Result
                  ByteString   -- leftover
          deriving (QInt -> QInt -> Bool
(QInt -> QInt -> Bool) -> (QInt -> QInt -> Bool) -> Eq QInt
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: QInt -> QInt -> Bool
$c/= :: QInt -> QInt -> Bool
== :: QInt -> QInt -> Bool
$c== :: QInt -> QInt -> Bool
Eq,Int -> QInt -> ShowS
[QInt] -> ShowS
QInt -> String
(Int -> QInt -> ShowS)
-> (QInt -> String) -> ([QInt] -> ShowS) -> Show QInt
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [QInt] -> ShowS
$cshowList :: [QInt] -> ShowS
show :: QInt -> String
$cshow :: QInt -> String
showsPrec :: Int -> QInt -> ShowS
$cshowsPrec :: Int -> QInt -> ShowS
Show)

parseQInt :: QInt -> ByteString -> QInt
parseQInt :: QInt -> ByteString -> QInt
parseQInt QInt
st ByteString
"" = QInt
st
parseQInt QInt
QInit ByteString
bs0
  | Int
len1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
reqLen = Word8 -> Int -> Int -> [ByteString] -> QInt
QMore Word8
ft Int
reqLen Int
len1 [ByteString
bs1]
  | Bool
otherwise     = let (ByteString
bs2,ByteString
bs3) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
reqLen ByteString
bs1
                    in Int64 -> ByteString -> QInt
QDone (Word8 -> ByteString -> Int64
toLen Word8
ft ByteString
bs2) ByteString
bs3
  where
    hd :: Word8
hd = ByteString -> Word8
BS.head ByteString
bs0
    reqLen :: Int
reqLen = Word8 -> Int
requiredLen (Word8
hd Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b11000000)
    ft :: Word8
ft = Word8
hd Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0b00111111
    bs1 :: ByteString
bs1  = ByteString -> ByteString
BS.tail ByteString
bs0
    len1 :: Int
len1 = ByteString -> Int
BS.length ByteString
bs1
parseQInt (QMore Word8
ft Int
reqLen Int
len0 [ByteString]
bss0) ByteString
bs0
  | Int
len1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
reqLen = Word8 -> Int -> Int -> [ByteString] -> QInt
QMore Word8
ft Int
reqLen Int
len1 (ByteString
bs0ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
bss0)
  | Bool
otherwise     = let (ByteString
bs2,ByteString
bs3) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
reqLen (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
compose ByteString
bs0 [ByteString]
bss0
                    in Int64 -> ByteString -> QInt
QDone (Word8 -> ByteString -> Int64
toLen Word8
ft ByteString
bs2) ByteString
bs3
  where
    len1 :: Int
len1 = Int
len0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ByteString -> Int
BS.length ByteString
bs0
parseQInt (QDone Int64
_ ByteString
_) ByteString
_ = String -> QInt
forall a. HasCallStack => String -> a
error String
"parseQInt"

requiredLen :: Word8 -> Int
requiredLen :: Word8 -> Int
requiredLen Word8
0b00000000 = Int
0
requiredLen Word8
0b01000000 = Int
1
requiredLen Word8
0b10000000 = Int
3
requiredLen Word8
_          = Int
7

toLen :: Word8 -> ByteString -> Int64
toLen :: Word8 -> ByteString -> Int64
toLen Word8
w0 ByteString
bs = (Int64 -> Word8 -> Int64) -> Int64 -> ByteString -> Int64
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
BS.foldl (\Int64
n Word8
w -> Int64
n Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
256 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Word8 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w) (Word8 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w0) ByteString
bs

data IFrame =
            -- | Parsing is about to start
              IInit
            -- | Parsing type
            | IType QInt
            -- | Parsing length
            | ILen H3FrameType QInt
            -- | Parsing payload
            | IPay H3FrameType
                   Int -- Bytes required
                   Int -- Bytes received so far.  (sum . map length)
                   [ByteString] -- Reverse order
            -- | Parsing done
            | IDone H3FrameType
                    ByteString -- Payload (entire or sentinel)
                    ByteString -- Leftover
            deriving (IFrame -> IFrame -> Bool
(IFrame -> IFrame -> Bool)
-> (IFrame -> IFrame -> Bool) -> Eq IFrame
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IFrame -> IFrame -> Bool
$c/= :: IFrame -> IFrame -> Bool
== :: IFrame -> IFrame -> Bool
$c== :: IFrame -> IFrame -> Bool
Eq, Int -> IFrame -> ShowS
[IFrame] -> ShowS
IFrame -> String
(Int -> IFrame -> ShowS)
-> (IFrame -> String) -> ([IFrame] -> ShowS) -> Show IFrame
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IFrame] -> ShowS
$cshowList :: [IFrame] -> ShowS
show :: IFrame -> String
$cshow :: IFrame -> String
showsPrec :: Int -> IFrame -> ShowS
$cshowsPrec :: Int -> IFrame -> ShowS
Show)

parseH3Frame :: IFrame -> ByteString -> IFrame
parseH3Frame :: IFrame -> ByteString -> IFrame
parseH3Frame IFrame
st ByteString
"" = IFrame
st
parseH3Frame IFrame
IInit ByteString
bs = case QInt -> ByteString -> QInt
parseQInt QInt
QInit ByteString
bs of
    QDone Int64
i ByteString
bs' -> let typ :: H3FrameType
typ = Int64 -> H3FrameType
toH3FrameType Int64
i
                   in IFrame -> ByteString -> IFrame
parseH3Frame (H3FrameType -> QInt -> IFrame
ILen H3FrameType
typ QInt
QInit) ByteString
bs'
    QInt
ist         -> QInt -> IFrame
IType QInt
ist
parseH3Frame (IType QInt
ist) ByteString
bs = case QInt -> ByteString -> QInt
parseQInt QInt
ist ByteString
bs of
    QDone Int64
i ByteString
bs' -> let typ :: H3FrameType
typ = Int64 -> H3FrameType
toH3FrameType Int64
i
                   in IFrame -> ByteString -> IFrame
parseH3Frame (H3FrameType -> QInt -> IFrame
ILen H3FrameType
typ QInt
QInit) ByteString
bs'
    QInt
ist'        -> QInt -> IFrame
IType QInt
ist'
parseH3Frame (ILen H3FrameType
typ QInt
ist) ByteString
bs = case QInt -> ByteString -> QInt
parseQInt QInt
ist ByteString
bs of
    QDone Int64
i ByteString
bs' -> let reqLen :: Int
reqLen = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
i
                   in if Int
reqLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then
                        H3FrameType -> ByteString -> ByteString -> IFrame
IDone H3FrameType
typ ByteString
"" ByteString
bs'
                      else
                        IFrame -> ByteString -> IFrame
parseH3Frame (H3FrameType -> Int -> Int -> [ByteString] -> IFrame
IPay H3FrameType
typ Int
reqLen Int
0 []) ByteString
bs'
    QInt
ist'        -> H3FrameType -> QInt -> IFrame
ILen H3FrameType
typ QInt
ist'
parseH3Frame (IPay H3FrameType
typ Int
reqLen Int
len0 [ByteString]
bss0) ByteString
bs0 = case Int
len1 Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
reqLen of
    Ordering
LT -> H3FrameType -> Int -> Int -> [ByteString] -> IFrame
IPay H3FrameType
typ Int
reqLen Int
len1 (ByteString
bs0ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
bss0)
    Ordering
EQ -> H3FrameType -> ByteString -> ByteString -> IFrame
IDone H3FrameType
typ (ByteString -> [ByteString] -> ByteString
compose ByteString
bs0 [ByteString]
bss0) ByteString
""
    Ordering
GT -> let (ByteString
bs2,ByteString
leftover) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (Int
reqLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len0) ByteString
bs0
          in H3FrameType -> ByteString -> ByteString -> IFrame
IDone H3FrameType
typ (ByteString -> [ByteString] -> ByteString
compose ByteString
bs2 [ByteString]
bss0) ByteString
leftover
  where
    len1 :: Int
len1 = Int
len0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ByteString -> Int
BS.length ByteString
bs0
parseH3Frame IFrame
st ByteString
_ = IFrame
st

compose :: ByteString -> [ByteString] -> ByteString
compose :: ByteString -> [ByteString] -> ByteString
compose ByteString
bs [ByteString]
bss = [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse (ByteString
bsByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
bss)

{-
test :: Int64 -> QInt
tset i = loop QInit bss0
  where
    loop st [] = st
    loop st (bs:bss) = case parseQInt st bs of
        st1@(QDone _ _) -> st1
        st1             -> loop st1 bss
    bs0 = encodeInt i
    bss0 = map BS.singleton $ BS.unpack bs0
-}