{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Client.Reader (
    readerClient,
    recvClient,
    ConnectionControl (..),
    controlConnection,
    clientSocket,
) where

import Control.Concurrent
import qualified Control.Exception as E
import Data.List (intersect)
import Network.Socket (Socket, close, getSocketName)
import qualified Network.Socket.ByteString as NSB

import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Crypto
import Network.QUIC.Exception
import Network.QUIC.Imports
import Network.QUIC.Packet
import Network.QUIC.Parameters
import Network.QUIC.Qlog
import Network.QUIC.Recovery
import Network.QUIC.Socket
import Network.QUIC.Types

-- | readerClient dies when the socket is closed.
readerClient :: Socket -> Connection -> IO ()
readerClient :: Socket -> Connection -> IO ()
readerClient Socket
s0 Connection
conn = DebugLogger -> IO () -> IO ()
handleLogUnit DebugLogger
logAction (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    IO ()
wait
    IO ()
loop
  where
    wait :: IO ()
wait = do
        Bool
bound <- (SomeException -> IO Bool) -> IO Bool -> IO Bool
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle (IO Bool -> SomeException -> IO Bool
forall a. IO a -> SomeException -> IO a
throughAsync (Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False)) (IO Bool -> IO Bool) -> IO Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ do
            SockAddr
_ <- Socket -> IO SockAddr
getSocketName Socket
s0
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
bound (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            IO ()
yield
            IO ()
wait
    loop :: IO ()
loop = do
        Microseconds
ito <- Connection -> IO Microseconds
readMinIdleTimeout Connection
conn
        Maybe (SockAddr, ByteString, [Cmsg], MsgFlag)
mbs <-
            Microseconds
-> String
-> IO (SockAddr, ByteString, [Cmsg], MsgFlag)
-> IO (Maybe (SockAddr, ByteString, [Cmsg], MsgFlag))
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout Microseconds
ito String
"readeClient" (IO (SockAddr, ByteString, [Cmsg], MsgFlag)
 -> IO (Maybe (SockAddr, ByteString, [Cmsg], MsgFlag)))
-> IO (SockAddr, ByteString, [Cmsg], MsgFlag)
-> IO (Maybe (SockAddr, ByteString, [Cmsg], MsgFlag))
forall a b. (a -> b) -> a -> b
$
                Socket
-> Int
-> Int
-> MsgFlag
-> IO (SockAddr, ByteString, [Cmsg], MsgFlag)
NSB.recvMsg Socket
s0 Int
2048 Int
2048 MsgFlag
0 -- fixme
        case Maybe (SockAddr, ByteString, [Cmsg], MsgFlag)
mbs of
            Maybe (SockAddr, ByteString, [Cmsg], MsgFlag)
Nothing -> Socket -> IO ()
close Socket
s0
            Just (SockAddr
peersa, ByteString
bs, [Cmsg]
cmsgs, MsgFlag
_) -> do
                Connection -> PeerInfo -> IO ()
setPeerInfo Connection
conn (PeerInfo -> IO ()) -> PeerInfo -> IO ()
forall a b. (a -> b) -> a -> b
$ SockAddr -> [Cmsg] -> PeerInfo
PeerInfo SockAddr
peersa [Cmsg]
cmsgs
                TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
                let quicBit :: Bool
quicBit = Parameters -> Bool
greaseQuicBit (Parameters -> Bool) -> Parameters -> Bool
forall a b. (a -> b) -> a -> b
$ Connection -> Parameters
getMyParameters Connection
conn
                [PacketI]
pkts <- ByteString -> Bool -> IO [PacketI]
decodePackets ByteString
bs (Bool -> Bool
not Bool
quicBit)
                (PacketI -> IO ()) -> [PacketI] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
now) [PacketI]
pkts
                IO ()
loop
    logAction :: DebugLogger
logAction Builder
msg = Connection -> DebugLogger
connDebugLog Connection
conn (Builder
"debug: readerClient: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
msg)
    putQ :: TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
_ (PacketIB BrokenPacket
BrokenPacket Int
_) = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    putQ TimeMicrosecond
t (PacketIV pkt :: VersionNegotiationPacket
pkt@(VersionNegotiationPacket CID
dCID CID
sCID [Version]
peerVers)) = do
        Connection -> VersionNegotiationPacket -> TimeMicrosecond -> IO ()
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn VersionNegotiationPacket
pkt TimeMicrosecond
t
        VersionInfo
myVerInfo <- Connection -> IO VersionInfo
getVersionInfo Connection
conn
        let myVer :: Version
myVer = VersionInfo -> Version
chosenVersion VersionInfo
myVerInfo
            myVers0 :: [Version]
myVers0 = VersionInfo -> [Version]
otherVersions VersionInfo
myVerInfo
        -- ignoring VN if the original version is included.
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
myVer Version -> [Version] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers Bool -> Bool -> Bool
&& Version
Negotiation Version -> [Version] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (CID -> Either CID (ByteString, ByteString)
forall a b. a -> Either a b
Left CID
sCID)
            let myVers :: [Version]
myVers = (Version -> Bool) -> [Version] -> [Version]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Version -> Bool) -> Version -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Bool
isGreasingVersion) [Version]
myVers0
                nextVerInfo :: VersionInfo
