{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.HTTP2.TLS.IO where

import Control.Monad (void, when)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import Network.Socket
import Network.Socket.BufferPool
import qualified Network.Socket.ByteString as NSB
import Network.TLS hiding (HostName)
import System.IO.Error (isEOFError)
import qualified System.TimeManager as T
import qualified UnliftIO.Exception as E

import Network.HTTP2.TLS.Server.Settings

----------------------------------------------------------------

-- HTTP2: confReadN == recvTLS
-- TLS:   recvData  == contextRecv == backendRecv

----------------------------------------------------------------

mkRecvTCP :: Settings -> Socket -> IO (IO ByteString)
mkRecvTCP :: Settings -> Socket -> IO (IO ByteString)
mkRecvTCP Settings{Int
SessionManager
String -> IO ()
AddrInfo -> IO Socket
settingsTimeout :: Int
settingsSendBufferSize :: Int
settingsSlowlorisSize :: Int
settingsReadBufferSize :: Int
settingsReadBufferLowerLimit :: Int
settingsKeyLogger :: String -> IO ()
settingsNumberOfWorkers :: Int
settingsConcurrentStreams :: Int
settingsStreamWindowSize :: Int
settingsConnectionWindowSize :: Int
settingsSessionManager :: SessionManager
settingsOpenServerSocket :: AddrInfo -> IO Socket
settingsEarlyDataSize :: Int
settingsTimeout :: Settings -> Int
settingsSendBufferSize :: Settings -> Int
settingsSlowlorisSize :: Settings -> Int
settingsReadBufferSize :: Settings -> Int
settingsReadBufferLowerLimit :: Settings -> Int
settingsKeyLogger :: Settings -> String -> IO ()
settingsNumberOfWorkers :: Settings -> Int
settingsConcurrentStreams :: Settings -> Int
settingsStreamWindowSize :: Settings -> Int
settingsConnectionWindowSize :: Settings -> Int
settingsSessionManager :: Settings -> SessionManager
settingsOpenServerSocket :: Settings -> AddrInfo -> IO Socket
settingsEarlyDataSize :: Settings -> Int
..} Socket
sock = do
    BufferPool
pool <- Int -> Int -> IO BufferPool
newBufferPool Int
settingsReadBufferLowerLimit Int
settingsReadBufferSize
    IO ByteString -> IO (IO ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IO ByteString -> IO (IO ByteString))
-> IO ByteString -> IO (IO ByteString)
forall a b. (a -> b) -> a -> b
$ Socket -> BufferPool -> IO ByteString
receive Socket
sock BufferPool
pool

sendTCP :: Socket -> ByteString -> IO ()
sendTCP :: Socket -> ByteString -> IO ()
sendTCP Socket
sock = Socket -> ByteString -> IO ()
NSB.sendAll Socket
sock

----------------------------------------------------------------

-- | Sending and receiving functions.
--   Tiemout is reset when they return.
--   One exception is the slowloris attach prevention.
--   See 'settingsSlowlorisSize'.
data IOBackend = IOBackend
    { IOBackend -> ByteString -> IO ()
send :: ByteString -> IO ()
    -- ^ Sending.
    , IOBackend -> [ByteString] -> IO ()
sendMany :: [ByteString] -> IO ()
    -- ^ Sending many.
    , IOBackend -> IO ByteString
recv :: IO ByteString
    -- ^ Receiving.
    , IOBackend -> SockAddr
mySockAddr :: SockAddr
    , IOBackend -> SockAddr
peerSockAddr :: SockAddr
    }

timeoutIOBackend :: T.Handle -> Settings -> IOBackend -> IOBackend
timeoutIOBackend :: Handle -> Settings -> IOBackend -> IOBackend
timeoutIOBackend Handle
th Settings{Int
SessionManager
String -> IO ()
AddrInfo -> IO Socket
settingsTimeout :: Settings -> Int
settingsSendBufferSize :: Settings -> Int
settingsSlowlorisSize :: Settings -> Int
settingsReadBufferSize :: Settings -> Int
settingsReadBufferLowerLimit :: Settings -> Int
settingsKeyLogger :: Settings -> String -> IO ()
settingsNumberOfWorkers :: Settings -> Int
settingsConcurrentStreams :: Settings -> Int
settingsStreamWindowSize :: Settings -> Int
settingsConnectionWindowSize :: Settings -> Int
settingsSessionManager :: Settings -> SessionManager
settingsOpenServerSocket :: Settings -> AddrInfo -> IO Socket
settingsEarlyDataSize :: Settings -> Int
settingsTimeout :: Int
settingsSendBufferSize :: Int
settingsSlowlorisSize :: Int
settingsReadBufferSize :: Int
settingsReadBufferLowerLimit :: Int
settingsKeyLogger :: String -> IO ()
settingsNumberOfWorkers :: Int
settingsConcurrentStreams :: Int
settingsStreamWindowSize :: Int
settingsConnectionWindowSize :: Int
settingsSessionManager :: SessionManager
settingsOpenServerSocket :: AddrInfo -> IO Socket
settingsEarlyDataSize :: Int
..} IOBackend{IO ByteString
SockAddr
[ByteString] -> IO ()
ByteString -> IO ()
send :: IOBackend -> ByteString -> IO ()
sendMany :: IOBackend -> [ByteString] -> IO ()
recv :: IOBackend -> IO ByteString
mySockAddr :: IOBackend -> SockAddr
peerSockAddr :: IOBackend -> SockAddr
send :: ByteString -> IO ()
sendMany :: [ByteString] -> IO ()
recv :: IO ByteString
mySockAddr :: SockAddr
peerSockAddr :: SockAddr
..} =
    (ByteString -> IO ())
-> ([ByteString] -> IO ())
-> IO ByteString
-> SockAddr
-> SockAddr
-> IOBackend
IOBackend ByteString -> IO ()
send' [ByteString] -> IO ()
sendMany' IO ByteString
recv' SockAddr
mySockAddr SockAddr
peerSockAddr
  where
    send' :: ByteString -> IO ()
send' ByteString
bs = ByteString -> IO ()
send ByteString
bs IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
T.tickle Handle
th
    sendMany' :: [ByteString] -> IO ()
sendMany' [ByteString]
bss = [ByteString] -> IO ()
sendMany [ByteString]
bss IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
T.tickle Handle
th
    recv' :: IO ByteString
recv' = do
        ByteString
bs <- IO ByteString
recv
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
settingsSlowlorisSize) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
T.tickle Handle
th
        ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

