{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Network.QUIC.Server.Run (
run
, stop
) where
import qualified Network.Socket as NS
import System.Log.FastLogger
import UnliftIO.Async
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E
import Network.QUIC.Closer
import Network.QUIC.Common
import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Crypto
import Network.QUIC.Exception
import Network.QUIC.Handshake
import Network.QUIC.Imports
import Network.QUIC.Logger
import Network.QUIC.Packet
import Network.QUIC.Parameters
import Network.QUIC.QLogger
import Network.QUIC.Qlog
import Network.QUIC.Receiver
import Network.QUIC.Recovery
import Network.QUIC.Sender
import Network.QUIC.Server.Reader
import Network.QUIC.Socket
import Network.QUIC.Types
run :: ServerConfig -> (Connection -> IO ()) -> IO ()
run :: ServerConfig -> (Connection -> IO ()) -> IO ()
run ServerConfig
conf Connection -> IO ()
server = DebugLogger -> IO () -> IO ()
handleLogUnit DebugLogger
debugLog (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
ThreadId
baseThreadId <- IO ThreadId
forall (m :: * -> *). MonadIO m => m ThreadId
myThreadId
IO (Dispatch, [ThreadId])
-> ((Dispatch, [ThreadId]) -> IO ())
-> ((Dispatch, [ThreadId]) -> IO ())
-> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO (Dispatch, [ThreadId])
setup (Dispatch, [ThreadId]) -> IO ()
forall (t :: * -> *). Foldable t => (Dispatch, t ThreadId) -> IO ()
teardown (((Dispatch, [ThreadId]) -> IO ()) -> IO ())
-> ((Dispatch, [ThreadId]) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Dispatch
dispatch,[ThreadId]
_) -> IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Accept
acc <- Dispatch -> IO Accept
accept Dispatch
dispatch
IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO (ServerConfig
-> (Connection -> IO ()) -> Dispatch -> ThreadId -> Accept -> IO ()
runServer ServerConfig
conf Connection -> IO ()
server Dispatch
dispatch ThreadId
baseThreadId Accept
acc)
where
doDebug :: Bool
doDebug = Maybe FilePath -> Bool
forall a. Maybe a -> Bool
isJust (Maybe FilePath -> Bool) -> Maybe FilePath -> Bool
forall a b. (a -> b) -> a -> b
$ ServerConfig -> Maybe FilePath
scDebugLog ServerConfig
conf
debugLog :: DebugLogger
debugLog Builder
msg | Bool
doDebug = DebugLogger
stdoutLogger (Builder
"run: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
msg)
| Bool
otherwise = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
setup :: IO (Dispatch, [ThreadId])
setup = do
Dispatch
dispatch <- IO Dispatch
newDispatch
[(Socket, SockAddr)]
ssas <- ((IP, PortNumber) -> IO (Socket, SockAddr))
-> [(IP, PortNumber)] -> IO [(Socket, SockAddr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (IP, PortNumber) -> IO (Socket, SockAddr)
udpServerListenSocket ([(IP, PortNumber)] -> IO [(Socket, SockAddr)])
-> [(IP, PortNumber)] -> IO [(Socket, SockAddr)]
forall a b. (a -> b) -> a -> b
$ ServerConfig -> [(IP, PortNumber)]
scAddresses ServerConfig
conf
[ThreadId]
tids <- ((Socket, SockAddr) -> IO ThreadId)
-> [(Socket, SockAddr)] -> IO [ThreadId]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Dispatch -> ServerConfig -> (Socket, SockAddr) -> IO ThreadId
runDispatcher Dispatch
dispatch ServerConfig
conf) [(Socket, SockAddr)]
ssas
ThreadId
ttid <- IO () -> IO ThreadId
forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
timeouter
(Dispatch, [ThreadId]) -> IO (Dispatch, [ThreadId])
forall (m :: * -> *) a. Monad m => a -> m a
return (Dispatch
dispatch, ThreadId
ttidThreadId -> [ThreadId] -> [ThreadId]
forall a. a -> [a] -> [a]
:[ThreadId]
tids)
teardown :: (Dispatch, t ThreadId) -> IO ()
teardown (Dispatch
dispatch, t ThreadId
tids) = do
Dispatch -> IO ()
clearDispatch Dispatch
dispatch
(ThreadId -> IO ()) -> t ThreadId -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ThreadId -> IO ()
forall (m :: * -> *). MonadIO m => ThreadId -> m ()
killThread t ThreadId
tids
runServer :: ServerConfig -> (Connection -> IO ()) -> Dispatch -> ThreadId -> Accept -> IO ()
runServer :: ServerConfig
-> (Connection -> IO ()) -> Dispatch -> ThreadId -> Accept -> IO ()
runServer ServerConfig
conf Connection -> IO ()
server0 Dispatch
dispatch ThreadId
baseThreadId Accept
acc =
IO ConnRes -> (ConnRes -> IO ()) -> (ConnRes -> IO ()) -> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO ConnRes
open ConnRes -> IO ()
clse ((ConnRes -> IO ()) -> IO ()) -> (ConnRes -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(ConnRes Connection
conn SendBuf
send Receive
recv AuthCIDs
myAuthCIDs IO ()
reader) ->
DebugLogger -> IO () -> IO ()
handleLogUnit (Connection -> DebugLogger
debugLog Connection
conn) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
IO () -> IO ThreadId
forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
reader IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addReader Connection
conn
IO ()
handshaker <- ServerConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeServer ServerConfig
conf Connection
conn AuthCIDs
myAuthCIDs
let server :: IO ()
server = do
Connection -> IO ()
wait1RTTReady Connection
conn
Connection -> IO ()
afterHandshakeServer Connection
conn
Connection -> IO ()
server0 Connection
conn
ldcc :: LDCC
ldcc = Connection -> LDCC
connLDCC Connection
conn
supporters :: IO ()
supporters = (IO () -> IO () -> IO ()) -> [IO ()] -> IO ()
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 IO () -> IO () -> IO ()
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m ()
concurrently_ [IO ()
handshaker
,Connection -> SendBuf -> IO ()
sender Connection
conn SendBuf
send
,Connection -> Receive -> IO ()
receiver Connection
conn Receive
recv
,LDCC -> IO ()
resender LDCC
ldcc
,LDCC -> IO ()
ldccTimer LDCC
ldcc
]
runThreads :: IO ()
runThreads = do
Either () ()
er <- IO () -> IO () -> IO (Either () ())
forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> m b -> m (Either a b)
race IO ()
supporters IO ()
server
case Either () ()
er of
Left () -> InternalControl -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO InternalControl
MustNotReached
Right ()
r -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
r
IO () -> IO (Either SomeException ())
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
E.trySyncOrAsync IO ()
runThreads IO (Either SomeException ())
-> (Either SomeException () -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> LDCC -> Either SomeException () -> IO ()
forall a. Connection -> LDCC -> Either SomeException a -> IO a
closure Connection
conn LDCC
ldcc
where
open :: IO ConnRes
open = ServerConfig -> Dispatch -> Accept -> ThreadId -> IO ConnRes
createServerConnection ServerConfig
conf Dispatch
dispatch Accept
acc ThreadId
baseThreadId
clse :: ConnRes -> IO ()
clse ConnRes
connRes = do
let conn :: Connection
conn = ConnRes -> Connection
connResConnection ConnRes
connRes
Connection -> IO ()
setDead Connection
conn
Connection -> IO ()
freeResources Connection
conn
Connection -> IO ()
killReaders Connection
conn
[Socket]
socks <- Connection -> IO [Socket]
getSockets Connection
conn
(Socket -> IO ()) -> [Socket] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Socket -> IO ()
NS.close [Socket]
socks
debugLog :: Connection -> DebugLogger
debugLog Connection
conn Builder
msg = do
Connection -> DebugLogger
connDebugLog Connection
conn (Builder
"runServer: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
msg)
Connection -> Debug -> IO ()
forall q. KeepQlog q => q -> Debug -> IO ()
qlogDebug Connection
conn (Debug -> IO ()) -> Debug -> IO ()
forall a b. (a -> b) -> a -> b
$ LogStr -> Debug
Debug (LogStr -> Debug) -> LogStr -> Debug
forall a b. (a -> b) -> a -> b
$ Builder -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr Builder
msg
createServerConnection :: ServerConfig -> Dispatch -> Accept -> ThreadId
-> IO ConnRes
createServerConnection :: ServerConfig -> Dispatch -> Accept -> ThreadId -> IO ConnRes
createServerConnection conf :: ServerConfig
conf@ServerConfig{Bool
[(IP, PortNumber)]
[Cipher]
[Group]
[Version]
Maybe FilePath
Maybe (Version -> [ByteString] -> IO ByteString)
Credentials
SessionManager
Parameters
Hooks
FilePath -> IO ()
scSessionManager :: ServerConfig -> SessionManager
scRequireRetry :: ServerConfig -> Bool
scALPN :: ServerConfig -> Maybe (Version -> [ByteString] -> IO ByteString)
scUse0RTT :: ServerConfig -> Bool
scHooks :: ServerConfig -> Hooks
scCredentials :: ServerConfig -> Credentials
scQLog :: ServerConfig -> Maybe FilePath
scKeyLog :: ServerConfig -> FilePath -> IO ()
scParameters :: ServerConfig -> Parameters
scGroups :: ServerConfig -> [Group]
scCiphers :: ServerConfig -> [Cipher]
scVersions :: ServerConfig -> [Version]
scDebugLog :: Maybe FilePath
scSessionManager :: SessionManager
scRequireRetry :: Bool
scALPN :: Maybe (Version -> [ByteString] -> IO ByteString)
scAddresses :: [(IP, PortNumber)]
scUse0RTT :: Bool
scHooks :: Hooks
scCredentials :: Credentials
scQLog :: Maybe FilePath
scKeyLog :: FilePath -> IO ()
scParameters :: Parameters
scGroups :: [Group]
scCiphers :: [Cipher]
scVersions :: [Version]
scAddresses :: ServerConfig -> [(IP, PortNumber)]
scDebugLog :: ServerConfig -> Maybe FilePath
..} Dispatch
dispatch Accept{Bool
Int
SockAddr
TimeMicrosecond
Version
RecvQ
AuthCIDs
CID -> IO ()
CID -> Connection -> IO ()
accTime :: Accept -> TimeMicrosecond
accAddressValidated :: Accept -> Bool
accUnregister :: Accept -> CID -> IO ()
accRegister :: Accept -> CID -> Connection -> IO ()
accPacketSize :: Accept -> Int
accRecvQ :: Accept -> RecvQ
accPeerSockAddr :: Accept -> SockAddr
accMySockAddr :: Accept -> SockAddr
accPeerAuthCIDs :: Accept -> AuthCIDs
accMyAuthCIDs :: Accept -> AuthCIDs
accVersion :: Accept -> Version
accTime :: TimeMicrosecond
accAddressValidated :: Bool
accUnregister :: CID -> IO ()
accRegister :: CID -> Connection -> IO ()
accPacketSize :: Int
accRecvQ :: RecvQ
accPeerSockAddr :: SockAddr
accMySockAddr :: SockAddr
accPeerAuthCIDs :: AuthCIDs
accMyAuthCIDs :: AuthCIDs
accVersion :: Version
..} ThreadId
baseThreadId = do
Socket
s0 <- SockAddr -> SockAddr -> IO Socket
udpServerConnectedSocket SockAddr
accMySockAddr SockAddr
accPeerSockAddr
IORef [Socket]
sref <- [Socket] -> IO (IORef [Socket])
forall a. a -> IO (IORef a)
newIORef [Socket
s0]
let send :: SendBuf
send Ptr Word8
buf Int
siz = IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Socket
s:[Socket]
_ <- IORef [Socket] -> IO [Socket]
forall a. IORef a -> IO a
readIORef IORef [Socket]
sref
Socket -> Ptr Word8 -> Int -> IO Int
NS.sendBuf Socket
s Ptr Word8
buf Int
siz
recv :: Receive
recv = RecvQ -> Receive
recvServer RecvQ
accRecvQ
let Just CID
myCID = AuthCIDs -> Maybe CID
initSrcCID AuthCIDs
accMyAuthCIDs
Just CID
ocid = AuthCIDs -> Maybe CID
origDstCID AuthCIDs
accMyAuthCIDs
(QLogger
qLog, IO ()
qclean) <- Maybe FilePath
-> TimeMicrosecond -> CID -> ByteString -> IO (QLogger, IO ())
dirQLogger Maybe FilePath
scQLog TimeMicrosecond
accTime CID
ocid ByteString
"server"
(DebugLogger
debugLog, IO ()
dclean) <- Maybe FilePath -> CID -> IO (DebugLogger, IO ())
dirDebugLogger Maybe FilePath
scDebugLog CID
ocid
DebugLogger
debugLog DebugLogger -> DebugLogger
forall a b. (a -> b) -> a -> b
$ Builder
"Original CID: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> CID -> Builder
forall a. Show a => a -> Builder
bhow CID
ocid
Connection
conn <- ServerConfig
-> Version
-> AuthCIDs
-> AuthCIDs
-> DebugLogger
-> QLogger
-> Hooks
-> IORef [Socket]
-> RecvQ
-> IO Connection
serverConnection ServerConfig
conf Version
accVersion AuthCIDs
accMyAuthCIDs AuthCIDs
accPeerAuthCIDs DebugLogger
debugLog QLogger
qLog Hooks
scHooks IORef [Socket]
sref RecvQ
accRecvQ
Connection -> IO () -> IO ()
addResource Connection
conn IO ()
qclean
Connection -> IO () -> IO ()
addResource Connection
conn IO ()
dclean
let cid :: CID
cid = CID -> Maybe CID -> CID
forall a. a -> Maybe a -> a
fromMaybe CID
ocid (Maybe CID -> CID) -> Maybe CID -> CID
forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
retrySrcCID AuthCIDs
accMyAuthCIDs
Connection
-> EncryptionLevel -> TrafficSecrets InitialSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel (TrafficSecrets InitialSecret -> IO ())
-> TrafficSecrets InitialSecret -> IO ()
forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
accVersion CID
cid
Connection -> IO ()
setupCryptoStreams Connection
conn
let pktSiz :: Int
pktSiz = (SockAddr -> Int
defaultPacketSize SockAddr
accMySockAddr Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
accPacketSize) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` SockAddr -> Int
maximumPacketSize SockAddr
accMySockAddr
Connection -> Int -> IO ()
setMaxPacketSize Connection
conn Int
pktSiz
LDCC -> Int -> IO ()
setInitialCongestionWindow (Connection -> LDCC
connLDCC Connection
conn) Int
pktSiz
DebugLogger
debugLog DebugLogger -> DebugLogger
forall a b. (a -> b) -> a -> b
$ Builder
"Packet size: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Int -> Builder
forall a. Show a => a -> Builder
bhow Int
pktSiz Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
" (" Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Int -> Builder
forall a. Show a => a -> Builder
bhow Int
accPacketSize Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
")"
Connection -> Int -> IO ()
addRxBytes Connection
conn Int
accPacketSize
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
accAddressValidated (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
setAddressValidated Connection
conn
let retried :: Bool
retried = Maybe CID -> Bool
forall a. Maybe a -> Bool
isJust (Maybe CID -> Bool) -> Maybe CID -> Bool
forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
retrySrcCID AuthCIDs
accMyAuthCIDs
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
retried (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Connection -> IO ()
forall q. KeepQlog q => q -> IO ()
qlogRecvInitial Connection
conn
Connection -> IO ()
forall q. KeepQlog q => q -> IO ()
qlogSentRetry Connection
conn
let mgr :: TokenManager
mgr = Dispatch -> TokenManager
tokenMgr Dispatch
dispatch
Connection -> TokenManager -> IO ()
setTokenManager Connection
conn TokenManager
mgr
Connection -> ThreadId -> IO ()
setBaseThreadId Connection
conn ThreadId
baseThreadId
Connection
-> (CID -> Connection -> IO ()) -> (CID -> IO ()) -> IO ()
setRegister Connection
conn CID -> Connection -> IO ()
accRegister CID -> IO ()
accUnregister
CID -> Connection -> IO ()
accRegister CID
myCID Connection
conn
Connection -> IO () -> IO ()
addResource Connection
conn (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
[CID]
myCIDs <- Connection -> IO [CID]
getMyCIDs Connection
conn
(CID -> IO ()) -> [CID] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ CID -> IO ()
accUnregister [CID]
myCIDs
let reader :: IO ()
reader = Socket -> Connection -> IO ()
readerServer Socket
s0 Connection
conn
ConnRes -> IO ConnRes
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnRes -> IO ConnRes) -> ConnRes -> IO ConnRes
forall a b. (a -> b) -> a -> b
$ Connection -> SendBuf -> Receive -> AuthCIDs -> IO () -> ConnRes
ConnRes Connection
conn SendBuf
send Receive
recv AuthCIDs
accMyAuthCIDs IO ()
reader
afterHandshakeServer :: Connection -> IO ()
afterHandshakeServer :: Connection -> IO ()
afterHandshakeServer Connection
conn = DebugLogger -> IO () -> IO ()
forall a. DebugLogger -> IO a -> IO a
handleLogT DebugLogger
logAction (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
CIDInfo
cidInfo <- Connection -> IO CIDInfo
getNewMyCID Connection
conn
CID -> Connection -> IO ()
register <- Connection -> IO (CID -> Connection -> IO ())
getRegister Connection
conn
CID -> Connection -> IO ()
register (CIDInfo -> CID
cidInfoCID CIDInfo
cidInfo) Connection
conn
CryptoToken
cryptoToken <- Version -> IO CryptoToken
generateToken (Version -> IO CryptoToken) -> IO Version -> IO CryptoToken
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Connection -> IO Version
getVersion Connection
conn
TokenManager
mgr <- Connection -> IO TokenManager
getTokenManager Connection
conn
ByteString
token <- TokenManager -> CryptoToken -> IO ByteString
encryptToken TokenManager
mgr CryptoToken
cryptoToken
let ncid :: Frame
ncid = CIDInfo -> Int -> Frame
NewConnectionID CIDInfo
cidInfo Int
0
Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [ByteString -> Frame
NewToken ByteString
token,Frame
ncid,Frame
HandshakeDone]
where
logAction :: DebugLogger
logAction Builder
msg = Connection -> DebugLogger
connDebugLog Connection
conn DebugLogger -> DebugLogger
forall a b. (a -> b) -> a -> b
$ Builder
"afterHandshakeServer: " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
msg
stop :: Connection -> IO ()
stop :: Connection -> IO ()
stop Connection
conn = Connection -> IO ThreadId
getBaseThreadId Connection
conn IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ThreadId -> IO ()
forall (m :: * -> *). MonadIO m => ThreadId -> m ()
killThread