{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE PatternGuards #-}

-- | HTTP over TLS support for Warp via the TLS package.
--
--   If HTTP\/2 is negotiated by ALPN, HTTP\/2 over TLS is used.
--   Otherwise HTTP\/1.1 over TLS is used.
--
--   Support for SSL is now obsoleted.

module Network.Wai.Handler.WarpTLS (
    -- * Runner
      runTLS
    , runTLSSocket
    -- * Settings
    , TLSSettings
    , defaultTlsSettings
    -- * Smart constructors
    -- ** From files
    , tlsSettings
    , tlsSettingsChain
    -- ** From memory
    , tlsSettingsMemory
    , tlsSettingsChainMemory
    -- ** From references
    , tlsSettingsRef
    , tlsSettingsChainRef
    , CertSettings
    -- * Accessors
    , tlsCredentials
    , tlsLogging
    , tlsAllowedVersions
    , tlsCiphers
    , tlsWantClientCert
    , tlsServerHooks
    , tlsServerDHEParams
    , tlsSessionManagerConfig
    , tlsSessionManager
    , onInsecure
    , OnInsecure (..)
    -- * Exception
    , WarpTLSException (..)
    ) where

import Control.Applicative ((<|>))
import UnliftIO.Exception (Exception, throwIO, bracket, finally, handleAny, try, IOException, onException, SomeException(..), handleJust)
import qualified UnliftIO.Exception as E
import Control.Monad (void, guard)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.Default.Class (def)
import qualified Data.IORef as I
import Data.Streaming.Network (bindPortTCP, safeRecv)
import Data.Typeable (Typeable)
import GHC.IO.Exception (IOErrorType(..))
import Network.Socket (
    SockAddr,
    Socket,
    close,
#if MIN_VERSION_network(3,1,1)
    gracefulClose,
#endif
    withSocketsDo,
    getSocketName,
 )
import Network.Socket.BufferPool
import Network.Socket.ByteString (sendAll)
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra as TLSExtra
import qualified Network.TLS.SessionManager as SM
import Network.Wai (Application)
import Network.Wai.Handler.Warp
import Network.Wai.Handler.Warp.Internal
import Network.Wai.Handler.WarpTLS.Internal(CertSettings(..), TLSSettings(..), OnInsecure(..))
import System.IO.Error (ioeGetErrorType, isEOFError)
import UnliftIO.Exception (handle, fromException)

-- | The default 'CertSettings'.
defaultCertSettings :: CertSettings
defaultCertSettings :: CertSettings
defaultCertSettings = FilePath -> [FilePath] -> FilePath -> CertSettings
CertFromFile FilePath
"certificate.pem" [] FilePath
"key.pem"

----------------------------------------------------------------

-- | Default 'TLSSettings'. Use this to create 'TLSSettings' with the field record name (aka accessors).
defaultTlsSettings :: TLSSettings
defaultTlsSettings :: TLSSettings
defaultTlsSettings = TLSSettings {
    certSettings :: CertSettings
certSettings = CertSettings
defaultCertSettings
  , onInsecure :: OnInsecure
onInsecure = ByteString -> OnInsecure
DenyInsecure ByteString
"This server only accepts secure HTTPS connections."
  , tlsLogging :: Logging
tlsLogging = forall a. Default a => a
def
#if MIN_VERSION_tls(1,5,0)
  , tlsAllowedVersions :: [Version]
tlsAllowedVersions = [Version
TLS.TLS13,Version
TLS.TLS12,Version
TLS.TLS11,Version
TLS.TLS10]
#else
  , tlsAllowedVersions = [TLS.TLS12,TLS.TLS11,TLS.TLS10]
#endif
  , tlsCiphers :: [Cipher]
tlsCiphers = [Cipher]
ciphers
  , tlsWantClientCert :: Bool
tlsWantClientCert = Bool
False
  , tlsServerHooks :: ServerHooks
tlsServerHooks = forall a. Default a => a
def
  , tlsServerDHEParams :: Maybe DHParams
tlsServerDHEParams = forall a. Maybe a
Nothing
  , tlsSessionManagerConfig :: Maybe Config
tlsSessionManagerConfig = forall a. Maybe a
Nothing
  , tlsCredentials :: Maybe Credentials
tlsCredentials = forall a. Maybe a
Nothing
  , tlsSessionManager :: Maybe SessionManager
tlsSessionManager = forall a. Maybe a
Nothing
  , tlsSupportedHashSignatures :: [HashAndSignatureAlgorithm]
tlsSupportedHashSignatures = Supported -> [HashAndSignatureAlgorithm]
TLS.supportedHashSignatures forall a. Default a => a
def
  }

-- taken from stunnel example in tls-extra
ciphers :: [TLS.Cipher]
ciphers :: [Cipher]
ciphers = [Cipher]
TLSExtra.ciphersuite_strong

----------------------------------------------------------------

-- | A smart constructor for 'TLSSettings' based on 'defaultTlsSettings'.
tlsSettings :: FilePath -- ^ Certificate file
            -> FilePath -- ^ Key file
            -> TLSSettings
tlsSettings :: FilePath -> FilePath -> TLSSettings
tlsSettings FilePath
cert FilePath
key = TLSSettings
defaultTlsSettings {
    certSettings :: CertSettings
certSettings = FilePath -> [FilePath] -> FilePath -> CertSettings
CertFromFile FilePath
cert [] FilePath
key
  }

-- | A smart constructor for 'TLSSettings' that allows specifying
-- chain certificates based on 'defaultTlsSettings'.
--
-- Since 3.0.3
tlsSettingsChain
            :: FilePath -- ^ Certificate file
            -> [FilePath] -- ^ Chain certificate files
            -> FilePath -- ^ Key file
            -> TLSSettings
tlsSettingsChain :: FilePath -> [FilePath] -> FilePath -> TLSSettings
tlsSettingsChain FilePath
cert [FilePath]
chainCerts FilePath
key = TLSSettings
defaultTlsSettings {
    certSettings :: CertSettings
certSettings = FilePath -> [FilePath] -> FilePath -> CertSettings
CertFromFile FilePath
cert [FilePath]
chainCerts FilePath
key
  }

-- | A smart constructor for 'TLSSettings', but uses in-memory representations
-- of the certificate and key based on 'defaultTlsSettings'.
--
-- Since 3.0.1
tlsSettingsMemory
    :: S.ByteString -- ^ Certificate bytes
    -> S.ByteString -- ^ Key bytes
    -> TLSSettings
tlsSettingsMemory :: ByteString -> ByteString -> TLSSettings
tlsSettingsMemory ByteString
cert ByteString
key = TLSSettings
defaultTlsSettings {
    certSettings :: CertSettings
certSettings = ByteString -> [ByteString] -> ByteString -> CertSettings
CertFromMemory ByteString
cert [] ByteString
key
  }

-- | A smart constructor for 'TLSSettings', but uses in-memory representations
-- of the certificate and key based on 'defaultTlsSettings'.
--
-- Since 3.0.3
tlsSettingsChainMemory
    :: S.ByteString -- ^ Certificate bytes
    -> [S.ByteString] -- ^ Chain certificate bytes
    -> S.ByteString -- ^ Key bytes
    -> TLSSettings
tlsSettingsChainMemory :: ByteString -> [ByteString] -> ByteString -> TLSSettings
tlsSettingsChainMemory ByteString
cert [ByteString]
chainCerts ByteString
key = TLSSettings
defaultTlsSettings {
    certSettings :: CertSettings
certSettings = ByteString -> [ByteString] -> ByteString -> CertSettings
CertFromMemory ByteString
cert [ByteString]
chainCerts ByteString
key
  }

-- | A smart constructor for 'TLSSettings', but uses references to in-memory
-- representations of the certificate and key based on 'defaultTlsSettings'.
--
-- @since 3.3.0
tlsSettingsRef
    :: I.IORef S.ByteString -- ^ Reference to certificate bytes
    -> I.IORef S.ByteString -- ^ Reference to key bytes
    -> TLSSettings
tlsSettingsRef :: IORef ByteString -> IORef ByteString -> TLSSettings
tlsSettingsRef IORef ByteString
cert IORef ByteString
key = TLSSettings
defaultTlsSettings {
    certSettings :: CertSettings
certSettings = IORef ByteString
-> [IORef ByteString] -> IORef ByteString -> CertSettings
CertFromRef IORef ByteString
cert [] IORef ByteString
key
  }

-- | A smart constructor for 'TLSSettings', but uses references to in-memory
-- representations of the certificate and key based on 'defaultTlsSettings'.
--
-- @since 3.3.0
tlsSettingsChainRef
    :: I.IORef S.ByteString -- ^ Reference to certificate bytes
    -> [I.IORef S.ByteString] -- ^ Reference to chain certificate bytes
    -> I.IORef S.ByteString -- ^ Reference to key bytes
    -> TLSSettings
tlsSettingsChainRef :: IORef ByteString
-> [IORef ByteString] -> IORef ByteString -> TLSSettings
tlsSettingsChainRef IORef ByteString
cert [IORef ByteString]
chainCerts IORef ByteString
key = TLSSettings
defaultTlsSettings {
    certSettings :: CertSettings
certSettings = IORef ByteString
-> [IORef ByteString] -> IORef ByteString -> CertSettings
CertFromRef IORef ByteString
cert [IORef ByteString]
chainCerts IORef ByteString
key
  }

----------------------------------------------------------------

-- | Running 'Application' with 'TLSSettings' and 'Settings'.
runTLS :: TLSSettings -> Settings -> Application -> IO ()
runTLS :: TLSSettings -> Settings -> Application -> IO ()
runTLS TLSSettings
tset Settings
set Application
app = forall a. IO a -> IO a
withSocketsDo forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
        (Int -> HostPreference -> IO Socket
bindPortTCP (Settings -> Int
getPort Settings
set) (Settings -> HostPreference
getHost Settings
set))
        Socket -> IO ()
close
        (\Socket
sock -> do
            Socket -> IO ()
setSocketCloseOnExec Socket
sock
            TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket TLSSettings
tset Settings
set Socket
sock Application
app)

----------------------------------------------------------------

loadCredentials :: TLSSettings -> IO TLS.Credentials
loadCredentials :: TLSSettings -> IO Credentials
loadCredentials TLSSettings{ tlsCredentials :: TLSSettings -> Maybe Credentials
tlsCredentials = Just Credentials
creds } = forall (m :: * -> *) a. Monad m => a -> m a
return Credentials
creds
loadCredentials TLSSettings{Bool
[HashAndSignatureAlgorithm]
[Cipher]
[Version]
Maybe DHParams
Maybe Credentials
Maybe SessionManager
Maybe Config
ServerHooks
Logging
OnInsecure
CertSettings
tlsSupportedHashSignatures :: [HashAndSignatureAlgorithm]
tlsSessionManager :: Maybe SessionManager
tlsCredentials :: Maybe Credentials
tlsSessionManagerConfig :: Maybe Config
tlsServerDHEParams :: Maybe DHParams
tlsServerHooks :: ServerHooks
tlsWantClientCert :: Bool
tlsCiphers :: [Cipher]
tlsAllowedVersions :: [Version]
tlsLogging :: Logging
onInsecure :: OnInsecure
certSettings :: CertSettings
tlsSupportedHashSignatures :: TLSSettings -> [HashAndSignatureAlgorithm]
certSettings :: TLSSettings -> CertSettings
onInsecure :: TLSSettings -> OnInsecure
tlsSessionManager :: TLSSettings -> Maybe SessionManager
tlsSessionManagerConfig :: TLSSettings -> Maybe Config
tlsServerDHEParams :: TLSSettings -> Maybe DHParams
tlsServerHooks :: TLSSettings -> ServerHooks
tlsWantClientCert :: TLSSettings -> Bool
tlsCiphers :: TLSSettings -> [Cipher]
tlsAllowedVersions :: TLSSettings -> [Version]
tlsLogging :: TLSSettings -> Logging
tlsCredentials :: TLSSettings -> Maybe Credentials
..} = case CertSettings
certSettings of
  CertFromFile FilePath
cert [FilePath]
chainFiles FilePath
key -> do
    Credential
cred <- forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => FilePath -> a
error forall a. a -> a
id forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FilePath
-> [FilePath] -> FilePath -> IO (Either FilePath Credential)
TLS.credentialLoadX509Chain FilePath
cert [FilePath]
chainFiles FilePath
key
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Credential] -> Credentials
TLS.Credentials [Credential
cred]
  CertFromRef IORef ByteString
