{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.QUIC (
    
    tlsQUICClient,
    tlsQUICServer,
    
    QUICCallbacks (..),
    CryptLevel (..),
    KeyScheduleEvent (..),
    
    EarlySecretInfo (..),
    HandshakeSecretInfo (..),
    ApplicationSecretInfo (..),
    EarlySecret,
    HandshakeSecret,
    ApplicationSecret,
    TrafficSecrets,
    ServerTrafficSecret (..),
    ClientTrafficSecret (..),
    
    NegotiatedProtocol,
    HandshakeMode13 (..),
    
    ExtensionRaw (..),
    ExtensionID (ExtensionID, EID_QuicTransportParameters),
    
    errorTLS,
    errorToAlertDescription,
    errorToAlertMessage,
    fromAlertDescription,
    toAlertDescription,
    
    hkdfExpandLabel,
    hkdfExtract,
    hashDigestSize,
    
    quicMaxEarlyDataSize,
    
    defaultSupported,
) where
import Network.TLS.Backend
import Network.TLS.Context
import Network.TLS.Context.Internal
import Network.TLS.Core
import Network.TLS.Crypto (hashDigestSize)
import Network.TLS.Crypto.Types
import Network.TLS.Extra.Cipher
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Control
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.Imports
import Network.TLS.KeySchedule (hkdfExpandLabel, hkdfExtract)
import Network.TLS.Parameters hiding (defaultSupported)
import Network.TLS.Record.Layer
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Types
import Data.Default (def)
nullBackend :: Backend
nullBackend :: Backend
nullBackend =
    Backend
        { backendFlush :: IO ()
backendFlush = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        , backendClose :: IO ()
backendClose = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        , backendSend :: ByteString -> IO ()
backendSend = \ByteString
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        , backendRecv :: Int -> IO ByteString
backendRecv = \Int
_ -> ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""
        }
data KeyScheduleEvent
    = 
      InstallEarlyKeys (Maybe EarlySecretInfo)
    | 
      InstallHandshakeKeys HandshakeSecretInfo
    | 
      InstallApplicationKeys ApplicationSecretInfo
data QUICCallbacks = QUICCallbacks
    { QUICCallbacks -> [(CryptLevel, ByteString)] -> IO ()
quicSend :: [(CryptLevel, ByteString)] -> IO ()
    
    
    
    
    
    
    
    
    
    
    
    
    , QUICCallbacks -> CryptLevel -> IO (Either TLSError ByteString)
quicRecv :: CryptLevel -> IO (Either TLSError ByteString)
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    , QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys :: Context -> KeyScheduleEvent -> IO ()
    
    
    
    
    , QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions :: Context -> [ExtensionRaw] -> IO ()
    
    
    , QUICCallbacks -> Context -> IO ()
quicDone :: Context -> IO ()
    
    
    
    
    }
newRecordLayer
    :: QUICCallbacks
    -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer :: QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer QUICCallbacks
callbacks = (Context -> IO CryptLevel)
-> ([(CryptLevel, ByteString)] -> IO ())
-> (Context -> IO (Either TLSError ByteString))
-> RecordLayer [(CryptLevel, ByteString)]
forall ann.
Eq ann =>
(Context -> IO ann)
-> ([(ann, ByteString)] -> IO ())
-> (Context -> IO (Either TLSError ByteString))
-> RecordLayer [(ann, ByteString)]
newTransparentRecordLayer Context -> IO CryptLevel
get [(CryptLevel, ByteString)] -> IO ()
send Context -> IO (Either TLSError ByteString)
recv
  where
    get :: Context -> IO CryptLevel
get = Context -> IO CryptLevel
getTxLevel
    send :: [(CryptLevel, ByteString)] -> IO ()
send = QUICCallbacks -> [(CryptLevel, ByteString)] -> IO ()
quicSend QUICCallbacks
callbacks
    recv :: Context -> IO (Either TLSError ByteString)
recv Context
ctx = Context -> IO CryptLevel
getRxLevel Context
ctx IO CryptLevel
-> (CryptLevel -> IO (Either TLSError ByteString))
-> IO (Either TLSError ByteString)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= QUICCallbacks -> CryptLevel -> IO (Either TLSError ByteString)
quicRecv QUICCallbacks
callbacks
tlsQUICClient :: ClientParams -> QUICCallbacks -> IO ()
tlsQUICClient :: ClientParams -> QUICCallbacks -> IO ()
tlsQUICClient ClientParams
cparams QUICCallbacks
callbacks = do
    Context
ctx0 <- Backend -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Backend
nullBackend ClientParams
cparams
    let ctx1 :: Context
ctx1 =
            Context
ctx0
                { ctxHandshakeSync = HandshakeSync sync (\Context
_ ServerState
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
                , ctxFragmentSize = Nothing
                , ctxQUICMode = True
                }
        rl :: RecordLayer [(CryptLevel, ByteString)]
rl = QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer QUICCallbacks
callbacks
        ctx2 :: Context
ctx2 = RecordLayer [(CryptLevel, ByteString)] -> Context -> Context
forall a. Monoid a => RecordLayer a -> Context -> Context
updateRecordLayer RecordLayer [(CryptLevel, ByteString)]
rl Context
ctx1
    Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
ctx2
    QUICCallbacks -> Context -> IO ()
quicDone QUICCallbacks
callbacks Context
ctx2
    IO ByteString -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ByteString -> IO ()) -> IO ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
recvData Context
ctx2 
  where
    sync :: Context -> ClientState -> IO ()
sync Context
ctx (SendClientHello Maybe EarlySecretInfo
mEarlySecInfo) =
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (Maybe EarlySecretInfo -> KeyScheduleEvent
InstallEarlyKeys Maybe EarlySecretInfo
mEarlySecInfo)
    sync Context
ctx (RecvServerHello HandshakeSecretInfo
handSecInfo) =
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (HandshakeSecretInfo -> KeyScheduleEvent
InstallHandshakeKeys HandshakeSecretInfo
handSecInfo)
    sync Context
ctx (SendClientFinished [ExtensionRaw]
exts ApplicationSecretInfo
appSecInfo) = do
        let qexts :: [ExtensionRaw]
qexts = [ExtensionRaw] -> [ExtensionRaw]
filterQTP [ExtensionRaw]
exts
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ExtensionRaw] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ExtensionRaw]
qexts) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol String
"QUIC transport parameters are mssing" AlertDescription
MissingExtension
        QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions QUICCallbacks
