{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Client.Reader (
    readerClient
  , recvClient
  , ConnectionControl(..)
  , controlConnection
  ) where

import Control.Concurrent
import qualified Data.ByteString as BS
import Data.List (intersect)
import Network.Socket (Socket, getPeerName, close)
import qualified Network.Socket.ByteString as NSB
import qualified UnliftIO.Exception as E

import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Crypto
import Network.QUIC.Exception
import Network.QUIC.Imports
import Network.QUIC.Packet
import Network.QUIC.Parameters
import Network.QUIC.Qlog
import Network.QUIC.Recovery
import Network.QUIC.Socket
import Network.QUIC.Types

-- | readerClient dies when the socket is closed.
readerClient :: [Version] -> Socket -> Connection -> IO ()
readerClient :: [Version] -> Socket -> Connection -> IO ()
readerClient [Version]
myVers Socket
s0 Connection
conn = DebugLogger -> IO () -> IO ()
handleLogUnit DebugLogger
logAction (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    Connection -> IO (Maybe SockAddr)
getServerAddr Connection
conn IO (Maybe SockAddr) -> (Maybe SockAddr -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe SockAddr -> IO ()
loop
  where
    loop :: Maybe SockAddr -> IO ()
loop Maybe SockAddr
msa0 = do
        Microseconds
ito <- Connection -> IO Microseconds
readMinIdleTimeout Connection
conn
        Maybe ByteString
mbs <- Microseconds -> IO ByteString -> IO (Maybe ByteString)
forall a. Microseconds -> IO a -> IO (Maybe a)
timeout Microseconds
ito (IO ByteString -> IO (Maybe ByteString))
-> IO ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ do
            case Maybe SockAddr
msa0 of
              Maybe SockAddr
Nothing  ->     Socket -> Int -> IO ByteString
NSB.recv     Socket
s0 Int
maximumUdpPayloadSize
              Just SockAddr
sa0 -> do
                  (ByteString
bs, SockAddr
sa) <- Socket -> Int -> IO (ByteString, SockAddr)
NSB.recvFrom Socket
s0 Int
maximumUdpPayloadSize
                  ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ if SockAddr
sa SockAddr -> SockAddr -> Bool
forall a. Eq a => a -> a -> Bool
== SockAddr
sa0 then ByteString
bs else ByteString
""
        case Maybe ByteString
mbs of
          Maybe ByteString
Nothing -> Socket -> IO ()
close Socket
s0
          Just ByteString
"" -> Maybe SockAddr -> IO ()
loop Maybe SockAddr
msa0
          Just ByteString
bs -> do
            TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
            let bytes :: Int
bytes = ByteString -> Int
BS.length ByteString
bs
            Connection -> Int -> IO ()
addRxBytes Connection
conn Int
bytes
            [PacketI]
pkts <- ByteString -> IO [PacketI]
decodePackets ByteString
bs
            (PacketI -> IO ()) -> [PacketI] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TimeMicrosecond -> Int -> PacketI -> IO ()
putQ TimeMicrosecond
now Int
bytes) [PacketI]
pkts
            Maybe SockAddr -> IO ()
loop Maybe SockAddr
msa0
    logAction :: DebugLogger
logAction Builder
msg = Connection -> DebugLogger
connDebugLog Connection
conn (Builder
"debug: readerClient: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
msg)
    putQ :: TimeMicrosecond -> Int -> PacketI -> IO ()
putQ TimeMicrosecond
_ Int
_ (PacketIB BrokenPacket
BrokenPacket) = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    putQ TimeMicrosecond
t Int
_ (PacketIV pkt :: VersionNegotiationPacket
pkt@(VersionNegotiationPacket CID
dCID CID
sCID [Version]
peerVers)) = do
        Connection -> VersionNegotiationPacket -> TimeMicrosecond -> IO ()
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn VersionNegotiationPacket
pkt TimeMicrosecond
t
        Maybe Version
mver <- case [Version]
myVers of
          []  -> Maybe Version -> IO (Maybe Version)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Version
forall a. Maybe a
Nothing
          [Version
_] -> Maybe Version -> IO (Maybe Version)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Version
forall a. Maybe a
Nothing
          Version
_:[Version]
myVers' -> case [Version]
myVers' [Version] -> [Version] -> [Version]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
peerVers of
                  []    -> Maybe Version -> IO (Maybe Version)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Version
forall a. Maybe a
Nothing
                  Version
ver:[Version]
_ -> do
                      Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (CID -> Either CID (ByteString, ByteString)
forall a b. a -> Either a b
Left CID
sCID)
                      Maybe Version -> IO (Maybe Version)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Version -> IO (Maybe Version))
-> Maybe Version -> IO (Maybe Version)
forall a b. (a -> b) -> a -> b
$ if Bool
ok then Version -> Maybe Version
forall a. a -> Maybe a
Just Version
ver else Maybe Version
forall a. Maybe a
Nothing
        ThreadId -> Abort -> IO ()
forall e (m :: * -> *).
(Exception e, MonadIO m) =>
ThreadId -> e -> m ()
E.throwTo (Connection -> ThreadId
mainThreadId Connection
conn) (Abort -> IO ()) -> Abort -> IO ()
forall a b. (a -> b) -> a -> b
$ Maybe Version -> Abort
VerNego Maybe Version
mver
    putQ TimeMicrosecond
t Int
z (PacketIC CryptPacket
pkt EncryptionLevel
lvl) = RecvQ -> ReceivedPacket -> IO ()
writeRecvQ (Connection -> RecvQ
connRecvQ Connection
conn) (ReceivedPacket -> IO ()) -> ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
pkt TimeMicrosecond
t Int
z EncryptionLevel
lvl
    putQ TimeMicrosecond
t Int
_ (PacketIR pkt :: RetryPacket
pkt@(RetryPacket Version
ver CID
dCID CID
sCID ByteString
token Either CID (ByteString, ByteString)
ex)) = do
        Connection -> RetryPacket -> TimeMicrosecond -> IO ()
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn RetryPacket
pkt TimeMicrosecond
t
        Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID Either CID (ByteString, ByteString)
ex
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
ok (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Connection -> CID -> IO ()
resetPeerCID Connection
conn CID
sCID
            Connection -> (AuthCIDs -> AuthCIDs) -> IO ()
setPeerAuthCIDs Connection
conn ((AuthCIDs -> AuthCIDs) -> IO ())
-> (AuthCIDs -> AuthCIDs) -> IO ()
forall a b. (a -> b) -> a -> b
$ \AuthCIDs
auth -> AuthCIDs
auth { retrySrcCID :: Maybe CID
retrySrcCID  = CID -> Maybe CID
forall a. a -> Maybe a
Just CID
sCID }
            Connection
-> EncryptionLevel -> TrafficSecrets InitialSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel (TrafficSecrets InitialSecret -> IO ())
-> TrafficSecrets InitialSecret -> IO ()
forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
sCID
            Connection -> ByteString -> IO ()
setToken Connection
conn ByteString
token
            Connection -> Bool -> IO ()
setRetried Connection
conn Bool
True
            LDCC -> IO (Seq PlainPacket)
releaseByRetry (Connection -> LDCC
connLDCC Connection
conn) IO (Seq PlainPacket) -> (Seq PlainPacket -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (PlainPacket -> IO ()) -> Seq PlainPacket -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PlainPacket -> IO ()
put
      where
        put :: PlainPacket -> IO ()
put PlainPacket
ppkt = Connection -> Output -> IO ()
putOutput Connection
conn (Output -> IO ()) -> Output -> IO ()
forall a b. (a -> b) -> a -> b
$ PlainPacket -> Output
OutRetrans PlainPacket
ppkt

checkCIDs :: Connection -> CID -> Either CID (ByteString,ByteString) -> IO Bool
checkCIDs :: Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (Left CID
sCID) = do
    CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
    CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
    Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& CID
sCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
remoteCID)
checkCIDs Connection
conn CID
dCID (Right (ByteString
pseudo0,ByteString
tag)) = do
    CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
    CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
    Version
ver <- Connection -> IO Version
getVersion Connection
conn
    let ok :: Bool
ok = Version -> CID -> ByteString -> ByteString
calculateIntegrityTag Version
ver CID
remoteCID ByteString
pseudo0 ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
tag
    Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& Bool
ok)

recvClient :: RecvQ -> IO ReceivedPacket
recvClient :: RecvQ -> IO ReceivedPacket
recvClient = RecvQ -> IO ReceivedPacket
readRecvQ

----------------------------------------------------------------

-- | How to control a connection.
data ConnectionControl = ChangeServerCID
                       | ChangeClientCID
                       | NATRebinding
                       | ActiveMigration
                       deriving (ConnectionControl -> ConnectionControl -> Bool
(ConnectionControl -> ConnectionControl -> Bool)
-> (ConnectionControl -> ConnectionControl -> Bool)
-> Eq ConnectionControl
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConnectionControl -> ConnectionControl -> Bool
$c/= :: ConnectionControl -> ConnectionControl -> Bool
== :: ConnectionControl -> ConnectionControl -> Bool
$c== :: ConnectionControl -> ConnectionControl -> Bool
Eq, Int -> ConnectionControl -> ShowS
[ConnectionControl] -> ShowS
ConnectionControl -> String
(Int -> ConnectionControl -> ShowS)
-> (ConnectionControl -> String)
-> ([ConnectionControl] -> ShowS)
-> Show ConnectionControl
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionControl] -> ShowS
$cshowList :: [ConnectionControl] -> ShowS
show :: ConnectionControl -> String
$cshow :: ConnectionControl -> String
showsPrec :: Int -> ConnectionControl -> ShowS
$cshowsPrec :: Int -> ConnectionControl -> ShowS
Show)

controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
typ
  | Connection -> Bool
forall a. Connector a => a -> Bool
isClient Connection
conn = do
        Connection -> IO ()
waitEstablished Connection
conn
        Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
typ
  | Bool
otherwise     = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
ChangeServerCID = do
    Maybe CIDInfo
mn <- Microseconds -> IO CIDInfo -> IO (Maybe CIDInfo)
forall a. Microseconds -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) (IO CIDInfo -> IO (Maybe CIDInfo))
-> IO CIDInfo -> IO (Maybe CIDInfo)
forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn -- fixme
    case Maybe CIDInfo
mn of
      Maybe CIDInfo
Nothing              -> Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
      Just (CIDInfo Int
n CID
_ StatelessResetToken
_) -> do
          Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [Int -> Frame
RetireConnectionID Int
n]
          Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ChangeClientCID = do
    CIDInfo
cidInfo <- Connection -> IO CIDInfo
getNewMyCID Connection
conn
    Int
x <- (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int -> Int) -> IO Int -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Int
getMyCIDSeqNum Connection
conn
    Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [CIDInfo -> Int -> Frame
NewConnectionID CIDInfo
cidInfo Int
x]
    Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
NATRebinding = do
    Connection -> Microseconds -> IO ()
rebind Connection
conn (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000 -- nearly 0
    Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ActiveMigration = do
    Maybe CIDInfo
mn <- Microseconds -> IO CIDInfo -> IO (Maybe CIDInfo)
forall a. Microseconds -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) (IO CIDInfo -> IO (Maybe CIDInfo))
-> IO CIDInfo -> IO (Maybe CIDInfo)
forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn -- fixme
    case Maybe CIDInfo