certRef [IORef ByteString]
chainCertsRef IORef ByteString
keyRef -> do
    ByteString
cert <- forall a. IORef a -> IO a
I.readIORef IORef ByteString
certRef
    [ByteString]
chainCerts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a. IORef a -> IO a
I.readIORef [IORef ByteString]
chainCertsRef
    ByteString
key <- forall a. IORef a -> IO a
I.readIORef IORef ByteString
keyRef
    Credential
cred <- forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => FilePath -> a
error forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString
-> [ByteString] -> ByteString -> Either FilePath Credential
TLS.credentialLoadX509ChainFromMemory ByteString
cert [ByteString]
chainCerts ByteString
key
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Credential] -> Credentials
TLS.Credentials [Credential
cred]
  CertFromMemory ByteString
certMemory [ByteString]
chainCertsMemory ByteString
keyMemory -> do
    Credential
cred <- forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => FilePath -> a
error forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString
-> [ByteString] -> ByteString -> Either FilePath Credential
TLS.credentialLoadX509ChainFromMemory ByteString
certMemory [ByteString]
chainCertsMemory ByteString
keyMemory
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Credential] -> Credentials
TLS.Credentials [Credential
cred]

getSessionManager :: TLSSettings -> IO TLS.SessionManager
getSessionManager :: TLSSettings -> IO SessionManager
getSessionManager TLSSettings{ tlsSessionManager :: TLSSettings -> Maybe SessionManager
tlsSessionManager = Just SessionManager
mgr } = forall (m :: * -> *) a. Monad m => a -> m a
return SessionManager
mgr
getSessionManager TLSSettings{Bool
[HashAndSignatureAlgorithm]
[Cipher]
[Version]
Maybe DHParams
Maybe Credentials
Maybe SessionManager
Maybe Config
ServerHooks
Logging
OnInsecure
CertSettings
tlsSupportedHashSignatures :: [HashAndSignatureAlgorithm]
tlsSessionManager :: Maybe SessionManager
tlsCredentials :: Maybe Credentials
tlsSessionManagerConfig :: Maybe Config
tlsServerDHEParams :: Maybe DHParams
tlsServerHooks :: ServerHooks
tlsWantClientCert :: Bool
tlsCiphers :: [Cipher]
tlsAllowedVersions :: [Version]
tlsLogging :: Logging
onInsecure :: OnInsecure
certSettings :: CertSettings
tlsSupportedHashSignatures :: TLSSettings -> [HashAndSignatureAlgorithm]
certSettings :: TLSSettings -> CertSettings
onInsecure :: TLSSettings -> OnInsecure
tlsSessionManager :: TLSSettings -> Maybe SessionManager
tlsSessionManagerConfig :: TLSSettings -> Maybe Config
tlsServerDHEParams :: TLSSettings -> Maybe DHParams
tlsServerHooks :: TLSSettings -> ServerHooks
tlsWantClientCert :: TLSSettings -> Bool
tlsCiphers :: TLSSettings -> [Cipher]
tlsAllowedVersions :: TLSSettings -> [Version]
tlsLogging :: TLSSettings -> Logging
tlsCredentials :: TLSSettings -> Maybe Credentials
..} = case Maybe Config
tlsSessionManagerConfig of
      Maybe Config