callbacks Context
ctx [ExtensionRaw]
qexts
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (ApplicationSecretInfo -> KeyScheduleEvent
InstallApplicationKeys ApplicationSecretInfo
appSecInfo)
tlsQUICServer :: ServerParams -> QUICCallbacks -> IO ()
tlsQUICServer :: ServerParams -> QUICCallbacks -> IO ()
tlsQUICServer ServerParams
sparams QUICCallbacks
callbacks = do
    Context
ctx0 <- Backend -> ServerParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Backend
nullBackend ServerParams
sparams
    let ctx1 :: Context
ctx1 =
            Context
ctx0
                { ctxHandshakeSync = HandshakeSync (\Context
_ ClientState
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) sync
                , ctxFragmentSize = Nothing
                , ctxQUICMode = True
                }
        rl :: RecordLayer [(CryptLevel, ByteString)]
rl = QUICCallbacks -> RecordLayer [(CryptLevel, ByteString)]
newRecordLayer QUICCallbacks
callbacks
        ctx2 :: Context
ctx2 = RecordLayer [(CryptLevel, ByteString)] -> Context -> Context
forall a. Monoid a => RecordLayer a -> Context -> Context
updateRecordLayer RecordLayer [(CryptLevel, ByteString)]
rl Context
ctx1
    Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
ctx2
    QUICCallbacks -> Context -> IO ()
quicDone QUICCallbacks
callbacks Context
ctx2
  where
    sync :: Context -> ServerState -> IO ()
sync Context
ctx (SendServerHello [ExtensionRaw]
exts Maybe EarlySecretInfo
mEarlySecInfo HandshakeSecretInfo
handSecInfo) = do
        let qexts :: [ExtensionRaw]
qexts = [ExtensionRaw] -> [ExtensionRaw]
filterQTP [ExtensionRaw]
exts
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([ExtensionRaw] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ExtensionRaw]
qexts) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol String
"QUIC transport parameters are mssing" AlertDescription
MissingExtension
        QUICCallbacks -> Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions QUICCallbacks
callbacks Context
ctx [ExtensionRaw]
qexts
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (Maybe EarlySecretInfo -> KeyScheduleEvent
InstallEarlyKeys Maybe EarlySecretInfo
mEarlySecInfo)
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (HandshakeSecretInfo -> KeyScheduleEvent
InstallHandshakeKeys HandshakeSecretInfo
handSecInfo)
    sync Context
ctx (SendServerFinished ApplicationSecretInfo
appSecInfo) =
        QUICCallbacks -> Context -> KeyScheduleEvent -> IO ()
quicInstallKeys QUICCallbacks
callbacks Context
ctx (ApplicationSecretInfo -> KeyScheduleEvent
InstallApplicationKeys ApplicationSecretInfo
appSecInfo)
filterQTP :: [ExtensionRaw] -> [ExtensionRaw]
filterQTP :: [ExtensionRaw] -> [ExtensionRaw]
filterQTP =
    (ExtensionRaw -> Bool) -> [ExtensionRaw] -> [ExtensionRaw]
forall a. (a -> Bool) -> [a] -> [a]
filter
        (\(ExtensionRaw ExtensionID
eid ByteString
_) -> ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_QuicTransportParameters)
errorTLS :: String -> IO a
errorTLS :: forall a. String -> IO a
errorTLS String
msg = TLSError -> IO a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO a) -> TLSError -> IO a
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
msg AlertDescription
InternalError
errorToAlertDescription :: TLSError -> AlertDescription
errorToAlertDescription :: TLSError -> AlertDescription
errorToAlertDescription = (AlertLevel, AlertDescription) -> AlertDescription
forall a b. (a, b) -> b
snd ((AlertLevel, AlertDescription) -> AlertDescription)
-> (TLSError -> (AlertLevel, AlertDescription))
-> TLSError
-> AlertDescription
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> (AlertLevel, AlertDescription)
errorToAlert
toAlertDescription :: Word8 -> AlertDescription
toAlertDescription :: Word8 -> AlertDescription
toAlertDescription = Word8 -> AlertDescription
AlertDescription
defaultSupported :: Supported
defaultSupported :: Supported
defaultSupported =
    Supported
forall a. Default a => a
def
        { supportedVersions = [TLS13]
        , supportedCiphers =
            [ cipher_TLS13_AES256GCM_SHA384
            , cipher_TLS13_AES128GCM_SHA256
            , cipher_TLS13_AES128CCM_SHA256
            ]
        , supportedGroups = [X25519, X448, P256, P384, P521]
        }
quicMaxEarlyDataSize :: Int
quicMaxEarlyDataSize :: Int
quicMaxEarlyDataSize = Int
0xffffffff