-- | TLS bindings for [Rustls](https://github.com/rustls/rustls) via
-- [rustls-ffi](https://github.com/rustls/rustls-ffi).
--
-- See the [README on GitHub](https://github.com/amesgen/hs-rustls/tree/main/rustls)
-- for setup instructions.
--
-- Currently, most of the functionality exposed by rustls-ffi is available,
-- while rustls-ffi is still missing some more niche Rustls features.
--
-- Also see [http-client-rustls](https://hackage.haskell.org/package/http-client-rustls)
-- for making HTTPS requests using
-- [http-client](https://hackage.haskell.org/package/http-client) and Rustls.
--
-- == Client example
--
-- Suppose you have already opened a 'Network.Socket.Socket' to @example.org@,
-- port 443 (see e.g. the examples at "Network.Socket"). This small example
-- showcases how to perform a simple HTTP GET request:
--
-- >>> :set -XOverloadedStrings
-- >>> import qualified Rustls
-- >>> import Network.Socket (Socket)
-- >>> import Data.Acquire (withAcquire)
-- >>> :{
-- example :: Socket -> IO ()
-- example socket = do
--   -- It is encouraged to share a single `clientConfig` when creating multiple
--   -- TLS connections.
--   clientConfig <-
--     Rustls.buildClientConfig =<< Rustls.defaultClientConfigBuilder
--   let backend = Rustls.mkSocketBackend socket
--       newConnection =
--         Rustls.newClientConnection backend clientConfig "example.org"
--   withAcquire newConnection $ \conn -> do
--     Rustls.writeBS conn "GET /"
--     recv <- Rustls.readBS conn 1000 -- max number of bytes to read
--     print recv
-- :}
--
-- == Using 'Acquire'
--
-- Some API functions (like 'newClientConnection' and 'newServerConnection')
-- return an 'Acquire' from
-- [resourcet](https://hackage.haskell.org/package/resourcet), as it is a
-- convenient abstraction for exposing a value that should be consumed in a
-- "bracketed" manner.
--
-- Usually, it can be used via 'Data.Acquire.with' or 'withAcquire', or via
-- 'allocateAcquire' when a 'Control.Monad.Trans.Resource.MonadResource'
-- constraint is available. If you really need the extra flexibility, you can
-- also access separate @open…@ and @close…@ functions by reaching for
-- "Data.Acquire.Internal".
module Rustls
  ( -- * Client

    -- ** Builder
    ClientConfigBuilder (..),
    defaultClientConfigBuilder,
    ServerCertVerifier (..),

    -- ** Config
    ClientConfig,
    clientConfigLogCallback,
    buildClientConfig,

    -- ** Open a connection
    newClientConnection,

    -- * Server

    -- ** Builder
    ServerConfigBuilder (..),
    defaultServerConfigBuilder,
    ClientCertVerifier (..),
    ClientCertVerifierPolicy (..),

    -- ** Config
    ServerConfig,
    serverConfigLogCallback,
    buildServerConfig,

    -- ** Open a connection
    newServerConnection,

    -- * Connection
    Connection,
    Side (..),

    -- ** Read and write
    readBS,
    writeBS,

    -- ** Handshaking
    handshake,
    HandshakeQuery,
    getALPNProtocol,
    getTLSVersion,
    getNegotiatedCipherSuite,
    getSNIHostname,
    getPeerCertificate,

    -- ** Closing
    sendCloseNotify,

    -- ** Logging
    LogCallback,
    newLogCallback,
    LogLevel (..),

    -- ** Raw 'Ptr'-based API
    readPtr,
    writePtr,

    -- * Misc
    version,

    -- ** Backend
    Backend (..),
    mkSocketBackend,
    mkByteStringBackend,

    -- ** Crypto provider
    CryptoProvider,
    getDefaultCryptoProvider,
    setCryptoProviderCipherSuites,
    cryptoProviderCipherSuites,
    cryptoProviderTLSVersions,

    -- ** Types
    ALPNProtocol (..),
    PEMCertificates (..),
    PEMCertificateParsing (..),
    CertifiedKey (..),
    DERCertificate (..),
    CertificateRevocationList (..),
    TLSVersion (TLS12, TLS13, unTLSVersion),
    CipherSuite (..),
    NegotiatedCipherSuite (..),

    -- ** Exceptions
    RustlsException,
    isCertError,
    RustlsLogException (..),
  )
where

import Control.Concurrent (forkFinally, killThread)
import Control.Concurrent.MVar
import Control.Exception qualified as E
import Control.Monad (forever, void, when)
import Control.Monad.Except (MonadError (..), liftEither)
import Control.Monad.IO.Class
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Reader
import Data.Acquire
import Data.ByteString (ByteString)
import Data.ByteString qualified as B
import Data.ByteString.Internal qualified as BI
import Data.ByteString.Unsafe qualified as BU
import Data.Coerce
import Data.Containers.ListUtils (nubOrd)
import Data.Foldable (for_, toList)
import Data.List.NonEmpty (NonEmpty)
import Data.List.NonEmpty qualified as NE
import Data.Set qualified as Set
import Data.Text (Text)
import Data.Text qualified as T
import Data.Text.Foreign qualified as T
import Data.Traversable (for)
import Data.Word
import Foreign hiding (void)
import Foreign.C
import GHC.Conc (reportError)
import GHC.Generics (Generic)
import Rustls.Internal
import Rustls.Internal.FFI (ConstPtr (..), TLSVersion (..))
import Rustls.Internal.FFI qualified as FFI
import System.IO.Unsafe (unsafePerformIO)

-- $setup
-- >>> import Control.Monad.IO.Class
-- >>> import Data.Acquire

-- | Combined version string of Rustls and rustls-ffi, as well as the Rustls
-- cryptography provider.
--
-- >>> version
-- "rustls-ffi/0.14.0/rustls/0.23.13/aws-lc-rs"
version :: Text
version :: Text
version = IO Text -> Text
forall a. IO a -> a
unsafePerformIO (IO Text -> Text) -> IO Text -> Text
forall a b. (a -> b) -> a -> b
$ (Ptr Str -> IO Text) -> IO Text
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr Str
strPtr -> do
  Ptr Str -> IO ()
FFI.hsVersion Ptr Str
strPtr
  Str -> IO Text
strToText (Str -> IO Text) -> IO Str -> IO Text
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr Str -> IO Str
forall a. Storable a => Ptr a -> IO a
peek Ptr Str
strPtr
{-# NOINLINE version #-}

buildCryptoProvider :: Ptr FFI.CryptoProviderBuilder -> IO CryptoProvider
buildCryptoProvider :: Ptr CryptoProviderBuilder -> IO CryptoProvider
buildCryptoProvider Ptr CryptoProviderBuilder
builder = (Ptr (ConstPtr CryptoProvider) -> IO CryptoProvider)
-> IO CryptoProvider
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr (ConstPtr CryptoProvider)
cryptoProviderPtr -> do
  Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr CryptoProviderBuilder
-> Ptr (ConstPtr CryptoProvider) -> IO Result
FFI.cryptoProviderBuilderBuild Ptr CryptoProviderBuilder
builder Ptr (ConstPtr CryptoProvider)
cryptoProviderPtr
  ConstPtr cryptoProviderPtr <- Ptr (ConstPtr CryptoProvider) -> IO (ConstPtr CryptoProvider)
forall a. Storable a => Ptr a -> IO a
peek Ptr (ConstPtr CryptoProvider)
cryptoProviderPtr
  CryptoProvider <$> newForeignPtr FFI.cryptoProviderFree cryptoProviderPtr

-- | Get the process-wide default Rustls cryptography provider.
getDefaultCryptoProvider :: (MonadIO m) => m CryptoProvider
getDefaultCryptoProvider :: forall (m :: * -> *). MonadIO m => m CryptoProvider
getDefaultCryptoProvider = IO CryptoProvider -> m CryptoProvider
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO CryptoProvider -> m CryptoProvider)
-> IO CryptoProvider -> m CryptoProvider
forall a b. (a -> b) -> a -> b
$ IO CryptoProvider -> IO CryptoProvider
forall a. IO a -> IO a
E.mask_ (IO CryptoProvider -> IO CryptoProvider)
-> IO CryptoProvider -> IO CryptoProvider
forall a b. (a -> b) -> a -> b
$ ContT CryptoProvider IO CryptoProvider -> IO CryptoProvider
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT do
  builderPtr <- ((Ptr (Ptr CryptoProviderBuilder) -> IO CryptoProvider)
 -> IO CryptoProvider)
-> ContT CryptoProvider IO (Ptr (Ptr CryptoProviderBuilder))
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (Ptr (Ptr CryptoProviderBuilder) -> IO CryptoProvider)
-> IO CryptoProvider
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca
  builder <-
    ContT $
      E.bracketOnError
        do
          -- This actually also sets the process-wide default crypto provider if
          -- not already set, which is a side effect.
          rethrowR =<< FFI.cryptoProviderBuilderNewFromDefault builderPtr
          peek builderPtr
        FFI.cryptoProviderBuilderFree
  liftIO $ buildCryptoProvider builder

-- | Create a derived 'CryptoProvider' by restricting the cipher suites to the
-- ones in the given list.
setCryptoProviderCipherSuites ::
  (MonadError RustlsException m) =>
  -- | Must be a subset of 'cryptoProviderCipherSuites'. Only the
  -- 'cipherSuiteID' is used.
  [CipherSuite] ->
  CryptoProvider ->
  m CryptoProvider
setCryptoProviderCipherSuites :: forall (m :: * -> *).
MonadError RustlsException m =>
[CipherSuite] -> CryptoProvider -> m CryptoProvider
setCryptoProviderCipherSuites [CipherSuite]
cipherSuites CryptoProvider
cryptoProvider =
  Either RustlsException CryptoProvider -> m CryptoProvider
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither (Either RustlsException CryptoProvider -> m CryptoProvider)
-> Either RustlsException CryptoProvider -> m CryptoProvider
forall a b. (a -> b) -> a -> b
$ IO (Either RustlsException CryptoProvider)
-> Either RustlsException CryptoProvider
forall a. IO a -> a
unsafePerformIO (IO (Either RustlsException CryptoProvider)
 -> Either RustlsException CryptoProvider)
-> IO (Either RustlsException CryptoProvider)
-> Either RustlsException CryptoProvider
forall a b. (a -> b) -> a -> b
$ IO CryptoProvider -> IO (Either RustlsException CryptoProvider)
forall e a. Exception e => IO a -> IO (Either e a)
E.try (IO CryptoProvider -> IO (Either RustlsException CryptoProvider))
-> IO CryptoProvider -> IO (Either RustlsException CryptoProvider)
forall a b. (a -> b) -> a -> b
$ IO CryptoProvider -> IO CryptoProvider
forall a. IO a -> IO a
E.mask_ (IO CryptoProvider -> IO CryptoProvider)
-> IO CryptoProvider -> IO CryptoProvider
forall a b. (a -> b) -> a -> b
$ ContT CryptoProvider IO CryptoProvider -> IO CryptoProvider
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT do
    cryptoProviderPtr <- CryptoProvider -> ContT CryptoProvider IO (ConstPtr CryptoProvider)
forall a. CryptoProvider -> ContT a IO (ConstPtr CryptoProvider)
withCryptoProvider CryptoProvider
cryptoProvider
    builder <-
      ContT $
        E.bracketOnError
          (FFI.cryptoProviderBuilderNewWithBase cryptoProviderPtr)
          FFI.cryptoProviderBuilderFree

    let filteredCipherSuites =
          [ ConstPtr SupportedCipherSuite
cipherSuitePtr
          | CSize
i <- [CSize
1 .. CSize
len],
            let cipherSuitePtr :: ConstPtr SupportedCipherSuite
cipherSuitePtr =
                  ConstPtr CryptoProvider -> CSize -> ConstPtr SupportedCipherSuite
FFI.cryptoProviderCiphersuitesGet ConstPtr CryptoProvider
cryptoProviderPtr (CSize
i CSize -> CSize -> CSize
forall a. Num a => a -> a -> a
- CSize
1)
                cipherSuiteID :: Word16
cipherSuiteID = ConstPtr SupportedCipherSuite -> Word16
FFI.supportedCipherSuiteGetSuite ConstPtr SupportedCipherSuite
cipherSuitePtr,
            Word16
cipherSuiteID Word16 -> Set Word16 -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Word16
cipherSuiteIDs
          ]
          where
            len :: CSize
len = ConstPtr CryptoProvider -> CSize
FFI.cryptoProviderCiphersuitesLen ConstPtr CryptoProvider
cryptoProviderPtr
            cipherSuiteIDs :: Set Word16
cipherSuiteIDs = [Word16] -> Set Word16
forall a. Ord a => [a] -> Set a
Set.fromList ([Word16] -> Set Word16) -> [Word16] -> Set Word16
forall a b. (a -> b) -> a -> b
$ CipherSuite -> Word16
cipherSuiteID (CipherSuite -> Word16) -> [CipherSuite] -> [Word16]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CipherSuite]
cipherSuites

    (csPtr, csLen) <- ContT \(ConstPtr (ConstPtr SupportedCipherSuite), CSize)
-> IO CryptoProvider
cb -> [ConstPtr SupportedCipherSuite]
-> (Int
    -> Ptr (ConstPtr SupportedCipherSuite) -> IO CryptoProvider)
-> IO CryptoProvider
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen [ConstPtr SupportedCipherSuite]
filteredCipherSuites \Int
len Ptr (ConstPtr SupportedCipherSuite)
ptr ->
      (ConstPtr (ConstPtr SupportedCipherSuite), CSize)
-> IO CryptoProvider
cb (Ptr (ConstPtr SupportedCipherSuite)
-> ConstPtr (ConstPtr SupportedCipherSuite)
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr (ConstPtr SupportedCipherSuite)
ptr, Int -> CSize
intToCSize Int
len)
    liftIO $ rethrowR =<< FFI.cryptoProviderBuilderSetCipherSuites builder csPtr csLen

    liftIO $ buildCryptoProvider builder

withCryptoProvider :: CryptoProvider -> ContT a IO (ConstPtr FFI.CryptoProvider)
withCryptoProvider :: forall a. CryptoProvider -> ContT a IO (ConstPtr CryptoProvider)
withCryptoProvider =
  (Ptr CryptoProvider -> ConstPtr CryptoProvider)
-> ContT a IO (Ptr CryptoProvider)
-> ContT a IO (ConstPtr CryptoProvider)
forall a b. (a -> b) -> ContT a IO a -> ContT a IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ptr CryptoProvider -> ConstPtr CryptoProvider
forall a. Ptr a -> ConstPtr a
ConstPtr (ContT a IO (Ptr CryptoProvider)
 -> ContT a IO (ConstPtr CryptoProvider))
-> (CryptoProvider -> ContT a IO (Ptr CryptoProvider))
-> CryptoProvider
-> ContT a IO (ConstPtr CryptoProvider)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Ptr CryptoProvider -> IO a) -> IO a)
-> ContT a IO (Ptr CryptoProvider)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr CryptoProvider -> IO a) -> IO a)
 -> ContT a IO (Ptr CryptoProvider))