tlsIOBackend :: Context -> Socket -> IO IOBackend
tlsIOBackend :: Context -> Socket -> IO IOBackend
tlsIOBackend Context
ctx Socket
sock = do
    SockAddr
mysa <- Socket -> IO SockAddr
getSocketName Socket
sock
    SockAddr
peersa <- Socket -> IO SockAddr
getPeerName Socket
sock
    IOBackend -> IO IOBackend
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOBackend -> IO IOBackend) -> IOBackend -> IO IOBackend
forall a b. (a -> b) -> a -> b
$
        IOBackend
            { send :: ByteString -> IO ()
send = Context -> ByteString -> IO ()
sendTLS Context
ctx
            , sendMany :: [ByteString] -> IO ()
sendMany = Context -> [ByteString] -> IO ()
sendManyTLS Context
ctx
            , recv :: IO ByteString
recv = Context -> IO ByteString
recvTLS Context
ctx
            , mySockAddr :: SockAddr
mySockAddr = SockAddr
mysa
            , peerSockAddr :: SockAddr
peerSockAddr = SockAddr
peersa
            }

tcpIOBackend :: Settings -> Socket -> IO IOBackend
tcpIOBackend :: Settings -> Socket -> IO IOBackend
tcpIOBackend Settings
settings Socket
sock = do
    IO ByteString
recv' <- Settings -> Socket -> IO (IO ByteString)
mkRecvTCP Settings
settings Socket
sock
    SockAddr
mysa <- Socket -> IO SockAddr
getSocketName Socket
sock
    SockAddr
peersa <- Socket -> IO SockAddr
getPeerName Socket
sock
    IOBackend -> IO IOBackend
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOBackend -> IO IOBackend) -> IOBackend -> IO IOBackend
forall a b. (a -> b) -> a -> b
$
        IOBackend
            { send :: ByteString -> IO ()
send = IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> (ByteString -> IO Int) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> ByteString -> IO Int
NSB.send Socket
sock
            , sendMany :: [ByteString] -> IO ()
sendMany = \[ByteString]
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            , recv :: IO ByteString
recv = IO ByteString
recv'
            , mySockAddr :: SockAddr
mySockAddr = SockAddr
mysa
            , peerSockAddr :: SockAddr
peerSockAddr = SockAddr
peersa
            }

----------------------------------------------------------------

sendTLS :: Context -> ByteString -> IO ()
sendTLS :: Context -> ByteString -> IO ()
sendTLS Context
ctx = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
sendData Context
ctx (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
LBS.fromStrict

sendManyTLS :: Context -> [ByteString] -> IO ()
sendManyTLS :: Context -> [ByteString] -> IO ()
sendManyTLS Context
ctx = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
sendData Context
ctx (ByteString -> IO ())
-> ([ByteString] -> ByteString) -> [ByteString] -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
LBS.fromChunks

{- FOURMOLU_DISABLE -}
-- TLS version of recv (decrypting) without a cache.
recvTLS :: Context -> IO ByteString
recvTLS :: Context -> IO ByteString
recvTLS Context
ctx = (SomeException -> IO ByteString) -> IO ByteString -> IO ByteString
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
E.handle SomeException -> IO ByteString
forall {m :: * -> *} {a}.
(IsString a, MonadIO m) =>
SomeException -> m a
onEOF (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
recvData Context
ctx
  where
    onEOF :: SomeException -> m a
onEOF SomeException
e
#if MIN_VERSION_tls(1,8,0)
        | Just (PostHandshake TLSError
Error_EOF) <- SomeException -> Maybe TLSException
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
e = a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
""
#else
        | Just Error_EOF <- E.fromException e = return ""
#endif
        | Just IOError
ioe <- SomeException -> Maybe IOError
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
e, IOError -> Bool
isEOFError IOError
ioe = a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
""
        | Bool
otherwise = SomeException -> m a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO SomeException
e
{- FOURMOLU_ENABLE -}

----------------------------------------------------------------

mkBackend :: Settings -> Socket -> IO Backend
mkBackend :: Settings -> Socket -> IO Backend
mkBackend Settings
settings Socket
sock = do
    let send' :: ByteString -> IO ()
send' = Socket -> ByteString -> IO ()
sendTCP Socket
sock
    IO ByteString
recv' <- Settings -> Socket -> IO (IO ByteString)
mkRecvTCP Settings
settings Socket
sock
    RecvN
recvN <- ByteString -> IO ByteString -> IO RecvN
makeRecvN ByteString
"" IO ByteString
recv'
    Backend -> IO Backend
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
        Backend
            { backendFlush :: IO ()
backendFlush = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            , backendClose :: IO ()
backendClose =
                Socket -> Int -> IO ()
gracefulClose Socket
sock Int
5000 IO () -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`E.catch` \(E.SomeException e
_) -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            , backendSend :: ByteString -> IO ()
backendSend = ByteString -> IO ()
send'
            , backendRecv :: RecvN
backendRecv = RecvN
recvN
            }