{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}

-- | HTTP over TLS support for Warp via the TLS package.
--
--   If HTTP\/2 is negotiated by ALPN, HTTP\/2 over TLS is used.
--   Otherwise HTTP\/1.1 over TLS is used.
--
--   Support for SSL is now obsoleted.
module Network.Wai.Handler.WarpTLS (
    -- * Runner
    runTLS,
    runTLSSocket,

    -- * Settings
    TLSSettings,
    defaultTlsSettings,

    -- * Smart constructors

    -- ** From files
    tlsSettings,
    tlsSettingsChain,

    -- ** From memory
    tlsSettingsMemory,
    tlsSettingsChainMemory,

    -- ** From references
    tlsSettingsRef,
    tlsSettingsChainRef,
    CertSettings,

    -- * Accessors
    tlsCredentials,
    tlsLogging,
    tlsAllowedVersions,
    tlsCiphers,
    tlsWantClientCert,
    tlsServerHooks,
    tlsServerDHEParams,
    tlsSessionManagerConfig,
    tlsSessionManager,
    onInsecure,
    OnInsecure (..),

    -- * Exception
    WarpTLSException (..),
) where

import Control.Applicative ((<|>))
import Control.Monad (guard, void)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.Default.Class (def)
import qualified Data.IORef as I
import Data.Streaming.Network (bindPortTCP, safeRecv)
import Data.Typeable (Typeable)
import GHC.IO.Exception (IOErrorType (..))
import Network.Socket (
    SockAddr,
    Socket,
    close,
    getSocketName,
#if MIN_VERSION_network(3,1,1)
    gracefulClose,
#endif
    withSocketsDo,
 )
import Network.Socket.BufferPool
import Network.Socket.ByteString (sendAll)
import qualified Network.TLS as TLS
import qualified Network.TLS.SessionManager as SM
import Network.Wai (Application)
import Network.Wai.Handler.Warp
import Network.Wai.Handler.Warp.Internal
import Network.Wai.Handler.WarpTLS.Internal
import System.IO.Error (ioeGetErrorType, isEOFError)
import UnliftIO.Exception (
    Exception,
    IOException,
    SomeException (..),
    bracket,
    finally,
    fromException,
    handle,
    handleAny,
    handleJust,
    onException,
    throwIO,
    try,
 )
import qualified UnliftIO.Exception as E

----------------------------------------------------------------

-- | A smart constructor for 'TLSSettings' based on 'defaultTlsSettings'.
tlsSettings
    :: FilePath
    -- ^ Certificate file
    -> FilePath
    -- ^ Key file
    -> TLSSettings
tlsSettings :: FilePath -> FilePath -> TLSSettings
tlsSettings FilePath
cert FilePath
key =
    TLSSettings
defaultTlsSettings
        { certSettings = CertFromFile cert [] key
        }

-- | A smart constructor for 'TLSSettings' that allows specifying
-- chain certificates based on 'defaultTlsSettings'.
--
-- Since 3.0.3
tlsSettingsChain
    :: FilePath
    -- ^ Certificate file
    -> [FilePath]
    -- ^ Chain certificate files
    -> FilePath
    -- ^ Key file
    -> TLSSettings
tlsSettingsChain :: FilePath -> [FilePath] -> FilePath -> TLSSettings
tlsSettingsChain FilePath
cert [FilePath]
chainCerts FilePath
key =
    TLSSettings
defaultTlsSettings
        { certSettings = CertFromFile cert chainCerts key
        }

-- | A smart constructor for 'TLSSettings', but uses in-memory representations
-- of the certificate and key based on 'defaultTlsSettings'.
--
-- Since 3.0.1
tlsSettingsMemory
    :: S.ByteString
    -- ^ Certificate bytes
    -> S.ByteString
    -- ^ Key bytes
    -> TLSSettings
tlsSettingsMemory :: ByteString -> ByteString -> TLSSettings
tlsSettingsMemory ByteString
cert ByteString
key =
    TLSSettings
defaultTlsSettings
        { certSettings = CertFromMemory cert [] key
        }

-- | A smart constructor for 'TLSSettings', but uses in-memory representations
-- of the certificate and key based on 'defaultTlsSettings'.
--
-- Since 3.0.3
tlsSettingsChainMemory
    :: S.ByteString
    -- ^ Certificate bytes
    -> [S.ByteString]
    -- ^ Chain certificate bytes
    -> S.ByteString
    -- ^ Key bytes
    -> TLSSettings
tlsSettingsChainMemory :: ByteString -> [ByteString] -> ByteString -> TLSSettings
tlsSettingsChainMemory ByteString
cert [ByteString]
chainCerts ByteString
key =
    TLSSettings
defaultTlsSettings
        { certSettings = CertFromMemory cert chainCerts key
        }

-- | A smart constructor for 'TLSSettings', but uses references to in-memory
-- representations of the certificate and key based on 'defaultTlsSettings'.
--
-- @since 3.3.0
tlsSettingsRef
    :: I.IORef S.ByteString
    -- ^ Reference to certificate bytes
    -> I.IORef S.ByteString
    -- ^ Reference to key bytes
    -> TLSSettings
tlsSettingsRef :: IORef ByteString -> IORef ByteString -> TLSSettings
tlsSettingsRef IORef ByteString
cert IORef ByteString
key =
    TLSSettings
