{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Handshake where

import qualified Network.TLS as TLS
import Network.TLS.QUIC
import qualified UnliftIO.Exception as E

import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Imports
import Network.QUIC.Info
import Network.QUIC.Logger
import Network.QUIC.Parameters
import Network.QUIC.Qlog
import Network.QUIC.Recovery
import Network.QUIC.TLS
import Network.QUIC.Types

----------------------------------------------------------------

newtype HndState = HndState
    { HndState -> Int
hsRecvCnt :: Int  -- number of 'recv' calls since last 'send'
    }

newHndStateRef :: IO (IORef HndState)
newHndStateRef :: IO (IORef HndState)
newHndStateRef = HndState -> IO (IORef HndState)
forall a. a -> IO (IORef a)
newIORef HndState :: Int -> HndState
HndState { hsRecvCnt :: Int
hsRecvCnt = Int
0 }

sendCompleted :: IORef HndState -> IO ()
sendCompleted :: IORef HndState -> IO ()
sendCompleted IORef HndState
hsr = IORef HndState -> (HndState -> HndState) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
atomicModifyIORef'' IORef HndState
hsr ((HndState -> HndState) -> IO ())
-> (HndState -> HndState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \HndState
hs -> HndState
hs { hsRecvCnt :: Int
hsRecvCnt = Int
0 }

recvCompleted :: IORef HndState -> IO Int
recvCompleted :: IORef HndState -> IO Int
recvCompleted IORef HndState
hsr = IORef HndState -> (HndState -> (HndState, Int)) -> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef HndState
hsr ((HndState -> (HndState, Int)) -> IO Int)
-> (HndState -> (HndState, Int)) -> IO Int
forall a b. (a -> b) -> a -> b
$ \HndState
hs ->
    let cnt :: Int
cnt = HndState -> Int
hsRecvCnt HndState
hs in (HndState
hs { hsRecvCnt :: Int
hsRecvCnt = Int
cnt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 }, Int
cnt)

rxLevelChanged :: IORef HndState -> IO ()
rxLevelChanged :: IORef HndState -> IO ()
rxLevelChanged = IORef HndState -> IO ()
sendCompleted

----------------------------------------------------------------

sendCryptoData :: Connection -> Output -> IO ()
sendCryptoData :: Connection -> Output -> IO ()
sendCryptoData = Connection -> Output -> IO ()
putOutput

recvCryptoData :: Connection -> IO Crypto
recvCryptoData :: Connection -> IO Crypto
recvCryptoData = Connection -> IO Crypto
takeCrypto

recvTLS :: Connection -> IORef HndState -> CryptLevel -> IO (Either TLS.TLSError ByteString)
recvTLS :: Connection
-> IORef HndState -> CryptLevel -> IO (Either TLSError ByteString)
recvTLS Connection
conn IORef HndState
hsr CryptLevel
level =
    case CryptLevel
level of
            CryptLevel
CryptInitial           -> EncryptionLevel -> IO (Either TLSError ByteString)
go EncryptionLevel
InitialLevel
            CryptLevel
CryptMasterSecret      -> String -> IO (Either TLSError ByteString)
forall b. String -> IO (Either TLSError b)
failure String
"QUIC does not receive data < TLS 1.3"
            CryptLevel
CryptEarlySecret       -> String -> IO (Either TLSError ByteString)
forall b. String -> IO (Either TLSError b)
failure String
"QUIC does not send early data with TLS library"
            CryptLevel
CryptHandshakeSecret   -> EncryptionLevel -> IO (Either TLSError ByteString)
go EncryptionLevel
HandshakeLevel
            CryptLevel
CryptApplicationSecret -> EncryptionLevel -> IO (Either TLSError ByteString)
go EncryptionLevel
RTT1Level
  where
    failure :: String -> IO (Either TLSError b)
failure = Either TLSError b -> IO (Either TLSError b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError b -> IO (Either TLSError b))
-> (String -> Either TLSError b)
-> String
-> IO (Either TLSError b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> Either TLSError b
forall a b. a -> Either a b
Left (TLSError -> Either TLSError b)
-> (String -> TLSError) -> String -> Either TLSError b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> TLSError
internalError

    go :: EncryptionLevel -> IO (Either TLSError ByteString)
go EncryptionLevel
expected = do
        InpHandshake EncryptionLevel
actual ByteString
bs <- Connection -> IO Crypto
recvCryptoData Connection
conn
        if ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"" then
            Either TLSError ByteString -> IO (Either TLSError ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError ByteString -> IO (Either TLSError ByteString))
-> Either TLSError ByteString -> IO (Either TLSError ByteString)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError ByteString
forall a b. a -> Either a b
Left TLSError
TLS.Error_EOF
          else if EncryptionLevel
actual EncryptionLevel -> EncryptionLevel -> Bool
forall a. Eq a => a -> a -> Bool
/= EncryptionLevel
expected then
            String -> IO (Either TLSError ByteString)
forall b. String -> IO (Either TLSError b)
failure (String -> IO (Either TLSError ByteString))
-> String -> IO (Either TLSError ByteString)
forall a b. (a -> b) -> a -> b
$ String
"encryption level mismatch: expected " String -> String -> String
forall a. [a] -> [a] -> [a]
++ EncryptionLevel -> String
forall a. Show a => a -> String
show EncryptionLevel
expected String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" but got " String -> String -> String
forall a. [a] -> [a] -> [a]
++ EncryptionLevel -> String
forall a. Show a => a -> String
show EncryptionLevel
actual
          else do
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Connection -> Bool
forall a. Connector a => a -> Bool
isClient Connection
conn) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                Int
n <- IORef HndState -> IO Int
recvCompleted IORef HndState
hsr
                -- Sending ACKs for three times rule
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
3) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                    Connection -> Output -> IO ()
sendCryptoData Connection
conn (Output -> IO ()) -> Output -> IO ()
forall a b. (a -> b) -> a -> b
$ EncryptionLevel -> [Frame] -> IO () -> Output
OutControl EncryptionLevel
HandshakeLevel [] (IO () -> Output) -> IO () -> Output
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Either TLSError ByteString -> IO (Either TLSError ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError ByteString -> IO (Either TLSError ByteString))
-> Either TLSError ByteString -> IO (Either TLSError ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either TLSError ByteString
forall a b. b -> Either a b
Right ByteString
bs

sendTLS :: Connection -> IORef HndState -> [(CryptLevel, ByteString)] -> IO ()
sendTLS :: Connection -> IORef HndState -> [(CryptLevel, ByteString)] -> IO ()
sendTLS Connection
conn IORef HndState
hsr [(CryptLevel, ByteString)]
x = do
    ((CryptLevel, ByteString) -> IO (EncryptionLevel, ByteString))
-> [(CryptLevel, ByteString)] -> IO [(EncryptionLevel, ByteString)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (CryptLevel, ByteString) -> IO (EncryptionLevel, ByteString)
forall b. (CryptLevel, b) -> IO (EncryptionLevel, b)
convertLevel [(CryptLevel, ByteString)]
x IO [(EncryptionLevel, ByteString)]
-> ([(EncryptionLevel, ByteString)] -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> Output -> IO ()
sendCryptoData Connection
conn (Output -> IO ())
-> ([(EncryptionLevel, ByteString)] -> Output)
-> [(EncryptionLevel, ByteString)]
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(EncryptionLevel, ByteString)] -> Output
OutHandshake
    IORef HndState -> IO ()
sendCompleted IORef HndState
hsr
  where
    convertLevel :: (CryptLevel, b) -> IO (EncryptionLevel, b)
convertLevel (CryptLevel
CryptInitial, b
bs) = (EncryptionLevel, b) -> IO (EncryptionLevel, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (EncryptionLevel
InitialLevel, b
bs)
    convertLevel (CryptLevel
CryptMasterSecret, b
_) = String -> IO (EncryptionLevel, b)
forall a. String -> IO a
errorTLS String
"QUIC does not send data < TLS 1.3"
    convertLevel (CryptLevel
CryptEarlySecret, b
_) = String -> IO (EncryptionLevel, b)
forall a. String -> IO a
errorTLS String
"QUIC does not receive early data with TLS library"
    convertLevel (CryptLevel
CryptHandshakeSecret, b
bs) = (EncryptionLevel, b) -> IO (EncryptionLevel, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (EncryptionLevel
HandshakeLevel, b
bs)
    convertLevel (CryptLevel
CryptApplicationSecret, b
bs) = (EncryptionLevel, b) -> IO (EncryptionLevel, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (EncryptionLevel
RTT1Level, b
bs)

internalError :: String -> TLS.TLSError
internalError :: String -> TLSError
internalError String
msg     = (String, Bool, AlertDescription) -> TLSError
TLS.Error_Protocol (String
msg, Bool
True, AlertDescription
TLS.InternalError)
-- unexpectedMessage msg = TLS.Error_Protocol (msg, True, TLS.UnexpectedMessage)

----------------------------------------------------------------

handshakeClient :: ClientConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeClient :: ClientConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeClient ClientConfig
conf Connection
conn AuthCIDs
myAuthCIDs = do
    Connection -> (Parameters, String) -> IO ()
forall q. KeepQlog q => q -> (Parameters, String) -> IO ()
qlogParamsSet Connection
conn (ClientConfig -> Parameters
ccParameters ClientConfig
conf, String
"local") -- fixme
    ClientConfig
-> Connection -> AuthCIDs -> Version -> IORef HndState -> IO ()
handshakeClient' ClientConfig
conf Connection
conn AuthCIDs
myAuthCIDs (Version -> IORef HndState -> IO ())
-> IO Version -> IO (IORef HndState -> IO ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Version
getVersion Connection
conn IO (IORef HndState -> IO ()) -> IO (IORef HndState) -> IO (IO ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (IORef HndState)
newHndStateRef

handshakeClient' :: ClientConfig -> Connection -> AuthCIDs -> Version -> IORef HndState -> IO ()
handshakeClient' :: ClientConfig
-> Connection -> AuthCIDs -> Version -> IORef HndState -> IO ()
handshakeClient' ClientConfig
conf Connection
conn AuthCIDs
myAuthCIDs Version
ver IORef HndState
hsr = IO ()
handshaker
  where
    handshaker :: IO ()
handshaker = QUICCallbacks
-> ClientConfig
-> Version
-> AuthCIDs
-> SessionEstablish
-> Bool
-> IO ()
clientHandshaker QUICCallbacks
qc ClientConfig
conf Version
ver AuthCIDs
myAuthCIDs SessionEstablish
setter Bool
use0RTT IO () -> (TLSException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`E.catch` TLSException -> IO ()
sendCCTLSError
    qc :: QUICCallbacks
qc = QUICCallbacks :: ([(CryptLevel, ByteString)] -> IO ())
-> (CryptLevel -> IO (Either TLSError ByteString))
-> (Context -> KeyScheduleEvent -> IO ())
-> (Context -> [ExtensionRaw] -> IO ())
-> (Context -> IO ())
-> QUICCallbacks
QUICCallbacks { quicSend :: [(CryptLevel, ByteString)] -> IO ()
quicSend = Connection -> IORef HndState -> [(CryptLevel, ByteString)] -> IO ()
sendTLS Connection
conn IORef HndState
hsr
                       , quicRecv :: CryptLevel -> IO (Either TLSError ByteString)
quicRecv = Connection
-> IORef HndState -> CryptLevel -> IO (Either TLSError ByteString)
recvTLS Connection
conn IORef HndState
hsr
                       , quicInstallKeys :: Context -> KeyScheduleEvent -> IO ()
quicInstallKeys = Context -> KeyScheduleEvent -> IO ()
installKeysClient
                       , quicNotifyExtensions :: Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions = Connection -> Context -> [ExtensionRaw] -> IO ()
setPeerParams Connection
conn
                       , quicDone :: Context -> IO ()
quicDone = Context -> IO ()
forall p. p -> IO ()
done
                       }
    setter :: SessionEstablish
setter = Connection -> SessionEstablish
setResumptionSession Connection
conn
    installKeysClient :: Context -> KeyScheduleEvent -> IO ()
installKeysClient Context
_ctx (InstallEarlyKeys Maybe EarlySecretInfo
Nothing) = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    installKeysClient Context
_ctx (InstallEarlyKeys (Just (EarlySecretInfo Cipher
cphr ClientTrafficSecret EarlySecret
cts))) = do
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
RTT0Level Cipher
cphr
        Connection
-> EncryptionLevel -> TrafficSecrets EarlySecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
RTT0Level (ClientTrafficSecret EarlySecret
cts, ByteString -> ServerTrafficSecret EarlySecret
forall a. ByteString -> ServerTrafficSecret a
ServerTrafficSecret ByteString
"")
        Connection -> IO ()
setConnection0RTTReady Connection
conn
    installKeysClient Context
_ctx (InstallHandshakeKeys (HandshakeSecretInfo Cipher
cphr TrafficSecrets HandshakeSecret
tss)) = do
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
HandshakeLevel Cipher
cphr
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
RTT1Level Cipher
cphr
        Connection
-> EncryptionLevel -> TrafficSecrets HandshakeSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
HandshakeLevel TrafficSecrets HandshakeSecret
tss
        Connection -> EncryptionLevel -> IO ()
setEncryptionLevel Connection
conn EncryptionLevel
HandshakeLevel
        IORef HndState -> IO ()
rxLevelChanged IORef HndState
hsr
    installKeysClient Context
ctx (InstallApplicationKeys appSecInf :: ApplicationSecretInfo
appSecInf@(ApplicationSecretInfo TrafficSecrets ApplicationSecret
tss)) = do
        Connection -> Context -> ApplicationSecretInfo -> IO ()
storeNegotiated Connection
conn Context
ctx ApplicationSecretInfo
appSecInf
        Connection -> TrafficSecrets ApplicationSecret -> IO ()
initializeCoder1RTT Connection
conn TrafficSecrets ApplicationSecret
tss
        Connection -> EncryptionLevel -> IO ()
setEncryptionLevel Connection
conn EncryptionLevel
RTT1Level
        IORef HndState -> IO ()
rxLevelChanged IORef HndState
hsr
        Connection -> IO ()
setConnection1RTTReady Connection
conn
        CIDInfo
cidInfo <- Connection -> IO CIDInfo
getNewMyCID Connection
conn
        Connection -> Output -> IO ()
putOutput Connection
conn (Output -> IO ()) -> Output -> IO ()
forall a b. (a -> b) -> a -> b
$ [(EncryptionLevel, ByteString)] -> Output
OutHandshake [] -- for h3spec testing
        Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [CIDInfo -> Int -> Frame
NewConnectionID CIDInfo
cidInfo Int
0]
    done :: p -> IO ()
done p
_ctx = do
        ConnectionInfo
info <- Connection -> IO ConnectionInfo
getConnectionInfo Connection
conn
        Connection -> DebugLogger
connDebugLog Connection
conn DebugLogger -> DebugLogger
forall a b. (a -> b) -> a -> b
$ ConnectionInfo -> Builder
forall a. Show a => a -> Builder
bhow ConnectionInfo
info
    use0RTT :: Bool
use0RTT = ClientConfig -> Bool
ccUse0RTT ClientConfig
conf

----------------------------------------------------------------

handshakeServer :: ServerConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeServer :: ServerConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeServer ServerConfig
conf Connection
conn AuthCIDs
myAuthCIDs =
    ServerConfig
-> Connection -> AuthCIDs -> Version -> IORef HndState -> IO ()
handshakeServer' ServerConfig
conf Connection
conn AuthCIDs
myAuthCIDs (Version -> IORef HndState -> IO ())
-> IO Version -> IO (IORef HndState -> IO ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Version
getVersion Connection
conn IO (IORef HndState -> IO ()) -> IO (IORef HndState) -> IO (IO ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (IORef HndState)
newHndStateRef

handshakeServer' :: ServerConfig -> Connection -> AuthCIDs -> Version -> IORef HndState -> IO ()
handshakeServer' :: ServerConfig
-> Connection -> AuthCIDs -> Version -> IORef HndState -> IO ()
handshakeServer' ServerConfig
conf Connection
conn AuthCIDs
myAuthCIDs Version
ver IORef HndState
hsr = IO ()
handshaker
  where
    handshaker :: IO ()
handshaker = QUICCallbacks -> ServerConfig -> Version -> AuthCIDs -> IO ()
serverHandshaker QUICCallbacks
qc ServerConfig
conf Version
ver AuthCIDs
myAuthCIDs IO () -> (TLSException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`E.catch` TLSException -> IO ()
sendCCTLSError
    qc :: QUICCallbacks
qc = QUICCallbacks :: ([(CryptLevel, ByteString)] -> IO ())
-> (CryptLevel -> IO (Either TLSError ByteString))
-> (Context -> KeyScheduleEvent -> IO ())
-> (Context -> [ExtensionRaw] -> IO ())
-> (Context -> IO ())
-> QUICCallbacks
QUICCallbacks { quicSend :: [(CryptLevel, ByteString)] -> IO ()
quicSend = Connection -> IORef HndState -> [(CryptLevel, ByteString)] -> IO ()
sendTLS Connection
conn IORef HndState
hsr
                       , quicRecv :: CryptLevel -> IO (Either TLSError ByteString)
quicRecv = Connection
-> IORef HndState -> CryptLevel -> IO (Either TLSError ByteString)
recvTLS Connection
conn IORef HndState
hsr
                       , quicInstallKeys :: Context -> KeyScheduleEvent -> IO ()
quicInstallKeys = Context -> KeyScheduleEvent -> IO ()
installKeysServer
                       , quicNotifyExtensions :: Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions = Connection -> Context -> [ExtensionRaw] -> IO ()
setPeerParams Connection
conn
                       , quicDone :: Context -> IO ()
quicDone = Context -> IO ()
done
                       }
    installKeysServer :: Context -> KeyScheduleEvent -> IO ()
installKeysServer Context
_ctx (InstallEarlyKeys Maybe EarlySecretInfo
Nothing) = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    installKeysServer Context
_ctx (InstallEarlyKeys (Just (EarlySecretInfo Cipher
cphr ClientTrafficSecret EarlySecret
cts))) = do
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
RTT0Level Cipher
cphr
        Connection
-> EncryptionLevel -> TrafficSecrets EarlySecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
RTT0Level (ClientTrafficSecret EarlySecret
cts, ByteString -> ServerTrafficSecret EarlySecret
forall a. ByteString -> ServerTrafficSecret a
ServerTrafficSecret ByteString
"")
        Connection -> IO ()
setConnection0RTTReady Connection
conn
    installKeysServer Context
_ctx (InstallHandshakeKeys (HandshakeSecretInfo Cipher
cphr TrafficSecrets HandshakeSecret
tss)) = do
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
HandshakeLevel Cipher
cphr
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
RTT1Level Cipher
cphr
        Connection
-> EncryptionLevel -> TrafficSecrets HandshakeSecret -> IO ()
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
HandshakeLevel TrafficSecrets HandshakeSecret
tss
        Connection -> EncryptionLevel -> IO ()
setEncryptionLevel Connection
conn EncryptionLevel
HandshakeLevel
        IORef HndState -> IO ()
rxLevelChanged IORef HndState
hsr
    installKeysServer Context
ctx (InstallApplicationKeys appSecInf :: ApplicationSecretInfo
appSecInf@(ApplicationSecretInfo TrafficSecrets ApplicationSecret
tss)) = do
        Connection -> Context -> ApplicationSecretInfo -> IO ()
storeNegotiated Connection
conn Context
ctx ApplicationSecretInfo
appSecInf
        Connection -> TrafficSecrets ApplicationSecret -> IO ()
initializeCoder1RTT Connection
conn TrafficSecrets ApplicationSecret
tss
        -- will switch to RTT1Level after client Finished
        -- is received and verified
    done :: Context -> IO ()
done Context
ctx = do
        Connection -> EncryptionLevel -> IO ()
setEncryptionLevel Connection
conn EncryptionLevel
RTT1Level
        Context -> IO (Maybe CertificateChain)
TLS.getClientCertificateChain Context
ctx IO (Maybe CertificateChain)
-> (Maybe CertificateChain -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> Maybe CertificateChain -> IO ()
setCertificateChain Connection
conn
        Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn (Int -> Microseconds
Microseconds Int
100000) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            let ldcc :: LDCC
ldcc = Connection -> LDCC
connLDCC Connection
conn
            Bool
discarded0 <- LDCC -> EncryptionLevel -> IO Bool
getAndSetPacketNumberSpaceDiscarded LDCC
ldcc EncryptionLevel
RTT0Level
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
discarded0 (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> EncryptionLevel -> IO ()
dropSecrets Connection
conn EncryptionLevel
RTT0Level
            Bool
discarded1 <- LDCC -> EncryptionLevel -> IO Bool
getAndSetPacketNumberSpaceDiscarded LDCC
ldcc EncryptionLevel
HandshakeLevel
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
discarded1 (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                Connection -> EncryptionLevel -> IO ()
dropSecrets Connection
conn EncryptionLevel
HandshakeLevel
                LDCC -> EncryptionLevel -> IO ()
onPacketNumberSpaceDiscarded (Connection -> LDCC
connLDCC Connection
conn) EncryptionLevel
HandshakeLevel
            Connection -> EncryptionLevel -> IO ()
clearCryptoStream Connection
conn EncryptionLevel
HandshakeLevel
            Connection -> EncryptionLevel -> IO ()
clearCryptoStream Connection
conn EncryptionLevel
RTT1Level
        Connection -> IO ()
setConnection1RTTReady Connection
conn
        Connection -> IO ()
setConnectionEstablished Connection
conn
--        sendFrames conn RTT1Level [HandshakeDone]
        --
        ConnectionInfo
info <- Connection -> IO ConnectionInfo
getConnectionInfo Connection
conn
        Connection -> DebugLogger
connDebugLog Connection
conn DebugLogger -> DebugLogger
forall a b. (a -> b) -> a -> b
$ ConnectionInfo -> Builder
forall a. Show a => a -> Builder
bhow ConnectionInfo
info

----------------------------------------------------------------

setPeerParams :: Connection -> TLS.Context -> [ExtensionRaw] -> IO ()
setPeerParams :: Connection -> Context -> [ExtensionRaw] -> IO ()
setPeerParams Connection
conn Context
_ctx [ExtensionRaw]
ps0 = do
    Version
ver <- Connection -> IO Version
getVersion Connection
conn
    let mps :: Maybe ByteString
mps | Version
ver Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== Version
Version1 = ExtensionID -> [ExtensionRaw] -> Maybe ByteString
getTP ExtensionID
extensionID_QuicTransportParameters [ExtensionRaw]
ps0
            | Bool
otherwise       = ExtensionID -> [ExtensionRaw] -> Maybe ByteString
getTP ExtensionID
0xffa5 [ExtensionRaw]
ps0
    Maybe ByteString -> IO ()
setPP Maybe ByteString
mps
  where
    getTP :: ExtensionID -> [ExtensionRaw] -> Maybe ByteString
getTP ExtensionID
_ [] = Maybe ByteString
forall a. Maybe a
Nothing
    getTP ExtensionID
n (ExtensionRaw ExtensionID
extid ByteString
bs : [ExtensionRaw]
ps)
      | ExtensionID
extid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
n = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs
      | Bool
otherwise  = ExtensionID -> [ExtensionRaw] -> Maybe ByteString
getTP ExtensionID
n [ExtensionRaw]
ps
    setPP :: Maybe ByteString -> IO ()
setPP Maybe ByteString
Nothing = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    setPP (Just ByteString
bs) = do
        let mparams :: Maybe Parameters
mparams = ByteString -> Maybe Parameters
decodeParameters ByteString
bs
        case Maybe Parameters
mparams of
          Maybe Parameters
Nothing     -> IO ()
sendCCParamError
          Just Parameters
params -> do
              Parameters -> IO ()
checkAuthCIDs Parameters
params
              Parameters -> IO ()
checkInvalid Parameters
params
              Parameters -> IO ()
setParams Parameters
params
              Connection -> (Parameters, String) -> IO ()
forall q. KeepQlog q => q -> (Parameters, String) -> IO ()
qlogParamsSet Connection
conn (Parameters
params,String
"remote")

    checkAuthCIDs :: Parameters -> IO ()
checkAuthCIDs Parameters
params = do
        AuthCIDs
peerAuthCIDs <- Connection -> IO AuthCIDs
getPeerAuthCIDs Connection
conn
        Maybe CID -> Maybe CID -> IO ()
forall a. Eq a => Maybe a -> Maybe a -> IO ()
ensure (Parameters -> Maybe CID
initialSourceConnectionId Parameters
params) (Maybe CID -> IO ()) -> Maybe CID -> IO ()
forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
initSrcCID AuthCIDs
peerAuthCIDs
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Connection -> Bool
forall a. Connector a => a -> Bool
isClient Connection
conn) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Maybe CID -> Maybe CID -> IO ()
forall a. Eq a => Maybe a -> Maybe a -> IO ()
ensure (Parameters -> Maybe CID
originalDestinationConnectionId Parameters
params) (Maybe CID -> IO ()) -> Maybe CID -> IO ()
forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
origDstCID AuthCIDs
peerAuthCIDs
            Maybe CID -> Maybe CID -> IO ()
forall a. Eq a => Maybe a -> Maybe a -> IO ()
ensure (Parameters -> Maybe CID
retrySourceConnectionId Parameters
params) (Maybe CID -> IO ()) -> Maybe CID -> IO ()
forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
retrySrcCID AuthCIDs
peerAuthCIDs
    ensure :: Maybe a -> Maybe a -> IO ()
ensure Maybe a
_ Maybe a
Nothing = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    ensure Maybe a
v0 Maybe a
v1
      | Maybe a
v0 Maybe a -> Maybe a -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe a
v1  = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise = IO ()
sendCCParamError
    checkInvalid :: Parameters -> IO ()
checkInvalid Parameters
params = do
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Parameters -> Int
maxUdpPayloadSize Parameters
params Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1200) IO ()
sendCCParamError
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Parameters -> Int
ackDelayExponent Parameters
params Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
20) IO ()
sendCCParamError
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Parameters -> Milliseconds
maxAckDelay Parameters
params Milliseconds -> Milliseconds -> Bool
forall a. Ord a => a -> a -> Bool
>= Milliseconds
2Milliseconds -> Int -> Milliseconds
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
14 :: Int)) IO ()
sendCCParamError
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Connection -> Bool
forall a. Connector a => a -> Bool
isServer Connection
conn) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe CID -> Bool
forall a. Maybe a -> Bool
isJust (Maybe CID -> Bool) -> Maybe CID -> Bool
forall a b. (a -> b) -> a -> b
$ Parameters -> Maybe CID
originalDestinationConnectionId Parameters
params) IO ()
sendCCParamError
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe ByteString -> Bool
forall a. Maybe a -> Bool
isJust (Maybe ByteString -> Bool) -> Maybe ByteString -> Bool
forall a b. (a -> b) -> a -> b
$ Parameters -> Maybe ByteString
preferredAddress Parameters
params) IO ()
sendCCParamError
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe CID -> Bool
forall a. Maybe a -> Bool
isJust (Maybe CID -> Bool) -> Maybe CID -> Bool
forall a b. (a -> b) -> a -> b
$ Parameters -> Maybe CID
retrySourceConnectionId Parameters
params) IO ()
sendCCParamError
            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe StatelessResetToken -> Bool
forall a. Maybe a -> Bool
isJust (Maybe StatelessResetToken -> Bool)
-> Maybe StatelessResetToken -> Bool
forall a b. (a -> b) -> a -> b
$ Parameters -> Maybe StatelessResetToken
statelessResetToken Parameters
params) IO ()
sendCCParamError
    setParams :: Parameters -> IO ()
setParams Parameters
params = do
        Connection -> Parameters -> IO ()
setPeerParameters Connection
conn Parameters
params
        (StatelessResetToken -> IO ())
-> Maybe StatelessResetToken -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Connection -> StatelessResetToken -> IO ()
setPeerStatelessResetToken Connection
conn) (Maybe StatelessResetToken -> IO ())
-> Maybe StatelessResetToken -> IO ()
forall a b. (a -> b) -> a -> b
$ Parameters -> Maybe StatelessResetToken
statelessResetToken Parameters
params
        Connection -> Int -> IO ()
setTxMaxData Connection
conn (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxData Parameters
params
        Connection -> Microseconds -> IO ()
setMinIdleTimeout Connection
conn (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Milliseconds -> Microseconds
milliToMicro (Milliseconds -> Microseconds) -> Milliseconds -> Microseconds
forall a b. (a -> b) -> a -> b
$ Parameters -> Milliseconds
maxIdleTimeout Parameters
params
        LDCC -> Microseconds -> IO ()
setMaxAckDaley (Connection -> LDCC
connLDCC Connection
conn) (Microseconds -> IO ()) -> Microseconds -> IO ()
forall a b. (a -> b) -> a -> b
$ Milliseconds -> Microseconds
milliToMicro (Milliseconds -> Microseconds) -> Milliseconds -> Microseconds
forall a b. (a -> b) -> a -> b
$ Parameters -> Milliseconds
maxAckDelay Parameters
params
        Connection -> Int -> IO ()
setMyMaxStreams Connection
conn (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxStreamsBidi Parameters
params
        Connection -> Int -> IO ()
setMyUniMaxStreams Connection
conn (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxStreamsUni Parameters
params

storeNegotiated :: Connection -> TLS.Context -> ApplicationSecretInfo -> IO ()
storeNegotiated :: Connection -> Context -> ApplicationSecretInfo -> IO ()
storeNegotiated Connection
conn Context
ctx ApplicationSecretInfo
appSecInf = do
    Maybe ByteString
appPro <- Context -> IO (Maybe ByteString)
forall (m :: * -> *). MonadIO m => Context -> m (Maybe ByteString)
TLS.getNegotiatedProtocol Context
ctx
    Maybe Information
minfo <- Context -> IO (Maybe Information)
TLS.contextGetInformation Context
ctx
    let mode :: HandshakeMode13
mode = HandshakeMode13 -> Maybe HandshakeMode13 -> HandshakeMode13
forall a. a -> Maybe a -> a
fromMaybe HandshakeMode13
FullHandshake (Maybe Information
minfo Maybe Information
-> (Information -> Maybe HandshakeMode13) -> Maybe HandshakeMode13
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Information -> Maybe HandshakeMode13
TLS.infoTLS13HandshakeMode)
    Connection
-> HandshakeMode13
-> Maybe ByteString
-> ApplicationSecretInfo
-> IO ()
setNegotiated Connection
conn HandshakeMode13
mode Maybe ByteString
appPro ApplicationSecretInfo
appSecInf

----------------------------------------------------------------

sendCCParamError :: IO ()
sendCCParamError :: IO ()
sendCCParamError = InternalControl -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO InternalControl
WrongTransportParameter

sendCCTLSError :: TLS.TLSException -> IO ()
sendCCTLSError :: TLSException -> IO ()
sendCCTLSError (TLS.HandshakeFailed (TLS.Error_Misc String
"WrongTransportParameter")) = TransportError -> ReasonPhrase -> IO ()
closeConnection TransportError
TransportParameterError ReasonPhrase
"Transport parametter error"
sendCCTLSError TLSException
e = TransportError -> ReasonPhrase -> IO ()
closeConnection TransportError
err ReasonPhrase
msg
  where
    tlserr :: TLSError
tlserr = TLSException -> TLSError
getErrorCause TLSException
e
    err :: TransportError
err = AlertDescription -> TransportError
cryptoError (AlertDescription -> TransportError)
-> AlertDescription -> TransportError
forall a b. (a -> b) -> a -> b
$ TLSError -> AlertDescription
errorToAlertDescription TLSError
tlserr
    msg :: ReasonPhrase
msg = String -> ReasonPhrase
shortpack (String -> ReasonPhrase) -> String -> ReasonPhrase
forall a b. (a -> b) -> a -> b
$ TLSError -> String
errorToAlertMessage TLSError
tlserr

getErrorCause :: TLS.TLSException -> TLS.TLSError
getErrorCause :: TLSException -> TLSError
getErrorCause (TLS.HandshakeFailed TLSError
e) = TLSError
e
getErrorCause (TLS.Terminated Bool
_ String
_ TLSError
e)  = TLSError
e
getErrorCause TLSException
e =
    let msg :: String
msg = String
"unexpected TLS exception: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TLSException -> String
forall a. Show a => a -> String
show TLSException
e
     in (String, Bool, AlertDescription) -> TLSError
TLS.Error_Protocol (String
msg, Bool
True, AlertDescription
TLS.InternalError)