-> (CryptoProvider -> (Ptr CryptoProvider -> IO a) -> IO a)
-> CryptoProvider
-> ContT a IO (Ptr CryptoProvider)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr CryptoProvider -> (Ptr CryptoProvider -> IO a) -> IO a
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr (ForeignPtr CryptoProvider -> (Ptr CryptoProvider -> IO a) -> IO a)
-> (CryptoProvider -> ForeignPtr CryptoProvider)
-> CryptoProvider
-> (Ptr CryptoProvider -> IO a)
-> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CryptoProvider -> ForeignPtr CryptoProvider
unCryptoProvider

-- | Get the cipher suites supported by the given cryptography provider.
cryptoProviderCipherSuites :: CryptoProvider -> [CipherSuite]
cryptoProviderCipherSuites :: CryptoProvider -> [CipherSuite]
cryptoProviderCipherSuites CryptoProvider
cryptoProvider = IO [CipherSuite] -> [CipherSuite]
forall a. IO a -> a
unsafePerformIO (IO [CipherSuite] -> [CipherSuite])
-> IO [CipherSuite] -> [CipherSuite]
forall a b. (a -> b) -> a -> b
$ ContT [CipherSuite] IO [CipherSuite] -> IO [CipherSuite]
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT do
  cryptoProviderPtr <- CryptoProvider -> ContT [CipherSuite] IO (ConstPtr CryptoProvider)
forall a. CryptoProvider -> ContT a IO (ConstPtr CryptoProvider)
withCryptoProvider CryptoProvider
cryptoProvider
  liftIO do
    let len = ConstPtr CryptoProvider -> CSize
FFI.cryptoProviderCiphersuitesLen ConstPtr CryptoProvider
cryptoProviderPtr
    for [1 .. len] \CSize
i -> do
      let cipherSuitePtr :: ConstPtr SupportedCipherSuite
cipherSuitePtr = ConstPtr CryptoProvider -> CSize -> ConstPtr SupportedCipherSuite
FFI.cryptoProviderCiphersuitesGet ConstPtr CryptoProvider
cryptoProviderPtr (CSize
i CSize -> CSize -> CSize
forall a. Num a => a -> a -> a
- CSize
1)
          cipherSuiteID :: Word16
cipherSuiteID = ConstPtr SupportedCipherSuite -> Word16
FFI.supportedCipherSuiteGetSuite ConstPtr SupportedCipherSuite
cipherSuitePtr
      cipherSuiteName <-
        (Ptr Str -> IO Text) -> IO Text
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr Str
strPtr -> do
          ConstPtr SupportedCipherSuite -> Ptr Str -> IO ()
FFI.hsSupportedCipherSuiteGetName ConstPtr SupportedCipherSuite
cipherSuitePtr Ptr Str
strPtr
          Str -> IO Text
strToText (Str -> IO Text) -> IO Str -> IO Text
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr Str -> IO Str
forall a. Storable a => Ptr a -> IO a
peek Ptr Str
strPtr
      cipherSuiteTLSVersion <-
        FFI.hsSupportedCiphersuiteProtocolVersion cipherSuitePtr
      pure CipherSuite {..}

-- | Get all TLS versions supported by at least one of the cipher suites
-- supported by the given cryptography provider.
cryptoProviderTLSVersions :: CryptoProvider -> [TLSVersion]
cryptoProviderTLSVersions :: CryptoProvider -> [TLSVersion]
cryptoProviderTLSVersions =
  [TLSVersion] -> [TLSVersion]
forall a. Ord a => [a] -> [a]
nubOrd
    ([TLSVersion] -> [TLSVersion])
-> (CryptoProvider -> [TLSVersion])
-> CryptoProvider
-> [TLSVersion]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CipherSuite -> TLSVersion) -> [CipherSuite] -> [TLSVersion]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CipherSuite -> TLSVersion
cipherSuiteTLSVersion
    ([CipherSuite] -> [TLSVersion])
-> (CryptoProvider -> [CipherSuite])
-> CryptoProvider
-> [TLSVersion]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CryptoProvider -> [CipherSuite]
cryptoProviderCipherSuites

-- | A 'ClientConfigBuilder' with good defaults, using the OS certificate store.
defaultClientConfigBuilder :: (MonadIO m) => m ClientConfigBuilder
defaultClientConfigBuilder :: forall (m :: * -> *). MonadIO m => m ClientConfigBuilder
defaultClientConfigBuilder = do
  cryptoProvider <- m CryptoProvider
