{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.QUIC.Recovery.Utils (
    retransmit
  , sendPing
  , mergeLostCandidates
  , mergeLostCandidatesAndClear
  , peerCompletedAddressValidation
  , countAckEli
  , inCongestionRecovery
  , delay
  ) where

import Control.Concurrent
import Control.Concurrent.STM
import Data.Sequence (Seq, (<|), ViewL(..))
import qualified Data.Sequence as Seq

import Network.QUIC.Connector
import Network.QUIC.Imports
import Network.QUIC.Recovery.Types
import Network.QUIC.Types

----------------------------------------------------------------

retransmit :: LDCC -> Seq SentPacket -> IO ()
retransmit :: LDCC -> Seq SentPacket -> IO ()
retransmit LDCC
ldcc Seq SentPacket
lostPackets
  | Seq SentPacket -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq SentPacket
packetsToBeResent = LDCC -> IO EncryptionLevel
forall a. Connector a => a -> IO EncryptionLevel
getEncryptionLevel LDCC
ldcc IO EncryptionLevel -> (EncryptionLevel -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= LDCC -> EncryptionLevel -> IO ()
sendPing LDCC
ldcc
  | Bool
otherwise              = (SentPacket -> IO ()) -> Seq SentPacket -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SentPacket -> IO ()
put Seq SentPacket
packetsToBeResent
  where
    packetsToBeResent :: Seq SentPacket
packetsToBeResent = (SentPacket -> Bool) -> Seq SentPacket -> Seq SentPacket
forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter SentPacket -> Bool
spAckEliciting Seq SentPacket
lostPackets
    put :: SentPacket -> IO ()
put = LDCC -> PlainPacket -> IO ()
putRetrans LDCC
ldcc (PlainPacket -> IO ())
-> (SentPacket -> PlainPacket) -> SentPacket -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SentPacket -> PlainPacket
spPlainPacket

----------------------------------------------------------------

sendPing :: LDCC -> EncryptionLevel -> IO ()
sendPing :: LDCC -> EncryptionLevel -> IO ()
sendPing LDCC{Array EncryptionLevel (IORef Bool)
Array EncryptionLevel (IORef PeerPacketNumbers)
Array EncryptionLevel (IORef LossDetection)
Array EncryptionLevel (IORef SentPackets)
TVar (Maybe EncryptionLevel)
TVar TimerInfoQ
TVar CC
TVar SentPackets
IORef Bool
IORef PacketNumber
IORef (Maybe TimeoutKey)
IORef (Maybe TimerInfo)
IORef PeerPacketNumbers
IORef RTT
ConnState
PlainPacket -> IO ()
QLogger
timerInfoQ :: LDCC -> TVar TimerInfoQ
previousRTT1PPNs :: LDCC -> IORef PeerPacketNumbers
peerPacketNumbers :: LDCC -> Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: LDCC -> IORef PacketNumber
speedingUp :: LDCC -> IORef Bool
ptoPing :: LDCC -> TVar (Maybe EncryptionLevel)
lostCandidates :: LDCC -> TVar SentPackets
timerInfo :: LDCC -> IORef (Maybe TimerInfo)
timerKey :: LDCC -> IORef (Maybe TimeoutKey)
lossDetection :: LDCC -> Array EncryptionLevel (IORef LossDetection)
sentPackets :: LDCC -> Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: LDCC -> Array EncryptionLevel (IORef Bool)
recoveryCC :: LDCC -> TVar CC
recoveryRTT :: LDCC -> IORef RTT
ldccQlogger :: LDCC -> QLogger
ldccState :: LDCC -> ConnState
timerInfoQ :: TVar TimerInfoQ
previousRTT1PPNs :: IORef PeerPacketNumbers
peerPacketNumbers :: Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: IORef PacketNumber
speedingUp :: IORef Bool
ptoPing :: TVar (Maybe EncryptionLevel)
lostCandidates :: TVar SentPackets
timerInfo :: IORef (Maybe TimerInfo)
timerKey :: IORef (Maybe TimeoutKey)
lossDetection :: Array EncryptionLevel (IORef LossDetection)
sentPackets :: Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: Array EncryptionLevel (IORef Bool)
recoveryCC :: TVar CC
recoveryRTT :: IORef RTT
putRetrans :: PlainPacket -> IO ()
ldccQlogger :: QLogger
ldccState :: ConnState
putRetrans :: LDCC -> PlainPacket -> IO ()
..} EncryptionLevel
lvl = do
    TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
    IORef LossDetection -> (LossDetection -> LossDetection) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
atomicModifyIORef'' (Array EncryptionLevel (IORef LossDetection)
lossDetection Array EncryptionLevel (IORef LossDetection)
-> EncryptionLevel -> IORef LossDetection
forall i e. Ix i => Array i e -> i -> e
! EncryptionLevel
lvl) ((LossDetection -> LossDetection) -> IO ())
-> (LossDetection -> LossDetection) -> IO ()
forall a b. (a -> b) -> a -> b
$ \LossDetection
ld -> LossDetection
ld {
        timeOfLastAckElicitingPacket :: TimeMicrosecond
timeOfLastAckElicitingPacket = TimeMicrosecond
now
      }
    STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (Maybe EncryptionLevel) -> Maybe EncryptionLevel -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe EncryptionLevel)
ptoPing (Maybe EncryptionLevel -> STM ())
-> Maybe EncryptionLevel -> STM ()
forall a b. (a -> b) -> a -> b
$ EncryptionLevel -> Maybe EncryptionLevel
forall a. a -> Maybe a
Just EncryptionLevel
lvl

----------------------------------------------------------------

mergeLostCandidates :: LDCC -> Seq SentPacket -> IO ()
mergeLostCandidates :: LDCC -> Seq SentPacket -> IO ()
mergeLostCandidates LDCC{Array EncryptionLevel (IORef Bool)
Array EncryptionLevel (IORef PeerPacketNumbers)
Array EncryptionLevel (IORef LossDetection)
Array EncryptionLevel (IORef SentPackets)
TVar (Maybe EncryptionLevel)
TVar TimerInfoQ
TVar CC
TVar SentPackets
IORef Bool
IORef PacketNumber
IORef (Maybe TimeoutKey)
IORef (Maybe TimerInfo)
IORef PeerPacketNumbers
IORef RTT
ConnState
PlainPacket -> IO ()
QLogger
timerInfoQ :: TVar TimerInfoQ
previousRTT1PPNs :: IORef PeerPacketNumbers
peerPacketNumbers :: Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: IORef PacketNumber
speedingUp :: IORef Bool
ptoPing :: TVar (Maybe EncryptionLevel)
lostCandidates :: TVar SentPackets
timerInfo :: IORef (Maybe TimerInfo)
timerKey :: IORef (Maybe TimeoutKey)
lossDetection :: Array EncryptionLevel (IORef LossDetection)
sentPackets :: Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: Array EncryptionLevel (IORef Bool)
recoveryCC :: TVar CC
recoveryRTT :: IORef RTT
putRetrans :: PlainPacket -> IO ()
ldccQlogger :: QLogger
ldccState :: ConnState
timerInfoQ :: LDCC -> TVar TimerInfoQ
previousRTT1PPNs :: LDCC -> IORef PeerPacketNumbers
peerPacketNumbers :: LDCC -> Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: LDCC -> IORef PacketNumber
speedingUp :: LDCC -> IORef Bool
ptoPing :: LDCC -> TVar (Maybe EncryptionLevel)
lostCandidates :: LDCC -> TVar SentPackets
timerInfo :: LDCC -> IORef (Maybe TimerInfo)
timerKey :: LDCC -> IORef (Maybe TimeoutKey)
lossDetection :: LDCC -> Array EncryptionLevel (IORef LossDetection)
sentPackets :: LDCC -> Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: LDCC -> Array EncryptionLevel (IORef Bool)
recoveryCC :: LDCC -> TVar CC
recoveryRTT :: LDCC -> IORef RTT
ldccQlogger :: LDCC -> QLogger
ldccState :: LDCC -> ConnState
putRetrans :: LDCC -> PlainPacket -> IO ()
..} Seq SentPacket
lostPackets = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    SentPackets Seq SentPacket
old <- TVar SentPackets -> STM SentPackets
forall a. TVar a -> STM a
readTVar TVar SentPackets
lostCandidates
    let new :: Seq SentPacket
new = Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge Seq SentPacket
old Seq SentPacket
lostPackets
    TVar SentPackets -> SentPackets -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar SentPackets
lostCandidates (SentPackets -> STM ()) -> SentPackets -> STM ()
forall a b. (a -> b) -> a -> b
$ Seq SentPacket -> SentPackets
SentPackets Seq SentPacket
new

mergeLostCandidatesAndClear :: LDCC -> Seq SentPacket -> IO (Seq SentPacket)
mergeLostCandidatesAndClear :: LDCC -> Seq SentPacket -> IO (Seq SentPacket)
mergeLostCandidatesAndClear LDCC{Array EncryptionLevel (IORef Bool)
Array EncryptionLevel (IORef PeerPacketNumbers)
Array EncryptionLevel (IORef LossDetection)
Array EncryptionLevel (IORef SentPackets)
TVar (Maybe EncryptionLevel)
TVar TimerInfoQ
TVar CC
TVar SentPackets
IORef Bool
IORef PacketNumber
IORef (Maybe TimeoutKey)
IORef (Maybe TimerInfo)
IORef PeerPacketNumbers
IORef RTT
ConnState
PlainPacket -> IO ()
QLogger
timerInfoQ :: TVar TimerInfoQ
previousRTT1PPNs :: IORef PeerPacketNumbers
peerPacketNumbers :: Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: IORef PacketNumber
speedingUp :: IORef Bool
ptoPing :: TVar (Maybe EncryptionLevel)
lostCandidates :: TVar SentPackets
timerInfo :: IORef (Maybe TimerInfo)
timerKey :: IORef (Maybe TimeoutKey)
lossDetection :: Array EncryptionLevel (IORef LossDetection)
sentPackets :: Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: Array EncryptionLevel (IORef Bool)
recoveryCC :: TVar CC
recoveryRTT :: IORef RTT
putRetrans :: PlainPacket -> IO ()
ldccQlogger :: QLogger
ldccState :: ConnState
timerInfoQ :: LDCC -> TVar TimerInfoQ
previousRTT1PPNs :: LDCC -> IORef PeerPacketNumbers
peerPacketNumbers :: LDCC -> Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: LDCC -> IORef PacketNumber
speedingUp :: LDCC -> IORef Bool
ptoPing :: LDCC -> TVar (Maybe EncryptionLevel)
lostCandidates :: LDCC -> TVar SentPackets
timerInfo :: LDCC -> IORef (Maybe TimerInfo)
timerKey :: LDCC -> IORef (Maybe TimeoutKey)
lossDetection :: LDCC -> Array EncryptionLevel (IORef LossDetection)
sentPackets :: LDCC -> Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: LDCC -> Array EncryptionLevel (IORef Bool)
recoveryCC :: LDCC -> TVar CC
recoveryRTT :: LDCC -> IORef RTT
ldccQlogger :: LDCC -> QLogger
ldccState :: LDCC -> ConnState
putRetrans :: LDCC -> PlainPacket -> IO ()
..} Seq SentPacket
lostPackets = STM (Seq SentPacket) -> IO (Seq SentPacket)
forall a. STM a -> IO a
atomically (STM (Seq SentPacket) -> IO (Seq SentPacket))
-> STM (Seq SentPacket) -> IO (Seq SentPacket)
forall a b. (a -> b) -> a -> b
$ do
    SentPackets Seq SentPacket
old <- TVar SentPackets -> STM SentPackets
forall a. TVar a -> STM a
readTVar TVar SentPackets
lostCandidates
    TVar SentPackets -> SentPackets -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar SentPackets
lostCandidates SentPackets
emptySentPackets
    Seq SentPacket -> STM (Seq SentPacket)
forall (m :: * -> *) a. Monad m => a -> m a
return (Seq SentPacket -> STM (Seq SentPacket))
-> Seq SentPacket -> STM (Seq SentPacket)
forall a b. (a -> b) -> a -> b
$ Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge Seq SentPacket
old Seq SentPacket
lostPackets

merge :: Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge :: Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge Seq SentPacket
s1 Seq SentPacket
s2 = case Seq SentPacket -> ViewL SentPacket
forall a. Seq a -> ViewL a
Seq.viewl Seq SentPacket
s1 of
  ViewL SentPacket
EmptyL   -> Seq SentPacket
s2
  SentPacket
x :< Seq SentPacket
s1' -> case Seq SentPacket -> ViewL SentPacket
forall a. Seq a -> ViewL a
Seq.viewl Seq SentPacket
s2 of
    ViewL SentPacket
EmptyL  -> Seq SentPacket
s1
    SentPacket
y :< Seq SentPacket
s2'
      | SentPacket -> PacketNumber
spPacketNumber SentPacket
x PacketNumber -> PacketNumber -> Bool
forall a. Ord a => a -> a -> Bool
< SentPacket -> PacketNumber
spPacketNumber SentPacket
y -> SentPacket
x SentPacket -> Seq SentPacket -> Seq SentPacket
forall a. a -> Seq a -> Seq a
<| Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge Seq SentPacket
s1' Seq SentPacket
s2
      | Bool
otherwise                           -> SentPacket
y SentPacket -> Seq SentPacket -> Seq SentPacket
forall a. a -> Seq a -> Seq a
<| Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge Seq SentPacket
s1 Seq SentPacket
s2'

----------------------------------------------------------------

-- Sec 6.2.1. Computing PTO
-- "That is, a client does not reset the PTO backoff factor on
--  receiving acknowledgements until it receives a HANDSHAKE_DONE
--  frame or an acknowledgement for one of its Handshake or 1-RTT
--  packets."
peerCompletedAddressValidation :: LDCC -> IO Bool
-- For servers: assume clients validate the server's address implicitly.
peerCompletedAddressValidation :: LDCC -> IO Bool
peerCompletedAddressValidation LDCC
ldcc
  | LDCC -> Bool
forall a. Connector a => a -> Bool
isServer LDCC
ldcc = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
-- For clients: servers complete address validation when a protected
-- packet is received.
peerCompletedAddressValidation LDCC
ldcc = LDCC -> IO Bool
forall a. Connector a => a -> IO Bool
isConnectionEstablished LDCC
ldcc

----------------------------------------------------------------

countAckEli :: SentPacket -> Int
countAckEli :: SentPacket -> PacketNumber
countAckEli SentPacket
sentPacket
  | SentPacket -> Bool
spAckEliciting SentPacket
sentPacket = PacketNumber
1
  | Bool
otherwise                 = PacketNumber
0

----------------------------------------------------------------

inCongestionRecovery :: TimeMicrosecond -> Maybe TimeMicrosecond -> Bool
inCongestionRecovery :: TimeMicrosecond -> Maybe TimeMicrosecond -> Bool
inCongestionRecovery TimeMicrosecond
_ Maybe TimeMicrosecond
Nothing = Bool
False
inCongestionRecovery TimeMicrosecond
sentTime (Just TimeMicrosecond
crst) = TimeMicrosecond
sentTime TimeMicrosecond -> TimeMicrosecond -> Bool
forall a. Ord a => a -> a -> Bool
<= TimeMicrosecond
crst

----------------------------------------------------------------

delay :: Microseconds -> IO ()
delay :: Microseconds -> IO ()
delay (Microseconds PacketNumber
microseconds) = PacketNumber -> IO ()
threadDelay PacketNumber
microseconds