{-# LANGUAGE OverloadedStrings #-}

module Network.HTTP3.Control (
    setupUnidirectional,
    controlStream,
) where

import qualified Data.ByteString as BS
import Data.IORef
import Network.QUIC

import Imports
import qualified Network.HTTP3.Config as H3
import Network.HTTP3.Error
import Network.HTTP3.Frame
import Network.HTTP3.Settings
import Network.HTTP3.Stream
import Network.QPACK

mkType :: H3StreamType -> ByteString
mkType :: H3StreamType -> ByteString
mkType = Word8 -> ByteString
BS.singleton (Word8 -> ByteString)
-> (H3StreamType -> Word8) -> H3StreamType -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Word8)
-> (H3StreamType -> Int64) -> H3StreamType -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. H3StreamType -> Int64
fromH3StreamType

setupUnidirectional :: Connection -> H3.Config -> IO ()
setupUnidirectional :: Connection -> Config -> IO ()
setupUnidirectional Connection
conn Config
conf = do
    ByteString
settings <-
        H3Settings -> IO ByteString
encodeH3Settings
            [ (H3SettingsKey
SettingsQpackBlockedStreams, Int
100)
            , (H3SettingsKey
SettingsQpackMaxTableCapacity, Int
4096)
            , (H3SettingsKey
SettingsMaxFieldSectionSize, Int
32768)
            ] -- fixme
    let framesC :: [H3Frame]
framesC = Hooks -> [H3Frame] -> [H3Frame]
H3.onControlFrameCreated Hooks
hooks [H3FrameType -> ByteString -> H3Frame
H3Frame H3FrameType
H3FrameSettings ByteString
settings]
    let bssC :: [ByteString]
bssC = [H3Frame] -> [ByteString]
encodeH3Frames [H3Frame]
framesC
    Stream
sC <- Connection -> IO Stream
unidirectionalStream Connection
conn
    Stream
sE <- Connection -> IO Stream
unidirectionalStream Connection
conn
    Stream
sD <- Connection -> IO Stream
unidirectionalStream Connection
conn
    -- fixme
    Stream -> [ByteString] -> IO ()
sendStreamMany Stream
sC (ByteString
stC ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bssC)
    Stream -> ByteString -> IO ()
sendStream Stream
sE ByteString
stE
    Stream -> ByteString -> IO ()
sendStream Stream
sD ByteString
stD
    Hooks -> Stream -> IO ()
H3.onControlStreamCreated Hooks
hooks Stream
sC
    Hooks -> Stream -> IO ()
H3.onEncoderStreamCreated Hooks
hooks Stream
sE
    Hooks -> Stream -> IO ()
H3.onDecoderStreamCreated Hooks
hooks Stream
sD
  where
    stC :: ByteString
stC = H3StreamType -> ByteString
mkType H3StreamType
H3ControlStreams
    stE :: ByteString
stE = H3StreamType -> ByteString
mkType H3StreamType
QPACKEncoderStream
    stD :: ByteString
stD = H3StreamType -> ByteString
mkType H3StreamType
QPACKDecoderStream
    hooks :: Hooks
hooks = Config -> Hooks
H3.confHooks Config
conf

controlStream :: Connection -> IORef IFrame -> InstructionHandler
controlStream :: Connection -> IORef IFrame -> InstructionHandler
controlStream Connection
conn IORef IFrame
ref Int -> IO ByteString
recv = IO ()
loop0
  where
    loop0 :: IO ()
loop0 = do
        ByteString
bs <- Int -> IO ByteString
recv Int
1024
        if ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
""
            then Connection -> ApplicationProtocolError -> ReasonPhrase -> IO ()
abortConnection Connection
conn ApplicationProtocolError
H3ClosedCriticalStream ReasonPhrase
""
            else do
                (Bool
done, IFrame
st1) <- IORef IFrame -> IO IFrame
forall a. IORef a -> IO a
readIORef IORef IFrame
ref IO IFrame -> (IFrame -> IO (Bool, IFrame)) -> IO (Bool, IFrame)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> IFrame -> IO (Bool, IFrame)
parse0 ByteString
bs
                IORef IFrame -> IFrame -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef IFrame
ref IFrame
st1
                if Bool
done then IO ()
loop else IO ()
loop0
    loop :: IO ()
loop = do
        ByteString
bs <- Int -> IO ByteString
recv Int
1024
        if ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
""
            then Connection -> ApplicationProtocolError -> ReasonPhrase -> IO ()
abortConnection Connection
conn ApplicationProtocolError
H3ClosedCriticalStream ReasonPhrase
""
            else do
                IORef IFrame -> IO IFrame