Nothing     -> forall (m :: * -> *) a. Monad m => a -> m a
return SessionManager
TLS.noSessionManager
      Just Config
config -> Config -> IO SessionManager
SM.newSessionManager Config
config

-- | Running 'Application' with 'TLSSettings' and 'Settings' using
--   specified 'Socket'.
runTLSSocket :: TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket :: TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket TLSSettings
tlsset Settings
set Socket
sock Application
app = do
    Credentials
credentials <- TLSSettings -> IO Credentials
loadCredentials TLSSettings
tlsset
    SessionManager
mgr <- TLSSettings -> IO SessionManager
getSessionManager TLSSettings
tlsset
    TLSSettings
-> Settings
-> Credentials
-> SessionManager
-> Socket
-> Application
-> IO ()
runTLSSocket' TLSSettings
tlsset Settings
set Credentials
credentials SessionManager
mgr Socket
sock Application
app

runTLSSocket' :: TLSSettings -> Settings -> TLS.Credentials -> TLS.SessionManager -> Socket -> Application -> IO ()
runTLSSocket' :: TLSSettings
-> Settings
-> Credentials
-> SessionManager
-> Socket
-> Application
-> IO ()
runTLSSocket' tlsset :: TLSSettings
tlsset@TLSSettings{Bool
[HashAndSignatureAlgorithm]
[Cipher]
[Version]
Maybe DHParams
Maybe Credentials
Maybe SessionManager
Maybe Config
ServerHooks
Logging
OnInsecure
CertSettings
tlsSupportedHashSignatures :: [HashAndSignatureAlgorithm]
tlsSessionManager :: Maybe SessionManager
tlsCredentials :: Maybe Credentials
tlsSessionManagerConfig :: Maybe Config
tlsServerDHEParams :: Maybe DHParams
tlsServerHooks :: ServerHooks
tlsWantClientCert :: Bool
tlsCiphers :: [Cipher]
tlsAllowedVersions :: [Version]
tlsLogging :: Logging
onInsecure :: OnInsecure
certSettings :: CertSettings
tlsSupportedHashSignatures :: TLSSettings -> [HashAndSignatureAlgorithm]
certSettings :: TLSSettings -> CertSettings
onInsecure :: TLSSettings -> OnInsecure
tlsSessionManager :: TLSSettings -> Maybe SessionManager
tlsSessionManagerConfig :: TLSSettings -> Maybe Config
tlsServerDHEParams :: TLSSettings -> Maybe DHParams
tlsServerHooks :: TLSSettings -> ServerHooks
tlsWantClientCert :: TLSSettings -> Bool
tlsCiphers :: TLSSettings -> [Cipher]
tlsAllowedVersions :: TLSSettings -> [Version]
tlsLogging :: TLSSettings -> Logging
tlsCredentials :: TLSSettings -> Maybe Credentials
..} Settings
set Credentials
credentials SessionManager
mgr Socket
sock =
    Settings