forall (m :: * -> *). MonadIO m => m CryptoProvider
getDefaultCryptoProvider
  pure
    ClientConfigBuilder
      { clientConfigCryptoProvider = cryptoProvider,
        clientConfigServerCertVerifier = PlatformServerCertVerifier,
        clientConfigALPNProtocols = [],
        clientConfigEnableSNI = True,
        clientConfigCertifiedKeys = []
      }

withCertifiedKeys :: [CertifiedKey] -> ContT a IO (ConstPtr (ConstPtr FFI.CertifiedKey), CSize)
withCertifiedKeys :: forall a.
[CertifiedKey]
-> ContT a IO (ConstPtr (ConstPtr CertifiedKey), CSize)
withCertifiedKeys [CertifiedKey]
certifiedKeys = do
  certKeys <- [CertifiedKey]
-> (CertifiedKey -> ContT a IO (ConstPtr CertifiedKey))
-> ContT a IO [ConstPtr CertifiedKey]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [CertifiedKey]
certifiedKeys CertifiedKey -> ContT a IO (ConstPtr CertifiedKey)
forall {r}. CertifiedKey -> ContT r IO (ConstPtr CertifiedKey)
withCertifiedKey
  ContT \(ConstPtr (ConstPtr CertifiedKey), CSize) -> IO a
cb -> [ConstPtr CertifiedKey]
-> (Int -> Ptr (ConstPtr CertifiedKey) -> IO a) -> IO a
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen [ConstPtr CertifiedKey]
certKeys \Int
len Ptr (ConstPtr CertifiedKey)
ptr -> (ConstPtr (ConstPtr CertifiedKey), CSize) -> IO a
cb (Ptr (ConstPtr CertifiedKey) -> ConstPtr (ConstPtr CertifiedKey)
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr (ConstPtr CertifiedKey)
ptr, Int -> CSize
intToCSize Int
len)
  where
    withCertifiedKey :: CertifiedKey -> ContT r IO (ConstPtr CertifiedKey)
withCertifiedKey CertifiedKey {ByteString
certificateChain :: ByteString
privateKey :: ByteString
privateKey :: CertifiedKey -> ByteString
certificateChain :: CertifiedKey -> ByteString
..} = do
      (certPtr, certLen) <- (((Ptr CChar, Int) -> IO r) -> IO r) -> ContT r IO (Ptr CChar, Int)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT ((((Ptr CChar, Int) -> IO r) -> IO r)
 -> ContT r IO (Ptr CChar, Int))
-> (((Ptr CChar, Int) -> IO r) -> IO r)
-> ContT r IO (Ptr CChar, Int)
forall a b. (a -> b) -> a -> b
$ ByteString -> ((Ptr CChar, Int) -> IO r) -> IO r
forall a. ByteString -> ((Ptr CChar, Int) -> IO a) -> IO a
BU.unsafeUseAsCStringLen ByteString
certificateChain
      (privPtr, privLen) <- ContT $ BU.unsafeUseAsCStringLen privateKey
      certKeyPtr <- ContT alloca
      liftIO do
        rethrowR
          =<< FFI.certifiedKeyBuild
            (ConstPtr $ castPtr certPtr)
            (intToCSize certLen)
            (ConstPtr $ castPtr privPtr)
            (intToCSize privLen)
            certKeyPtr
        peek certKeyPtr

withALPNProtocols :: [ALPNProtocol] -> ContT a IO (ConstPtr FFI.SliceBytes, CSize)
withALPNProtocols :: forall a. [ALPNProtocol] -> ContT a IO (ConstPtr SliceBytes, CSize)
withALPNProtocols [ALPNProtocol]
bss = do
  bsPtrs <- [ByteString]
-> (ByteString -> ContT a IO SliceBytes) -> ContT a IO [SliceBytes]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for ([ALPNProtocol] -> [ByteString]
forall a b. Coercible a b => a -> b
coerce [ALPNProtocol]
bss) ByteString -> ContT a IO SliceBytes
forall {r}. ByteString -> ContT r IO SliceBytes
withSliceBytes
  ContT \(ConstPtr SliceBytes, CSize) -> IO a
cb -> [SliceBytes] -> (Int -> Ptr SliceBytes -> IO a) -> IO a
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen [SliceBytes]
bsPtrs \Int
len Ptr SliceBytes
bsPtr -> (ConstPtr SliceBytes, CSize) -> IO a
cb (Ptr SliceBytes -> ConstPtr SliceBytes
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr SliceBytes
bsPtr, Int -> CSize
intToCSize Int
len)
  where
    withSliceBytes :: ByteString -> ContT r IO SliceBytes
withSliceBytes ByteString
bs = do
      (buf, len) <- (((Ptr CChar, Int) -> IO r) -> IO r) -> ContT r IO (Ptr CChar, Int)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT ((((Ptr CChar, Int) -> IO r) -> IO r)
 -> ContT r IO (Ptr CChar, Int))
-> (((Ptr CChar, Int) -> IO r) -> IO r)
-> ContT r IO (Ptr CChar, Int)
forall a b. (a -> b) -> a -> b
$ ByteString -> ((Ptr CChar, Int) -> IO r) -> IO r
forall a. ByteString -> ((Ptr CChar, Int) -> IO a) -> IO a
BU.unsafeUseAsCStringLen ByteString
bs
      pure $ FFI.SliceBytes (castPtr buf) (intToCSize len)

configBuilderNew ::
  ( ConstPtr FFI.CryptoProvider ->
    ConstPtr TLSVersion ->
    CSize ->
    Ptr (Ptr configBuilder) ->
    IO FFI.Result
  ) ->
  CryptoProvider ->
  IO (Ptr configBuilder)
configBuilderNew :: forall configBuilder.
(ConstPtr CryptoProvider
 -> ConstPtr TLSVersion
 -> CSize
 -> Ptr (Ptr configBuilder)
 -> IO Result)
-> CryptoProvider -> IO (Ptr configBuilder)
configBuilderNew ConstPtr CryptoProvider
-> ConstPtr TLSVersion
-> CSize
-> Ptr (Ptr configBuilder)
-> IO Result
configBuilderNewCustom CryptoProvider
cryptoProvider = ContT (Ptr configBuilder) IO (Ptr configBuilder)
-> IO (Ptr configBuilder)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT do
  cryptoProviderPtr <- CryptoProvider
-> ContT (Ptr configBuilder) IO (ConstPtr CryptoProvider)
forall a. CryptoProvider -> ContT a IO (ConstPtr CryptoProvider)
withCryptoProvider CryptoProvider
cryptoProvider
  builderPtr <- ContT alloca
  (tlsVersionsLen, tlsVersionsPtr) <-
    ContT \(CSize, ConstPtr TLSVersion) -> IO (Ptr configBuilder)
cb -> [TLSVersion]
-> (Int -> Ptr TLSVersion -> IO (Ptr configBuilder))
-> IO (Ptr configBuilder)
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen (CryptoProvider -> [TLSVersion]
cryptoProviderTLSVersions CryptoProvider
cryptoProvider) \Int
len Ptr TLSVersion
ptr ->
      (CSize, ConstPtr TLSVersion) -> IO (Ptr configBuilder)
cb (Int -> CSize
intToCSize Int
len, Ptr TLSVersion -> ConstPtr TLSVersion
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr TLSVersion
ptr)
  liftIO do
    rethrowR
      =<< configBuilderNewCustom
        cryptoProviderPtr
        tlsVersionsPtr
        tlsVersionsLen
        builderPtr
    peek builderPtr

withRootCertStore :: [PEMCertificates] -> ContT a IO (ConstPtr FFI.RootCertStore)
withRootCertStore :: forall a. [PEMCertificates] -> ContT a IO (ConstPtr RootCertStore)
withRootCertStore [PEMCertificates]
certs = do
  storeBuilder <-
    ((Ptr RootCertStoreBuilder -> IO a) -> IO a)
-> ContT a IO (Ptr RootCertStoreBuilder)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr RootCertStoreBuilder -> IO a) -> IO a)
 -> ContT a IO (Ptr RootCertStoreBuilder))
-> ((Ptr RootCertStoreBuilder -> IO a) -> IO a)
-> ContT a IO (Ptr RootCertStoreBuilder)
forall a b. (a -> b) -> a -> b
$ IO (Ptr RootCertStoreBuilder)
-> (Ptr RootCertStoreBuilder -> IO ())
-> (Ptr RootCertStoreBuilder -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket IO (Ptr RootCertStoreBuilder)
FFI.rootCertStoreBuilderNew Ptr RootCertStoreBuilder -> IO ()
FFI.rootCertStoreBuilderFree
  let isStrict :: PEMCertificateParsing -> CBool
      isStrict =
        forall a. Num a => Bool -> a
fromBool @CBool (Bool -> CBool)
-> (PEMCertificateParsing -> Bool)
-> PEMCertificateParsing
-> CBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
          PEMCertificateParsing
PEMCertificateParsingStrict -> Bool
True
          PEMCertificateParsing
PEMCertificateParsingLax -> Bool
False
  for_ certs \case
    PEMCertificatesInMemory ByteString
bs PEMCertificateParsing
parsing -> do
      (buf, len) <- (((Ptr CChar, Int) -> IO a) -> IO a) -> ContT a IO (Ptr CChar, Int)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT ((((Ptr CChar, Int) -> IO a) -> IO a)
 -> ContT a IO (Ptr CChar, Int))
-> (((Ptr CChar, Int) -> IO a) -> IO a)
-> ContT a IO (Ptr CChar, Int)
forall a b. (a -> b) -> a -> b
$ ByteString -> ((Ptr CChar, Int) -> IO a) -> IO a
forall a. ByteString -> ((Ptr CChar, Int) -> IO a) -> IO a
BU.unsafeUseAsCStringLen ByteString
bs
      liftIO $
        rethrowR
          =<< FFI.rootCertStoreBuilderAddPem
            storeBuilder
            (ConstPtr $ castPtr buf)
            (intToCSize len)
            (isStrict parsing)
    PemCertificatesFromFile FilePath
path PEMCertificateParsing
parsing -> do
      pathPtr <- ((Ptr CChar -> IO a) -> IO a) -> ContT a IO (Ptr CChar)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr CChar -> IO a) -> IO a) -> ContT a IO (Ptr CChar))