forall a. IORef a -> IO a
readIORef IORef IFrame
ref IO IFrame -> (IFrame -> IO IFrame) -> IO IFrame
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> IFrame -> IO IFrame
parse ByteString
bs IO IFrame -> (IFrame -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef IFrame -> IFrame -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef IFrame
ref
                IO ()
loop
    parse0 :: ByteString -> IFrame -> IO (Bool, IFrame)
parse0 ByteString
bs IFrame
st0 = do
        case IFrame -> ByteString -> IFrame
parseH3Frame IFrame
st0 ByteString
bs of
            IDone H3FrameType
typ ByteString
payload ByteString
leftover -> do
                case H3FrameType
typ of
                    H3FrameType
H3FrameSettings -> Connection -> ByteString -> IO ()
checkSettings Connection
conn ByteString
payload
                    H3FrameType
_ -> Connection -> ApplicationProtocolError -> ReasonPhrase -> IO ()
abortConnection Connection
conn ApplicationProtocolError
H3MissingSettings ReasonPhrase
""
                IFrame
st1 <- ByteString -> IFrame -> IO IFrame
parse ByteString
leftover IFrame
IInit
                (Bool, IFrame) -> IO (Bool, IFrame)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, IFrame
st1)
            IFrame
st1 -> (Bool, IFrame) -> IO (Bool, IFrame)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, IFrame
st1)

    parse :: ByteString -> IFrame -> IO IFrame
parse ByteString
bs IFrame
st0 = do
        case IFrame -> ByteString -> IFrame
parseH3Frame IFrame
st0 ByteString
bs of
            IDone H3FrameType
typ ByteString
_payload ByteString
leftover -> do
                case H3FrameType
typ of
                    H3FrameType
H3FrameCancelPush -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                    H3FrameType
H3FrameSettings -> Connection -> ApplicationProtocolError -> ReasonPhrase -> IO ()
abortConnection Connection
conn ApplicationProtocolError
H3FrameUnexpected ReasonPhrase
""
                    H3FrameType
H3FrameGoaway -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                    H3FrameType
H3FrameMaxPushId -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                    H3FrameType
_
                        | H3FrameType -> Bool
permittedInControlStream H3FrameType
typ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                        | Bool
otherwise -> Connection -> ApplicationProtocolError -> ReasonPhrase -> IO ()
abortConnection Connection
conn ApplicationProtocolError
H3FrameUnexpected ReasonPhrase
""
                ByteString -> IFrame -> IO IFrame
parse ByteString
leftover IFrame
IInit
            IFrame
st1 -> IFrame -> IO IFrame
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IFrame
st1

checkSettings :: Connection -> ByteString -> IO ()
checkSettings :: Connection -> ByteString -> IO ()
checkSettings Connection
conn ByteString
payload = do
    H3Settings
h3settings <- ByteString -> IO H3Settings
decodeH3Settings ByteString
payload
    Int -> H3Settings -> IO ()
forall {t} {b}. Bits t => t -> [(H3SettingsKey, b)] -> IO ()
loop (Int
0 :: Int) H3Settings
h3settings
  where
    loop :: t -> [(H3SettingsKey, b)] -> IO ()
loop t
_ [] = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    loop t
flags ((k :: H3SettingsKey
k@(H3SettingsKey Int
i), b
_v) : [(H3SettingsKey, b)]
ss)
        | t
flags t -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
`testBit` Int
i = Connection -> ApplicationProtocolError -> ReasonPhrase -> IO ()
abortConnection Connection
conn ApplicationProtocolError
H3SettingsError ReasonPhrase
""
        | Bool
otherwise = do
            let flags' :: t
flags' = t
flags t -> Int -> t
forall a. Bits a => a -> Int -> a
`setBit` Int
i
            case H3SettingsKey
k of
                H3SettingsKey
SettingsQpackMaxTableCapacity -> t -> [(H3SettingsKey, b)] -> IO ()
loop t
flags' [(H3SettingsKey, b)]
ss
                H3SettingsKey
SettingsMaxFieldSectionSize -> t -> [(H3SettingsKey, b)] -> IO ()
loop t
flags' [(H3SettingsKey, b)]
ss
                H3SettingsKey
SettingsQpackBlockedStreams -> t -> [(H3SettingsKey, b)] -> IO ()
loop t
flags' [(H3SettingsKey, b)]
ss
                H3SettingsKey
_
                    -- HTTP/2 settings
                    | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0x6 -> Connection -> ApplicationProtocolError -> ReasonPhrase -> IO ()
abortConnection Connection
conn ApplicationProtocolError
H3SettingsError ReasonPhrase
""
                    | Bool
otherwise -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()