{-# LANGUAGE OverloadedStrings #-}
-- |
-- Module      : Network.TLS.QUIC
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Experimental API to run the TLS handshake establishing a QUIC connection.
--
-- On the northbound API:
--
-- * QUIC starts a TLS client or server thread with 'tlsQUICClient' or
--   'tlsQUICServer'.
--
--  TLS invokes QUIC callbacks to use the QUIC transport
--
-- * TLS uses 'quicSend' and 'quicRecv' to send and receive handshake message
--   fragments.
--
-- * TLS calls 'quicInstallKeys' to provide to QUIC the traffic secrets it
--   should use for encryption/decryption.
--
-- * TLS calls 'quicNotifyExtensions' to notify to QUIC the transport parameters
--   exchanged through the handshake protocol.
--
-- * TLS calls 'quicDone' when the handshake is done.
--
module Network.TLS.QUIC (
    -- * Handshakers
      tlsQUICClient
    , tlsQUICServer
    -- * Callback
    , QUICCallbacks(..)
    , CryptLevel(..)
    , KeyScheduleEvent(..)
    -- * Secrets
    , EarlySecretInfo(..)
    , HandshakeSecretInfo(..)
    , ApplicationSecretInfo(..)
    , EarlySecret
    , HandshakeSecret
    , ApplicationSecret
    , TrafficSecrets
    , ServerTrafficSecret(..)
    , ClientTrafficSecret(..)
    -- * Negotiated parameters
    , NegotiatedProtocol
    , HandshakeMode13(..)
    -- * Extensions
    , ExtensionRaw(..)
    , ExtensionID
    , extensionID_QuicTransportParameters
    -- * Errors
    , errorTLS
    , errorToAlertDescription
    , errorToAlertMessage
    , fromAlertDescription
    , toAlertDescription
    -- * Hash
    , hkdfExpandLabel
    , hkdfExtract
    , hashDigestSize
    -- * Constants
    , quicMaxEarlyDataSize
    -- * Supported
    , defaultSupported
    ) where

import Network.TLS.Backend
import Network.TLS.Context
import Network.TLS.Context.Internal
import Network.TLS.Core
import Network.TLS.Crypto (hashDigestSize)
import Network.TLS.Crypto.Types
import Network.TLS.Extension (extensionID_QuicTransportParameters)
import Network.TLS.Extra.Cipher
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Control
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.Imports
import Network.TLS.KeySchedule (hkdfExtract, hkdfExpandLabel)
import Network.TLS.Parameters
import Network.TLS.Record.Layer
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Types

import Data.Default.Class

nullBackend :: Backend
nullBackend :: Backend
nullBackend = Backend :: IO ()
-> IO ()
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> Backend
Backend {
    backendFlush :: IO ()
backendFlush = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  , backendClose :: IO ()
backendClose = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  , backendSend :: ByteString -> IO ()
backendSend  = \ByteString
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  , backendRecv :: Int -> IO ByteString
backendRecv  = \Int
_ -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""
  }

-- | Argument given to 'quicInstallKeys' when encryption material is available.
data KeyScheduleEvent
    = InstallEarlyKeys (Maybe EarlySecretInfo)
      -- ^ Key material and parameters for traffic at 0-RTT level
    | InstallHandshakeKeys HandshakeSecretInfo
      -- ^ Key material and parameters for traffic at handshake level
    | InstallApplicationKeys ApplicationSecretInfo
      -- ^ Key material and parameters for traffic at application level

-- | Callbacks implemented by QUIC and to be called by TLS at specific points
-- during the handshake.  TLS may invoke them from external threads but calls
-- are not concurrent.  Only a single callback function is called at a given
-- point in time.
data QUICCallbacks = QUICCallbacks
    { QUICCallbacks -> [(CryptLevel, ByteString)] -> IO ()
quicSend              :: [(CryptLevel, ByteString)] -> IO ()
      -- ^ Called by TLS so that QUIC sends one or more handshake fragments. The
      -- content transiting on this API is the plaintext of the fragments and
      -- QUIC responsability is to encrypt this payload with the key material
      -- given for the specified level and an appropriate encryption scheme.
      --
      -- The size of the fragments may exceed QUIC datagram limits so QUIC may
      -- break them into smaller fragments.
      --
      -- The handshake protocol sometimes combines content at two levels in a
      -- single flight.  The TLS library does its best to provide this in the
      -- same @quicSend@ call and with a multi-valued argument.  QUIC can then
      -- decide how to transmit this optimally.
    , QUICCallbacks -> CryptLevel -> IO (Either TLSError ByteString)
quicRecv              :: CryptLevel -> IO (Either TLSError ByteString)
      -- ^ Called by TLS to receive from QUIC the next plaintext handshake
      -- fragment.  The argument specifies with which encryption level the
      -- fragment should be decrypted.
      --
      -- QUIC may return partial fragments to TLS.  TLS will then call
      -- @quicRecv@ again as long as necessary.  Note however that fragments
      -- must be returned in the correct sequence, i.e. the order the TLS peer
      -- emitted them.
      --
      -- The function may return an error to TLS if end of stream is reached or
      -- if a protocol error has been received, believing the handshake cannot
      -- proceed any longer.  If the TLS handshake protocol cannot recover from
      -- this error, the failure condition will be reported back to QUIC through
      -- the control interface.
    , QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys       :: Context -> KeyScheduleEvent -> IO ()
      -- ^ Called by TLS when new encryption material is ready to be used in the
      -- handshake.  The next 'quicSend' or 'quicRecv' may now use the
      -- associated encryption level (although the previous level is also
      -- possible: directions Send/Recv do not change at the same time).
    , QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions  :: Context -> [ExtensionRaw] -> IO ()
      -- ^ Called by TLS when QUIC-specific extensions have been received from
      -- the peer.
    , QUICCallbacks -> Context -> IO ()
quicDone :: Context -> IO ()
      -- ^ Called when 'handshake' is done. 'tlsQUICServer' is
      -- finished after calling this hook. 'tlsQUICClient' calls
      -- 'recvData' after calling this hook to wait for new session
      -- tickets.
    }

getTxLevel :: Context -> IO CryptLevel
getTxLevel :: Context -> IO CryptLevel
getTxLevel Context
ctx = do
    (Hash
_, Cipher
_, CryptLevel
level, ByteString
_) <- Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxState Context
ctx
    CryptLevel -> IO CryptLevel
forall (m :: * -> *) a. Monad m => a -> m a
return CryptLevel
level

getRxLevel :: Context -> IO CryptLevel
getRxLevel :: Context -> IO CryptLevel
getRxLevel Context
ctx = do
    (Hash
_, Cipher
_, CryptLevel
level, ByteString
_) <- Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxState Context
ctx
    CryptLevel -> IO CryptLevel
forall (m :: * -> *) a. Monad m => a -> m a
return CryptLevel
level

newRecordLayer :: Context -> QUICCallbacks
               -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer :: Context -> QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer Context
ctx QUICCallbacks
callbacks = IO CryptLevel
-> ([(CryptLevel, ByteString)] -> IO ())
-> IO (Either TLSError ByteString)
-> RecordLayer [(CryptLevel, ByteString)]
forall ann.
Eq ann =>
IO ann
-> ([(ann, ByteString)] -> IO ())
-> IO (Either TLSError ByteString)
-> RecordLayer [(ann, ByteString)]
newTransparentRecordLayer IO CryptLevel
get [(CryptLevel, ByteString)] -> IO ()
send IO (Either TLSError ByteString)
recv
  where
    get :: IO CryptLevel
get     = Context -> IO CryptLevel
getTxLevel Context
ctx
    send :: [(CryptLevel, ByteString)] -> IO ()
send    = QUICCallbacks -> [(CryptLevel, ByteString)] -> IO ()
quicSend QUICCallbacks
callbacks
    recv :: IO (Either TLSError ByteString)
recv    = Context -> IO CryptLevel
getRxLevel Context
ctx IO CryptLevel
-> (CryptLevel -> IO (Either TLSError ByteString))
-> IO (Either TLSError ByteString)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= QUICCallbacks -> CryptLevel -> IO (Either TLSError ByteString)
quicRecv QUICCallbacks
callbacks

-- | Start a TLS handshake thread for a QUIC client.  The client will use the
-- specified TLS parameters and call the provided callback functions to send and
-- receive handshake data.
tlsQUICClient :: ClientParams -> QUICCallbacks -> IO ()
tlsQUICClient :: ClientParams -> QUICCallbacks -> IO ()
tlsQUICClient ClientParams
cparams QUICCallbacks
callbacks = do
    Context
ctx0 <- Backend -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Backend
nullBackend ClientParams
cparams
    let ctx1 :: Context
ctx1 = Context
ctx0
           { ctxHandshakeSync :: HandshakeSync
ctxHandshakeSync = (Context -> ClientState -> IO ())
-> (Context -> ServerState -> IO ()) -> HandshakeSync
HandshakeSync Context -> ClientState -> IO ()
sync (\Context
_ ServerState
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
           , ctxFragmentSize :: Maybe Int
ctxFragmentSize = Maybe Int
forall a. Maybe a
Nothing
           , ctxQUICMode :: Bool
ctxQUICMode = Bool
True
           }
        rl :: RecordLayer [(CryptLevel, ByteString)]
rl = Context -> QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer Context
ctx2 QUICCallbacks
callbacks
        ctx2 :: Context
ctx2 = RecordLayer [(CryptLevel, ByteString)] -> Context -> Context
forall bytes.
Monoid bytes =>
RecordLayer bytes -> Context -> Context
updateRecordLayer RecordLayer [(CryptLevel, ByteString)]
rl Context
ctx1
    Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
ctx2
    QUICCallbacks -> Context -> IO ()
quicDone QUICCallbacks
callbacks Context
ctx2
    IO ByteString -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ByteString -> IO ()) -> IO ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
recvData Context
ctx2 -- waiting for new session tickets
  where
    sync :: Context -> ClientState -> IO ()
sync Context
ctx (SendClientHello Maybe EarlySecretInfo
mEarlySecInfo) =
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (Maybe EarlySecretInfo -> KeyScheduleEvent
InstallEarlyKeys Maybe EarlySecretInfo
mEarlySecInfo)
    sync Context
ctx (RecvServerHello HandshakeSecretInfo
handSecInfo) =
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (HandshakeSecretInfo -> KeyScheduleEvent
InstallHandshakeKeys HandshakeSecretInfo
handSecInfo)
    sync Context
ctx (SendClientFinished [ExtensionRaw]
exts ApplicationSecretInfo
appSecInfo) = do
        let qexts :: [ExtensionRaw]
qexts = [ExtensionRaw] -> [ExtensionRaw]
filterQTP [ExtensionRaw]
exts
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ExtensionRaw] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ExtensionRaw]
qexts) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"QUIC transport parameters are mssing", Bool
True, AlertDescription
MissingExtension)
        QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions QUICCallbacks
