{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}

-- | Running an HTTP\/2 client over TLS.
module Network.HTTP2.TLS.Client (
    -- * Runners
    run,
    runH2C,
    Client,
    HostName,
    PortNumber,
    runTLS,

    -- * Settings
    Settings,
    defaultSettings,
    settingsKeyLogger,
    settingsValidateCert,
    settingsCAStore,
    settingsAddrInfoFlags,
) where

import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as C8
import Data.Default.Class (def)
import Data.X509.Validation (validateDefault)
import Network.HTTP2.Client (
    Client,
    ClientConfig (..),
 )
import qualified Network.HTTP2.Client as H2Client
import Network.Socket
import Network.TLS hiding (HostName)
import qualified UnliftIO.Exception as E

import Network.HTTP2.TLS.Client.Settings
import Network.HTTP2.TLS.Config
import Network.HTTP2.TLS.IO
import Network.HTTP2.TLS.Internal (gclose)
import qualified Network.HTTP2.TLS.Server.Settings as Server
import Network.HTTP2.TLS.Supported

----------------------------------------------------------------

-- | Running a TLS client.
runTLS
    :: Settings
    -> HostName
    -> PortNumber
    -> ByteString
    -- ^ ALPN
    -> (Context -> SockAddr -> SockAddr -> IO a)
    -> IO a
runTLS :: forall a.
Settings
-> HostName
-> PortNumber
-> ByteString
-> (Context -> SockAddr -> SockAddr -> IO a)
-> IO a
runTLS Settings
settings HostName
serverName PortNumber
port ByteString
alpn Context -> SockAddr -> SockAddr -> IO a
action =
    forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO Socket
open Socket -> IO ()
gclose forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
        SockAddr
mysa <- Socket -> IO SockAddr
getSocketName Socket
sock
        SockAddr
peersa <- Socket -> IO SockAddr
getPeerName Socket
sock
        forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket (forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Socket
sock ClientParams
params) forall (m :: * -> *). MonadIO m => Context -> m ()
bye forall a b. (a -> b) -> a -> b
$ \Context
ctx -> do
            forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
ctx
            Context -> SockAddr -> SockAddr -> IO a
action Context
ctx SockAddr
mysa SockAddr
peersa
  where
    open :: IO Socket
open = [AddrInfoFlag] -> HostName -> PortNumber -> IO Socket
openTCP (Settings -> [AddrInfoFlag]
settingsAddrInfoFlags Settings
settings) HostName
serverName PortNumber
port
    params :: ClientParams
params = Settings -> HostName -> ByteString -> ClientParams
getClientParams Settings
settings HostName
serverName ByteString
alpn

-- | Running an HTTP\/2 client over TLS (over TCP).
run :: Settings -> HostName -> PortNumber -> Client a -> IO a
run :: forall a. Settings -> HostName -> PortNumber -> Client a -> IO a
run Settings
settings HostName
serverName PortNumber
port Client a
client =
    forall a.
Settings
-> HostName
-> PortNumber
-> ByteString
-> (Context -> SockAddr -> SockAddr -> IO a)
-> IO a
runTLS Settings
settings HostName
serverName PortNumber
port ByteString
"h2" forall a b. (a -> b) -> a -> b
$ \Context
ctx SockAddr
mysa SockAddr
peersa ->
        forall a.
ByteString
-> HostName
-> (ByteString -> IO ())
-> IO ByteString
-> SockAddr
-> SockAddr
-> Client a
-> IO a
run' ByteString
"https" HostName
serverName (Context -> ByteString -> IO ()
sendTLS Context
ctx) (Context -> IO ByteString
recvTLS Context
ctx) SockAddr
mysa SockAddr
peersa Client a
client

-- | Running an HTTP\/2 client over TCP.
runH2C :: HostName -> PortNumber -> Client a -> IO a
runH2C :: forall a. HostName -> PortNumber -> Client a -> IO a
runH2C HostName
serverName PortNumber
port Client a
client =
    forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO Socket
open Socket -> IO ()
close forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
        SockAddr
mysa <- Socket -> IO SockAddr
getSocketName Socket
sock
        SockAddr
peersa <- Socket -> IO SockAddr
getPeerName Socket
sock
        IO ByteString
recv <- Settings -> Socket -> IO (IO ByteString)
mkRecvTCP Settings
Server.defaultSettings Socket
sock
        forall a.
ByteString
-> HostName
-> (ByteString -> IO ())
-> IO ByteString
-> SockAddr
-> SockAddr
-> Client a
-> IO a
run' ByteString
"http" HostName
serverName (Socket -> ByteString -> IO ()
sendTCP Socket
sock) IO ByteString
recv SockAddr
mysa SockAddr
peersa Client a
client
  where
    open :: IO Socket
open = [AddrInfoFlag] -> HostName -> PortNumber -> IO Socket
openTCP (Settings -> [AddrInfoFlag]
settingsAddrInfoFlags Settings
defaultSettings) HostName
serverName PortNumber
port

run'
    :: ByteString
    -> HostName
    -> (ByteString -> IO ())
    -> IO ByteString
    -> SockAddr
    -> SockAddr
    -> Client a
    -> IO a
