{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Client.Run (
    run,
    migrate,
) where

import qualified Network.Socket as NS
import Network.UDP (UDPSocket (..))
import qualified Network.UDP as UDP
import UnliftIO.Async
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E

import Network.QUIC.Client.Reader
import Network.QUIC.Closer
import Network.QUIC.Common
import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Crypto
import Network.QUIC.Handshake
import Network.QUIC.Imports
import Network.QUIC.Logger
import Network.QUIC.Parameters
import Network.QUIC.QLogger
import Network.QUIC.Receiver
import Network.QUIC.Recovery
import Network.QUIC.Sender
import Network.QUIC.Types

----------------------------------------------------------------

-- | Running a QUIC client.
--   A UDP socket is created according to 'ccServerName' and 'ccPortName'.
--
--   If 'ccAutoMigration' is 'True', a unconnected socket is made.
--   Otherwise, a connected socket is made.
--   Use the 'migrate' API for the connected socket.
run :: ClientConfig -> (Connection -> IO a) -> IO a
-- Don't use handleLogUnit here because of a return value.
run :: forall a. ClientConfig -> (Connection -> IO a) -> IO a
run ClientConfig
conf Connection -> IO a
client = IO a -> IO a
forall a. IO a -> IO a
NS.withSocketsDo (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
    let resInfo :: ResumptionInfo
resInfo = ClientConfig -> ResumptionInfo
ccResumption ClientConfig
conf
        verInfo :: VersionInfo
verInfo = case ResumptionInfo -> Maybe (ByteString, SessionData)
resumptionSession ResumptionInfo
resInfo of
            Maybe (ByteString, SessionData)
Nothing
                | ResumptionInfo -> ByteString
resumptionToken ResumptionInfo
resInfo ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
emptyToken ->
                    let ver :: Version
ver = ClientConfig -> Version
ccVersion ClientConfig
conf
                        vers :: [Version]
vers = ClientConfig -> [Version]
ccVersions ClientConfig
conf
                     in Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version]
vers
            Maybe (ByteString, SessionData)
_ -> let ver :: Version
ver = ResumptionInfo -> Version
resumptionVersion ResumptionInfo
resInfo in Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version
ver]
    Either NextVersion a