callbacks Context
ctx [ExtensionRaw]
qexts
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (ApplicationSecretInfo -> KeyScheduleEvent
InstallApplicationKeys ApplicationSecretInfo
appSecInfo)

-- | Start a TLS handshake thread for a QUIC server.  The server will use the
-- specified TLS parameters and call the provided callback functions to send and
-- receive handshake data.
tlsQUICServer :: ServerParams -> QUICCallbacks -> IO ()
tlsQUICServer :: ServerParams -> QUICCallbacks -> IO ()
tlsQUICServer ServerParams
sparams QUICCallbacks
callbacks = do
    Context
ctx0 <- Backend -> ServerParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Backend
nullBackend ServerParams
sparams
    let ctx1 :: Context
ctx1 = Context
ctx0
          { ctxHandshakeSync :: HandshakeSync
ctxHandshakeSync = (Context -> ClientState -> IO ())
-> (Context -> ServerState -> IO ()) -> HandshakeSync
HandshakeSync (\Context
_ ClientState
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) Context -> ServerState -> IO ()
sync
          , ctxFragmentSize :: Maybe Int
ctxFragmentSize = Maybe Int
forall a. Maybe a
Nothing
          , ctxQUICMode :: Bool
ctxQUICMode = Bool
True
          }
        rl :: RecordLayer [(CryptLevel, ByteString)]