-> IO (IO (Connection, Transport), SockAddr)
-> Application
-> IO ()
runSettingsConnectionMakerSecure Settings
set IO (IO (Connection, Transport), SockAddr)
get
  where
    get :: IO (IO (Connection, Transport), SockAddr)
get = forall params.
TLSParams params =>
TLSSettings
-> Settings
-> Socket
-> params
-> IO (IO (Connection, Transport), SockAddr)
getter TLSSettings
tlsset Settings
set Socket
sock ServerParams
params
    params :: ServerParams
params = forall a. Default a => a
def { -- TLS.ServerParams
        serverWantClientCert :: Bool
TLS.serverWantClientCert = Bool
tlsWantClientCert
      , serverCACertificates :: [SignedCertificate]
TLS.serverCACertificates = []
      , serverDHEParams :: Maybe DHParams
TLS.serverDHEParams      = Maybe DHParams
tlsServerDHEParams
      , serverHooks :: ServerHooks
TLS.serverHooks          = ServerHooks
hooks
      , serverShared :: Shared
TLS.serverShared         = Shared
shared
      , serverSupported :: Supported
TLS.serverSupported      = Supported
supported
#if MIN_VERSION_tls(1,5,0)
      , serverEarlyDataSize :: Int
TLS.serverEarlyDataSize  = Int
2018
#endif
      }
    -- Adding alpn to user's tlsServerHooks.
    hooks :: ServerHooks
hooks = ServerHooks
tlsServerHooks {
        onALPNClientSuggest :: Maybe ([ByteString] -> IO ByteString)
TLS.onALPNClientSuggest = ServerHooks -> Maybe ([ByteString] -> IO ByteString)
TLS.onALPNClientSuggest ServerHooks
tlsServerHooks forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
          (if Settings -> Bool
settingsHTTP2Enabled Settings
set then forall a. a -> Maybe a
Just [ByteString] -> IO ByteString
alpn else forall a. Maybe a
Nothing)
      }
    shared :: Shared
shared = forall a. Default a => a
def {
        sharedCredentials :: Credentials
TLS.sharedCredentials    = Credentials
credentials
      , sharedSessionManager :: SessionManager
TLS.sharedSessionManager = SessionManager
mgr
      }
    supported :: Supported