ex <- IO a -> IO (Either NextVersion a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
E.try (IO a -> IO (Either NextVersion a))
-> IO a -> IO (Either NextVersion a)
forall a b. (a -> b) -> a -> b
$ ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client Bool
False VersionInfo
verInfo
    case Either NextVersion a
ex of
        Right a
v -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
        Left (NextVersion VersionInfo
nextVerInfo)
            | VersionInfo
verInfo VersionInfo -> VersionInfo -> Bool
forall a. Eq a => a -> a -> Bool
== VersionInfo
brokenVersionInfo -> QUICException -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO QUICException
VersionNegotiationFailed
            | Bool
otherwise -> ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client Bool
True VersionInfo
nextVerInfo

runClient :: ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient :: forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client0 Bool
isICVN VersionInfo
verInfo = do
    IO ConnRes -> (ConnRes -> IO ()) -> (ConnRes -> IO a) -> IO a
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO ConnRes
open ConnRes -> IO ()
clse ((ConnRes -> IO a) -> IO a) -> (ConnRes -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \(ConnRes Connection
conn AuthCIDs
myAuthCIDs IO ()
reader) -> do
        IO () -> IO ThreadId
forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
reader IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addReader Connection
conn
        let conf' :: ClientConfig
conf' =
                ClientConfig
conf
                    { ccParameters =
                        (ccParameters conf)
                            { versionInformation = Just verInfo
                            }
                    }
        Connection -> Bool -> IO ()
setIncompatibleVN Connection
conn Bool
isICVN -- must be before handshaker
        Connection -> ByteString -> IO ()
setToken Connection
conn (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ ResumptionInfo -> ByteString
resumptionToken (ResumptionInfo -> ByteString) -> ResumptionInfo -> ByteString
forall a b. (a -> b) -> a -> b
$ ClientConfig -> ResumptionInfo
ccResumption ClientConfig
conf
        IO ()
handshaker <- ClientConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeClient ClientConfig
conf' Connection
conn AuthCIDs
myAuthCIDs
        let client :: IO a
client = do
                -- For 0-RTT, the following variables should be initialized
                -- in advance.
                Connection -> Int -> IO ()
setTxMaxStreams Connection
conn (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxStreamsBidi Parameters
defaultParameters
                Connection -> Int -> IO ()
setTxUniMaxStreams Connection
conn (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxStreamsUni Parameters
defaultParameters
                if ClientConfig -> Bool
ccUse0RTT ClientConfig
conf
                    then Connection -> IO ()
wait0RTTReady Connection
conn
                    else Connection -> IO ()
wait1RTTReady Connection
conn
                Connection -> IO a
client0 Connection
conn
            ldcc :: LDCC
ldcc = Connection -> LDCC
connLDCC Connection
conn
            supporters :: IO ()
supporters =
                (IO () -> IO () -> IO ()) -> [IO ()] -> IO ()
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1
                    IO () -> IO () -> IO ()
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m ()
concurrently_
                    [ IO ()
handshaker
                    , Connection -> IO ()
sender Connection
conn
                    , Connection -> IO ()
receiver Connection
conn
                    , LDCC -> IO ()
resender LDCC
ldcc
                    , LDCC -> IO ()
ldccTimer LDCC
ldcc
                    ]
            runThreads :: IO a
runThreads = do
                Either () a
er <- IO () -> IO a -> IO (Either () a)
forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> m b -> m (Either a b)
race IO ()
supporters IO a
client
                case Either () a
er of
                    Left () -> InternalControl -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO InternalControl
MustNotReached
                    Right a
r -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
        Either SomeException a
ex <- IO a -> IO (Either SomeException a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
E.trySyncOrAsync IO a
runThreads
        Connection -> IO ()
sendFinal Connection
conn
        Connection -> LDCC -> Either SomeException a -> IO a
forall a. Connection -> LDCC -> Either SomeException a -> IO a
closure Connection
conn LDCC
ldcc Either SomeException a
ex
  where
    open :: IO ConnRes
open = ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection ClientConfig
conf VersionInfo
verInfo
    clse :: ConnRes -> IO ()
clse ConnRes
connRes = do
        let conn :: Connection
conn = ConnRes -> Connection
connResConnection ConnRes
connRes
        Connection -> IO ()
setDead Connection
conn
        Connection -> IO ()
freeResources Connection
conn
        Connection -> IO ()
killReaders Connection
conn

createClientConnection :: ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection :: ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection conf :: ClientConfig
conf@ClientConfig{Bool
ServiceName
[Group]
[Cipher]
[Version]
Maybe Int
Maybe ServiceName
Credentials
ClientHooks
Version
ResumptionInfo
Parameters
Hooks
ServiceName -> IO ()
Version -> IO (Maybe [ByteString])
ccServerName :: ClientConfig -> ServiceName
ccPortName :: ClientConfig -> ServiceName
ccAutoMigration :: ClientConfig -> Bool
ccResumption :: ClientConfig -> ResumptionInfo
ccVersion :: ClientConfig -> Version
ccVersions :: ClientConfig -> [Version]
ccParameters :: ClientConfig -> Parameters
ccUse0RTT :: ClientConfig -> Bool
ccVersion :: Version
ccVersions :: [Version]
ccCiphers :: [Cipher]
ccGroups :: [Group]
ccParameters :: Parameters
ccKeyLog :: ServiceName -> IO ()
ccQLog :: Maybe ServiceName
ccCredentials :: Credentials
ccHooks :: Hooks
ccTlsHooks :: ClientHooks
ccUse0RTT :: Bool
ccServerName :: ServiceName
ccPortName :: ServiceName
ccALPN :: Version -> IO (Maybe [ByteString])
ccValidate :: Bool
ccResumption :: ResumptionInfo
ccPacketSize :: Maybe Int
ccDebugLog :: Bool
ccAutoMigration :: Bool
ccCiphers :: ClientConfig -> [Cipher]
ccGroups :: ClientConfig -> [Group]
ccKeyLog :: ClientConfig -> ServiceName -> IO ()
ccQLog :: ClientConfig -> Maybe ServiceName
ccCredentials :: ClientConfig -> Credentials
ccHooks :: ClientConfig -> Hooks
ccTlsHooks :: ClientConfig -> ClientHooks
ccALPN :: ClientConfig -> Version -> IO (Maybe [ByteString])
ccValidate :: ClientConfig -> Bool
ccPacketSize :: ClientConfig -> Maybe Int
ccDebugLog :: ClientConfig -> Bool
..} VersionInfo
verInfo = do
    us :: UDPSocket
us@(UDPSocket Socket
_ SockAddr
sa Bool
_) <-
        ServiceName -> ServiceName -> Bool -> IO UDPSocket
UDP.clientSocket ServiceName
ccServerName ServiceName
ccPortName (Bool -> Bool
not Bool
ccAutoMigration)
    RecvQ
q <- IO RecvQ
newRecvQ
    IORef UDPSocket
sref <- UDPSocket -> IO (IORef UDPSocket)
forall a. a -> IO (IORef a)
newIORef UDPSocket
us
    let send :: Ptr Word8 -> Int -> IO ()
send = \Ptr Word8
buf Int
siz -> do
            UDPSocket
cs <- IORef UDPSocket -> IO UDPSocket
forall a. IORef a -> IO a
readIORef IORef UDPSocket
sref
            UDPSocket -> Ptr Word8 -> Int -> IO ()
UDP.sendBuf UDPSocket
cs Ptr Word8
buf Int
siz
        recv :: IO ReceivedPacket
recv = RecvQ -> IO ReceivedPacket
recvClient RecvQ
q
    CID
myCID <- IO CID
newCID
    CID
peerCID <- IO CID
newCID
    TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
    (QLogger
qLog, IO ()
qclean) <- Maybe ServiceName
-> TimeMicrosecond -> CID -> ByteString -> IO (QLogger, IO ())
dirQLogger Maybe ServiceName
ccQLog TimeMicrosecond
now CID
peerCID ByteString
"client"
    let debugLog :: Builder -> IO ()
debugLog Builder
msg
            | Bool
ccDebugLog = Builder -> IO ()
stdoutLogger Builder
msg
            | Bool
otherwise = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Builder -> IO ()
debugLog (Builder -> IO ()) -> Builder -> IO ()
forall a b. (a -> b) -> a -> b
$ Builder
"Original CID: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> CID -> Builder
forall a. Show a => a -> Builder
bhow CID
peerCID
    let myAuthCIDs :: AuthCIDs
myAuthCIDs = AuthCIDs
defaultAuthCIDs{initSrcCID = Just myCID}
        peerAuthCIDs :: AuthCIDs
peerAuthCIDs = AuthCIDs
defaultAuthCIDs{initSrcCID = Just peerCID, origDstCID = Just peerCID}
    Connection
conn <-
        ClientConfig
-> VersionInfo
-> AuthCIDs
-> AuthCIDs
-> (Builder -> IO ())
-> QLogger
-> Hooks
-> IORef UDPSocket
-> RecvQ
-> (Ptr Word8 -> Int -> IO ())
-> IO ReceivedPacket
-> IO Connection
clientConnection
            ClientConfig
conf
            VersionInfo
verInfo
            AuthCIDs
myAuthCIDs
            AuthCIDs
peerAuthCIDs
            Builder -> IO ()
debugLog
            QLogger
qLog
            Hooks
ccHooks
            IORef UDPSocket
sref
            RecvQ
q
            Ptr Word8 -> Int -> IO ()
send
            IO ReceivedPacket
recv
    Connection -> IO () -> IO ()
addResource Connection
conn IO ()
qclean
    let ver :: Version
ver = VersionInfo -> Version
chosenVersion VersionInfo
verInfo
    Connection
-> EncryptionLevel -> TrafficSecrets InitialSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel (TrafficSecrets InitialSecret -> IO ())
-> TrafficSecrets InitialSecret -> IO ()
forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
peerCID
    Connection -> IO ()
setupCryptoStreams Connection
conn -- fixme: cleanup
    let pktSiz0 :: Int
pktSiz0 = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 Maybe Int
ccPacketSize
        pktSiz :: Int
pktSiz = (SockAddr -> Int
defaultPacketSize SockAddr
sa Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
pktSiz0) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` SockAddr -> Int
maximumPacketSize SockAddr
sa
    Connection -> Int -> IO ()
setMaxPacketSize Connection
conn Int
pktSiz
    LDCC -> Int -> IO ()
setInitialCongestionWindow (Connection -> LDCC
connLDCC Connection
conn) Int
pktSiz
    Connection -> IO ()
setAddressValidated Connection
conn
    let reader :: IO ()
reader = UDPSocket -> Connection -> IO ()
readerClient UDPSocket
us Connection
conn -- dies when s0 is closed.
    ConnRes -> IO ConnRes
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnRes -> IO ConnRes) -> ConnRes -> IO ConnRes
forall a b. (a -> b) -> a -> b
$ Connection -> AuthCIDs -> IO () -> ConnRes
ConnRes Connection
conn AuthCIDs
myAuthCIDs IO ()
reader

-- | Creating a new socket and execute a path validation
--   with a new connection ID. Typically, this is used
--   for migration in the case where 'ccAutoMigration' is 'False'.
--   But this can also be used even when the value is 'True'.
migrate :: Connection -> IO Bool
migrate :: Connection -> IO Bool
migrate Connection
conn = Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
ActiveMigration