{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

-- | Running an HTTP\/2 client over TLS.
module Network.HTTP2.TLS.Client (
    -- * Runners
    run,
    runH2C,
    Client,
    HostName,
    PortNumber,
    runTLS,
) where

import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as C8
import Data.Default.Class (def)
import Network.HTTP2.Client (
    Client,
    ClientConfig (..),
 )
import qualified Network.HTTP2.Client as H2Client
import Network.Socket
import Network.TLS hiding (HostName)
import qualified UnliftIO.Exception as E

import Network.HTTP2.TLS.Config
import Network.HTTP2.TLS.IO
import Network.HTTP2.TLS.Settings
import Network.HTTP2.TLS.Supported

----------------------------------------------------------------

-- | Running a TLS client.
runTLS
    :: HostName
    -> PortNumber
    -> ByteString
    -- ^ ALPN
    -> (Context -> IO a)
    -> IO a
runTLS :: forall a.
HostName -> PortNumber -> ByteString -> (Context -> IO a) -> IO a
runTLS HostName
serverName PortNumber
port ByteString
alpn Context -> IO a
action =
    forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO Socket
open Socket -> IO ()
close forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
        forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket (forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Socket
sock ClientParams
params) forall (m :: * -> *). MonadIO m => Context -> m ()
bye forall a b. (a -> b) -> a -> b
$ \Context
ctx -> do
            forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
ctx
            Context -> IO a
action Context
ctx
  where
    open :: IO Socket
open = HostName -> PortNumber -> IO Socket
openTCP HostName
serverName PortNumber
port
    params :: ClientParams
params = HostName -> ByteString -> Bool -> ClientParams
getClientParams HostName
serverName ByteString
alpn Bool
False

-- | Running an HTTP\/2 client over TLS (over TCP).
run :: HostName -> PortNumber -> Client a -> IO a
run :: forall a. HostName -> PortNumber -> Client a -> IO a
run HostName
serverName PortNumber
port Client a
client =
    forall a.
HostName -> PortNumber -> ByteString -> (Context -> IO a) -> IO a
runTLS HostName
serverName PortNumber
port ByteString
"h2" forall a b. (a -> b) -> a -> b
$ \Context
ctx ->
        forall a.
ByteString
-> HostName
-> (ByteString -> IO ())
-> IO ByteString
-> Client a
-> IO a
run' ByteString
"https" HostName
serverName (Context -> ByteString -> IO ()
sendTLS Context
ctx) (Context -> IO ByteString
recvTLS Context
ctx) Client a
client

-- | Running an HTTP\/2 client over TCP.
runH2C :: HostName -> PortNumber -> Client a -> IO a
runH2C :: forall a. HostName -> PortNumber -> Client a -> IO a
runH2C HostName
serverName PortNumber
port Client a
client =
    forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO Socket
open Socket -> IO ()
close forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
        IO ByteString
recv <- Settings -> Socket -> IO (IO ByteString)
mkRecvTCP Settings
defaultSettings Socket
sock
        forall a.
ByteString
-> HostName
-> (ByteString -> IO ())
-> IO ByteString
-> Client a
-> IO a
run' ByteString
"http" HostName
serverName (Socket -> ByteString -> IO ()
sendTCP Socket
sock) IO ByteString
recv Client a
client
  where
    open :: IO Socket
open = HostName -> PortNumber -> IO Socket
openTCP HostName
serverName PortNumber
port

run'
    :: ByteString
    -> HostName
    -> (ByteString -> IO ())
    -> IO ByteString
    -> Client a
    -> IO a
run' :: forall a.
ByteString
-> HostName
-> (ByteString -> IO ())
-> IO ByteString
-> Client a
-> IO a
run' ByteString
schm HostName
serverName ByteString -> IO ()
send IO ByteString
recv Client a
client =
    forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket
        ((ByteString -> IO ()) -> IO ByteString -> IO Config
allocConfigForClient ByteString -> IO ()
send IO ByteString
recv)
        Config -> IO ()
freeConfigForClient
        (\Config
conf -> forall a. ClientConfig -> Config -> Client a -> IO a
H2Client.run ClientConfig
cliconf Config
conf Client a
client)
  where
    cliconf :: ClientConfig
cliconf =
        ClientConfig
            { scheme :: ByteString
scheme = ByteString
schm
            , authority :: ByteString
authority = HostName -> ByteString
C8.pack HostName
serverName
            , cacheLimit :: Int
cacheLimit = Int
20
            }

openTCP :: HostName -> PortNumber -> IO Socket
openTCP :: HostName -> PortNumber -> IO Socket
openTCP HostName
h PortNumber
p = do
    AddrInfo
ai <- HostName -> PortNumber -> IO AddrInfo
makeAddrInfo HostName
h PortNumber
p
    Socket
sock <- AddrInfo -> IO Socket
openSocket AddrInfo
ai
    Socket -> SockAddr -> IO ()
connect Socket
sock forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
addrAddress AddrInfo
ai
    forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock

makeAddrInfo :: HostName -> PortNumber -> IO AddrInfo
makeAddrInfo :: HostName -> PortNumber -> IO AddrInfo
makeAddrInfo HostName
nh PortNumber
p = do
    let hints :: AddrInfo
hints =
            AddrInfo
defaultHints
                { addrFlags :: [AddrInfoFlag]
addrFlags = [AddrInfoFlag
AI_ADDRCONFIG, AddrInfoFlag
AI_NUMERICHOST, AddrInfoFlag
AI_NUMERICSERV]
                , addrSocketType :: SocketType
addrSocketType = SocketType
Stream
                }
    let np :: HostName
np = forall a. Show a => a -> HostName
show PortNumber
p
    forall a. [a] -> a
head forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
getAddrInfo (forall a. a -> Maybe a
Just AddrInfo
hints) (forall a. a -> Maybe a
Just HostName
nh) (forall a. a -> Maybe a
Just HostName
np)

----------------------------------------------------------------

getClientParams
    :: HostName
    -> ByteString
    -- ^ ALPN
    -> Bool
    -- ^ Checking server certificates
    -> ClientParams
getClientParams :: HostName -> ByteString -> Bool -> ClientParams
getClientParams HostName
serverName ByteString
alpn Bool
validate =
    (HostName -> ByteString -> ClientParams
defaultParamsClient HostName
serverName ByteString
"")
        { clientSupported :: Supported
clientSupported = Supported
supported
        , clientWantSessionResume :: Maybe (ByteString, SessionData)
clientWantSessionResume = forall a. Maybe a
Nothing
        , clientUseServerNameIndication :: Bool
clientUseServerNameIndication = Bool
True
        , clientShared :: Shared
clientShared = Shared
shared
        , clientHooks :: ClientHooks
clientHooks = ClientHooks
hooks
        }
  where
    shared :: Shared
shared =
        forall a. Default a => a
def
            { sharedValidationCache :: ValidationCache
sharedValidationCache = ValidationCache
validateCache
            }
    supported :: Supported
supported = Supported
strongSupported
    hooks :: ClientHooks
hooks =
        forall a. Default a => a
def
            { onSuggestALPN :: IO (Maybe [ByteString])
onSuggestALPN = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just [ByteString
alpn]
            }
    validateCache :: ValidationCache
validateCache
        | Bool
validate = forall a. Default a => a
def
        | Bool
otherwise =
            ValidationCacheQueryCallback
-> ValidationCacheAddCallback -> ValidationCache
ValidationCache
                (\ServiceID
_ Fingerprint
_ Certificate
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ValidationCacheResult
ValidationCachePass)
                (\ServiceID
_ Fingerprint
_ Certificate
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ())