{-# LANGUAGE OverloadedStrings #-}

module Network.QUIC.IO where

import qualified Data.ByteString as BS
import qualified UnliftIO.Exception as E
import UnliftIO.STM

import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Imports
import Network.QUIC.Parameters
import Network.QUIC.Stream
import Network.QUIC.Types

-- | Creating a bidirectional stream.
stream :: Connection -> IO Stream
stream :: Connection -> IO Stream
stream Connection
conn = do
    Int
sid <- Connection -> IO Int
waitMyNewStreamId Connection
conn
    Connection -> Int -> IO Stream
addStream Connection
conn Int
sid

-- | Creating a unidirectional stream.
unidirectionalStream :: Connection -> IO Stream
unidirectionalStream :: Connection -> IO Stream
unidirectionalStream Connection
conn = do
    Int
sid <- Connection -> IO Int
waitMyNewUniStreamId Connection
conn
    Connection -> Int -> IO Stream
addStream Connection
conn Int
sid

-- | Sending data in the stream.
sendStream :: Stream -> ByteString -> IO ()
sendStream :: Stream -> ByteString -> IO ()
sendStream Stream
s ByteString
dat = Stream -> [ByteString] -> IO ()
sendStreamMany Stream
s [ByteString
dat]

----------------------------------------------------------------

data Blocked = BothBlocked Stream Int Int
             | ConnBlocked Int
             | StrmBlocked Stream Int
             deriving Int -> Blocked -> ShowS
[Blocked] -> ShowS
Blocked -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Blocked] -> ShowS
$cshowList :: [Blocked] -> ShowS
show :: Blocked -> String
$cshow :: Blocked -> String
showsPrec :: Int -> Blocked -> ShowS
$cshowsPrec :: Int -> Blocked -> ShowS
Show

addTx :: Connection -> Stream -> Int -> IO ()
addTx :: Connection -> Stream -> Int -> IO ()
addTx Connection
conn Stream
s Int
len = forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ do
    Stream -> Int -> STM ()
addTxStreamData Stream
s Int
len
    Connection -> Int -> STM ()
addTxData Connection
conn Int
len

-- | Sending a list of data in the stream.
sendStreamMany :: Stream -> [ByteString] -> IO ()
sendStreamMany :: Stream -> [ByteString] -> IO ()
sendStreamMany Stream
_   [] = forall (m :: * -> *) a. Monad m => a -> m a
return ()
sendStreamMany Stream
s [ByteString]
dats0 = do
    Bool
sclosed <- Stream -> IO Bool
isTxStreamClosed Stream
s
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
sclosed forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO QUICException
StreamIsClosed
    -- fixme: size check for 0RTT
    let len :: Int
len = [ByteString] -> Int
totalLen [ByteString]
dats0
    Bool
ready <- Connection -> IO Bool
isConnection1RTTReady Connection
conn
    if Bool -> Bool
not Bool
ready then do
        -- 0-RTT
        Connection -> TxStreamData -> IO ()
putSendStreamQ Connection
conn forall a b. (a -> b) -> a -> b
$ Stream -> [ByteString] -> Int -> Bool -> TxStreamData
TxStreamData Stream
s [ByteString]
dats0 Int
len Bool
False
        Connection -> Stream -> Int -> IO ()
addTx Connection
conn Stream
s Int
len
      else
        [ByteString] -> Int -> Bool -> IO ()
flowControl [ByteString]
dats0 Int
len Bool
False
  where
    conn :: Connection
conn = Stream -> Connection
streamConnection Stream
s
    flowControl :: [ByteString] -> Int -> Bool -> IO ()
flowControl [ByteString]
dats Int
len Bool
wait = do
        -- 1-RTT
        Either Blocked Int
eblocked <- Stream -> Int -> Bool -> IO (Either Blocked Int)
checkBlocked Stream
s Int
len Bool
wait
        case Either Blocked Int
eblocked of
          Right Int
n
            | Int
len forall a. Eq a => a -> a -> Bool
== Int
n  -> do
                  Connection -> TxStreamData -> IO ()
putSendStreamQ Connection
conn forall a b. (a -> b) -> a -> b
$ Stream -> [ByteString] -> Int -> Bool -> TxStreamData
TxStreamData Stream
s [ByteString]
dats Int
len Bool
False
                  Connection -> Stream -> Int -> IO ()
addTx Connection
conn Stream
s Int
n
            | Bool
otherwise -> do
                  let ([ByteString]
dats1,[ByteString]
dats2) = Int -> [ByteString] -> ([ByteString], [ByteString])
split Int
n [ByteString]
dats
                  Connection -> TxStreamData -> IO ()
putSendStreamQ Connection
conn forall a b. (a -> b) -> a -> b
$ Stream -> [ByteString] -> Int -> Bool -> TxStreamData
TxStreamData Stream
s [ByteString]
dats1 Int
n Bool
False
                  Connection -> Stream -> Int -> IO ()
addTx Connection
conn Stream
s Int
n
                  [ByteString] -> Int -> Bool -> IO ()
flowControl [ByteString]
dats2 (Int
len forall a. Num a => a -> a -> a
- Int
n) Bool
False
          Left Blocked
blocked  -> do
              -- fixme: RTT0Level?
              Connection -> EncryptionLevel -> Blocked -> IO ()
sendBlocked Connection
conn EncryptionLevel
RTT1Level Blocked
blocked
              [ByteString] -> Int -> Bool -> IO ()
flowControl [ByteString]
dats Int
len Bool
True

sendBlocked :: Connection -> EncryptionLevel -> Blocked -> IO ()
sendBlocked :: Connection -> EncryptionLevel -> Blocked -> IO ()
sendBlocked Connection
conn EncryptionLevel
lvl Blocked
blocked = Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
lvl [Frame]
frames
  where
    frames :: [Frame]
frames = case Blocked
blocked of
      StrmBlocked Stream
strm Int
n   -> [Int -> Int -> Frame
StreamDataBlocked (Stream -> Int
streamId Stream
strm) Int
n]
      ConnBlocked Int
n        -> [Int -> Frame
DataBlocked Int
n]
      BothBlocked Stream
strm Int
n Int
m -> [Int -> Int -> Frame
StreamDataBlocked (Stream -> Int
streamId Stream
strm) Int
n, Int -> Frame
DataBlocked Int
m]

split :: Int -> [BS.ByteString] -> ([BS.ByteString],[BS.ByteString])
split :: Int -> [ByteString] -> ([ByteString], [ByteString])
split Int
n0 [ByteString]
dats0 = forall {c}.
Int -> [ByteString] -> ([ByteString] -> c) -> (c, [ByteString])
loop Int
n0 [ByteString]
dats0 forall a. a -> a
id
  where
    loop :: Int -> [ByteString] -> ([ByteString] -> c) -> (c, [ByteString])
loop Int
0 [ByteString]
bss      [ByteString] -> c
build = ([ByteString] -> c
build [], [ByteString]
bss)
    loop Int
_ []       [ByteString] -> c
build = ([ByteString] -> c
build [], [])
    loop Int
n (ByteString
bs:[ByteString]
bss) [ByteString] -> c
build = case Int
len forall a. Ord a => a -> a -> Ordering
`compare` Int
n of
        Ordering
GT -> let (ByteString
bs1,ByteString
bs2) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
n ByteString
bs
              in ([ByteString] -> c
build [ByteString
bs1], ByteString
bs2forall a. a -> [a] -> [a]
:[ByteString]
bss)
        Ordering
EQ -> ([ByteString] -> c
build [ByteString
bs], [ByteString]
bss)
        Ordering
LT -> Int -> [ByteString] -> ([ByteString] -> c) -> (c, [ByteString])
loop (Int
n forall a. Num a => a -> a -> a
- Int
len) [ByteString]
bss ([ByteString] -> c
build forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString
bs forall a. a -> [a] -> [a]
:))
      where
        len :: Int
len = ByteString -> Int
BS.length ByteString
bs

checkBlocked :: Stream -> Int -> Bool -> IO (Either Blocked Int)
checkBlocked :: Stream -> Int -> Bool -> IO (Either Blocked Int)
checkBlocked Stream
s Int
len Bool
wait = forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ do
    let conn :: Connection
conn = Stream -> Connection
streamConnection Stream
s
    Flow
strmFlow <- Stream -> STM Flow
readStreamFlowTx Stream
s
    Flow
connFlow <- Connection -> STM Flow
readConnectionFlowTx Connection
conn
    let strmWindow :: Int
strmWindow = Flow -> Int
flowWindow Flow
strmFlow
        connWindow :: Int
connWindow = Flow -> Int
flowWindow Flow
connFlow
        minFlow :: Int
minFlow = forall a. Ord a => a -> a -> a
min Int
strmWindow Int
connWindow
        n :: Int
n = forall a. Ord a => a -> a -> a
min Int
len Int
minFlow
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
wait forall a b. (a -> b) -> a -> b
$ Bool -> STM ()
checkSTM (Int
n forall a. Ord a => a -> a -> Bool
> Int
0)
    if Int
n forall a. Ord a => a -> a -> Bool
> Int
0 then
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right Int
n
      else do
        let cs :: Bool