mn of
      Maybe CIDInfo
Nothing  -> Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
      Maybe CIDInfo
mcidinfo -> do
          Connection -> Microseconds -> IO ()
rebind Connection
conn (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000000
          Connection -> Maybe CIDInfo -> IO ()
validatePath Connection
conn Maybe CIDInfo
mcidinfo
          Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

rebind :: Connection -> Microseconds -> IO ()
rebind :: Connection -> Microseconds -> IO ()
rebind Connection
conn Microseconds
microseconds = do
    Socket
s0:[Socket]
_ <- Connection -> IO [Socket]
getSockets Connection
conn
    Maybe SockAddr
msa0 <- Connection -> IO (Maybe SockAddr)
getServerAddr Connection
conn
    Socket
s1 <- case Maybe SockAddr
msa0 of
      Maybe SockAddr
Nothing  -> Socket -> IO SockAddr
getPeerName Socket
s0 IO SockAddr -> (SockAddr -> IO Socket) -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SockAddr -> IO Socket
udpNATRebindingConnectedSocket
      Just SockAddr
sa0 -> SockAddr -> IO Socket
udpNATRebindingSocket SockAddr
sa0
    Socket
_ <- Connection -> Socket -> IO Socket
addSocket Connection
conn Socket
s1
    Version
v <- Connection -> IO Version
getVersion Connection
conn
    let reader :: IO ()
reader = [Version] -> Socket -> Connection -> IO ()
readerClient [Version
v] Socket
s1 Connection
conn -- versions are dummy
    IO () -> IO ThreadId
forkIO IO ()
reader IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addReader Connection
conn
    Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn Microseconds
microseconds (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> IO ()
close Socket
s0