supported = forall a. Default a => a
def { -- TLS.Supported
        supportedVersions :: [Version]
TLS.supportedVersions       = [Version]
tlsAllowedVersions
      , supportedCiphers :: [Cipher]
TLS.supportedCiphers        = [Cipher]
tlsCiphers
      , supportedCompressions :: [Compression]
TLS.supportedCompressions   = [Compression
TLS.nullCompression]
      , supportedSecureRenegotiation :: Bool
TLS.supportedSecureRenegotiation = Bool
True
      , supportedClientInitiatedRenegotiation :: Bool
TLS.supportedClientInitiatedRenegotiation = Bool
False
      , supportedSession :: Bool
TLS.supportedSession             = Bool
True
      , supportedFallbackScsv :: Bool
TLS.supportedFallbackScsv        = Bool
True
      , supportedHashSignatures :: [HashAndSignatureAlgorithm]
TLS.supportedHashSignatures      = [HashAndSignatureAlgorithm]
tlsSupportedHashSignatures
#if MIN_VERSION_tls(1,5,0)
      , supportedGroups :: [Group]
TLS.supportedGroups              = [Group
TLS.X25519,Group
TLS.P256,Group
TLS.P384]
#endif
      }

alpn :: [S.ByteString] -> IO S.ByteString
alpn :: [ByteString] -> IO ByteString
alpn [ByteString]
xs
  | ByteString
"h2"    forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
xs = forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
"h2"
  | Bool
otherwise         = forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
"http/1.1"

----------------------------------------------------------------

getter :: TLS.TLSParams params => TLSSettings -> Settings -> Socket -> params -> IO (IO (Connection, Transport), SockAddr)
getter :: forall params.
TLSParams params =>
TLSSettings
-> Settings
-> Socket
-> params
-> IO (IO (Connection, Transport), SockAddr)
getter TLSSettings
tlsset set :: Settings
set@Settings{settingsAccept :: Settings -> Socket -> IO (Socket, SockAddr)
settingsAccept = Socket -> IO (Socket, SockAddr)
accept'} Socket
sock params
params = do
    (Socket
s, SockAddr
sa) <- Socket -> IO (Socket, SockAddr)
accept' Socket
sock
    Socket -> IO ()
setSocketCloseOnExec Socket
s
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall params.
TLSParams params =>
TLSSettings
-> Settings -> Socket -> params -> IO (Connection, Transport)
mkConn TLSSettings
tlsset Settings
set Socket
s params
params, SockAddr
sa)

mkConn :: TLS.TLSParams params => TLSSettings -> Settings -> Socket -> params -> IO (Connection, Transport)
mkConn :: forall params.
TLSParams params =>
TLSSettings
-> Settings -> Socket -> params -> IO (Connection, Transport)
mkConn TLSSettings
tlsset Settings
set Socket
s params
params = (Socket -> Int -> IO ByteString
safeRecv Socket
s Int
4096 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> IO (Connection, Transport)
switch) forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
`onException` Socket -> IO ()
close Socket
s
  where
    switch :: ByteString -> IO (Connection, Transport)
switch ByteString
firstBS
        | ByteString -> Bool
S.null ByteString
firstBS = Socket -> IO ()
close Socket
s forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO WarpTLSException
ClientClosedConnectionPrematurely
        | HasCallStack => ByteString -> Word8
S.head ByteString
firstBS forall a. Eq a => a -> a -> Bool
== Word8
0x16 = forall params.
TLSParams params =>
TLSSettings
-> Settings
-> Socket
-> ByteString
-> params
-> IO (Connection, Transport)
httpOverTls TLSSettings
tlsset Settings
set Socket
s ByteString
firstBS params
params
        | Bool
otherwise = TLSSettings
-> Settings -> Socket -> ByteString -> IO (Connection, Transport)
plainHTTP TLSSettings
tlsset Settings
set Socket
s ByteString
firstBS

----------------------------------------------------------------

httpOverTls :: TLS.TLSParams params => TLSSettings -> Settings -> Socket -> S.ByteString -> params -> IO (Connection, Transport)
httpOverTls :: forall params.
TLSParams params =>
TLSSettings
-> Settings
-> Socket
-> ByteString
-> params
-> IO (Connection, Transport)
httpOverTls TLSSettings{Bool
[HashAndSignatureAlgorithm]
[Cipher]
[Version]
Maybe DHParams
Maybe Credentials
Maybe SessionManager
Maybe Config
ServerHooks
Logging
OnInsecure
CertSettings
tlsSupportedHashSignatures :: [HashAndSignatureAlgorithm]
tlsSessionManager :: Maybe SessionManager
tlsCredentials :: Maybe Credentials
tlsSessionManagerConfig :: Maybe Config
tlsServerDHEParams :: Maybe DHParams
tlsServerHooks :: ServerHooks
tlsWantClientCert :: Bool
tlsCiphers :: [Cipher]
tlsAllowedVersions :: [Version]
tlsLogging :: Logging
onInsecure :: OnInsecure
certSettings :: CertSettings
tlsSupportedHashSignatures :: TLSSettings -> [HashAndSignatureAlgorithm]
certSettings :: TLSSettings -> CertSettings
onInsecure :: TLSSettings -> OnInsecure
tlsSessionManager :: TLSSettings -> Maybe SessionManager
tlsSessionManagerConfig :: TLSSettings -> Maybe Config
tlsServerDHEParams :: TLSSettings -> Maybe DHParams
tlsServerHooks :: TLSSettings -> ServerHooks
tlsWantClientCert :: TLSSettings -> Bool
tlsCiphers :: TLSSettings -> [Cipher]
tlsAllowedVersions :: TLSSettings -> [Version]
tlsLogging :: TLSSettings -> Logging
tlsCredentials :: TLSSettings -> Maybe Credentials
..} Settings
_set Socket
s ByteString
bs0 params
params = do
    BufferPool