-> ((Ptr CChar -> IO a) -> IO a) -> ContT a IO (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ FilePath -> (Ptr CChar -> IO a) -> IO a
forall a. FilePath -> (Ptr CChar -> IO a) -> IO a
withCString FilePath
path
      liftIO $
        rethrowR
          =<< FFI.rootCertStoreBuilderLoadRootsFromFile
            storeBuilder
            (ConstPtr pathPtr)
            (isStrict parsing)
  storePtr <- ContT alloca
  let buildRootCertStore = do
        IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr RootCertStoreBuilder
-> Ptr (ConstPtr RootCertStore) -> IO Result
FFI.rootCertStoreBuilderBuild Ptr RootCertStoreBuilder
storeBuilder Ptr (ConstPtr RootCertStore)
storePtr
        Ptr (ConstPtr RootCertStore) -> IO (ConstPtr RootCertStore)
forall a. Storable a => Ptr a -> IO a
peek Ptr (ConstPtr RootCertStore)
storePtr
  ContT $ E.bracket buildRootCertStore FFI.rootCertStoreFree

-- | Build a 'ClientConfigBuilder' into a 'ClientConfig'.
--
-- This is a relatively expensive operation, so it is a good idea to share one
-- 'ClientConfig' when creating multiple 'Connection's.
buildClientConfig :: (MonadIO m) => ClientConfigBuilder -> m ClientConfig
buildClientConfig :: forall (m :: * -> *).
MonadIO m =>
ClientConfigBuilder -> m ClientConfig
buildClientConfig ClientConfigBuilder {Bool
[CertifiedKey]
[ALPNProtocol]
ServerCertVerifier
CryptoProvider
clientConfigCryptoProvider :: ClientConfigBuilder -> CryptoProvider
clientConfigServerCertVerifier :: ClientConfigBuilder -> ServerCertVerifier
clientConfigALPNProtocols :: ClientConfigBuilder -> [ALPNProtocol]
clientConfigEnableSNI :: ClientConfigBuilder -> Bool
clientConfigCertifiedKeys :: ClientConfigBuilder -> [CertifiedKey]
clientConfigCryptoProvider :: CryptoProvider
clientConfigServerCertVerifier :: ServerCertVerifier
clientConfigALPNProtocols :: [ALPNProtocol]
clientConfigEnableSNI :: Bool
clientConfigCertifiedKeys :: [CertifiedKey]
..} = IO ClientConfig -> m ClientConfig
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ClientConfig -> m ClientConfig)
-> (IO ClientConfig -> IO ClientConfig)
-> IO ClientConfig
-> m ClientConfig
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO ClientConfig -> IO ClientConfig
forall a. IO a -> IO a
E.mask_ (IO ClientConfig -> m ClientConfig)
-> IO ClientConfig -> m ClientConfig
forall a b. (a -> b) -> a -> b
$ ContT ClientConfig IO ClientConfig -> IO ClientConfig
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT do
  builder <-
    ((Ptr ClientConfigBuilder -> IO ClientConfig) -> IO ClientConfig)
-> ContT ClientConfig IO (Ptr ClientConfigBuilder)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr ClientConfigBuilder -> IO ClientConfig) -> IO ClientConfig)
 -> ContT ClientConfig IO (Ptr ClientConfigBuilder))
-> ((Ptr ClientConfigBuilder -> IO ClientConfig)
    -> IO ClientConfig)
-> ContT ClientConfig IO (Ptr ClientConfigBuilder)
forall a b. (a -> b) -> a -> b
$
      IO (Ptr ClientConfigBuilder)
-> (Ptr ClientConfigBuilder -> IO ())
-> (Ptr ClientConfigBuilder -> IO ClientConfig)
-> IO ClientConfig
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError
        ( (ConstPtr CryptoProvider
 -> ConstPtr TLSVersion
 -> CSize
 -> Ptr (Ptr ClientConfigBuilder)
 -> IO Result)
-> CryptoProvider -> IO (Ptr ClientConfigBuilder)
forall configBuilder.
(ConstPtr CryptoProvider
 -> ConstPtr TLSVersion
 -> CSize
 -> Ptr (Ptr configBuilder)
 -> IO Result)
-> CryptoProvider -> IO (Ptr configBuilder)
configBuilderNew
            ConstPtr CryptoProvider
-> ConstPtr TLSVersion
-> CSize
-> Ptr (Ptr ClientConfigBuilder)
-> IO Result
FFI.clientConfigBuilderNewCustom
            CryptoProvider
clientConfigCryptoProvider
        )
        Ptr ClientConfigBuilder -> IO ()
FFI.clientConfigBuilderFree

  cryptoProviderPtr <- withCryptoProvider clientConfigCryptoProvider

  scv <- case clientConfigServerCertVerifier of
    ServerCertVerifier
PlatformServerCertVerifier ->
      CryptoProvider -> ContT ClientConfig IO (ConstPtr CryptoProvider)
forall a. CryptoProvider -> ContT a IO (ConstPtr CryptoProvider)
withCryptoProvider CryptoProvider
clientConfigCryptoProvider
        ContT ClientConfig IO (ConstPtr CryptoProvider)
-> (ConstPtr CryptoProvider
    -> ContT ClientConfig IO (Ptr ServerCertVerifier))
-> ContT ClientConfig IO (Ptr ServerCertVerifier)
forall a b.
ContT ClientConfig IO a
-> (a -> ContT ClientConfig IO b) -> ContT ClientConfig IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO (Ptr ServerCertVerifier)
-> ContT ClientConfig IO (Ptr ServerCertVerifier)
forall a. IO a -> ContT ClientConfig IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Ptr ServerCertVerifier)
 -> ContT ClientConfig IO (Ptr ServerCertVerifier))
-> (ConstPtr CryptoProvider -> IO (Ptr ServerCertVerifier))
-> ConstPtr CryptoProvider
-> ContT ClientConfig IO (Ptr ServerCertVerifier)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConstPtr CryptoProvider -> IO (Ptr ServerCertVerifier)
FFI.platformServerCertVerifierWithProvider
    ServerCertVerifier {[CertificateRevocationList]
NonEmpty PEMCertificates
serverCertVerifierCertificates :: NonEmpty PEMCertificates
serverCertVerifierCRLs :: [CertificateRevocationList]
serverCertVerifierCRLs :: ServerCertVerifier -> [CertificateRevocationList]
serverCertVerifierCertificates :: ServerCertVerifier -> NonEmpty PEMCertificates
..} -> do
      rootCertStore <- [PEMCertificates] -> ContT ClientConfig IO (ConstPtr RootCertStore)
forall a. [PEMCertificates] -> ContT a IO (ConstPtr RootCertStore)
withRootCertStore ([PEMCertificates]
 -> ContT ClientConfig IO (ConstPtr RootCertStore))
-> [PEMCertificates]
-> ContT ClientConfig IO (ConstPtr RootCertStore)
forall a b. (a -> b) -> a -> b
$ NonEmpty PEMCertificates -> [PEMCertificates]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList NonEmpty PEMCertificates
serverCertVerifierCertificates
      scvb <-
        ContT $
          E.bracket
            (FFI.webPkiServerCertVerifierBuilderNewWithProvider cryptoProviderPtr rootCertStore)
            FFI.webPkiServerCertVerifierBuilderFree
      crls :: [CStringLen] <-
        for serverCertVerifierCRLs $
          ContT . BU.unsafeUseAsCStringLen . unCertificateRevocationList
      liftIO $ for_ crls \(Ptr CChar
ptr, Int
len) ->
        Ptr WebPkiServerCertVerifierBuilder
-> ConstPtr Word8 -> CSize -> IO Result
FFI.webPkiServerCertVerifierBuilderAddCrl
          Ptr WebPkiServerCertVerifierBuilder
scvb
          (Ptr Word8 -> ConstPtr Word8
forall a. Ptr a -> ConstPtr a
ConstPtr (Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
ptr))
          (Int -> CSize
intToCSize Int
len)
      scvPtr <- ContT alloca
      let buildScv = do
            Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr WebPkiServerCertVerifierBuilder
-> Ptr (Ptr ServerCertVerifier) -> IO Result
FFI.webPkiServerCertVerifierBuilderBuild Ptr WebPkiServerCertVerifierBuilder
scvb Ptr (Ptr ServerCertVerifier)
scvPtr
            Ptr (Ptr ServerCertVerifier) -> IO (Ptr ServerCertVerifier)
forall a. Storable a => Ptr a -> IO a
peek Ptr (Ptr ServerCertVerifier)
scvPtr
      ContT $ E.bracket buildScv FFI.serverCertVerifierFree
  liftIO $ FFI.clientConfigBuilderSetServerVerifier builder (ConstPtr scv)

  (alpnPtr, len) <- withALPNProtocols clientConfigALPNProtocols
  liftIO $ rethrowR =<< FFI.clientConfigBuilderSetALPNProtocols builder alpnPtr len

  liftIO $
    FFI.clientConfigBuilderSetEnableSNI builder (fromBool @CBool clientConfigEnableSNI)

  (ptr, len) <- withCertifiedKeys clientConfigCertifiedKeys
  liftIO $ rethrowR =<< FFI.clientConfigBuilderSetCertifiedKey builder ptr len

  let clientConfigLogCallback = Maybe a