defaultTlsSettings
        { certSettings = CertFromRef cert [] key
        }

-- | A smart constructor for 'TLSSettings', but uses references to in-memory
-- representations of the certificate and key based on 'defaultTlsSettings'.
--
-- @since 3.3.0
tlsSettingsChainRef
    :: I.IORef S.ByteString
    -- ^ Reference to certificate bytes
    -> [I.IORef S.ByteString]
    -- ^ Reference to chain certificate bytes
    -> I.IORef S.ByteString
    -- ^ Reference to key bytes
    -> TLSSettings
tlsSettingsChainRef :: IORef ByteString
-> [IORef ByteString] -> IORef ByteString -> TLSSettings
tlsSettingsChainRef IORef ByteString
cert [IORef ByteString]
chainCerts IORef ByteString
key =
    TLSSettings
defaultTlsSettings
        { certSettings = CertFromRef cert chainCerts key
        }

----------------------------------------------------------------

-- | Running 'Application' with 'TLSSettings' and 'Settings'.
runTLS :: TLSSettings -> Settings -> Application -> IO ()
runTLS :: TLSSettings -> Settings -> Application -> IO ()
runTLS TLSSettings
tset Settings
set Application
app =
    IO () -> IO ()
forall a. IO a -> IO a
withSocketsDo (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        IO Socket -> (Socket -> IO ()) -> (Socket -> IO ()) -> IO ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
            (Int -> HostPreference -> IO Socket
bindPortTCP (Settings -> Int
getPort Settings
set) (Settings -> HostPreference
getHost Settings
set))
            Socket -> IO ()
close
            ( \Socket
sock -> do
                Socket -> IO ()
setSocketCloseOnExec Socket
sock
                TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket TLSSettings
tset Settings
set Socket
sock Application
app
            )

----------------------------------------------------------------

loadCredentials :: TLSSettings -> IO TLS.Credentials
loadCredentials :: TLSSettings -> IO Credentials
loadCredentials TLSSettings{tlsCredentials :: TLSSettings -> Maybe Credentials
tlsCredentials = Just Credentials
creds} = Credentials -> IO Credentials
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Credentials
creds
loadCredentials TLSSettings{Bool
[HashAndSignatureAlgorithm]
[Version]
[Cipher]
Maybe DHParams
Maybe SessionManager
Maybe Credentials
Maybe Config
Logging
ServerHooks
OnInsecure
CertSettings
tlsCredentials :: TLSSettings -> Maybe Credentials
tlsLogging :: TLSSettings -> Logging
tlsAllowedVersions :: TLSSettings -> [Version]
tlsCiphers :: TLSSettings -> [Cipher]
tlsWantClientCert :: TLSSettings -> Bool
tlsServerHooks :: TLSSettings -> ServerHooks
tlsServerDHEParams :: TLSSettings -> Maybe DHParams
tlsSessionManagerConfig :: TLSSettings -> Maybe Config
tlsSessionManager :: TLSSettings -> Maybe SessionManager
onInsecure :: TLSSettings -> OnInsecure
certSettings :: TLSSettings -> CertSettings
certSettings :: CertSettings
onInsecure :: OnInsecure
tlsLogging :: Logging
tlsAllowedVersions :: [Version]
tlsCiphers :: [Cipher]
tlsWantClientCert :: Bool
tlsServerHooks :: ServerHooks
tlsServerDHEParams :: Maybe DHParams
tlsSessionManagerConfig :: Maybe Config
tlsCredentials :: Maybe Credentials
tlsSessionManager :: Maybe SessionManager
tlsSupportedHashSignatures :: [HashAndSignatureAlgorithm]
tlsSupportedHashSignatures :: TLSSettings -> [HashAndSignatureAlgorithm]
..} = case CertSettings
certSettings of
    CertFromFile FilePath
cert [FilePath]
chainFiles FilePath
key -> do
        Credential
cred <- (FilePath -> Credential)
-> (Credential -> Credential)
-> Either FilePath Credential
-> Credential
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either FilePath -> Credential
forall a. HasCallStack => FilePath -> a
error Credential -> Credential
forall a. a -> a
id (Either FilePath Credential -> Credential)
-> IO (Either FilePath Credential) -> IO Credential
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FilePath
-> [FilePath] -> FilePath -> IO (Either FilePath Credential)
TLS.credentialLoadX509Chain FilePath
cert [FilePath]
chainFiles FilePath
key
        Credentials -> IO Credentials
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Credentials -> IO Credentials) -> Credentials -> IO Credentials
forall a b. (a -> b) -> a -> b
$ [Credential] -> Credentials
TLS.Credentials [Credential
cred]
    CertFromRef IORef ByteString
certRef [IORef ByteString]
chainCertsRef IORef ByteString
keyRef -> do
        ByteString
cert <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
I.readIORef IORef ByteString
certRef
        [ByteString]
chainCerts <- (IORef ByteString -> IO ByteString)
-> [IORef ByteString] -> IO [ByteString]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
I.readIORef [IORef ByteString]
chainCertsRef
        ByteString
key <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
I.readIORef IORef ByteString
keyRef
        Credential
cred <-
            (FilePath -> IO Credential)
