{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.QUIC.Recovery.Persistent (
    getMaxAckDelay
  , calcPTO
  , backOff
  , inPersistentCongestion
  , findDuration -- for testing
  , getPTO
  ) where

import Data.Sequence (Seq, ViewL(..))
import qualified Data.Sequence as Seq
import Data.UnixTime

import Network.QUIC.Imports
import Network.QUIC.Recovery.Constants
import Network.QUIC.Recovery.Misc
import Network.QUIC.Recovery.Types
import Network.QUIC.Types

getMaxAckDelay :: Maybe EncryptionLevel -> Microseconds -> Microseconds
getMaxAckDelay :: Maybe EncryptionLevel -> Microseconds -> Microseconds
getMaxAckDelay Maybe EncryptionLevel
Nothing Microseconds
n = Microseconds
n
getMaxAckDelay (Just EncryptionLevel
lvl) Microseconds
n
  | EncryptionLevel
lvl EncryptionLevel -> [EncryptionLevel] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [EncryptionLevel
InitialLevel,EncryptionLevel
HandshakeLevel] = Microseconds
0
  | Bool
otherwise                                = Microseconds
n

-- Sec 6.2.1. Computing PTO
-- PTO = smoothed_rtt + max(4*rttvar, kGranularity) + max_ack_delay
calcPTO :: RTT -> Maybe EncryptionLevel -> Microseconds
calcPTO :: RTT -> Maybe EncryptionLevel -> Microseconds
calcPTO RTT{Int
Microseconds
ptoCount :: RTT -> Int
maxAckDelay1RTT :: RTT -> Microseconds
minRTT :: RTT -> Microseconds
rttvar :: RTT -> Microseconds
smoothedRTT :: RTT -> Microseconds
latestRTT :: RTT -> Microseconds
ptoCount :: Int
maxAckDelay1RTT :: Microseconds
minRTT :: Microseconds
rttvar :: Microseconds
smoothedRTT :: Microseconds
latestRTT :: Microseconds
..} Maybe EncryptionLevel
mlvl = Microseconds
smoothedRTT Microseconds -> Microseconds -> Microseconds
forall a. Num a => a -> a -> a
+ Microseconds -> Microseconds -> Microseconds
forall a. Ord a => a -> a -> a
max (Microseconds
rttvar Microseconds -> Int -> Microseconds
forall a. Bits a => a -> Int -> a
.<<. Int
2) Microseconds
kGranularity Microseconds -> Microseconds -> Microseconds
forall a. Num a => a -> a -> a
+ Microseconds
dly
  where
    dly :: Microseconds
dly = Maybe EncryptionLevel -> Microseconds -> Microseconds
getMaxAckDelay Maybe EncryptionLevel
mlvl Microseconds
maxAckDelay1RTT

backOff :: Microseconds -> Int -> Microseconds
backOff :: Microseconds -> Int -> Microseconds
backOff Microseconds
n Int
cnt = Microseconds
n Microseconds -> Microseconds -> Microseconds
forall a. Num a => a -> a -> a
* (Microseconds
2 Microseconds -> Int -> Microseconds
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
cnt)

-- Sec 7.8. Persistent Congestion
inPersistentCongestion :: LDCC -> Seq SentPacket -> IO Bool
inPersistentCongestion :: LDCC -> Seq SentPacket -> IO Bool
inPersistentCongestion ldcc :: LDCC
ldcc@LDCC{Array EncryptionLevel (IORef Bool)
Array EncryptionLevel (IORef PeerPacketNumbers)
Array EncryptionLevel (IORef LossDetection)
Array EncryptionLevel (IORef SentPackets)
TVar (Maybe EncryptionLevel)
TVar TimerInfoQ
TVar CC
TVar SentPackets
IORef Bool
IORef Int
IORef (Maybe TimeoutKey)
IORef (Maybe TimerInfo)
IORef PeerPacketNumbers
IORef RTT
ConnState
PlainPacket -> IO ()
QLogger
timerInfoQ :: LDCC -> TVar TimerInfoQ
previousRTT1PPNs :: LDCC -> IORef PeerPacketNumbers
peerPacketNumbers :: LDCC -> Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: LDCC -> IORef Int
speedingUp :: LDCC -> IORef Bool
ptoPing :: LDCC -> TVar (Maybe EncryptionLevel)
lostCandidates :: LDCC -> TVar SentPackets
timerInfo :: LDCC -> IORef (Maybe TimerInfo)
timerKey :: LDCC -> IORef (Maybe TimeoutKey)
lossDetection :: LDCC -> Array EncryptionLevel (IORef LossDetection)
sentPackets :: LDCC -> Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: LDCC -> Array EncryptionLevel (IORef Bool)
recoveryCC :: LDCC -> TVar CC
recoveryRTT :: LDCC -> IORef RTT
putRetrans :: LDCC -> PlainPacket -> IO ()
ldccQlogger :: LDCC -> QLogger
ldccState :: LDCC -> ConnState
timerInfoQ :: TVar TimerInfoQ
previousRTT1PPNs :: IORef PeerPacketNumbers
peerPacketNumbers :: Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: IORef Int
speedingUp :: IORef Bool
ptoPing :: TVar (Maybe EncryptionLevel)
lostCandidates :: TVar SentPackets
timerInfo :: IORef (Maybe TimerInfo)
timerKey :: IORef (Maybe TimeoutKey)
lossDetection :: Array EncryptionLevel (IORef LossDetection)
sentPackets :: Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: Array EncryptionLevel (IORef Bool)
recoveryCC :: TVar CC
recoveryRTT :: IORef RTT
putRetrans :: PlainPacket -> IO ()
ldccQlogger :: QLogger
ldccState :: ConnState
..} Seq SentPacket
lostPackets = do
    Int
pn <- LDCC -> IO Int
getPktNumPersistent LDCC
ldcc
    let mduration :: Maybe UnixDiffTime
mduration = Seq SentPacket -> Int -> Maybe UnixDiffTime
findDuration Seq SentPacket
lostPackets Int
pn
    case Maybe UnixDiffTime
mduration of
      Maybe UnixDiffTime
Nothing -> Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
      Just UnixDiffTime
duration -> do
          RTT
rtt <- IORef RTT -> IO RTT
forall a. IORef a -> IO a
readIORef IORef RTT
recoveryRTT
          let pto :: Microseconds
pto = RTT -> Maybe EncryptionLevel -> Microseconds
calcPTO RTT
rtt Maybe EncryptionLevel
forall a. Maybe a
Nothing
              Microseconds Int
congestionPeriod = Microseconds -> Microseconds
kPersistentCongestionThreshold Microseconds
pto
              threshold :: UnixDiffTime
threshold = Int -> UnixDiffTime
forall a. Integral a => a -> UnixDiffTime
microSecondsToUnixDiffTime Int
congestionPeriod
          Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (UnixDiffTime
duration UnixDiffTime -> UnixDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
> UnixDiffTime
threshold)

findDuration :: Seq SentPacket -> PacketNumber -> Maybe UnixDiffTime
findDuration :: Seq SentPacket -> Int -> Maybe UnixDiffTime
findDuration Seq SentPacket
pkts0 Int
pn = Seq SentPacket -> Maybe UnixDiffTime -> Maybe UnixDiffTime
leftEdge Seq SentPacket
pkts0 Maybe UnixDiffTime
forall a. Maybe a
Nothing
  where
    leftEdge :: Seq SentPacket -> Maybe UnixDiffTime -> Maybe UnixDiffTime
leftEdge Seq SentPacket
pkts Maybe UnixDiffTime
mdiff = case Seq SentPacket -> ViewL SentPacket
forall a. Seq a -> ViewL a
Seq.viewl Seq SentPacket
pkts' of
        ViewL SentPacket
EmptyL      -> Maybe UnixDiffTime
mdiff
        SentPacket
l :< Seq SentPacket
pkts'' -> case Int
-> Seq SentPacket
-> Maybe SentPacket
-> (Maybe SentPacket, Seq SentPacket)
rightEdge (SentPacket -> Int
spPacketNumber SentPacket
l) Seq SentPacket
pkts'' Maybe SentPacket
forall a. Maybe a
Nothing of
          (Maybe SentPacket
Nothing, Seq SentPacket
pkts''') -> Seq SentPacket -> Maybe UnixDiffTime -> Maybe UnixDiffTime
leftEdge Seq SentPacket
pkts''' Maybe UnixDiffTime
mdiff
          (Just SentPacket
r,  Seq SentPacket
pkts''') ->
              let diff' :: UnixDiffTime
diff' = SentPacket -> TimeMicrosecond
spTimeSent SentPacket
r TimeMicrosecond -> TimeMicrosecond -> UnixDiffTime
`diffUnixTime` SentPacket -> TimeMicrosecond
spTimeSent SentPacket
l
              in case Maybe UnixDiffTime
mdiff of
                Maybe UnixDiffTime
Nothing          -> Seq SentPacket -> Maybe UnixDiffTime -> Maybe UnixDiffTime
leftEdge Seq SentPacket
pkts''' (Maybe UnixDiffTime -> Maybe UnixDiffTime)
-> Maybe UnixDiffTime -> Maybe UnixDiffTime
forall a b. (a -> b) -> a -> b
$ UnixDiffTime -> Maybe UnixDiffTime
forall a. a -> Maybe a
Just UnixDiffTime
diff'
                Just UnixDiffTime
diff
                  | UnixDiffTime
diff' UnixDiffTime -> UnixDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
> UnixDiffTime
diff -> Seq SentPacket -> Maybe UnixDiffTime -> Maybe UnixDiffTime
leftEdge Seq SentPacket
pkts''' (Maybe UnixDiffTime -> Maybe UnixDiffTime)
-> Maybe UnixDiffTime -> Maybe UnixDiffTime
forall a b. (a -> b) -> a -> b
$ UnixDiffTime -> Maybe UnixDiffTime
forall a. a -> Maybe a
Just UnixDiffTime
diff'
                  | Bool
otherwise    -> Seq SentPacket -> Maybe UnixDiffTime -> Maybe UnixDiffTime
leftEdge Seq SentPacket
pkts''' (Maybe UnixDiffTime -> Maybe UnixDiffTime)
-> Maybe UnixDiffTime -> Maybe UnixDiffTime
forall a b. (a -> b) -> a -> b
$ UnixDiffTime -> Maybe UnixDiffTime
forall a. a -> Maybe a
Just UnixDiffTime
diff
      where
        (Seq SentPacket
_, Seq SentPacket
pkts') = (SentPacket -> Bool)
-> Seq SentPacket -> (Seq SentPacket, Seq SentPacket)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.breakl (\SentPacket
x -> SentPacket -> Bool
spAckEliciting SentPacket
x Bool -> Bool -> Bool
&& SentPacket -> Int
spPacketNumber SentPacket
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
pn) Seq SentPacket
pkts
    rightEdge :: Int
-> Seq SentPacket
-> Maybe SentPacket
-> (Maybe SentPacket, Seq SentPacket)
rightEdge Int
n Seq SentPacket
pkts Maybe SentPacket
Nothing = case Seq SentPacket -> ViewL SentPacket
forall a. Seq a -> ViewL a
Seq.viewl Seq SentPacket
pkts of
        ViewL SentPacket
EmptyL -> (Maybe SentPacket
forall a. Maybe a
Nothing, Seq SentPacket
forall a. Seq a
Seq.empty)
        SentPacket
r :< Seq SentPacket
pkts'
          | SentPacket -> Int
spPacketNumber SentPacket
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 ->
              if SentPacket -> Bool
spAckEliciting SentPacket
r then
                  Int
-> Seq SentPacket
-> Maybe SentPacket
-> (Maybe SentPacket, Seq SentPacket)
rightEdge (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Seq SentPacket
pkts' (Maybe SentPacket -> (Maybe SentPacket, Seq SentPacket))
-> Maybe SentPacket -> (Maybe SentPacket, Seq SentPacket)
forall a b. (a -> b) -> a -> b
$ SentPacket -> Maybe SentPacket
forall a. a -> Maybe a
Just SentPacket
r
                else
                  Int
-> Seq SentPacket
-> Maybe SentPacket
-> (Maybe SentPacket, Seq SentPacket)
rightEdge (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Seq SentPacket
pkts' Maybe SentPacket
forall a. Maybe a
Nothing
          | Bool
otherwise -> (Maybe SentPacket
forall a. Maybe a
Nothing, Seq SentPacket
pkts)
    rightEdge Int
n Seq SentPacket
pkts Maybe SentPacket
mr0 = case Seq SentPacket -> ViewL SentPacket
forall a. Seq a -> ViewL a
Seq.viewl Seq SentPacket
pkts of
        ViewL SentPacket
EmptyL -> (Maybe SentPacket
mr0, Seq SentPacket
forall a. Seq a
Seq.empty)
        SentPacket
r :< Seq SentPacket
pkts'
          | SentPacket -> Int
spPacketNumber SentPacket
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 ->
              if SentPacket -> Bool
spAckEliciting SentPacket
r then
                  Int
-> Seq SentPacket
-> Maybe SentPacket
-> (Maybe SentPacket, Seq SentPacket)
rightEdge (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Seq SentPacket
pkts' (Maybe SentPacket -> (Maybe SentPacket, Seq SentPacket))
-> Maybe SentPacket -> (Maybe SentPacket, Seq SentPacket)
forall a b. (a -> b) -> a -> b
$ SentPacket -> Maybe SentPacket
forall a. a -> Maybe a
Just SentPacket
r
                else
                  Int
-> Seq SentPacket
-> Maybe SentPacket
-> (Maybe SentPacket, Seq SentPacket)
rightEdge (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Seq SentPacket
pkts' Maybe SentPacket
mr0
          | Bool
otherwise -> (Maybe SentPacket
mr0, Seq SentPacket
pkts)

getPTO :: LDCC -> IO Microseconds
getPTO :: LDCC -> IO Microseconds
getPTO LDCC{Array EncryptionLevel (IORef Bool)
Array EncryptionLevel (IORef PeerPacketNumbers)
Array EncryptionLevel (IORef LossDetection)
Array EncryptionLevel (IORef SentPackets)
TVar (Maybe EncryptionLevel)
TVar TimerInfoQ
TVar CC
TVar SentPackets
IORef Bool
IORef Int
IORef (Maybe TimeoutKey)
IORef (Maybe TimerInfo)
IORef PeerPacketNumbers
IORef RTT
ConnState
PlainPacket -> IO ()
QLogger
timerInfoQ :: TVar TimerInfoQ
previousRTT1PPNs :: IORef PeerPacketNumbers
peerPacketNumbers :: Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: IORef Int
speedingUp :: IORef Bool
ptoPing :: TVar (Maybe EncryptionLevel)
lostCandidates :: TVar SentPackets
timerInfo :: IORef (Maybe TimerInfo)
timerKey :: IORef (Maybe TimeoutKey)
lossDetection :: Array EncryptionLevel (IORef LossDetection)
sentPackets :: Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: Array EncryptionLevel (IORef Bool)
recoveryCC :: TVar CC
recoveryRTT :: IORef RTT
putRetrans :: PlainPacket -> IO ()
ldccQlogger :: QLogger
ldccState :: ConnState
timerInfoQ :: LDCC -> TVar TimerInfoQ
previousRTT1PPNs :: LDCC -> IORef PeerPacketNumbers
peerPacketNumbers :: LDCC -> Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: LDCC -> IORef Int
speedingUp :: LDCC -> IORef Bool
ptoPing :: LDCC -> TVar (Maybe EncryptionLevel)
lostCandidates :: LDCC -> TVar SentPackets
timerInfo :: LDCC -> IORef (Maybe TimerInfo)
timerKey :: LDCC -> IORef (Maybe TimeoutKey)
lossDetection :: LDCC -> Array EncryptionLevel (IORef LossDetection)
sentPackets :: LDCC -> Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: LDCC -> Array EncryptionLevel (IORef Bool)
recoveryCC :: LDCC -> TVar CC
recoveryRTT :: LDCC -> IORef RTT
putRetrans :: LDCC -> PlainPacket -> IO ()
ldccQlogger :: LDCC -> QLogger
ldccState :: LDCC -> ConnState
..} = do
    RTT
rtt <- IORef RTT -> IO RTT
forall a. IORef a -> IO a
readIORef IORef RTT
recoveryRTT
    Microseconds -> IO Microseconds
forall (m :: * -> *) a. Monad m => a -> m a
return (Microseconds -> IO Microseconds)
-> Microseconds -> IO Microseconds
forall a b. (a -> b) -> a -> b
$ RTT -> Maybe EncryptionLevel -> Microseconds
calcPTO RTT
rtt Maybe EncryptionLevel
forall a. Maybe a
Nothing