forall a. Maybe a
Nothing

  clientConfigPtrPtr <- ContT alloca
  liftIO do
    rethrowR =<< FFI.clientConfigBuilderBuild builder clientConfigPtrPtr
    clientConfigPtr <-
      newForeignPtr FFI.clientConfigFree . unConstPtr
        =<< peek clientConfigPtrPtr
    pure ClientConfig {..}

-- | Build a 'ServerConfigBuilder' into a 'ServerConfig'.
--
-- This is a relatively expensive operation, so it is a good idea to share one
-- 'ServerConfig' when creating multiple 'Connection's.
buildServerConfig :: (MonadIO m) => ServerConfigBuilder -> m ServerConfig
buildServerConfig :: forall (m :: * -> *).
MonadIO m =>
ServerConfigBuilder -> m ServerConfig
buildServerConfig ServerConfigBuilder {Bool
[ALPNProtocol]
Maybe ClientCertVerifier
NonEmpty CertifiedKey
CryptoProvider
serverConfigCryptoProvider :: CryptoProvider
serverConfigCertifiedKeys :: NonEmpty CertifiedKey
serverConfigALPNProtocols :: [ALPNProtocol]
serverConfigIgnoreClientOrder :: Bool
serverConfigClientCertVerifier :: Maybe ClientCertVerifier
serverConfigClientCertVerifier :: ServerConfigBuilder -> Maybe ClientCertVerifier
serverConfigIgnoreClientOrder :: ServerConfigBuilder -> Bool
serverConfigALPNProtocols :: ServerConfigBuilder -> [ALPNProtocol]
serverConfigCertifiedKeys :: ServerConfigBuilder -> NonEmpty CertifiedKey
serverConfigCryptoProvider :: ServerConfigBuilder -> CryptoProvider
..} = IO ServerConfig -> m ServerConfig
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ServerConfig -> m ServerConfig)
-> (IO ServerConfig -> IO ServerConfig)
-> IO ServerConfig
-> m ServerConfig
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO ServerConfig -> IO ServerConfig
forall a. IO a -> IO a
E.mask_ (IO ServerConfig -> m ServerConfig)
-> IO ServerConfig -> m ServerConfig
forall a b. (a -> b) -> a -> b
$ ContT ServerConfig IO ServerConfig -> IO ServerConfig
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT do
  builder <-
    ((Ptr ServerConfigBuilder -> IO ServerConfig) -> IO ServerConfig)
-> ContT ServerConfig IO (Ptr ServerConfigBuilder)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr ServerConfigBuilder -> IO ServerConfig) -> IO ServerConfig)
 -> ContT ServerConfig IO (Ptr ServerConfigBuilder))
-> ((Ptr ServerConfigBuilder -> IO ServerConfig)
    -> IO ServerConfig)
-> ContT ServerConfig IO (Ptr ServerConfigBuilder)
forall a b. (a -> b) -> a -> b
$
      IO (Ptr ServerConfigBuilder)
-> (Ptr ServerConfigBuilder -> IO ())
-> (Ptr ServerConfigBuilder -> IO ServerConfig)
-> IO ServerConfig
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError
        ( (ConstPtr CryptoProvider
 -> ConstPtr TLSVersion
 -> CSize
 -> Ptr (Ptr ServerConfigBuilder)
 -> IO Result)
-> CryptoProvider -> IO (Ptr ServerConfigBuilder)
forall configBuilder.
(ConstPtr CryptoProvider
 -> ConstPtr TLSVersion
 -> CSize
 -> Ptr (Ptr configBuilder)
 -> IO Result)
-> CryptoProvider -> IO (Ptr configBuilder)
configBuilderNew
            ConstPtr CryptoProvider
-> ConstPtr TLSVersion
-> CSize
-> Ptr (Ptr ServerConfigBuilder)
-> IO Result
FFI.serverConfigBuilderNewCustom
            CryptoProvider
serverConfigCryptoProvider
        )
        Ptr ServerConfigBuilder -> IO ()
FFI.serverConfigBuilderFree

  cryptoProviderPtr <- withCryptoProvider serverConfigCryptoProvider

  (alpnPtr, len) <- withALPNProtocols serverConfigALPNProtocols
  liftIO $ rethrowR =<< FFI.serverConfigBuilderSetALPNProtocols builder alpnPtr len

  liftIO $
    rethrowR
      =<< FFI.serverConfigBuilderSetIgnoreClientOrder
        builder
        (fromBool @CBool serverConfigIgnoreClientOrder)

  (ptr, len) <- withCertifiedKeys (NE.toList serverConfigCertifiedKeys)
  liftIO $ rethrowR =<< FFI.serverConfigBuilderSetCertifiedKeys builder ptr len

  for_ serverConfigClientCertVerifier \ClientCertVerifier {[CertificateRevocationList]
NonEmpty PEMCertificates
ClientCertVerifierPolicy
clientCertVerifierPolicy :: ClientCertVerifierPolicy
clientCertVerifierCertificates :: NonEmpty PEMCertificates
clientCertVerifierCRLs :: [CertificateRevocationList]
clientCertVerifierCRLs :: ClientCertVerifier -> [CertificateRevocationList]
clientCertVerifierCertificates :: ClientCertVerifier -> NonEmpty PEMCertificates
clientCertVerifierPolicy :: ClientCertVerifier -> ClientCertVerifierPolicy
..} -> do
    roots <- [PEMCertificates] -> ContT ServerConfig IO (ConstPtr RootCertStore)
forall a. [PEMCertificates] -> ContT a IO (ConstPtr RootCertStore)
withRootCertStore ([PEMCertificates]
 -> ContT ServerConfig IO (ConstPtr RootCertStore))
-> [PEMCertificates]
-> ContT ServerConfig IO (ConstPtr RootCertStore)
forall a b. (a -> b) -> a -> b
$ NonEmpty PEMCertificates -> [PEMCertificates]
forall a. NonEmpty a -> [a]
NE.toList NonEmpty PEMCertificates
clientCertVerifierCertificates
    ccvb <-
      ContT $
        E.bracket
          (FFI.webPkiClientCertVerifierBuilderNewWithProvider cryptoProviderPtr roots)
          FFI.webPkiClientCertVerifierBuilderFree
    crls :: [CStringLen] <-
      for clientCertVerifierCRLs $
        ContT . BU.unsafeUseAsCStringLen . unCertificateRevocationList
    liftIO do
      case clientCertVerifierPolicy of
        ClientCertVerifierPolicy
AllowAnyAuthenticatedClient -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        ClientCertVerifierPolicy
AllowAnyAnonymousOrAuthenticatedClient ->
          Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr WebPkiClientCertVerifierBuilder -> IO Result
FFI.webPkiClientCertVerifierBuilderAllowUnauthenticated Ptr WebPkiClientCertVerifierBuilder
ccvb
      for_ crls \(Ptr CChar
ptr, Int
len) ->
        Ptr WebPkiClientCertVerifierBuilder
-> ConstPtr Word8 -> CSize -> IO Result
FFI.webPkiClientCertVerifierBuilderAddCrl
          Ptr WebPkiClientCertVerifierBuilder
ccvb
          (Ptr Word8 -> ConstPtr Word8
forall a. Ptr a -> ConstPtr a
ConstPtr (Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
ptr))
          (Int -> CSize
intToCSize Int
len)
    ccvPtr <- ContT alloca
    let buildCcv = do
          Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr WebPkiClientCertVerifierBuilder
-> Ptr (Ptr ClientCertVerifier) -> IO Result
FFI.webPkiClientCertVerifierBuilderBuild Ptr WebPkiClientCertVerifierBuilder
ccvb Ptr (Ptr ClientCertVerifier)
ccvPtr
          Ptr (Ptr ClientCertVerifier) -> IO (Ptr ClientCertVerifier)
forall a. Storable a => Ptr a -> IO a
peek Ptr (Ptr ClientCertVerifier)
ccvPtr
    ccv <- ContT $ E.bracket buildCcv FFI.clientCertVerifierFree
    liftIO $ FFI.serverConfigBuilderSetClientVerifier builder (ConstPtr ccv)

  serverConfigPtrPtr <- ContT alloca
  liftIO do
    rethrowR =<< FFI.serverConfigBuilderBuild builder serverConfigPtrPtr
    serverConfigPtr <-
      newForeignPtr FFI.serverConfigFree . unConstPtr
        =<< peek serverConfigPtrPtr
    let serverConfigLogCallback = Maybe a
forall a. Maybe a
Nothing
    pure ServerConfig {..}

-- | A 'ServerConfigBuilder' with good defaults.
defaultServerConfigBuilder ::
  (MonadIO m) => NonEmpty CertifiedKey -> m ServerConfigBuilder
defaultServerConfigBuilder :: forall (m :: * -> *).
MonadIO m =>
NonEmpty CertifiedKey -> m ServerConfigBuilder
defaultServerConfigBuilder NonEmpty CertifiedKey
certifiedKeys = do
  cryptoProvider <- m CryptoProvider
forall (m :: * -> *). MonadIO m => m CryptoProvider
getDefaultCryptoProvider
  pure
    ServerConfigBuilder
      { serverConfigCryptoProvider = cryptoProvider,
        serverConfigCertifiedKeys = certifiedKeys,
        serverConfigALPNProtocols = [],
        serverConfigIgnoreClientOrder = False,
        serverConfigClientCertVerifier = Nothing
      }