-> (Credential -> IO Credential)
-> Either FilePath Credential
-> IO Credential
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either FilePath -> IO Credential
forall a. HasCallStack => FilePath -> a
error Credential -> IO Credential
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either FilePath Credential -> IO Credential)
-> Either FilePath Credential -> IO Credential
forall a b. (a -> b) -> a -> b
$ ByteString
-> [ByteString] -> ByteString -> Either FilePath Credential
TLS.credentialLoadX509ChainFromMemory ByteString
cert [ByteString]
chainCerts ByteString
key
        Credentials -> IO Credentials
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Credentials -> IO Credentials) -> Credentials -> IO Credentials
forall a b. (a -> b) -> a -> b
$ [Credential] -> Credentials
TLS.Credentials [Credential
cred]
    CertFromMemory ByteString
certMemory [ByteString]
chainCertsMemory ByteString
keyMemory -> do
        Credential
cred <-
            (FilePath -> IO Credential)
-> (Credential -> IO Credential)
-> Either FilePath Credential
-> IO Credential
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either FilePath -> IO Credential
forall a. HasCallStack => FilePath -> a
error Credential -> IO Credential
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either FilePath Credential -> IO Credential)
-> Either FilePath Credential -> IO Credential
forall a b. (a -> b) -> a -> b
$
                ByteString
-> [ByteString] -> ByteString -> Either FilePath Credential
TLS.credentialLoadX509ChainFromMemory ByteString
certMemory [ByteString]
chainCertsMemory ByteString
keyMemory
        Credentials -> IO Credentials
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Credentials -> IO Credentials) -> Credentials -> IO Credentials
forall a b. (a -> b) -> a -> b
$ [Credential] -> Credentials
TLS.Credentials [Credential
cred]

getSessionManager :: TLSSettings -> IO TLS.SessionManager
getSessionManager :: TLSSettings -> IO SessionManager
getSessionManager TLSSettings{tlsSessionManager :: TLSSettings -> Maybe SessionManager
tlsSessionManager = Just SessionManager
mgr} = SessionManager -> IO SessionManager
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return SessionManager
mgr
getSessionManager TLSSettings{Bool
[HashAndSignatureAlgorithm]
[Version]
[Cipher]
Maybe DHParams
Maybe SessionManager
Maybe Credentials
Maybe Config
Logging
ServerHooks
OnInsecure
CertSettings
tlsCredentials :: TLSSettings -> Maybe Credentials
tlsLogging :: TLSSettings -> Logging
tlsAllowedVersions :: TLSSettings -> [Version]
tlsCiphers :: TLSSettings -> [Cipher]
tlsWantClientCert :: TLSSettings -> Bool
tlsServerHooks :: TLSSettings -> ServerHooks
tlsServerDHEParams :: TLSSettings -> Maybe DHParams
tlsSessionManagerConfig :: TLSSettings -> Maybe Config
tlsSessionManager :: TLSSettings -> Maybe SessionManager
onInsecure :: TLSSettings -> OnInsecure
certSettings :: TLSSettings -> CertSettings
tlsSupportedHashSignatures :: TLSSettings -> [HashAndSignatureAlgorithm]
certSettings :: CertSettings
onInsecure :: OnInsecure
tlsLogging :: Logging
tlsAllowedVersions :: [Version]
tlsCiphers :: [Cipher]
tlsWantClientCert :: Bool
tlsServerHooks :: ServerHooks
tlsServerDHEParams :: Maybe DHParams
tlsSessionManagerConfig :: Maybe Config
tlsCredentials :: Maybe Credentials
tlsSessionManager :: Maybe SessionManager
tlsSupportedHashSignatures :: [HashAndSignatureAlgorithm]
..} = case Maybe Config
tlsSessionManagerConfig of
    Maybe Config
Nothing -> SessionManager -> IO SessionManager
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return SessionManager
TLS.noSessionManager
    Just Config
config -> Config -> IO SessionManager
SM.newSessionManager Config
config

-- | Running 'Application' with 'TLSSettings' and 'Settings' using
--   specified 'Socket'.
runTLSSocket :: TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket :: TLSSettings -> Settings -> Socket -> Application -> IO ()
runTLSSocket TLSSettings
tlsset Settings
set Socket
sock Application
app = do
    Settings -> IO () -> IO ()
settingsInstallShutdownHandler Settings
set (Socket -> IO ()
close Socket
sock)
    Credentials
credentials <- TLSSettings -> IO Credentials
loadCredentials TLSSettings
tlsset
    SessionManager
mgr <- TLSSettings -> IO SessionManager
getSessionManager TLSSettings
tlsset
    TLSSettings
-> Settings
-> Credentials
-> SessionManager
-> Socket
-> Application
-> IO ()
runTLSSocket' TLSSettings
tlsset Settings
set Credentials
credentials SessionManager
mgr Socket
sock Application
app

runTLSSocket'
    :: TLSSettings
    -> Settings
    -> TLS.Credentials
    -> TLS.SessionManager
    -> Socket
    -> Application
    -> IO ()
runTLSSocket' :: TLSSettings
-> Settings
-> Credentials
-> SessionManager
-> Socket
-> Application
-> IO ()
runTLSSocket' tlsset :: TLSSettings
tlsset@TLSSettings{Bool
[HashAndSignatureAlgorithm]
[Version]
[Cipher]
Maybe DHParams
Maybe SessionManager
Maybe Credentials
Maybe Config
Logging
ServerHooks
OnInsecure
CertSettings
tlsCredentials :: TLSSettings -> Maybe Credentials
tlsLogging :: TLSSettings -> Logging
tlsAllowedVersions :: TLSSettings -> [Version]
tlsCiphers :: TLSSettings -> [Cipher]
tlsWantClientCert :: TLSSettings -> Bool
tlsServerHooks :: TLSSettings -> ServerHooks
tlsServerDHEParams :: TLSSettings -> Maybe DHParams
tlsSessionManagerConfig :: TLSSettings -> Maybe Config
tlsSessionManager :: TLSSettings -> Maybe SessionManager
onInsecure :: TLSSettings -> OnInsecure
certSettings :: TLSSettings -> CertSettings
tlsSupportedHashSignatures :: TLSSettings -> [HashAndSignatureAlgorithm]
certSettings :: CertSettings
onInsecure :: OnInsecure
tlsLogging :: Logging
tlsAllowedVersions :: [Version]
tlsCiphers :: [Cipher]
tlsWantClientCert :: Bool
tlsServerHooks :: ServerHooks
tlsServerDHEParams :: Maybe DHParams
tlsSessionManagerConfig :: Maybe Config
tlsCredentials :: Maybe Credentials
tlsSessionManager :: Maybe SessionManager
tlsSupportedHashSignatures :: [HashAndSignatureAlgorithm]
..} Settings
set Credentials
credentials SessionManager
mgr Socket
sock =
    Settings
-> IO (IO (Connection, Transport), SockAddr)
-> Application
-> IO ()
runSettingsConnectionMakerSecure Settings
set IO (IO (Connection, Transport), SockAddr)
get
  where
    get :: IO (IO (Connection, Transport), SockAddr)
get = TLSSettings
-> Settings
-> Socket
-> ServerParams
-> IO (IO (Connection, Transport), SockAddr)
forall params.
TLSParams params =>
TLSSettings
-> Settings
-> Socket
-> params
-> IO (IO (Connection, Transport), SockAddr)
getter TLSSettings
tlsset Settings
set Socket
sock ServerParams
params
    params :: ServerParams
params =
        ServerParams
forall a. Default a => a
def -- TLS.ServerParams
            { TLS.serverWantClientCert = tlsWantClientCert
            , TLS.serverCACertificates = []
            , TLS.serverDHEParams = tlsServerDHEParams
            , TLS.serverHooks = hooks
            , TLS.serverShared = shared
            , TLS.serverSupported = supported
#if MIN_VERSION_tls(1,5,0)
            , TLS.serverEarlyDataSize = 2018
#endif
            }
    -- Adding alpn to user's tlsServerHooks.
    hooks :: ServerHooks
hooks =
        ServerHooks
tlsServerHooks
            { TLS.onALPNClientSuggest =
                TLS.onALPNClientSuggest tlsServerHooks
                    <|> (if settingsHTTP2Enabled set then Just alpn else Nothing)
            }
    shared :: Shared
shared =
        Shared
forall a. Default a => a
def
            { TLS.sharedCredentials = credentials
            , TLS.sharedSessionManager = mgr
            }
    supported :: Supported
supported =
        Supported
forall a. Default a => a
def -- TLS.Supported
            { TLS.supportedVersions = tlsAllowedVersions
            , TLS.supportedCiphers = tlsCiphers
            , TLS.supportedCompressions = [TLS.nullCompression]
            , TLS.supportedSecureRenegotiation = True
            , TLS.supportedClientInitiatedRenegotiation = False
            , TLS.supportedSession = True
            , TLS.supportedFallbackScsv = True
            , TLS.supportedHashSignatures = tlsSupportedHashSignatures
#if MIN_VERSION_tls(1,5,0)
            , TLS.supportedGroups = [TLS.X25519,TLS.P256,TLS.P384]
#endif
            }

alpn :: [S.ByteString] -> IO S.ByteString
alpn :: [ByteString] -> IO ByteString
alpn [ByteString]
xs
    | ByteString
"h2" ByteString -> [ByteString] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
xs = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
"h2"
    | Bool
otherwise = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
"http/1.1"

----------------------------------------------------------------

getter
    :: TLS.TLSParams params
    => TLSSettings
    -> Settings
    -> Socket
    -> params
    -> IO (IO (Connection, Transport), SockAddr)