cs = Int
len forall a. Ord a => a -> a -> Bool
> Int
strmWindow
            cw :: Bool
cw = Int
len forall a. Ord a => a -> a -> Bool
> Int
connWindow
            blocked :: Blocked
blocked
              | Bool
cs Bool -> Bool -> Bool
&& Bool
cw  = Stream -> Int -> Int -> Blocked
BothBlocked Stream
s (Flow -> Int
flowMaxData Flow
strmFlow) (Flow -> Int
flowMaxData Flow
connFlow)
              | Bool
cs        = Stream -> Int -> Blocked
StrmBlocked Stream
s (Flow -> Int
flowMaxData Flow
strmFlow)
              | Bool
otherwise = Int -> Blocked
ConnBlocked (Flow -> Int
flowMaxData Flow
connFlow)
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left Blocked
blocked

----------------------------------------------------------------

-- | Sending FIN in a stream.
--   'closeStream' should be called later.
shutdownStream :: Stream -> IO ()
shutdownStream :: Stream -> IO ()
shutdownStream Stream
s = do
    Bool
sclosed <- Stream -> IO Bool
isTxStreamClosed Stream
s
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
sclosed forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO QUICException
StreamIsClosed
    Stream -> IO ()
setTxStreamClosed Stream
s
    Connection -> TxStreamData -> IO ()
putSendStreamQ (Stream -> Connection
streamConnection Stream
s) forall a b. (a -> b) -> a -> b
$ Stream -> [ByteString] -> Int -> Bool -> TxStreamData
TxStreamData Stream
s [] Int
0 Bool
True
    Stream -> IO ()
waitFinTx Stream
s

-- | Closing a stream without an error.
--   This sends FIN if necessary.
closeStream :: Stream -> IO ()
closeStream :: Stream -> IO ()
closeStream Stream
s = do
    let conn :: Connection
conn = Stream -> Connection
streamConnection Stream
s
    let sid :: Int
sid = Stream -> Int
streamId Stream
s
    Bool
sclosed <- Stream -> IO Bool
isTxStreamClosed Stream
s
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
sclosed forall a b. (a -> b) -> a -> b
$ do
        Stream -> IO ()
setTxStreamClosed Stream
s
        Stream -> IO ()
setRxStreamClosed Stream
s
        Connection -> TxStreamData -> IO ()
putSendStreamQ Connection
conn forall a b. (a -> b) -> a -> b
$ Stream -> [ByteString] -> Int -> Bool -> TxStreamData
TxStreamData Stream
s [] Int
0 Bool
True
        Stream -> IO ()
waitFinTx Stream
s
    Connection -> Stream -> IO ()
delStream Connection
conn Stream
s
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((forall a. Connector a => a -> Bool
isClient Connection
conn Bool -> Bool -> Bool
&& Int -> Bool
isServerInitiatedBidirectional Int
sid)
       Bool -> Bool -> Bool
|| (forall a. Connector a => a -> Bool
isServer Connection
conn Bool -> Bool -> Bool
&& Int -> Bool
isClientInitiatedBidirectional Int
sid)) forall a b. (a -> b) -> a -> b
$ do
        Int
n <- Connection -> IO Int
getPeerMaxStreams Connection
conn
        Connection -> Output -> IO ()
putOutput Connection
conn forall a b. (a -> b) -> a -> b
$ EncryptionLevel -> [Frame] -> IO () -> Output
OutControl EncryptionLevel
RTT1Level [Direction -> Int -> Frame
MaxStreams Direction
Unidirectional Int
n] forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Accepting a stream initiated by the peer.
acceptStream :: Connection -> IO Stream
acceptStream :: Connection -> IO Stream
acceptStream Connection
conn = do
    InpStream Stream
s <- Connection -> IO Input
takeInput Connection
conn
    forall (m :: * -> *) a. Monad m => a -> m a
return Stream
s

-- | Receiving data in the stream. In the case where a FIN is received
--   an empty bytestring is returned.
recvStream :: Stream -> Int -> IO ByteString
recvStream :: Stream -> Int -> IO ByteString
recvStream Stream
s Int
n = do
    ByteString
bs <- Stream -> Int -> IO ByteString
takeRecvStreamQwithSize Stream
s Int
n
    let len :: Int
len = ByteString -> Int
BS.length ByteString
bs
        conn :: Connection
conn = Stream -> Connection
streamConnection Stream
s
    Stream -> Int -> IO ()
addRxStreamData Stream
s Int
len
    Connection -> Int -> IO ()
addRxData Connection
conn Int
len
    Int