-- | Allocate a new logging callback, taking a 'LogLevel' and a message.
--
-- If it throws an exception, it will be wrapped in a 'RustlsLogException' and
-- passed to 'reportError'.
--
-- 🚫 Make sure that its lifetime encloses those of the 'Connection's which you
-- configured to use it.
newLogCallback :: (LogLevel -> Text -> IO ()) -> Acquire LogCallback
newLogCallback :: (LogLevel -> Text -> IO ()) -> Acquire LogCallback
newLogCallback LogLevel -> Text -> IO ()
cb = (FunPtr LogCallback -> LogCallback)
-> Acquire (FunPtr LogCallback) -> Acquire LogCallback
forall a b. (a -> b) -> Acquire a -> Acquire b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FunPtr LogCallback -> LogCallback
LogCallback (Acquire (FunPtr LogCallback) -> Acquire LogCallback)
-> (IO (FunPtr LogCallback) -> Acquire (FunPtr LogCallback))
-> IO (FunPtr LogCallback)
-> Acquire LogCallback
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IO (FunPtr LogCallback)
 -> (FunPtr LogCallback -> IO ()) -> Acquire (FunPtr LogCallback))
-> (FunPtr LogCallback -> IO ())
-> IO (FunPtr LogCallback)
-> Acquire (FunPtr LogCallback)
forall a b c. (a -> b -> c) -> b -> a -> c
flip IO (FunPtr LogCallback)
-> (FunPtr LogCallback -> IO ()) -> Acquire (FunPtr LogCallback)
forall a. IO a -> (a -> IO ()) -> Acquire a
mkAcquire FunPtr LogCallback -> IO ()
forall a. FunPtr a -> IO ()
freeHaskellFunPtr (IO (FunPtr LogCallback) -> Acquire LogCallback)
-> IO (FunPtr LogCallback) -> Acquire LogCallback
forall a b. (a -> b) -> a -> b
$
  LogCallback -> IO (FunPtr LogCallback)
FFI.mkLogCallback \Ptr Userdata
_ (ConstPtr Ptr LogParams
logParamsPtr) -> IO () -> IO ()
ignoreExceptions do
    FFI.LogParams {..} <- Ptr LogParams -> IO LogParams
forall a. Storable a => Ptr a -> IO a
peek Ptr LogParams
logParamsPtr
    let logLevel = case LogLevel
rustlsLogParamsLevel of
          FFI.LogLevel CSize
1 -> LogLevel -> Either LogLevel LogLevel
forall a b. b -> Either a b
Right LogLevel
LogLevelError
          FFI.LogLevel CSize
2 -> LogLevel -> Either LogLevel LogLevel
forall a b. b -> Either a b
Right LogLevel
LogLevelWarn
          FFI.LogLevel CSize
3 -> LogLevel -> Either LogLevel LogLevel
forall a b. b -> Either a b
Right LogLevel
LogLevelInfo
          FFI.LogLevel CSize
4 -> LogLevel -> Either LogLevel LogLevel
forall a b. b -> Either a b
Right LogLevel
LogLevelDebug
          FFI.LogLevel CSize
5 -> LogLevel -> Either LogLevel LogLevel
forall a b. b -> Either a b
Right LogLevel
LogLevelTrace
          LogLevel
l -> LogLevel -> Either LogLevel LogLevel
forall a b. a -> Either a b
Left LogLevel
l
    case logLevel of
      Left LogLevel
l -> SomeException -> IO ()
report (SomeException -> IO ()) -> SomeException -> IO ()
forall a b. (a -> b) -> a -> b
$ RustlsUnknownLogLevel -> SomeException
forall e. (Exception e, HasExceptionContext) => e -> SomeException
E.SomeException (RustlsUnknownLogLevel -> SomeException)
-> RustlsUnknownLogLevel -> SomeException
forall a b. (a -> b) -> a -> b
$ LogLevel -> RustlsUnknownLogLevel
RustlsUnknownLogLevel LogLevel
l
      Right LogLevel
logLevel -> do
        msg <- Str -> IO Text
strToText Str
rustlsLogParamsMessage
        cb logLevel msg `E.catch` report
  where
    report :: SomeException -> IO ()
report = SomeException -> IO ()
reportError (SomeException -> IO ())
-> (SomeException -> SomeException) -> SomeException -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RustlsLogException -> SomeException
forall e. (Exception e, HasExceptionContext) => e -> SomeException
E.SomeException (RustlsLogException -> SomeException)
-> (SomeException -> RustlsLogException)
-> SomeException
-> SomeException
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> RustlsLogException
RustlsLogException

newConnection ::
  Backend ->
  ForeignPtr config ->
  Maybe LogCallback ->
  (ConstPtr config -> Ptr (Ptr FFI.Connection) -> IO FFI.Result) ->
  Acquire (Connection side)
newConnection :: forall config (side :: Side).
Backend
-> ForeignPtr config
-> Maybe LogCallback
-> (ConstPtr config -> Ptr (Ptr Connection) -> IO Result)
-> Acquire (Connection side)
newConnection Backend
backend ForeignPtr config
configPtr Maybe LogCallback
logCallback ConstPtr config -> Ptr (Ptr Connection) -> IO Result
connectionNew =
  IO (Connection side)
-> (Connection side -> IO ()) -> Acquire (Connection side)
forall a. IO a -> (a -> IO ()) -> Acquire a
mkAcquire IO (Connection side)
forall {side :: Side}. IO (Connection side)
acquire Connection side -> IO ()
forall {side :: Side}. Connection side -> IO ()
release
  where
    acquire :: IO (Connection side)
acquire = do
      conn <-
        (Ptr (Ptr Connection) -> IO (Ptr Connection))
-> IO (Ptr Connection)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr (Ptr Connection)
connPtrPtr ->
          ForeignPtr config
-> (Ptr config -> IO (Ptr Connection)) -> IO (Ptr Connection)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr config
configPtr \Ptr config
cfgPtr -> IO (Ptr Connection) -> IO (Ptr Connection)
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
            Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConstPtr config -> Ptr (Ptr Connection) -> IO Result
connectionNew (Ptr config -> ConstPtr config
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr config
cfgPtr) Ptr (Ptr Connection)
connPtrPtr
            Ptr (Ptr Connection) -> IO (Ptr Connection)
forall a. Storable a => Ptr a -> IO a
peek Ptr (Ptr Connection)
connPtrPtr
      ioMsgReq <- newEmptyMVar
      ioMsgRes <- newEmptyMVar
      lenPtr <- malloc
      let readWriteCallback t -> Ptr Word8
toBuf p
_ud t
buf CSize
len Ptr CSize
iPtr = do
            MVar IOMsgRes -> IOMsgRes -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar IOMsgRes
ioMsgRes (IOMsgRes -> IO ()) -> IOMsgRes -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> CSize -> Ptr CSize -> IOMsgRes
UsingBuffer (t -> Ptr Word8
toBuf t
buf) CSize
len Ptr CSize
iPtr
            Done ioResult <- MVar IOMsgReq -> IO IOMsgReq
forall a. MVar a -> IO a
takeMVar MVar IOMsgReq
ioMsgReq
            pure ioResult
      readCallback <- FFI.mkReadCallback $ readWriteCallback id
      writeCallback <- FFI.mkWriteCallback $ readWriteCallback unConstPtr
      let freeCallback = do
            FunPtr ReadCallback -> IO ()
forall a. FunPtr a -> IO ()
freeHaskellFunPtr FunPtr ReadCallback
readCallback
            FunPtr WriteCallback -> IO ()
forall a. FunPtr a -> IO ()
freeHaskellFunPtr FunPtr WriteCallback
writeCallback
          interact = IO () -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever do
            Request readOrWrite <- MVar IOMsgReq -> IO IOMsgReq
forall a. MVar a -> IO a
takeMVar MVar IOMsgReq
ioMsgReq
            let readOrWriteTls = case ReadOrWrite
readOrWrite of
                  ReadOrWrite
Read -> (Ptr Connection
 -> FunPtr ReadCallback -> Ptr Userdata -> Ptr CSize -> IO IOResult)
-> FunPtr ReadCallback
-> Ptr Connection
-> Ptr Userdata
-> Ptr CSize
-> IO IOResult
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr Connection
-> FunPtr ReadCallback -> Ptr Userdata -> Ptr CSize -> IO IOResult
FFI.connectionReadTls FunPtr ReadCallback
readCallback
                  ReadOrWrite
Write -> (Ptr Connection
 -> FunPtr WriteCallback
 -> Ptr Userdata
 -> Ptr CSize
 -> IO IOResult)
-> FunPtr WriteCallback
-> Ptr Connection
-> Ptr Userdata
-> Ptr CSize
-> IO IOResult
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr Connection
-> FunPtr WriteCallback -> Ptr Userdata -> Ptr CSize -> IO IOResult
FFI.connectionWriteTls FunPtr WriteCallback
writeCallback
            _ <- readOrWriteTls conn nullPtr lenPtr
            putMVar ioMsgRes DoneFFI
      interactThread <- forkFinally interact (const freeCallback)
      for_ logCallback $ FFI.connectionSetLogCallback conn . unLogCallback
      Connection <$> newMVar Connection' {..}
    release :: Connection side -> IO ()