rl = Context -> QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer Context
ctx2 QUICCallbacks
callbacks
        ctx2 :: Context
ctx2 = RecordLayer [(CryptLevel, ByteString)] -> Context -> Context
forall bytes.
Monoid bytes =>
RecordLayer bytes -> Context -> Context
updateRecordLayer RecordLayer [(CryptLevel, ByteString)]
rl Context
ctx1
    Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
ctx2
    QUICCallbacks -> Context -> IO ()
quicDone QUICCallbacks
callbacks Context
ctx2
  where
    sync :: Context -> ServerState -> IO ()
sync Context
ctx (SendServerHello [ExtensionRaw]
exts Maybe EarlySecretInfo
mEarlySecInfo HandshakeSecretInfo
handSecInfo) = do
        let qexts :: [ExtensionRaw]
qexts = [ExtensionRaw] -> [ExtensionRaw]
filterQTP [ExtensionRaw]
exts
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ExtensionRaw] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ExtensionRaw]
qexts) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"QUIC transport parameters are mssing", Bool
True, AlertDescription
MissingExtension)
        QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions QUICCallbacks
callbacks Context
ctx [ExtensionRaw]
qexts
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (Maybe EarlySecretInfo -> KeyScheduleEvent
InstallEarlyKeys Maybe EarlySecretInfo
mEarlySecInfo)
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (HandshakeSecretInfo -> KeyScheduleEvent
InstallHandshakeKeys HandshakeSecretInfo
handSecInfo)
    sync Context