window <- Stream -> IO Int
getRxStreamWindow Stream
s
    let sid :: Int
sid = Stream -> Int
streamId Stream
s
        initialWindow :: Int
initialWindow = Connection -> Int -> Int
initialRxMaxStreamData Connection
conn Int
sid
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
window forall a. Ord a => a -> a -> Bool
<= (Int
initialWindow forall a. Bits a => a -> Int -> a
!>>. Int
1)) forall a b. (a -> b) -> a -> b
$ do
        Int
newMax <- Stream -> Int -> IO Int
addRxMaxStreamData Stream
s Int
initialWindow
        Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [Int -> Int -> Frame
MaxStreamData Int
sid Int
newMax]
        Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn (Int -> Microseconds
Microseconds Int
50000) forall a b. (a -> b) -> a -> b
$ do
            Int
newMax' <- Stream -> IO Int
getRxMaxStreamData Stream
s
            Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [Int -> Int -> Frame
MaxStreamData Int
sid Int
newMax']
    Int
cwindow <- Connection -> IO Int
getRxDataWindow Connection
conn
    let cinitialWindow :: Int
cinitialWindow = Parameters -> Int
initialMaxData forall a b. (a -> b) -> a -> b
$ Connection -> Parameters
getMyParameters Connection
conn
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
cwindow forall a. Ord a => a -> a -> Bool
<= (Int
cinitialWindow forall a. Bits a => a -> Int -> a
!>>. Int
1)) forall a b. (a -> b) -> a -> b
$ do
        Int
newMax <- Connection -> Int -> IO Int
addRxMaxData Connection
conn Int
cinitialWindow
        Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [Int -> Frame
MaxData Int
newMax]
        Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn (Int -> Microseconds
Microseconds Int
50000) forall a b. (a -> b) -> a -> b
$ do
            Int
newMax' <- Connection -> IO Int
getRxMaxData Connection
conn
            Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [Int -> Frame
MaxData Int
newMax']
    forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

-- | Closing a stream with an error code.
--   This sends RESET_STREAM to the peer.
--   This is an alternative of 'closeStream'.
resetStream :: Stream -> ApplicationProtocolError -> IO ()
resetStream :: Stream -> ApplicationProtocolError -> IO ()
resetStream Stream
s ApplicationProtocolError
aerr = do
    let conn :: Connection
conn = Stream -> Connection
streamConnection Stream
s
    let sid :: Int
sid = Stream -> Int
streamId Stream
s
    Bool
sclosed <- Stream -> IO Bool
isTxStreamClosed Stream
s
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
sclosed forall a b. (a -> b) -> a -> b
$ do
        Stream -> IO ()
setTxStreamClosed Stream
s
        Stream -> IO ()
setRxStreamClosed Stream
s
        EncryptionLevel
lvl <- forall a. Connector a => a -> IO EncryptionLevel
getEncryptionLevel Connection
conn
        let frame :: Frame
frame = Int -> ApplicationProtocolError -> Int -> Frame
ResetStream Int
sid ApplicationProtocolError
aerr Int
0
        Connection -> Output -> IO ()
putOutput Connection
conn forall a b. (a -> b) -> a -> b
$ EncryptionLevel -> [Frame] -> IO () -> Output
OutControl EncryptionLevel
lvl [Frame
frame] forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Connection -> Stream -> IO ()
delStream Connection
conn Stream
s

-- | Asking the peer to stop sending.
--   This sends STOP_SENDING to the peer
--   and it will send RESET_STREAM back.
--   'closeStream' should be called later.
stopStream :: Stream -> ApplicationProtocolError -> IO ()
stopStream :: Stream -> ApplicationProtocolError -> IO ()
stopStream Stream
s ApplicationProtocolError
aerr = do
    let conn :: Connection
conn = Stream -> Connection
streamConnection Stream
s
    let sid :: Int
sid = Stream -> Int
streamId Stream
s
    Bool
sclosed <- Stream -> IO Bool
isRxStreamClosed Stream
s
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
sclosed forall a b. (a -> b) -> a -> b
$ do
        Stream -> IO ()
setRxStreamClosed Stream
s
        EncryptionLevel
lvl <- forall a. Connector a => a -> IO EncryptionLevel
getEncryptionLevel Connection
conn
        let frame :: Frame
frame = Int -> ApplicationProtocolError -> Frame
StopSending Int
sid ApplicationProtocolError
aerr
        Connection -> Output -> IO ()
putOutput Connection
conn forall a b. (a -> b) -> a -> b
$ EncryptionLevel -> [Frame] -> IO () -> Output
OutControl EncryptionLevel
lvl [Frame
frame] forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return ()