release (Connection MVar Connection'
c) = do
      Just Connection' {..} <- MVar Connection' -> IO (Maybe Connection')
forall a. MVar a -> IO (Maybe a)
tryTakeMVar MVar Connection'
c
      FFI.connectionFree conn
      free lenPtr
      killThread interactThread

-- | Initialize a TLS connection as a client.
newClientConnection ::
  Backend ->
  ClientConfig ->
  -- | Hostname.
  Text ->
  Acquire (Connection Client)
newClientConnection :: Backend -> ClientConfig -> Text -> Acquire (Connection 'Client)
newClientConnection Backend
b ClientConfig {Maybe LogCallback
ForeignPtr ClientConfig
clientConfigLogCallback :: ClientConfig -> Maybe LogCallback
clientConfigPtr :: ClientConfig -> ForeignPtr ClientConfig
clientConfigPtr :: ForeignPtr ClientConfig
clientConfigLogCallback :: Maybe LogCallback
..} Text
hostname =
  Backend
-> ForeignPtr ClientConfig
-> Maybe LogCallback
-> (ConstPtr ClientConfig -> Ptr (Ptr Connection) -> IO Result)
-> Acquire (Connection 'Client)
forall config (side :: Side).
Backend
-> ForeignPtr config
-> Maybe LogCallback
-> (ConstPtr config -> Ptr (Ptr Connection) -> IO Result)
-> Acquire (Connection side)
newConnection Backend
b ForeignPtr ClientConfig
clientConfigPtr Maybe LogCallback
clientConfigLogCallback \ConstPtr ClientConfig
configPtr Ptr (Ptr Connection)
connPtrPtr ->
    Text -> (Ptr CChar -> IO Result) -> IO Result
forall a. Text -> (Ptr CChar -> IO a) -> IO a
T.withCString Text
hostname \Ptr CChar
hostnamePtr ->
      ConstPtr ClientConfig
-> ConstCString -> Ptr (Ptr Connection) -> IO Result
FFI.clientConnectionNew ConstPtr ClientConfig
configPtr (Ptr CChar -> ConstCString
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr CChar
hostnamePtr) Ptr (Ptr Connection)
connPtrPtr

-- | Initialize a TLS connection as a server.
newServerConnection ::
  Backend ->
  ServerConfig ->
  Acquire (Connection Server)
newServerConnection :: Backend -> ServerConfig -> Acquire (Connection 'Server)
newServerConnection Backend
b ServerConfig {Maybe LogCallback
ForeignPtr ServerConfig
serverConfigLogCallback :: ServerConfig -> Maybe LogCallback
serverConfigPtr :: ServerConfig -> ForeignPtr ServerConfig
serverConfigPtr :: ForeignPtr ServerConfig
serverConfigLogCallback :: Maybe LogCallback
..} =
  Backend
-> ForeignPtr ServerConfig
-> Maybe LogCallback
-> (ConstPtr ServerConfig -> Ptr (Ptr Connection) -> IO Result)
-> Acquire (Connection 'Server)
forall config (side :: Side).
Backend
-> ForeignPtr config
-> Maybe LogCallback
-> (ConstPtr config -> Ptr (Ptr Connection) -> IO Result)
-> Acquire (Connection side)
newConnection Backend
b ForeignPtr ServerConfig
serverConfigPtr Maybe LogCallback
serverConfigLogCallback ConstPtr ServerConfig -> Ptr (Ptr Connection) -> IO Result
FFI.serverConnectionNew

-- | Ensure that the connection is handshaked. It is only necessary to call this
-- if you want to obtain connection information. You can do so by providing a
-- 'HandshakeQuery'.
--
-- >>> :{
-- getALPNAndTLSVersion ::
--   MonadIO m =>
--   Connection side ->
--   m (Maybe ALPNProtocol, TLSVersion)
-- getALPNAndTLSVersion conn =
--   handshake conn $ (,) <$> getALPNProtocol <*> getTLSVersion
-- :}
handshake :: (MonadIO m) => Connection side -> HandshakeQuery side a -> m a
handshake :: forall (m :: * -> *) (side :: Side) a.
MonadIO m =>
Connection side -> HandshakeQuery side a -> m a
handshake Connection side
conn (HandshakeQuery ReaderT Connection' IO a
query) = IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
  Connection side -> (Connection' -> IO a) -> IO a
forall (side :: Side) a.
Connection side -> (Connection' -> IO a) -> IO a
withConnection Connection side
conn \Connection'
c -> do
    _ <- Connection' -> IO ()
completePriorIO Connection'
c
    runReaderT query c

-- | Get the negotiated ALPN protocol, if any.
getALPNProtocol :: HandshakeQuery side (Maybe ALPNProtocol)
getALPNProtocol :: forall (side :: Side). HandshakeQuery side (Maybe ALPNProtocol)
getALPNProtocol = (Connection' -> IO (Maybe ALPNProtocol))
-> HandshakeQuery side (Maybe ALPNProtocol)
forall a (side :: Side).
(Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery \Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn, Ptr CSize
lenPtr :: Connection' -> Ptr CSize
lenPtr :: Ptr CSize
lenPtr} ->
  (Ptr (ConstPtr Word8) -> IO (Maybe ALPNProtocol))
-> IO (Maybe ALPNProtocol)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr (ConstPtr Word8)
bufPtrPtr -> do
    ConstPtr Connection -> Ptr (ConstPtr Word8) -> Ptr CSize -> IO ()
FFI.connectionGetALPNProtocol (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn) Ptr (ConstPtr Word8)
bufPtrPtr Ptr CSize
lenPtr
    ConstPtr bufPtr <- Ptr (ConstPtr Word8) -> IO (ConstPtr Word8)
forall a. Storable a => Ptr a -> IO a
peek Ptr (ConstPtr Word8)
bufPtrPtr
    len <- peek lenPtr
    !alpn <- B.packCStringLen (castPtr bufPtr, cSizeToInt len)
    pure $ if B.null alpn then Nothing else Just $ ALPNProtocol alpn

-- | Get the negotiated TLS protocol version.
getTLSVersion :: HandshakeQuery side TLSVersion
getTLSVersion :: forall (side :: Side). HandshakeQuery side TLSVersion
getTLSVersion = (Connection' -> IO TLSVersion) -> HandshakeQuery side TLSVersion
forall a (side :: Side).
(Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery \Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn} -> do
  !ver <- ConstPtr Connection -> IO TLSVersion
FFI.connectionGetProtocolVersion (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn)
  when (unTLSVersion ver == 0) $
    fail "internal rustls error: no protocol version negotiated"
  pure ver

-- | Get the negotiated cipher suite.
getNegotiatedCipherSuite :: HandshakeQuery side NegotiatedCipherSuite
getNegotiatedCipherSuite :: forall (side :: Side). HandshakeQuery side NegotiatedCipherSuite
getNegotiatedCipherSuite = (Connection' -> IO NegotiatedCipherSuite)
-> HandshakeQuery side NegotiatedCipherSuite
forall a (side :: Side).
(Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery \Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn} -> do
  negotiatedCipherSuiteID <-
    ConstPtr Connection -> IO Word16
FFI.connectionGetNegotiatedCipherSuite (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn)
  when (negotiatedCipherSuiteID == 0) $
    fail "internal rustls error: no cipher suite negotiated"

  negotiatedCipherSuiteName <- alloca \Ptr Str
strPtr -> do
    ConstPtr Connection -> Ptr Str -> IO ()
FFI.connectionGetNegotiatedCipherSuiteName (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn) Ptr Str
strPtr
    Str -> IO Text
strToText (Str -> IO Text) -> IO Str -> IO Text
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr Str -> IO Str
forall a. Storable a => Ptr a -> IO a
peek Ptr Str
strPtr
  when (T.null negotiatedCipherSuiteName) $
    fail "internal rustls error: no cipher suite negotiated"

  pure NegotiatedCipherSuite {..}

-- | Get the SNI hostname set by the client, if any.
getSNIHostname :: HandshakeQuery Server (Maybe Text)
getSNIHostname :: HandshakeQuery 'Server (Maybe Text)
getSNIHostname = (Connection' -> IO (Maybe Text))
-> HandshakeQuery 'Server (Maybe Text)
forall a (side :: Side).
(Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery \Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn, Ptr CSize
lenPtr :: Connection' -> Ptr CSize
lenPtr :: Ptr CSize
lenPtr} ->
  let go :: CSize -> IO (Maybe Text)
go CSize
n = Int -> (Ptr Word8 -> IO (Maybe Text)) -> IO (Maybe Text)
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes (CSize -> Int
cSizeToInt CSize
n) \Ptr Word8
bufPtr -> do
        res <- ConstPtr Connection -> Ptr Word8 -> CSize -> Ptr CSize -> IO Result
FFI.serverConnectionGetSNIHostname (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn) Ptr Word8
bufPtr CSize
n Ptr CSize
lenPtr
        if res == FFI.resultInsufficientSize
          then go (2 * n)
          else do
            rethrowR res
            len <- peek lenPtr
            !sni <- T.peekCStringLen (castPtr bufPtr, cSizeToInt len)
            pure $ if T.null sni then Nothing else Just sni
   in CSize -> IO (Maybe Text)
go CSize
16

-- | A DER-encoded certificate.
newtype DERCertificate = DERCertificate {DERCertificate -> ByteString
unDERCertificate :: ByteString}
  deriving stock (Int -> DERCertificate -> ShowS
[DERCertificate] -> ShowS
DERCertificate -> FilePath
(Int -> DERCertificate -> ShowS)
-> (DERCertificate -> FilePath)
-> ([DERCertificate] -> ShowS)
-> Show DERCertificate
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DERCertificate -> ShowS
showsPrec :: Int -> DERCertificate -> ShowS
$cshow :: DERCertificate -> FilePath
show :: DERCertificate -> FilePath
$cshowList :: [DERCertificate] -> ShowS
showList :: [DERCertificate] -> ShowS
Show, DERCertificate -> DERCertificate -> Bool
(DERCertificate -> DERCertificate -> Bool)
-> (DERCertificate -> DERCertificate -> Bool) -> Eq DERCertificate
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DERCertificate -> DERCertificate -> Bool
== :: DERCertificate -> DERCertificate -> Bool
$c/= :: DERCertificate -> DERCertificate -> Bool
/= :: DERCertificate -> DERCertificate -> Bool
Eq, Eq DERCertificate
Eq DERCertificate =>
(DERCertificate -> DERCertificate -> Ordering)
-> (DERCertificate -> DERCertificate -> Bool)
-> (DERCertificate -> DERCertificate -> Bool)
-> (DERCertificate -> DERCertificate -> Bool)
-> (DERCertificate -> DERCertificate -> Bool)
-> (DERCertificate -> DERCertificate -> DERCertificate)
-> (DERCertificate -> DERCertificate -> DERCertificate)
-> Ord DERCertificate
DERCertificate -> DERCertificate -> Bool
DERCertificate -> DERCertificate -> Ordering
DERCertificate -> DERCertificate -> DERCertificate
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: DERCertificate -> DERCertificate -> Ordering
compare :: DERCertificate -> DERCertificate -> Ordering
$c< :: DERCertificate -> DERCertificate -> Bool
< :: DERCertificate -> DERCertificate -> Bool
$c<= :: DERCertificate -> DERCertificate -> Bool
<= :: DERCertificate -> DERCertificate -> Bool
$c> :: DERCertificate -> DERCertificate -> Bool
> :: DERCertificate -> DERCertificate -> Bool
$c>= :: DERCertificate -> DERCertificate -> Bool
>= :: DERCertificate -> DERCertificate -> Bool
$cmax :: DERCertificate -> DERCertificate -> DERCertificate
max :: DERCertificate -> DERCertificate -> DERCertificate
$cmin :: DERCertificate -> DERCertificate -> DERCertificate
min :: DERCertificate -> DERCertificate -> DERCertificate
Ord, (forall x. DERCertificate -> Rep DERCertificate x)
-> (forall x. Rep DERCertificate x -> DERCertificate)
-> Generic DERCertificate
forall x. Rep DERCertificate x -> DERCertificate
forall x. DERCertificate -> Rep DERCertificate x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. DERCertificate -> Rep DERCertificate x
from :: forall x. DERCertificate -> Rep DERCertificate x
$cto :: forall x. Rep DERCertificate x -> DERCertificate
to :: forall x. Rep DERCertificate x -> DERCertificate
Generic)

-- | Get the @i@-th certificate provided by the peer.
--
-- Index @0@ is the end entity certificate. Higher indices are certificates in
-- the chain. Requesting an index higher than what is available returns
-- 'Nothing'.
getPeerCertificate :: CSize -> HandshakeQuery side (Maybe DERCertificate)
getPeerCertificate :: forall (side :: Side).
CSize -> HandshakeQuery side (Maybe DERCertificate)
getPeerCertificate CSize
i = (Connection' -> IO (Maybe DERCertificate))
-> HandshakeQuery side (Maybe DERCertificate)
forall a (side :: Side).
(Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery \Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn, Ptr CSize
lenPtr :: Connection' -> Ptr CSize
lenPtr :: Ptr CSize
lenPtr} -> do
  certPtr <- ConstPtr Connection -> CSize -> IO (ConstPtr Certificate)
FFI.connectionGetPeerCertificate (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn) CSize
i
  if certPtr == ConstPtr nullPtr
    then pure Nothing
    else alloca \Ptr (ConstPtr Word8)
bufPtrPtr -> do
      Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConstPtr Certificate
-> Ptr (ConstPtr Word8) -> Ptr CSize -> IO Result
FFI.certificateGetDER ConstPtr Certificate
certPtr Ptr (ConstPtr Word8)
bufPtrPtr Ptr CSize
lenPtr
      ConstPtr bufPtr <- Ptr (ConstPtr Word8) -> IO (ConstPtr Word8)
forall a. Storable a => Ptr a -> IO a
peek Ptr (ConstPtr Word8)
bufPtrPtr
      len <- cSizeToInt <$> peek lenPtr
      !bs <- B.packCStringLen (castPtr bufPtr, len)
      pure $ Just $ DERCertificate bs

-- | Send a @close_notify@ warning alert. This informs the peer that the
-- connection is being closed.
sendCloseNotify :: (MonadIO m) => Connection side -> m ()
sendCloseNotify :: forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> m ()
sendCloseNotify Connection side
conn = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$
  Connection side -> (Connection' -> IO ()) -> IO ()
forall (side :: Side) a.
Connection side -> (Connection' -> IO a) -> IO a
withConnection Connection side
conn \c :: Connection'
c@Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn} -> do
    Ptr Connection -> IO ()
FFI.connectionSendCloseNotify Ptr Connection
conn
    IO IsEOF -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO IsEOF -> IO ()) -> IO IsEOF -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection' -> IO IsEOF
completeIO Connection'
c

-- | Read data from the Rustls 'Connection' into the given buffer.
readPtr :: (MonadIO m) => Connection side -> Ptr Word8 -> CSize -> m CSize
readPtr :: forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> Ptr Word8 -> CSize -> m CSize
readPtr Connection side
conn Ptr Word8
buf CSize
len = IO CSize -> m CSize
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO CSize -> m CSize) -> IO CSize -> m CSize
forall a b. (a -> b) -> a -> b
$
  Connection side -> (Connection' -> IO CSize) -> IO CSize
forall (side :: Side) a.
Connection side -> (Connection' -> IO a) -> IO a
withConnection Connection side
conn \c :: Connection'
c@Connection' {Ptr CSize
Ptr Connection
ThreadId
MVar IOMsgRes
MVar IOMsgReq
Backend
interactThread :: Connection' -> ThreadId
ioMsgRes :: Connection' -> MVar IOMsgRes
ioMsgReq :: Connection' -> MVar IOMsgReq
lenPtr :: Connection' -> Ptr CSize
backend :: Connection' -> Backend
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
backend :: Backend
lenPtr :: Ptr CSize
ioMsgReq :: MVar IOMsgReq
ioMsgRes :: MVar IOMsgRes
interactThread :: ThreadId
..} -> do
    Connection' -> IO ()
completePriorIO Connection'
c
    IO Bool -> IO ()
forall (m :: * -> *). Monad m => m Bool -> m ()
loopWhileTrue (IO Bool -> IO ()) -> IO Bool -> IO ()
forall a b. (a -> b) -> a -> b
$
      Connection' -> IO Bool
getWantsRead Connection'
c IO Bool -> (Bool -> IO Bool) -> IO Bool
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
True -> (IsEOF
NotEOF IsEOF -> IsEOF -> Bool
forall a. Eq a => a -> a -> Bool
==) (IsEOF -> Bool) -> IO IsEOF -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection' -> IO IsEOF
completeIO Connection'
c
        Bool
False -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
    Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr Connection -> Ptr Word8 -> CSize -> Ptr CSize -> IO Result
FFI.connectionRead Ptr Connection
conn Ptr Word8
buf CSize
len Ptr CSize
lenPtr
    Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
peek Ptr CSize
lenPtr

-- | Read data from the Rustls 'Connection' into a 'ByteString'. The result will
-- not be longer than the given length.
readBS ::
  (MonadIO m) =>
  Connection side ->
  -- | Maximum result length. Note that a buffer of this size will be allocated.
  Int ->
  m ByteString
readBS :: forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> Int -> m ByteString
readBS Connection side
conn Int
maxLen = IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$
  Int -> (Ptr Word8 -> IO Int) -> IO ByteString
BI.createAndTrim Int
maxLen \Ptr Word8
buf ->
    CSize -> Int
cSizeToInt (CSize -> Int) -> IO CSize -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection side -> Ptr Word8 -> CSize -> IO CSize
forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> Ptr Word8 -> CSize -> m CSize
readPtr Connection side
conn Ptr Word8
buf (Int -> CSize
intToCSize Int
maxLen)

-- | Write data to the Rustls 'Connection' from the given buffer.
writePtr :: (MonadIO m) => Connection side -> Ptr Word8 -> CSize -> m CSize
writePtr :: forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> Ptr Word8 -> CSize -> m CSize
writePtr Connection side
conn Ptr Word8
buf CSize
len = IO CSize -> m CSize
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO CSize -> m CSize) -> IO CSize -> m CSize
forall a b. (a -> b) -> a -> b
$
  Connection side -> (Connection' -> IO CSize) -> IO CSize
forall (side :: Side) a.
Connection side -> (Connection' -> IO a) -> IO a
withConnection Connection side
conn \c :: Connection'
c@Connection' {Ptr CSize
Ptr Connection
ThreadId
MVar IOMsgRes
MVar IOMsgReq
Backend
interactThread :: Connection' -> ThreadId
ioMsgRes :: Connection' -> MVar IOMsgRes
ioMsgReq :: Connection' -> MVar IOMsgReq
lenPtr :: Connection' -> Ptr CSize
backend :: Connection' -> Backend
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
backend :: Backend
lenPtr :: Ptr CSize
ioMsgReq :: MVar IOMsgReq
ioMsgRes :: MVar IOMsgRes
interactThread :: ThreadId
..} -> do
    Connection' -> IO ()
completePriorIO Connection'
c
    Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr Connection -> Ptr Word8 -> CSize -> Ptr CSize -> IO Result
FFI.connectionWrite Ptr Connection
conn Ptr Word8
buf CSize
len Ptr CSize
lenPtr
    _ <- Connection' -> IO IsEOF
completeIO Connection'
c
    peek lenPtr

-- | Write a 'ByteString' to the Rustls 'Connection'.
writeBS :: (MonadIO m) => Connection side -> ByteString -> m ()
writeBS :: forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> ByteString -> m ()
writeBS Connection side
conn ByteString
bs = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ ByteString -> ((Ptr CChar, Int) -> IO ()) -> IO ()
forall a. ByteString -> ((Ptr CChar, Int) -> IO a) -> IO a
BU.unsafeUseAsCStringLen ByteString
bs (Ptr CChar, Int) -> IO ()
forall {m :: * -> *} {b}. MonadIO m => (Ptr b, Int) -> m ()
go
  where
    go :: (Ptr b, Int) -> m ()
go (Ptr b
buf, Int
len) = do
      written <- CSize -> Int
cSizeToInt (CSize -> Int) -> m CSize -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection side -> Ptr Word8 -> CSize -> m CSize
forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> Ptr Word8 -> CSize -> m CSize
writePtr Connection side
conn (Ptr b -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr b
buf) (Int -> CSize
intToCSize Int
len)
      when (written < len) $
        go (buf `plusPtr` len, len - written)