nextVerInfo = case [Version]
myVers [Version] -> [Version] -> [Version]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
peerVers of
                    vers :: [Version]
vers@(Version
ver : [Version]
_) | Bool
ok -> Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version]
vers
                    [Version]
_ -> VersionInfo
brokenVersionInfo
            ThreadId -> Abort -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
E.throwTo (Connection -> ThreadId
mainThreadId Connection
conn) (Abort -> IO ()) -> Abort -> IO ()
forall a b. (a -> b) -> a -> b
$ VersionInfo -> Abort
VerNego VersionInfo
nextVerInfo
    putQ TimeMicrosecond
t (PacketIC CryptPacket
pkt EncryptionLevel
lvl Int
siz) = RecvQ -> ReceivedPacket -> IO ()
writeRecvQ (Connection -> RecvQ
connRecvQ Connection
conn) (ReceivedPacket -> IO ()) -> ReceivedPacket -> IO ()
forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
pkt TimeMicrosecond
t Int
siz EncryptionLevel
lvl
    putQ TimeMicrosecond
t (PacketIR pkt :: RetryPacket
pkt@(RetryPacket Version
ver CID
dCID CID
sCID ByteString
token Either CID (ByteString, ByteString)
ex)) = do
        Connection -> RetryPacket -> TimeMicrosecond -> IO ()
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn RetryPacket
pkt TimeMicrosecond
t
        Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID Either CID (ByteString, ByteString)
ex
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
ok (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Connection -> CID -> IO ()
resetPeerCID Connection
conn CID
sCID
            Connection -> (AuthCIDs -> AuthCIDs) -> IO ()
setPeerAuthCIDs Connection
conn ((AuthCIDs -> AuthCIDs) -> IO ())
-> (AuthCIDs -> AuthCIDs) -> IO ()
forall a b. (a -> b) -> a -> b
$ \AuthCIDs
auth -> AuthCIDs
auth{retrySrcCID = Just sCID}
            Connection
-> EncryptionLevel -> TrafficSecrets InitialSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel (TrafficSecrets InitialSecret -> IO ())
-> TrafficSecrets InitialSecret -> IO ()
forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
sCID
            Connection -> ByteString -> IO ()
setToken Connection
conn ByteString
token
            Connection -> Bool -> IO ()
setRetried Connection
conn Bool
True
            LDCC -> IO (Seq PlainPacket)
releaseByRetry (Connection -> LDCC
connLDCC Connection
conn) IO (Seq PlainPacket) -> (Seq PlainPacket -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (PlainPacket -> IO ()) -> Seq PlainPacket -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PlainPacket -> IO ()
put
      where
        put :: PlainPacket -> IO ()
put PlainPacket
ppkt = Connection -> Output -> IO ()
putOutput Connection
conn (Output -> IO ()) -> Output -> IO ()
forall a b. (a -> b) -> a -> b
$ PlainPacket -> Output
OutRetrans PlainPacket
ppkt

checkCIDs :: Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs :: Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (Left CID
sCID) = do
    CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
    CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& CID
sCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
remoteCID)
checkCIDs Connection
conn CID
dCID (Right (ByteString
pseudo0, ByteString
tag)) = do
    CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
    CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
    Version
ver <- Connection -> IO Version
getVersion Connection
conn
    let ok :: Bool
ok = Version -> CID -> ByteString -> ByteString
calculateIntegrityTag Version
ver CID
remoteCID ByteString
pseudo0 ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
tag
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID CID -> CID -> Bool
forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& Bool
ok)