getter :: forall params.
TLSParams params =>
TLSSettings
-> Settings
-> Socket
-> params
-> IO (IO (Connection, Transport), SockAddr)
getter TLSSettings
tlsset set :: Settings
set@Settings{settingsAccept :: Settings -> Socket -> IO (Socket, SockAddr)
settingsAccept = Socket -> IO (Socket, SockAddr)
accept'} Socket
sock params
params = do
    (Socket
s, SockAddr
sa) <- Socket -> IO (Socket, SockAddr)
accept' Socket
sock
    Socket -> IO ()
setSocketCloseOnExec Socket
s
    (IO (Connection, Transport), SockAddr)
-> IO (IO (Connection, Transport), SockAddr)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (TLSSettings
-> Settings -> Socket -> params -> IO (Connection, Transport)
forall params.
TLSParams params =>
TLSSettings
-> Settings -> Socket -> params -> IO (Connection, Transport)
mkConn TLSSettings
tlsset Settings
set Socket
s params
params, SockAddr
sa)

mkConn
    :: TLS.TLSParams params
    => TLSSettings
    -> Settings
    -> Socket
    -> params
    -> IO (Connection, Transport)
mkConn :: forall params.
TLSParams params =>
TLSSettings
-> Settings -> Socket -> params -> IO (Connection, Transport)
mkConn TLSSettings
tlsset Settings
set Socket
s params
params = (Socket -> Int -> IO ByteString
safeRecv Socket
s Int
4096 IO ByteString
-> (ByteString -> IO (Connection, Transport))
-> IO (Connection, Transport)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> IO (Connection, Transport)
switch) IO (Connection, Transport) -> IO () -> IO (Connection, Transport)
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
`onException` Socket -> IO ()
close Socket
s
  where
    switch :: ByteString -> IO (Connection, Transport)
switch ByteString
firstBS
        | ByteString -> Bool
S.null ByteString
firstBS = Socket -> IO ()
close Socket
s IO () -> IO (Connection, Transport) -> IO (Connection, Transport)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> WarpTLSException -> IO (Connection, Transport)
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO WarpTLSException
ClientClosedConnectionPrematurely
        | HasCallStack => ByteString -> Word8
ByteString -> Word8
S.head ByteString
firstBS Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x16 = TLSSettings
-> Settings
-> Socket
-> ByteString
-> params
-> IO (Connection, Transport)
forall params.
TLSParams params =>
TLSSettings
-> Settings
-> Socket
-> ByteString
-> params
-> IO (Connection, Transport)
httpOverTls TLSSettings
tlsset Settings
set Socket
s ByteString
firstBS params
params
        | Bool
otherwise = TLSSettings
-> Settings -> Socket -> ByteString -> IO (Connection, Transport)
plainHTTP TLSSettings
tlsset Settings
set Socket
s ByteString
firstBS

----------------------------------------------------------------

httpOverTls
    :: TLS.TLSParams params
    => TLSSettings
    -> Settings
    -> Socket
    -> S.ByteString
    -> params
    -> IO (Connection, Transport)
httpOverTls :: forall params.
TLSParams params =>
TLSSettings
-> Settings
-> Socket
-> ByteString
-> params
-> IO (Connection, Transport)
httpOverTls TLSSettings{Bool
[HashAndSignatureAlgorithm]
[Version]
[Cipher]
Maybe DHParams
Maybe SessionManager
Maybe Credentials
Maybe Config
Logging
ServerHooks
OnInsecure
CertSettings
tlsCredentials :: TLSSettings -> Maybe Credentials
tlsLogging :: TLSSettings -> Logging
tlsAllowedVersions :: TLSSettings -> [Version]
tlsCiphers :: TLSSettings -> [Cipher]
tlsWantClientCert :: TLSSettings -> Bool
tlsServerHooks :: TLSSettings -> ServerHooks
tlsServerDHEParams :: TLSSettings -> Maybe DHParams
tlsSessionManagerConfig :: TLSSettings -> Maybe Config
tlsSessionManager :: TLSSettings -> Maybe SessionManager
onInsecure :: TLSSettings -> OnInsecure
certSettings :: TLSSettings -> CertSettings
tlsSupportedHashSignatures :: TLSSettings -> [HashAndSignatureAlgorithm]
certSettings :: CertSettings
onInsecure :: OnInsecure
tlsLogging :: Logging
tlsAllowedVersions :: [Version]
tlsCiphers :: [Cipher]
tlsWantClientCert :: Bool
tlsServerHooks :: ServerHooks
tlsServerDHEParams :: Maybe DHParams
tlsSessionManagerConfig :: Maybe Config
tlsCredentials :: Maybe Credentials
tlsSessionManager :: Maybe SessionManager
tlsSupportedHashSignatures :: [HashAndSignatureAlgorithm]
..} Settings
_set Socket
s ByteString
bs0 params
params = do
    BufferPool
pool <- Int -> Int -> IO BufferPool
newBufferPool Int
2048 Int
16384
    Int -> IO ByteString
rawRecvN <- ByteString -> IO ByteString -> IO (Int -> IO ByteString)
makeRecvN ByteString
bs0 (IO ByteString -> IO (Int -> IO ByteString))
-> IO ByteString -> IO (Int -> IO ByteString)
forall a b. (a -> b) -> a -> b
$ Socket -> BufferPool -> IO ByteString
receive Socket
s BufferPool
pool
    let recvN :: Int -> IO ByteString
recvN = (Int -> IO ByteString) -> Int -> IO ByteString
forall {t}. (t -> IO ByteString) -> t -> IO ByteString
wrappedRecvN Int -> IO ByteString
rawRecvN
    Context
ctx <- Backend -> params -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew ((Int -> IO ByteString) -> Backend
backend Int -> IO ByteString
recvN) params
params
    Context -> Logging -> IO ()
TLS.contextHookSetLogging Context
ctx Logging
tlsLogging
    Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
    Bool
h2 <- (Maybe ByteString -> Maybe ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"h2") (Maybe ByteString -> Bool) -> IO (Maybe ByteString) -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> IO (Maybe ByteString)
forall (m :: * -> *). MonadIO m => Context -> m (Maybe ByteString)
TLS.getNegotiatedProtocol Context
ctx
    IORef Bool
isH2 <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
I.newIORef Bool
h2
    WriteBuffer
writeBuffer <- Int -> IO WriteBuffer
createWriteBuffer Int
16384
    IORef WriteBuffer
writeBufferRef <- WriteBuffer -> IO (IORef WriteBuffer)
forall a. a -> IO (IORef a)
I.newIORef WriteBuffer
writeBuffer
    -- Creating a cache for leftover input data.
    Transport
tls <- Context -> IO Transport
getTLSinfo Context
ctx
    SockAddr
mysa <- Socket -> IO SockAddr
getSocketName Socket
s
    (Connection, Transport) -> IO (Connection, Transport)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Context
-> IORef WriteBuffer -> IORef Bool -> SockAddr -> Connection
conn Context
ctx IORef WriteBuffer
writeBufferRef IORef Bool
isH2 SockAddr
mysa, Transport
tls)
  where
    backend :: (Int -> IO ByteString) -> Backend