ctx (SendServerFinished ApplicationSecretInfo
appSecInfo) =
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (ApplicationSecretInfo -> KeyScheduleEvent
InstallApplicationKeys ApplicationSecretInfo
appSecInfo)

filterQTP :: [ExtensionRaw] -> [ExtensionRaw]
filterQTP :: [ExtensionRaw] -> [ExtensionRaw]
filterQTP = (ExtensionRaw -> Bool) -> [ExtensionRaw] -> [ExtensionRaw]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(ExtensionRaw ExtensionID
eid ByteString
_) -> ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
extensionID_QuicTransportParameters Bool -> Bool -> Bool
|| ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
0xffa5) -- to be deleted

-- | Can be used by callbacks to signal an unexpected condition.  This will then
-- generate an "internal_error" alert in the TLS stack.
errorTLS :: String -> IO a
errorTLS :: String -> IO a
errorTLS String
msg = TLSError -> IO a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO a) -> TLSError -> IO a
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
msg, Bool
True, AlertDescription
InternalError)

-- | Return the alert that a TLS endpoint would send to the peer for the
-- specified library error.
errorToAlertDescription :: TLSError -> AlertDescription
errorToAlertDescription :: TLSError -> AlertDescription
errorToAlertDescription = (AlertLevel, AlertDescription) -> AlertDescription
forall a b. (a, b) -> b
snd ((AlertLevel, AlertDescription) -> AlertDescription)
-> (TLSError -> (AlertLevel, AlertDescription))
-> TLSError
-> AlertDescription
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> (AlertLevel, AlertDescription)
errorToAlert

-- | Encode an alert to the assigned value.
fromAlertDescription :: AlertDescription -> Word8
fromAlertDescription :: AlertDescription -> Word8
fromAlertDescription = AlertDescription -> Word8
forall a. TypeValuable a => a -> Word8
valOfType

-- | Decode an alert from the assigned value.
toAlertDescription :: Word8 -> Maybe AlertDescription
toAlertDescription :: Word8 -> Maybe AlertDescription
toAlertDescription = Word8 -> Maybe AlertDescription
forall a. TypeValuable a => Word8 -> Maybe a
valToType

defaultSupported :: Supported
defaultSupported :: Supported
defaultSupported = Supported
forall a. Default a => a
def
    { supportedVersions :: [Version]
supportedVersions       = [Version
TLS13]
    , supportedCiphers :: [Cipher]
supportedCiphers        = [ Cipher
cipher_TLS13_AES256GCM_SHA384
                                , Cipher
cipher_TLS13_AES128GCM_SHA256
                                , Cipher
cipher_TLS13_AES128CCM_SHA256
                                ]
    , supportedGroups :: [Group]
supportedGroups         = [Group
X25519,Group
X448,Group
P256,Group
P384,Group
P521]
    }

-- | Max early data size for QUIC.
quicMaxEarlyDataSize :: Int
quicMaxEarlyDataSize :: Int
quicMaxEarlyDataSize = Int
0xffffffff