recvClient :: RecvQ -> IO ReceivedPacket
recvClient :: RecvQ -> IO ReceivedPacket
recvClient = RecvQ -> IO ReceivedPacket
readRecvQ

----------------------------------------------------------------

-- | How to control a connection.
data ConnectionControl
    = ChangeServerCID
    | ChangeClientCID
    | NATRebinding
    | ActiveMigration
    deriving (ConnectionControl -> ConnectionControl -> Bool
(ConnectionControl -> ConnectionControl -> Bool)
-> (ConnectionControl -> ConnectionControl -> Bool)
-> Eq ConnectionControl
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectionControl -> ConnectionControl -> Bool
== :: ConnectionControl -> ConnectionControl -> Bool
$c/= :: ConnectionControl -> ConnectionControl -> Bool
/= :: ConnectionControl -> ConnectionControl -> Bool
Eq, Int -> ConnectionControl -> ShowS
[ConnectionControl] -> ShowS
ConnectionControl -> String
(Int -> ConnectionControl -> ShowS)
-> (ConnectionControl -> String)
-> ([ConnectionControl] -> ShowS)
-> Show ConnectionControl
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectionControl -> ShowS
showsPrec :: Int -> ConnectionControl -> ShowS
$cshow :: ConnectionControl -> String
show :: ConnectionControl -> String
$cshowList :: [ConnectionControl] -> ShowS
showList :: [ConnectionControl] -> ShowS
Show)

controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
typ
    | Connection -> Bool
forall a. Connector a => a -> Bool
isClient Connection
conn = do
        Connection -> IO ()
waitEstablished Connection
conn
        Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
typ
    | Bool
otherwise = Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
ChangeServerCID = do
    Maybe CIDInfo
mn <- Microseconds -> String -> IO CIDInfo -> IO (Maybe CIDInfo)
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) String
"controlConnection' 1" (IO CIDInfo -> IO (Maybe CIDInfo))
-> IO CIDInfo -> IO (Maybe CIDInfo)
forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn -- fixme
    case Maybe CIDInfo
mn of
        Maybe CIDInfo
Nothing -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        Just (CIDInfo Int
n CID
_ StatelessResetToken
_) -> do
            Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [Int -> Frame
RetireConnectionID Int
n]
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ChangeClientCID = do
    CIDInfo
cidInfo <- Connection -> IO CIDInfo
getNewMyCID Connection
conn
    Int
x <- (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int -> Int) -> IO Int -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Int
getMyCIDSeqNum Connection
conn
    Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [CIDInfo -> Int -> Frame
NewConnectionID CIDInfo
cidInfo Int
x]
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
NATRebinding = do
    Connection -> Microseconds -> IO ()
rebind Connection
conn (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000 -- nearly 0
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ActiveMigration = do
    Maybe CIDInfo
mn <- Microseconds -> String -> IO CIDInfo -> IO (Maybe CIDInfo)
forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) String
"controlConnection' 2" (IO CIDInfo -> IO (Maybe CIDInfo))
-> IO CIDInfo -> IO (Maybe CIDInfo)
forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn -- fixme
    case Maybe CIDInfo
mn of
        Maybe CIDInfo
Nothing -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        Maybe CIDInfo
mcidinfo -> do
            Connection -> Microseconds -> IO ()
rebind Connection
conn (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000000
            Connection -> Maybe CIDInfo -> IO ()
validatePath Connection
conn Maybe CIDInfo
mcidinfo
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

rebind :: Connection -> Microseconds -> IO ()
rebind :: Connection -> Microseconds -> IO ()
rebind Connection
conn Microseconds
microseconds = do
    PeerInfo SockAddr
peersa [Cmsg]
_ <- Connection -> IO PeerInfo
getPeerInfo Connection
conn
    Socket
newSock <- SockAddr -> IO Socket
natRebinding SockAddr
peersa
    Socket
oldSock <- Connection -> Socket -> IO Socket
setSocket Connection
conn Socket
newSock
    let reader :: IO ()
reader = Socket -> Connection -> IO ()
readerClient Socket
newSock Connection
conn
    IO () -> IO ThreadId
forkIO IO ()
reader IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addReader Connection
conn
    Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn Microseconds
microseconds (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> IO ()
close Socket
oldSock