run' :: forall a.
ByteString
-> HostName
-> (ByteString -> IO ())
-> IO ByteString
-> SockAddr
-> SockAddr
-> Client a
-> IO a
run' ByteString
schm HostName
serverName ByteString -> IO ()
send IO ByteString
recv SockAddr
mysa SockAddr
peersa Client a
client =
    forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket
        ((ByteString -> IO ())
-> IO ByteString -> SockAddr -> SockAddr -> IO Config
allocConfigForClient ByteString -> IO ()
send IO ByteString
recv SockAddr
mysa SockAddr
peersa)
        Config -> IO ()
freeConfigForClient
        (\Config
conf -> forall a. ClientConfig -> Config -> Client a -> IO a
H2Client.run ClientConfig
cliconf Config
conf Client a
client)
  where
    cliconf :: ClientConfig
cliconf =
        ClientConfig
            { scheme :: ByteString
scheme = ByteString
schm
            , authority :: ByteString
authority = HostName -> ByteString
C8.pack HostName
serverName
            , cacheLimit :: Int
cacheLimit = Int
20
            }

openTCP :: [AddrInfoFlag] -> HostName -> PortNumber -> IO Socket
openTCP :: [AddrInfoFlag] -> HostName -> PortNumber -> IO Socket
openTCP [AddrInfoFlag]
flags HostName
h PortNumber
p = do
    AddrInfo
ai <- [AddrInfoFlag] -> HostName -> PortNumber -> IO AddrInfo
makeAddrInfo [AddrInfoFlag]
flags HostName
h PortNumber
p
    Socket
sock <- AddrInfo -> IO Socket
openSocket AddrInfo
ai
    Socket -> SockAddr -> IO ()
connect Socket
sock forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
addrAddress AddrInfo
ai
    forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock

makeAddrInfo :: [AddrInfoFlag] -> HostName -> PortNumber -> IO AddrInfo
makeAddrInfo :: [AddrInfoFlag] -> HostName -> PortNumber -> IO AddrInfo
makeAddrInfo [AddrInfoFlag]
flags HostName
nh PortNumber
p = do
    let hints :: AddrInfo
hints =
            AddrInfo
defaultHints
                { addrFlags :: [AddrInfoFlag]
addrFlags = [AddrInfoFlag]
flags
                , addrSocketType :: SocketType
addrSocketType = SocketType
Stream
                }
    let np :: HostName
np = forall a. Show a => a -> HostName
show PortNumber
p
    forall a. [a] -> a
head forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
getAddrInfo (forall a. a -> Maybe a
Just AddrInfo
hints) (forall a. a -> Maybe a
Just HostName
nh) (forall a. a -> Maybe a
Just HostName
np)

----------------------------------------------------------------

getClientParams
    :: Settings
    -> HostName
    -> ByteString
    -- ^ ALPN
    -> ClientParams
getClientParams :: Settings -> HostName -> ByteString -> ClientParams
getClientParams Settings{Bool
[AddrInfoFlag]
CertificateStore
HostName -> IO ()
settingsAddrInfoFlags :: [AddrInfoFlag]
settingsCAStore :: CertificateStore
settingsValidateCert :: Bool
settingsKeyLogger :: HostName -> IO ()
settingsAddrInfoFlags :: Settings -> [AddrInfoFlag]
settingsCAStore :: Settings -> CertificateStore
settingsValidateCert :: Settings -> Bool
settingsKeyLogger :: Settings -> HostName -> IO ()
..} HostName
serverName ByteString
alpn =
    (HostName -> ByteString -> ClientParams
defaultParamsClient HostName
serverName ByteString
"")
        { clientSupported :: Supported
clientSupported = Supported
supported
        , clientWantSessionResume :: Maybe (ByteString, SessionData)
clientWantSessionResume = forall a. Maybe a
Nothing
        , clientUseServerNameIndication :: Bool
clientUseServerNameIndication = Bool
True
        , clientShared :: Shared
clientShared = Shared
shared
        , clientHooks :: ClientHooks
clientHooks = ClientHooks
hooks
        , clientDebug :: DebugParams
clientDebug = DebugParams
debug
        }
  where
    shared :: Shared
shared =
        forall a. Default a => a
def
            { sharedValidationCache :: ValidationCache
sharedValidationCache = ValidationCache
validateCache
            , sharedCAStore :: CertificateStore
sharedCAStore = CertificateStore
settingsCAStore
            }
    supported :: Supported
supported = Supported
strongSupported
    hooks :: ClientHooks
hooks =
        forall a. Default a => a
def
            { onSuggestALPN :: IO (Maybe [ByteString])
onSuggestALPN = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just [ByteString
alpn]
            , onServerCertificate :: OnServerCertificate
onServerCertificate = OnServerCertificate
validateCert
            }
    validateCache :: ValidationCache
validateCache
        | Bool
settingsValidateCert = forall a. Default a => a
def
        | Bool
otherwise =
            ValidationCacheQueryCallback
-> ValidationCacheAddCallback -> ValidationCache
ValidationCache
                (\ServiceID
_ Fingerprint
_ Certificate
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ValidationCacheResult
ValidationCachePass)
                (\ServiceID
_ Fingerprint
_ Certificate
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ())
    validateCert :: OnServerCertificate
validateCert
        | Bool
settingsValidateCert = OnServerCertificate
validateDefault
        | Bool
otherwise = \CertificateStore
_ ValidationCache
_ ServiceID
_ CertificateChain
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return []
    debug :: DebugParams
debug =
        forall a. Default a => a
def
            { debugKeyLogger :: HostName -> IO ()
debugKeyLogger = HostName -> IO ()
settingsKeyLogger
            }