pool <- Int -> Int -> IO BufferPool
newBufferPool Int
2048 Int
16384
    Int -> IO ByteString
rawRecvN <- ByteString -> IO ByteString -> IO (Int -> IO ByteString)
makeRecvN ByteString
bs0 forall a b. (a -> b) -> a -> b
$ Socket -> BufferPool -> IO ByteString
receive Socket
s BufferPool
pool
    let recvN :: Int -> IO ByteString
recvN = forall {t}. (t -> IO ByteString) -> t -> IO ByteString
wrappedRecvN Int -> IO ByteString
rawRecvN
    Context
ctx <- forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew ((Int -> IO ByteString) -> Backend
backend Int -> IO ByteString
recvN) params
params
    Context -> Logging -> IO ()
TLS.contextHookSetLogging Context
ctx Logging
tlsLogging
    forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
    Bool
h2 <- (forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just ByteString
"h2") forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadIO m => Context -> m (Maybe ByteString)
TLS.getNegotiatedProtocol Context
ctx
    IORef Bool
isH2 <- forall a. a -> IO (IORef a)
I.newIORef Bool
h2
    WriteBuffer
writeBuffer <- Int -> IO WriteBuffer
createWriteBuffer Int
16384
    IORef WriteBuffer
writeBufferRef <- forall a. a -> IO (IORef a)
I.newIORef WriteBuffer
writeBuffer
    -- Creating a cache for leftover input data.
    Transport
tls <- Context -> IO Transport
getTLSinfo Context
ctx
    SockAddr
mysa <- Socket -> IO SockAddr
getSocketName Socket
s
    forall (m :: * -> *) a. Monad m => a -> m a
return (Context
-> IORef WriteBuffer -> IORef Bool -> SockAddr -> Connection
conn Context
ctx IORef WriteBuffer
writeBufferRef IORef Bool
isH2 SockAddr
mysa, Transport
tls)
  where
    backend :: (Int -> IO ByteString) -> Backend