backend Int -> IO ByteString
recvN =
        TLS.Backend
            { backendFlush :: IO ()
TLS.backendFlush = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if MIN_VERSION_network(3,1,1)
            , backendClose :: IO ()
TLS.backendClose =
                Socket -> Int -> IO ()
gracefulClose Socket
s Int
5000 IO () -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`E.catch` \(SomeException e
_) -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#else
            , TLS.backendClose = close s
#endif
            , backendSend :: ByteString -> IO ()
TLS.backendSend = Socket -> ByteString -> IO ()
sendAll' Socket
s
            , backendRecv :: Int -> IO ByteString
TLS.backendRecv = Int -> IO ByteString
recvN
            }
    sendAll' :: Socket -> ByteString -> IO ()
sendAll' Socket
sock ByteString
bs =
        (IOError -> Maybe InvalidRequest)
-> (InvalidRequest -> IO ()) -> IO () -> IO ()
forall (m :: * -> *) e b a.
(MonadUnliftIO m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
E.handleJust
            ( \IOError
e ->
                if IOError -> IOErrorType
ioeGetErrorType IOError
e IOErrorType -> IOErrorType -> Bool
forall a. Eq a => a -> a -> Bool
== IOErrorType
ResourceVanished
                    then InvalidRequest -> Maybe InvalidRequest
forall a. a -> Maybe a
Just InvalidRequest
ConnectionClosedByPeer
                    else Maybe InvalidRequest
forall a. Maybe a
Nothing
            )
            InvalidRequest -> IO ()
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO
            (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> IO ()
sendAll Socket
sock ByteString
bs
    conn :: Context
-> IORef WriteBuffer -> IORef Bool -> SockAddr -> Connection
conn Context
ctx IORef WriteBuffer
writeBufferRef IORef Bool
isH2 SockAddr
mysa =
        Connection
            { connSendMany :: [ByteString] -> IO ()
connSendMany = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> IO ())
-> ([ByteString] -> ByteString) -> [ByteString] -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
L.fromChunks
            , connSendAll :: ByteString -> IO ()
connSendAll = ByteString -> IO ()
sendall
            , connSendFile :: SendFile
connSendFile = SendFile
sendfile
            , connClose :: IO ()
connClose = IO ()
close'
            , connRecv :: IO ByteString
connRecv = IO ByteString
recv
            , connRecvBuf :: RecvBuf
connRecvBuf = \Buffer
_ Int
_ -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True -- obsoleted
            , connWriteBuffer :: IORef WriteBuffer
connWriteBuffer = IORef WriteBuffer
writeBufferRef
            , connHTTP2 :: IORef Bool
connHTTP2 = IORef Bool
isH2
            , connMySockAddr :: SockAddr
connMySockAddr = SockAddr
mysa
            }
      where
        sendall :: ByteString -> IO ()
sendall = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
L.fromChunks ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
forall a. a -> [a]
forall (m :: * -> *) a. Monad m => a -> m a
return
        recv :: IO ByteString
recv = (SomeException -> IO ByteString) -> IO ByteString -> IO ByteString
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
handle SomeException -> IO ByteString
forall {m :: * -> *}. MonadIO m => SomeException -> m ByteString
onEOF (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
          where
            onEOF :: SomeException -> m ByteString
onEOF SomeException
e
#if MIN_VERSION_tls(1,8,0)
                | Just (TLS.PostHandshake TLSError
TLS.Error_EOF) <- SomeException -> Maybe TLSException
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
e = ByteString -> m ByteString
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
#else
                | Just TLS.Error_EOF <- fromException e = return S.empty
#endif
                | Just IOError
ioe <- SomeException -> Maybe IOError
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e, IOError -> Bool
isEOFError IOError
ioe = ByteString -> m ByteString
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
                | Bool
otherwise = SomeException -> m ByteString
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO SomeException
e
        sendfile :: SendFile
sendfile FileId
fid Integer
offset Integer
len IO ()
hook [ByteString]
headers = do
            WriteBuffer
writeBuffer <- IORef WriteBuffer -> IO WriteBuffer
forall a. IORef a -> IO a
I.readIORef IORef WriteBuffer
writeBufferRef
            Buffer -> Int -> (ByteString -> IO ()) -> SendFile
readSendFile
                (WriteBuffer -> Buffer
bufBuffer WriteBuffer
writeBuffer)
                (WriteBuffer -> Int
bufSize WriteBuffer
writeBuffer)
                ByteString -> IO ()
sendall
                FileId
fid
                Integer
offset
                Integer
len
                IO ()
hook
                [ByteString]
headers

        close' :: IO ()
close' =
            IO (Either IOError ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO (Either IOError ())
forall a. IO a -> IO (Either IOError a)
tryIO IO ()
sendBye)
                IO () -> IO () -> IO ()
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
`finally` Context -> IO ()
TLS.contextClose Context
ctx

        sendBye :: IO ()
sendBye =
            -- It's fine if the connection was closed by the other side before
            -- receiving close_notify, see RFC 5246 section 7.2.1.
            (InvalidRequest -> Maybe InvalidRequest)
-> (InvalidRequest -> IO ()) -> IO () -> IO ()
forall (m :: * -> *) e b a.
(MonadUnliftIO m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust
                (\InvalidRequest
e -> Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (InvalidRequest
e InvalidRequest -> InvalidRequest -> Bool
forall a. Eq a => a -> a -> Bool
== InvalidRequest
ConnectionClosedByPeer) Maybe () -> Maybe InvalidRequest -> Maybe InvalidRequest
forall a b. Maybe a -> Maybe b -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> InvalidRequest -> Maybe InvalidRequest
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return InvalidRequest
e)
                (IO () -> InvalidRequest -> IO ()
forall a b. a -> b -> a
const (() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()))
                (Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx)

    wrappedRecvN :: (t -> IO ByteString) -> t -> IO ByteString
wrappedRecvN t -> IO ByteString
recvN t
n = (SomeException -> IO ByteString) -> IO ByteString -> IO ByteString
forall (m :: * -> *) a.
MonadUnliftIO m =>
(SomeException -> m a) -> m a -> m a
handleAny SomeException -> IO ByteString
handler (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ t -> IO ByteString
recvN t
n
    handler :: SomeException -> IO S.ByteString
    handler :: SomeException -> IO ByteString
handler SomeException
_ = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""

getTLSinfo :: TLS.Context -> IO Transport
getTLSinfo :: Context -> IO Transport
getTLSinfo Context
ctx = do
    Maybe ByteString
proto <- Context -> IO (Maybe ByteString)
forall (m :: * -> *). MonadIO m => Context -> m (Maybe ByteString)
TLS.getNegotiatedProtocol Context
ctx
    Maybe Information
minfo <- Context -> IO (Maybe Information)
TLS.contextGetInformation Context
ctx
    case Maybe Information
minfo of
        Maybe Information
Nothing -> Transport -> IO Transport
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Transport
TCP
        Just TLS.Information{Bool
Maybe ByteString
Maybe Group
Maybe ClientRandom
Maybe ServerRandom
Maybe HandshakeMode13
Version
Compression
Cipher
infoVersion :: Version
infoCipher :: Cipher
infoCompression :: Compression
infoMasterSecret :: Maybe ByteString
infoExtendedMasterSec :: Bool
infoClientRandom :: Maybe ClientRandom
infoServerRandom :: Maybe ServerRandom
infoNegotiatedGroup :: Maybe Group
infoTLS13HandshakeMode :: Maybe HandshakeMode13
infoIsEarlyDataAccepted :: Bool
infoVersion :: Information -> Version
infoCipher :: Information -> Cipher
infoCompression :: Information -> Compression
infoMasterSecret :: Information -> Maybe ByteString
infoExtendedMasterSec :: Information -> Bool
infoClientRandom :: Information -> Maybe ClientRandom
infoServerRandom :: Information -> Maybe ServerRandom
infoNegotiatedGroup :: Information -> Maybe Group
infoTLS13HandshakeMode :: Information -> Maybe HandshakeMode13
infoIsEarlyDataAccepted :: Information -> Bool
..} -> do
            let (Int
major, Int
minor) = case Version
infoVersion of
                    Version
TLS.SSL2 -> (Int
2, Int
0)
                    Version
TLS.SSL3 -> (Int
3, Int
0)
                    Version
TLS.TLS10 -> (Int
3, Int
1)
                    Version
TLS.TLS11 -> (Int
3, Int
2)
                    Version
TLS.TLS12 -> (Int
3, Int
3)
                    Version
_ -> (Int
3,Int
4)
            Maybe CertificateChain
clientCert <- Context -> IO (Maybe CertificateChain)
TLS.getClientCertificateChain Context
ctx
            Transport -> IO Transport
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
                TLS
                    { tlsMajorVersion :: Int
tlsMajorVersion = Int
major
                    , tlsMinorVersion :: Int
tlsMinorVersion = Int
minor
                    , tlsNegotiatedProtocol :: Maybe ByteString
tlsNegotiatedProtocol = Maybe ByteString
proto
                    , tlsChiperID :: Word16
tlsChiperID = Cipher -> Word16
TLS.cipherID Cipher
infoCipher
                    , tlsClientCertificate :: Maybe CertificateChain
tlsClientCertificate = Maybe CertificateChain
clientCert
                    }

tryIO :: IO a -> IO (Either IOException a)
tryIO :: forall a. IO a -> IO (Either IOError a)
tryIO = IO a -> IO (Either IOError a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
try

----------------------------------------------------------------

plainHTTP
    :: TLSSettings -> Settings -> Socket -> S.ByteString -> IO (Connection, Transport)
plainHTTP :: TLSSettings
-> Settings -> Socket -> ByteString -> IO (Connection, Transport)
plainHTTP TLSSettings{Bool
[HashAndSignatureAlgorithm]
[Version]
[Cipher]
Maybe DHParams
Maybe SessionManager
Maybe Credentials
Maybe Config
Logging
ServerHooks
OnInsecure
CertSettings
tlsCredentials :: TLSSettings -> Maybe Credentials
tlsLogging :: TLSSettings -> Logging
tlsAllowedVersions :: TLSSettings -> [Version]
tlsCiphers :: TLSSettings -> [Cipher]
tlsWantClientCert :: TLSSettings -> Bool
tlsServerHooks :: TLSSettings -> ServerHooks
tlsServerDHEParams :: TLSSettings -> Maybe DHParams
tlsSessionManagerConfig :: TLSSettings -> Maybe Config
tlsSessionManager :: TLSSettings -> Maybe SessionManager
onInsecure :: TLSSettings -> OnInsecure
certSettings :: TLSSettings -> CertSettings
tlsSupportedHashSignatures :: TLSSettings -> [HashAndSignatureAlgorithm]
certSettings :: CertSettings
onInsecure :: OnInsecure
tlsLogging :: Logging
tlsAllowedVersions :: [Version]
tlsCiphers :: [Cipher]
tlsWantClientCert :: Bool
tlsServerHooks :: ServerHooks
tlsServerDHEParams :: Maybe DHParams
tlsSessionManagerConfig :: Maybe Config
tlsCredentials :: Maybe Credentials
tlsSessionManager :: Maybe SessionManager
tlsSupportedHashSignatures :: [HashAndSignatureAlgorithm]
..} Settings
set Socket
s ByteString
bs0 = case OnInsecure
onInsecure of
    OnInsecure
AllowInsecure -> do
        Connection
conn' <- Settings -> Socket -> IO Connection
socketConnection Settings
set Socket
s
        IORef ByteString
cachedRef <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
I.newIORef ByteString
bs0
        let conn'' :: Connection
conn'' =
                Connection
conn'
                    { connRecv = recvPlain cachedRef (connRecv conn')
                    }
        (Connection, Transport) -> IO (Connection, Transport)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Connection
conn'', Transport
TCP)
    DenyInsecure ByteString
lbs -> do
        -- Listening port 443 but TLS records do not arrive.
        -- We want to let the browser know that TLS is required.
        -- So, we use 426.
        --     http://tools.ietf.org/html/rfc2817#section-4.2
        --     https://tools.ietf.org/html/rfc7231#section-6.5.15
        -- FIXME: should we distinguish HTTP/1.1 and HTTP/2?
        --        In the case of HTTP/2, should we send
        --        GOAWAY + INADEQUATE_SECURITY?
        -- FIXME: Content-Length:
        -- FIXME: TLS/<version>
        Socket -> ByteString -> IO ()
sendAll
            Socket
s
            "HTTP/1.1 426 Upgrade Required\
            \\r\nUpgrade: TLS/1.0, HTTP/1.1\
            \\r\nConnection: Upgrade\
            \\r\nContent-Type: text/plain\r\n\r\n"
        (ByteString -> IO ()) -> [ByteString] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Socket -> ByteString -> IO ()
sendAll Socket
s) ([ByteString] -> IO ()) -> [ByteString] -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString]
L.toChunks ByteString
lbs
        Socket -> IO ()
close Socket
s
        WarpTLSException -> IO (Connection, Transport)
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO WarpTLSException
InsecureConnectionDenied

----------------------------------------------------------------

-- | Modify the given receive function to first check the given @IORef@ for a
-- chunk of data. If present, takes the chunk of data from the @IORef@ and
-- empties out the @IORef@. Otherwise, calls the supplied receive function.
recvPlain :: I.IORef S.ByteString -> IO S.ByteString -> IO S.ByteString
recvPlain :: IORef ByteString -> IO ByteString -> IO ByteString
recvPlain IORef ByteString
ref IO ByteString
fallback = do
    ByteString
bs <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
I.readIORef IORef ByteString
ref
    if ByteString -> Bool
S.null ByteString
bs
        then IO ByteString
fallback
        else do
            IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ByteString
ref ByteString
S.empty
            ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

----------------------------------------------------------------

data WarpTLSException
    = InsecureConnectionDenied
    | ClientClosedConnectionPrematurely
    deriving (Int -> WarpTLSException -> ShowS
[WarpTLSException] -> ShowS
WarpTLSException -> FilePath
(Int -> WarpTLSException -> ShowS)
-> (WarpTLSException -> FilePath)
-> ([WarpTLSException] -> ShowS)
-> Show WarpTLSException
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> WarpTLSException -> ShowS
showsPrec :: Int -> WarpTLSException -> ShowS
$cshow :: WarpTLSException -> FilePath
show :: WarpTLSException -> FilePath
$cshowList :: [WarpTLSException] -> ShowS
showList :: [WarpTLSException] -> ShowS
Show, Typeable)
instance Exception WarpTLSException