{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Client.Run (
    run
  , migrate
  ) where

import qualified Network.Socket as NS
import UnliftIO.Async
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E

import Network.QUIC.Client.Reader
import Network.QUIC.Closer
import Network.QUIC.Common
import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Crypto
import Network.QUIC.Handshake
import Network.QUIC.Imports
import Network.QUIC.Logger
import Network.QUIC.Parameters
import Network.QUIC.QLogger
import Network.QUIC.Receiver
import Network.QUIC.Recovery
import Network.QUIC.Sender
import Network.QUIC.Socket
import Network.QUIC.Types

----------------------------------------------------------------

-- | Running a QUIC client.
--   A UDP socket is created according to 'ccServerName' and 'ccPortName'.
--
--   If 'ccAutoMigration' is 'True', a unconnected socket is made.
--   Otherwise, a connected socket is made.
--   Use the 'migrate' API for the connected socket.
run :: ClientConfig -> (Connection -> IO a) -> IO a
-- Don't use handleLogUnit here because of a return value.
run :: ClientConfig -> (Connection -> IO a) -> IO a
run ClientConfig
conf Connection -> IO a
client = case ClientConfig -> [Version]
ccVersions ClientConfig
conf of
  []     -> QUICException -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO QUICException
NoVersionIsSpecified
  Version
ver1:[Version]
_ -> do
      Either NextVersion a
ex <- IO a -> IO (Either NextVersion a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
E.try (IO a -> IO (Either NextVersion a))
-> IO a -> IO (Either NextVersion a)
forall a b. (a -> b) -> a -> b
$ ClientConfig -> (Connection -> IO a) -> Version -> IO a
forall a. ClientConfig -> (Connection -> IO a) -> Version -> IO a
runClient ClientConfig
conf Connection -> IO a
client Version
ver1
      case Either NextVersion a
ex of
        Right a
v                        -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
        Left (NextVersion Maybe Version
Nothing)     -> QUICException -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO QUICException
VersionNegotiationFailed
        Left (NextVersion (Just Version
ver2)) -> ClientConfig -> (Connection -> IO a) -> Version -> IO a
forall a. ClientConfig -> (Connection -> IO a) -> Version -> IO a
runClient ClientConfig
conf Connection -> IO a
client Version
ver2

runClient :: ClientConfig -> (Connection -> IO a) -> Version -> IO a
runClient :: ClientConfig -> (Connection -> IO a) -> Version -> IO a
runClient ClientConfig
conf Connection -> IO a
client0 Version
ver = do
    IO ConnRes -> (ConnRes -> IO ()) -> (ConnRes -> IO a) -> IO a
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO ConnRes
open ConnRes -> IO ()
clse ((ConnRes -> IO a) -> IO a) -> (ConnRes -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \(ConnRes Connection
conn SendBuf
send Receive
recv AuthCIDs
myAuthCIDs IO ()
reader) -> do
        IO () -> IO ThreadId
forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
reader    IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addReader Connection
conn
        IO () -> IO ThreadId
forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
timeouter IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addTimeouter Connection
conn
        IO ()
handshaker <- ClientConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeClient ClientConfig
conf Connection
conn AuthCIDs
myAuthCIDs
        let client :: IO a
client = do
                if ClientConfig -> Bool
ccUse0RTT ClientConfig
conf then
                    Connection -> IO ()
wait0RTTReady Connection
conn
                  else
                    Connection -> IO ()
wait1RTTReady Connection
conn
                Connection -> Token -> IO ()
setToken Connection
conn (Token -> IO ()) -> Token -> IO ()
forall a b. (a -> b) -> a -> b
$ ResumptionInfo -> Token
resumptionToken (ResumptionInfo -> Token) -> ResumptionInfo -> Token
forall a b. (a -> b) -> a -> b
$ ClientConfig -> ResumptionInfo
ccResumption ClientConfig
conf
                Connection -> IO a
client0 Connection
conn
            ldcc :: LDCC
ldcc = Connection -> LDCC
connLDCC Connection
conn
            supporters :: IO ()
supporters = (IO () -> IO () -> IO ()) -> [IO ()] -> IO ()
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 IO () -> IO () -> IO ()
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m ()
concurrently_ [IO ()
handshaker
                                              ,Connection -> SendBuf -> IO ()
sender   Connection
conn SendBuf
send
                                              ,Connection -> Receive -> IO ()
receiver Connection
conn Receive
recv
                                              ,LDCC -> IO ()
resender  LDCC
ldcc
                                              ,LDCC -> IO ()
ldccTimer LDCC
ldcc
                                              ]
            runThreads :: IO a
runThreads = do
                Either () a
er <- IO () -> IO a -> IO (Either () a)
forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> m b -> m (Either a b)
race IO ()
supporters IO a
client
                case Either () a
er of
                  Left () -> InternalControl -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO InternalControl
MustNotReached
                  Right a
r -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
        IO a -> IO (Either SomeException a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
E.trySyncOrAsync IO a
runThreads IO (Either SomeException a)
-> (Either SomeException a -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> LDCC -> Either SomeException a -> IO a
forall a. Connection -> LDCC -> Either SomeException a -> IO a
closure Connection
conn LDCC
ldcc
  where
    open :: IO ConnRes
open = ClientConfig -> Version -> IO ConnRes
createClientConnection ClientConfig
conf Version
ver
    clse :: ConnRes -> IO ()
clse ConnRes
connRes = do
        let conn :: Connection
conn = ConnRes -> Connection
connResConnection ConnRes
connRes
        Connection -> IO ()
setDead Connection
conn
        Connection -> IO ()
freeResources Connection
conn
        Connection -> IO ()
killReaders Connection
conn
        [Socket]
socks <- Connection -> IO [Socket]
getSockets Connection
conn
        (Socket -> IO ()) -> [Socket] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Socket -> IO ()
NS.close [Socket]
socks
        IO (IO ()) -> IO ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (IO (IO ()) -> IO ()) -> IO (IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO (IO ())
replaceKillTimeouter Connection
conn

createClientConnection :: ClientConfig -> Version -> IO ConnRes
createClientConnection :: ClientConfig -> Version -> IO ConnRes
createClientConnection conf :: ClientConfig
conf@ClientConfig{Bool
ServiceName
[Cipher]
[Group]
[Version]
Maybe Int
Maybe ServiceName
Credentials
ResumptionInfo
Parameters
Hooks
ServiceName -> IO ()
Version -> IO (Maybe [Token])
ccAutoMigration :: ClientConfig -> Bool
ccDebugLog :: ClientConfig -> Bool
ccPacketSize :: ClientConfig -> Maybe Int
ccValidate :: ClientConfig -> Bool
ccALPN :: ClientConfig -> Version -> IO (Maybe [Token])
ccPortName :: ClientConfig -> ServiceName
ccServerName :: ClientConfig -> ServiceName
ccHooks :: ClientConfig -> Hooks
ccCredentials :: ClientConfig -> Credentials
ccQLog :: ClientConfig -> Maybe ServiceName
ccKeyLog :: ClientConfig -> ServiceName -> IO ()
ccParameters :: ClientConfig -> Parameters
ccGroups :: ClientConfig -> [Group]
ccCiphers :: ClientConfig -> [Cipher]
ccAutoMigration :: Bool
ccDebugLog :: Bool
ccPacketSize :: Maybe Int
ccResumption :: ResumptionInfo
ccValidate :: Bool
ccALPN :: Version -> IO (Maybe [Token])
ccPortName :: ServiceName
ccServerName :: ServiceName
ccUse0RTT :: Bool
ccHooks :: Hooks
ccCredentials :: Credentials
ccQLog :: Maybe ServiceName
ccKeyLog :: ServiceName -> IO ()
ccParameters :: Parameters
ccGroups :: [Group]
ccCiphers :: [Cipher]
ccVersions :: [Version]
ccResumption :: ClientConfig -> ResumptionInfo
ccUse0RTT :: ClientConfig -> Bool
ccVersions :: ClientConfig -> [Version]
..} Version
ver = do
    (Socket
s0,SockAddr
sa0) <- if Bool
ccAutoMigration then
                  ServiceName -> ServiceName -> IO (Socket, SockAddr)
udpClientSocket ServiceName
ccServerName ServiceName
ccPortName
                else
                  ServiceName -> ServiceName -> IO (Socket, SockAddr)
udpClientConnectedSocket ServiceName
ccServerName ServiceName
ccPortName
    RecvQ
q <- IO RecvQ
newRecvQ
    IORef [Socket]
sref <- [Socket] -> IO (IORef [Socket])
forall a. a -> IO (IORef a)
newIORef [Socket
s0]
    let send :: SendBuf
send Ptr Word8
buf Int
siz = do
            Socket
s:[Socket]
_ <- IORef [Socket] -> IO [Socket]
forall a. IORef a -> IO a
readIORef IORef [Socket]
sref
            if Bool
ccAutoMigration then
                IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr Word8 -> Int -> SockAddr -> IO Int
forall a. Socket -> Ptr a -> Int -> SockAddr -> IO Int
NS.sendBufTo Socket
s Ptr Word8
buf Int
siz SockAddr
sa0
              else
                IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr Word8 -> Int -> IO Int
NS.sendBuf Socket
s Ptr Word8
buf Int
siz
        recv :: Receive
recv = RecvQ -> Receive
recvClient RecvQ
q
    CID
myCID   <- IO CID
newCID
    CID
peerCID <- IO CID
newCID
    TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
    (QLogger
qLog, IO ()
qclean) <- Maybe ServiceName
-> TimeMicrosecond -> CID -> Token -> IO (QLogger, IO ())
dirQLogger Maybe ServiceName
ccQLog TimeMicrosecond
now CID
peerCID Token
"client"
    let debugLog :: Builder -> IO ()
debugLog Builder
msg | Bool
ccDebugLog = Builder -> IO ()
stdoutLogger Builder
msg
                     | Bool
otherwise  = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Builder -> IO ()
debugLog (Builder -> IO ()) -> Builder -> IO ()
forall a b. (a -> b) -> a -> b
$ Builder
"Original CID: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> CID -> Builder
forall a. Show a => a -> Builder
bhow CID
peerCID
    let myAuthCIDs :: AuthCIDs
myAuthCIDs   = AuthCIDs
defaultAuthCIDs { initSrcCID :: Maybe CID
initSrcCID = CID -> Maybe CID
forall a. a -> Maybe a
Just CID
myCID }
        peerAuthCIDs :: AuthCIDs
peerAuthCIDs = AuthCIDs
defaultAuthCIDs { initSrcCID :: Maybe CID
initSrcCID = CID -> Maybe CID
forall a. a -> Maybe a
Just CID
peerCID, origDstCID :: Maybe CID
origDstCID = CID -> Maybe CID
forall a. a -> Maybe a
Just CID
peerCID }
    Connection
conn <- ClientConfig
-> Version
-> AuthCIDs
-> AuthCIDs
-> (Builder -> IO ())
-> QLogger
-> Hooks
-> IORef [Socket]
-> RecvQ
-> IO Connection
clientConnection ClientConfig
conf Version
ver AuthCIDs
myAuthCIDs AuthCIDs
peerAuthCIDs Builder -> IO ()
debugLog QLogger
qLog Hooks
ccHooks IORef [Socket]
sref RecvQ
q
    Connection -> IO () -> IO ()
addResource Connection
conn IO ()
qclean
    Connection
-> EncryptionLevel -> TrafficSecrets InitialSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel (TrafficSecrets InitialSecret -> IO ())
-> TrafficSecrets InitialSecret -> IO ()
forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
peerCID
    Connection -> IO ()
setupCryptoStreams Connection
conn -- fixme: cleanup
    let pktSiz0 :: Int
pktSiz0 = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 Maybe Int
ccPacketSize
        pktSiz :: Int
pktSiz = (SockAddr -> Int
defaultPacketSize SockAddr
sa0 Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
pktSiz0) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` SockAddr -> Int
maximumPacketSize SockAddr
sa0
    Connection -> Int -> IO ()
setMaxPacketSize Connection
conn Int
pktSiz
    LDCC -> Int -> IO ()
setInitialCongestionWindow (Connection -> LDCC
connLDCC Connection
conn) Int
pktSiz
    Connection -> IO ()
setAddressValidated Connection
conn
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
ccAutoMigration (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> SockAddr -> IO ()
setServerAddr Connection
conn SockAddr
sa0
    let reader :: IO ()
reader = [Version] -> Socket -> Connection -> IO ()
readerClient [Version]
ccVersions Socket
s0 Connection
conn -- dies when s0 is closed.
    ConnRes -> IO ConnRes
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnRes -> IO ConnRes) -> ConnRes -> IO ConnRes
forall a b. (a -> b) -> a -> b
$ Connection -> SendBuf -> Receive -> AuthCIDs -> IO () -> ConnRes
ConnRes Connection
conn SendBuf
send Receive
recv AuthCIDs
myAuthCIDs IO ()
reader

-- | Creating a new socket and execute a path validation
--   with a new connection ID. Typically, this is used
--   for migration in the case where 'ccAutoMigration' is 'False'.
--   But this can also be used even when the value is 'True'.
migrate :: Connection -> IO Bool
migrate :: Connection -> IO Bool
migrate Connection
conn = Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
ActiveMigration