backend Int -> IO ByteString
recvN = TLS.Backend {
        backendFlush :: IO ()
TLS.backendFlush = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if MIN_VERSION_network(3,1,1)
      , backendClose :: IO ()
TLS.backendClose = Socket -> Int -> IO ()
gracefulClose Socket
s Int
5000 forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`E.catch` \(SomeException e
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
#else
      , TLS.backendClose = close s
#endif
      , backendSend :: ByteString -> IO ()
TLS.backendSend  = Socket -> ByteString -> IO ()
sendAll' Socket
s
      , backendRecv :: Int -> IO ByteString
TLS.backendRecv  = Int -> IO ByteString
recvN
      }
    sendAll' :: Socket -> ByteString -> IO ()
sendAll' Socket
sock ByteString
bs = forall (m :: * -> *) e b a.
(MonadUnliftIO m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
E.handleJust
      (\ IOError
e -> if IOError -> IOErrorType
ioeGetErrorType IOError
e forall a. Eq a => a -> a -> Bool
== IOErrorType
ResourceVanished
        then forall a. a -> Maybe a
Just InvalidRequest
ConnectionClosedByPeer
        else forall a. Maybe a
Nothing)
      forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO
      forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> IO ()
sendAll Socket
sock ByteString
bs
    conn :: Context
-> IORef WriteBuffer -> IORef Bool -> SockAddr -> Connection
conn Context
ctx IORef WriteBuffer
writeBufferRef IORef Bool
isH2 SockAddr
mysa = Connection {
        connSendMany :: [ByteString] -> IO ()
connSendMany         = forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
L.fromChunks
      , connSendAll :: ByteString -> IO ()
connSendAll          = ByteString -> IO ()
sendall
      , connSendFile :: SendFile
connSendFile         = SendFile
sendfile
      , connClose :: IO ()
connClose            = IO ()
close'
      , connRecv :: IO ByteString
connRecv             = IO ByteString
recv
      , connRecvBuf :: RecvBuf
connRecvBuf          = \Buffer
_ Int
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True -- obsoleted
      , connWriteBuffer :: IORef WriteBuffer
connWriteBuffer      = IORef WriteBuffer
writeBufferRef
      , connHTTP2 :: IORef Bool
connHTTP2            = IORef Bool
isH2
      , connMySockAddr :: SockAddr
connMySockAddr       = SockAddr
mysa
      }
      where
        sendall :: ByteString -> IO ()
sendall = forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
L.fromChunks forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return
        recv :: IO ByteString
recv = forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
handle forall {m :: * -> *}. MonadIO m => SomeException -> m ByteString
onEOF forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
          where
            onEOF :: SomeException -> m ByteString
onEOF SomeException
e
#if MIN_VERSION_tls(1,8,0)
              | Just (TLS.PostHandshake TLSError
TLS.Error_EOF) <- forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
e = forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
#else
              | Just TLS.Error_EOF <- fromException e       = return S.empty
#endif
              | Just IOError
ioe <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e, IOError -> Bool
isEOFError IOError
ioe = forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty                  | Bool
otherwise                                   = forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO SomeException
e
        sendfile :: SendFile
sendfile FileId
fid Integer
offset Integer
len IO ()
hook [ByteString]
headers = do
            WriteBuffer
writeBuffer <- forall a. IORef a -> IO a
I.readIORef IORef WriteBuffer
writeBufferRef
            Buffer -> Int -> (ByteString -> IO ()) -> SendFile
readSendFile (WriteBuffer -> Buffer
bufBuffer WriteBuffer
writeBuffer) (WriteBuffer -> Int
bufSize WriteBuffer
writeBuffer) ByteString -> IO ()
sendall FileId
fid Integer
offset Integer
len IO ()
hook [ByteString]
headers

        close' :: IO ()
close' = forall (f :: * -> *) a. Functor f => f a -> f ()
void (forall a. IO a -> IO (Either IOError a)
tryIO IO ()
sendBye) forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
`finally`
                 Context -> IO ()
TLS.contextClose Context
ctx

        sendBye :: IO ()
sendBye =
          -- It's fine if the connection was closed by the other side before
          -- receiving close_notify, see RFC 5246 section 7.2.1.
          forall (m :: * -> *) e b a.
(MonadUnliftIO m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust
            (\InvalidRequest
e -> forall (f :: * -> *). Alternative f => Bool -> f ()
guard (InvalidRequest
e forall a. Eq a => a -> a -> Bool
== InvalidRequest
ConnectionClosedByPeer) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return InvalidRequest
e)
            (forall a b. a -> b -> a
const (forall (m :: * -> *) a. Monad m => a -> m a
return ()))
            (forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx)


    wrappedRecvN :: (t -> IO ByteString) -> t -> IO ByteString
wrappedRecvN t -> IO ByteString
recvN t
n = forall (m :: * -> *) a.
MonadUnliftIO m =>
(SomeException -> m a) -> m a -> m a
handleAny SomeException -> IO ByteString
handler forall a b. (a -> b) -> a -> b
$ t -> IO ByteString
recvN t
n
    handler :: SomeException -> IO S.ByteString
    handler :: SomeException -> IO ByteString
handler SomeException
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""

getTLSinfo :: TLS.Context -> IO Transport
getTLSinfo :: Context -> IO Transport
getTLSinfo Context
ctx = do
    Maybe ByteString
proto <- forall (m :: * -> *). MonadIO m => Context -> m (Maybe ByteString)
TLS.getNegotiatedProtocol Context
ctx
    Maybe Information
minfo <- Context -> IO (Maybe Information)
TLS.contextGetInformation Context
ctx
    case Maybe Information
minfo of
        Maybe Information
Nothing   -> forall (m :: * -> *) a. Monad m => a -> m a
return Transport
TCP
        Just TLS.Information{Bool
Maybe ByteString
Maybe HandshakeMode13
Maybe ServerRandom
Maybe ClientRandom
Maybe Group
Cipher
Compression
Version
infoVersion :: Information -> Version
infoCipher :: Information -> Cipher
infoCompression :: Information -> Compression
infoMasterSecret :: Information -> Maybe ByteString
infoExtendedMasterSec :: Information -> Bool
infoClientRandom :: Information -> Maybe ClientRandom
infoServerRandom :: Information -> Maybe ServerRandom
infoNegotiatedGroup :: Information -> Maybe Group
infoTLS13HandshakeMode :: Information -> Maybe HandshakeMode13
infoIsEarlyDataAccepted :: Information -> Bool
infoIsEarlyDataAccepted :: Bool
infoTLS13HandshakeMode :: Maybe HandshakeMode13
infoNegotiatedGroup :: Maybe Group
infoServerRandom :: Maybe ServerRandom
infoClientRandom :: Maybe ClientRandom
infoExtendedMasterSec :: Bool
infoMasterSecret :: Maybe ByteString
infoCompression :: Compression
infoCipher :: Cipher
infoVersion :: Version
..} -> do
            let (Int
major, Int
minor) = case Version
infoVersion of
                    Version
TLS.SSL2  -> (Int
2,Int
0)
                    Version
TLS.SSL3  -> (Int
3,Int
0)
                    Version
TLS.TLS10 -> (Int
3,Int
1)
                    Version
TLS.TLS11 -> (Int
3,Int
2)
                    Version
TLS.TLS12 -> (Int
3,Int
3)
#if MIN_VERSION_tls(1,5,0)
                    Version
TLS.TLS13 -> (Int
3,Int
4)
#endif
            Maybe CertificateChain
clientCert <- Context -> IO (Maybe CertificateChain)
TLS.getClientCertificateChain Context
ctx
            forall (m :: * -> *) a. Monad m => a -> m a
return TLS {
                tlsMajorVersion :: Int
tlsMajorVersion = Int
major
              , tlsMinorVersion :: Int
tlsMinorVersion = Int
minor
              , tlsNegotiatedProtocol :: Maybe ByteString
tlsNegotiatedProtocol = Maybe ByteString
proto
              , tlsChiperID :: Word16
tlsChiperID = Cipher -> Word16
TLS.cipherID Cipher
infoCipher
              , tlsClientCertificate :: Maybe CertificateChain
tlsClientCertificate = Maybe CertificateChain
clientCert
              }

tryIO :: IO a -> IO (Either IOException a)
tryIO :: forall a. IO a -> IO (Either IOError a)
tryIO = forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
try

----------------------------------------------------------------

plainHTTP :: TLSSettings -> Settings -> Socket -> S.ByteString -> IO (Connection, Transport)
plainHTTP :: TLSSettings
-> Settings -> Socket -> ByteString -> IO (Connection, Transport)
plainHTTP TLSSettings{Bool
[HashAndSignatureAlgorithm]
[Cipher]
[Version]
Maybe DHParams
Maybe Credentials
Maybe SessionManager
Maybe Config
ServerHooks
Logging
OnInsecure
CertSettings
tlsSupportedHashSignatures :: [HashAndSignatureAlgorithm]
tlsSessionManager :: Maybe SessionManager
tlsCredentials :: Maybe Credentials
tlsSessionManagerConfig :: Maybe Config
tlsServerDHEParams :: Maybe DHParams
tlsServerHooks :: ServerHooks
tlsWantClientCert :: Bool
tlsCiphers :: [Cipher]
tlsAllowedVersions :: [Version]
tlsLogging :: Logging
onInsecure :: OnInsecure
certSettings :: CertSettings
tlsSupportedHashSignatures :: TLSSettings -> [HashAndSignatureAlgorithm]
certSettings :: TLSSettings -> CertSettings
onInsecure :: TLSSettings -> OnInsecure
tlsSessionManager :: TLSSettings -> Maybe SessionManager
tlsSessionManagerConfig :: TLSSettings -> Maybe Config
tlsServerDHEParams :: TLSSettings -> Maybe DHParams
tlsServerHooks :: TLSSettings -> ServerHooks
tlsWantClientCert :: TLSSettings -> Bool
tlsCiphers :: TLSSettings -> [Cipher]
tlsAllowedVersions :: TLSSettings -> [Version]
tlsLogging :: TLSSettings -> Logging
tlsCredentials :: TLSSettings -> Maybe Credentials
..} Settings
set Socket
s ByteString
bs0 = case OnInsecure
onInsecure of
    OnInsecure
AllowInsecure -> do
        Connection
conn' <- Settings -> Socket -> IO Connection
socketConnection Settings
set Socket
s
        IORef ByteString
cachedRef <- forall a. a -> IO (IORef a)
I.newIORef ByteString
bs0
        let conn'' :: Connection
conn'' = Connection
conn'
                { connRecv :: IO ByteString
connRecv = IORef ByteString -> IO ByteString -> IO ByteString
recvPlain IORef ByteString
cachedRef (Connection -> IO ByteString
connRecv Connection
conn')
                }
        forall (m :: * -> *) a. Monad m => a -> m a
return (Connection
conn'', Transport
TCP)
    DenyInsecure ByteString
lbs -> do
        -- Listening port 443 but TLS records do not arrive.
        -- We want to let the browser know that TLS is required.
        -- So, we use 426.
        --     http://tools.ietf.org/html/rfc2817#section-4.2
        --     https://tools.ietf.org/html/rfc7231#section-6.5.15
        -- FIXME: should we distinguish HTTP/1.1 and HTTP/2?
        --        In the case of HTTP/2, should we send
        --        GOAWAY + INADEQUATE_SECURITY?
        -- FIXME: Content-Length:
        -- FIXME: TLS/<version>
        Socket -> ByteString -> IO ()
sendAll Socket
s "HTTP/1.1 426 Upgrade Required\
        \r\nUpgrade: TLS/1.0, HTTP/1.1\
        \r\nConnection: Upgrade\
        \r\nContent-Type: text/plain\r\n\r\n"
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Socket -> ByteString -> IO ()
sendAll Socket
s) forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString]
L.toChunks ByteString
lbs
        Socket -> IO ()
close Socket
s
        forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO WarpTLSException
InsecureConnectionDenied

----------------------------------------------------------------

-- | Modify the given receive function to first check the given @IORef@ for a
-- chunk of data. If present, takes the chunk of data from the @IORef@ and
-- empties out the @IORef@. Otherwise, calls the supplied receive function.
recvPlain :: I.IORef S.ByteString -> IO S.ByteString -> IO S.ByteString
recvPlain :: IORef ByteString -> IO ByteString -> IO ByteString
recvPlain IORef ByteString
ref IO ByteString
fallback = do
    ByteString
bs <- forall a. IORef a -> IO a
I.readIORef IORef ByteString
ref
    if ByteString -> Bool
S.null ByteString
bs
        then IO ByteString
fallback
        else do
            forall a. IORef a -> a -> IO ()
I.writeIORef IORef ByteString
ref ByteString
S.empty
            forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

----------------------------------------------------------------

data WarpTLSException
    = InsecureConnectionDenied
    | ClientClosedConnectionPrematurely
    deriving (Int -> WarpTLSException -> ShowS
[WarpTLSException] -> ShowS
WarpTLSException -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [WarpTLSException] -> ShowS
$cshowList :: [WarpTLSException] -> ShowS
show :: WarpTLSException -> FilePath
$cshow :: WarpTLSException -> FilePath
showsPrec :: Int -> WarpTLSException -> ShowS
$cshowsPrec :: Int -> WarpTLSException -> ShowS
Show, Typeable)
instance Exception WarpTLSException