{-# LANGUAGE OverloadedStrings #-}

module Network.QUIC.Stream.Table (
    StreamTable
  , emptyStreamTable
  , lookupStream
  , insertStream
  , deleteStream
  , insertCryptoStreams
  , deleteCryptoStream
  , lookupCryptoStream
  ) where

import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as Map

import {-# Source #-} Network.QUIC.Connection.Types
import Network.QUIC.Stream.Types
import Network.QUIC.Types

----------------------------------------------------------------

newtype StreamTable = StreamTable (IntMap Stream)

emptyStreamTable :: StreamTable
emptyStreamTable :: StreamTable
emptyStreamTable = IntMap Stream -> StreamTable
StreamTable IntMap Stream
forall a. IntMap a
Map.empty

----------------------------------------------------------------

lookupStream :: StreamId -> StreamTable -> Maybe Stream
lookupStream :: StreamId -> StreamTable -> Maybe Stream
lookupStream StreamId
sid (StreamTable IntMap Stream
tbl) = StreamId -> IntMap Stream -> Maybe Stream
forall a. StreamId -> IntMap a -> Maybe a
Map.lookup StreamId
sid IntMap Stream
tbl

insertStream :: StreamId -> Stream -> StreamTable -> StreamTable
insertStream :: StreamId -> Stream -> StreamTable -> StreamTable
insertStream StreamId
sid Stream
strm (StreamTable IntMap Stream
tbl) = IntMap Stream -> StreamTable
StreamTable (IntMap Stream -> StreamTable) -> IntMap Stream -> StreamTable
forall a b. (a -> b) -> a -> b
$ StreamId -> Stream -> IntMap Stream -> IntMap Stream
forall a. StreamId -> a -> IntMap a -> IntMap a
Map.insert StreamId
sid Stream
strm IntMap Stream
tbl

deleteStream :: StreamId -> StreamTable -> StreamTable
deleteStream :: StreamId -> StreamTable -> StreamTable
deleteStream StreamId
sid (StreamTable IntMap Stream
tbl) = IntMap Stream -> StreamTable
StreamTable (IntMap Stream -> StreamTable) -> IntMap Stream -> StreamTable
forall a b. (a -> b) -> a -> b
$ StreamId -> IntMap Stream -> IntMap Stream
forall a. StreamId -> IntMap a -> IntMap a
Map.delete StreamId
sid IntMap Stream
tbl

----------------------------------------------------------------

initialCryptoStreamId,handshakeCryptoStreamId,rtt1CryptoStreamId :: StreamId
initialCryptoStreamId :: StreamId
initialCryptoStreamId   = -StreamId
1
handshakeCryptoStreamId :: StreamId
handshakeCryptoStreamId = -StreamId
2
rtt1CryptoStreamId :: StreamId
rtt1CryptoStreamId      = -StreamId
3

toCryptoStreamId :: EncryptionLevel -> StreamId
toCryptoStreamId :: EncryptionLevel -> StreamId
toCryptoStreamId EncryptionLevel
InitialLevel   = StreamId
initialCryptoStreamId
-- This is to generate an error packet of CRYPTO in 0-RTT
toCryptoStreamId EncryptionLevel
RTT0Level      = StreamId
rtt1CryptoStreamId
toCryptoStreamId EncryptionLevel
HandshakeLevel = StreamId
handshakeCryptoStreamId
toCryptoStreamId EncryptionLevel
RTT1Level      = StreamId
rtt1CryptoStreamId

----------------------------------------------------------------

insertCryptoStreams :: Connection -> StreamTable -> IO StreamTable
insertCryptoStreams :: Connection -> StreamTable -> IO StreamTable
insertCryptoStreams Connection
conn StreamTable
stbl = do
    Stream
strm1 <- Connection -> StreamId -> IO Stream
newStream Connection
conn StreamId
initialCryptoStreamId
    Stream
strm2 <- Connection -> StreamId -> IO Stream
newStream Connection
conn StreamId
handshakeCryptoStreamId
    Stream
strm3 <- Connection -> StreamId -> IO Stream
newStream Connection
conn StreamId
rtt1CryptoStreamId
    StreamTable -> IO StreamTable
forall (m :: * -> *) a. Monad m => a -> m a
return (StreamTable -> IO StreamTable) -> StreamTable -> IO StreamTable
forall a b. (a -> b) -> a -> b
$ StreamId -> Stream -> StreamTable -> StreamTable
insertStream StreamId
initialCryptoStreamId   Stream
strm1
           (StreamTable -> StreamTable) -> StreamTable -> StreamTable
forall a b. (a -> b) -> a -> b
$ StreamId -> Stream -> StreamTable -> StreamTable
insertStream StreamId
handshakeCryptoStreamId Stream
strm2
           (StreamTable -> StreamTable) -> StreamTable -> StreamTable
forall a b. (a -> b) -> a -> b
$ StreamId -> Stream -> StreamTable -> StreamTable
insertStream StreamId
rtt1CryptoStreamId      Stream
strm3 StreamTable
stbl

deleteCryptoStream :: EncryptionLevel -> StreamTable -> StreamTable
deleteCryptoStream :: EncryptionLevel -> StreamTable -> StreamTable
deleteCryptoStream EncryptionLevel
InitialLevel   = StreamId -> StreamTable -> StreamTable
deleteStream StreamId
initialCryptoStreamId
deleteCryptoStream EncryptionLevel
RTT0Level      = StreamId -> StreamTable -> StreamTable
deleteStream StreamId
rtt1CryptoStreamId
deleteCryptoStream EncryptionLevel
HandshakeLevel = StreamId -> StreamTable -> StreamTable
deleteStream StreamId
handshakeCryptoStreamId
deleteCryptoStream EncryptionLevel
RTT1Level      = StreamId -> StreamTable -> StreamTable
deleteStream StreamId
rtt1CryptoStreamId

----------------------------------------------------------------

lookupCryptoStream :: EncryptionLevel -> StreamTable -> Maybe Stream
lookupCryptoStream :: EncryptionLevel -> StreamTable -> Maybe Stream
lookupCryptoStream EncryptionLevel
lvl StreamTable
stbl = StreamId -> StreamTable -> Maybe Stream
lookupStream StreamId
sid StreamTable
stbl
  where
    sid :: StreamId
sid = EncryptionLevel -> StreamId
toCryptoStreamId EncryptionLevel
lvl