{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.Wai.Handler.Warp.Types where

import qualified UnliftIO
import qualified Data.ByteString as S
import Data.IORef (IORef, readIORef, writeIORef, newIORef)
import Data.Typeable (Typeable)
#ifdef MIN_VERSION_crypton_x509
import Data.X509
#endif
import Network.Socket (SockAddr)
import Network.Socket.BufferPool
import System.Posix.Types (Fd)
import qualified System.TimeManager as T

import qualified Network.Wai.Handler.Warp.Date as D
import qualified Network.Wai.Handler.Warp.FdCache as F
import qualified Network.Wai.Handler.Warp.FileInfoCache as I
import Network.Wai.Handler.Warp.Imports

----------------------------------------------------------------

-- | TCP port number.
type Port = Int

----------------------------------------------------------------

-- | The type for header value used with 'HeaderName'.
type HeaderValue = ByteString

----------------------------------------------------------------

-- | Error types for bad 'Request'.
data InvalidRequest = NotEnoughLines [String]
                    | BadFirstLine String
                    | NonHttp
                    | IncompleteHeaders
                    | ConnectionClosedByPeer
                    | OverLargeHeader
                    | BadProxyHeader String
                    | PayloadTooLarge -- ^ Since 3.3.22
                    | RequestHeaderFieldsTooLarge -- ^ Since 3.3.22
                    deriving (InvalidRequest -> InvalidRequest -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: InvalidRequest -> InvalidRequest -> Bool
$c/= :: InvalidRequest -> InvalidRequest -> Bool
== :: InvalidRequest -> InvalidRequest -> Bool
$c== :: InvalidRequest -> InvalidRequest -> Bool
Eq, Typeable)

instance Show InvalidRequest where
    show :: InvalidRequest -> String
show (NotEnoughLines [String]
xs) = String
"Warp: Incomplete request headers, received: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [String]
xs
    show (BadFirstLine String
s) = String
"Warp: Invalid first line of request: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show String
s
    show InvalidRequest
NonHttp = String
"Warp: Request line specified a non-HTTP request"
    show InvalidRequest
IncompleteHeaders = String
"Warp: Request headers did not finish transmission"
    show InvalidRequest
ConnectionClosedByPeer = String
"Warp: Client closed connection prematurely"
    show InvalidRequest
OverLargeHeader = String
"Warp: Request headers too large, possible memory attack detected. Closing connection."
    show (BadProxyHeader String
s) = String
"Warp: Invalid PROXY protocol header: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show String
s
    show InvalidRequest
RequestHeaderFieldsTooLarge = String
"Request header fields too large"
    show InvalidRequest
PayloadTooLarge = String
"Payload too large"

instance UnliftIO.Exception InvalidRequest

----------------------------------------------------------------

-- | Exception thrown if something goes wrong while in the midst of
-- sending a response, since the status code can't be altered at that
-- point.
--
-- Used to determine whether keeping the HTTP1.1 connection / HTTP2 stream alive is safe
-- or irrecoverable.

newtype ExceptionInsideResponseBody = ExceptionInsideResponseBody UnliftIO.SomeException
    deriving (Int -> ExceptionInsideResponseBody -> ShowS
[ExceptionInsideResponseBody] -> ShowS
ExceptionInsideResponseBody -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ExceptionInsideResponseBody] -> ShowS
$cshowList :: [ExceptionInsideResponseBody] -> ShowS
show :: ExceptionInsideResponseBody -> String
$cshow :: ExceptionInsideResponseBody -> String
showsPrec :: Int -> ExceptionInsideResponseBody -> ShowS
$cshowsPrec :: Int -> ExceptionInsideResponseBody -> ShowS
Show, Typeable)

instance UnliftIO.Exception ExceptionInsideResponseBody

----------------------------------------------------------------

-- | Data type to abstract file identifiers.
--   On Unix, a file descriptor would be specified to make use of
--   the file descriptor cache.
--
-- Since: 3.1.0
data FileId = FileId {
    FileId -> String
fileIdPath :: FilePath
  , FileId -> Maybe Fd
fileIdFd   :: Maybe Fd
  }

-- |  fileid, offset, length, hook action, HTTP headers
--
-- Since: 3.1.0
type SendFile = FileId -> Integer -> Integer -> IO () -> [ByteString] -> IO ()

-- | A write buffer of a specified size
-- containing bytes and a way to free the buffer.
data WriteBuffer = WriteBuffer {
      WriteBuffer -> Buffer
bufBuffer :: Buffer
      -- | The size of the write buffer.
    , WriteBuffer -> Int
bufSize :: !BufSize
      -- | Free the allocated buffer. Warp guarantees it will only be
      -- called once, and no other functions will be called after it.
    , WriteBuffer -> IO ()
bufFree :: IO ()
    }

type RecvBuf = Buffer -> BufSize -> IO Bool

-- | Data type to manipulate IO actions for connections.
--   This is used to abstract IO actions for plain HTTP and HTTP over TLS.
data Connection = Connection {
    -- | This is not used at this moment.
      Connection -> [ByteString] -> IO ()
connSendMany    :: [ByteString] -> IO ()
    -- | The sending function.
    , Connection -> ByteString -> IO ()
connSendAll     :: ByteString -> IO ()
    -- | The sending function for files in HTTP/1.1.
    , Connection -> SendFile
connSendFile    :: SendFile
    -- | The connection closing function. Warp guarantees it will only be
    -- called once. Other functions (like 'connRecv') may be called after
    -- 'connClose' is called.
    , Connection -> IO ()
connClose       :: IO ()
    -- | The connection receiving function. This returns "" for EOF or exceptions.
    , Connection -> Recv
connRecv        :: Recv
    -- | Obsoleted.
    , Connection -> RecvBuf
connRecvBuf     :: RecvBuf
    -- | Reference to a write buffer. When during sending of a 'Builder'
    -- response it's detected the current 'WriteBuffer' is too small it will be
    -- freed and a new bigger buffer will be created and written to this
    -- reference.
    , Connection -> IORef WriteBuffer
connWriteBuffer :: IORef WriteBuffer
    -- | Is this connection HTTP/2?
    , Connection -> IORef Bool
connHTTP2       :: IORef Bool
    , Connection -> SockAddr
connMySockAddr  :: SockAddr
    }

getConnHTTP2 :: Connection -> IO Bool
getConnHTTP2 :: Connection -> IO Bool
getConnHTTP2 Connection
conn = forall a. IORef a -> IO a
readIORef (Connection -> IORef Bool
connHTTP2 Connection
conn)

setConnHTTP2 :: Connection -> Bool -> IO ()
setConnHTTP2 :: Connection -> Bool -> IO ()
setConnHTTP2 Connection
conn Bool
b = forall a. IORef a -> a -> IO ()
writeIORef (Connection -> IORef Bool
connHTTP2 Connection
conn) Bool
b

----------------------------------------------------------------

data InternalInfo = InternalInfo {
    InternalInfo -> Manager
timeoutManager :: T.Manager
  , InternalInfo -> Recv
getDate        :: IO D.GMTDate
  , InternalInfo -> String -> IO (Maybe Fd, IO ())
getFd          :: FilePath -> IO (Maybe F.Fd, F.Refresh)
  , InternalInfo -> String -> IO FileInfo
getFileInfo    :: FilePath -> IO I.FileInfo
  }

----------------------------------------------------------------

-- | Type for input streaming.
data Source = Source !(IORef ByteString) !(IO ByteString)

mkSource :: IO ByteString -> IO Source
mkSource :: Recv -> IO Source
mkSource Recv
func = do
    IORef ByteString
ref <- forall a. a -> IO (IORef a)
newIORef ByteString
S.empty
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! IORef ByteString -> Recv -> Source
Source IORef ByteString
ref Recv
func

readSource :: Source -> IO ByteString
readSource :: Source -> Recv
readSource (Source IORef ByteString
ref Recv
func) = do
    ByteString
bs <- forall a. IORef a -> IO a
readIORef IORef ByteString
ref
    if ByteString -> Bool
S.null ByteString
bs
        then Recv
func
        else do
            forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
ref ByteString
S.empty
            forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

-- | Read from a Source, ignoring any leftovers.
readSource' :: Source -> IO ByteString
readSource' :: Source -> Recv
readSource' (Source IORef ByteString
_ Recv
func) = Recv
func

leftoverSource :: Source -> ByteString -> IO ()
leftoverSource :: Source -> ByteString -> IO ()
leftoverSource (Source IORef ByteString
ref Recv
_) ByteString
bs = forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
ref ByteString
bs

readLeftoverSource :: Source -> IO ByteString
readLeftoverSource :: Source -> Recv
readLeftoverSource (Source IORef ByteString
ref Recv
_) = forall a. IORef a -> IO a
readIORef IORef ByteString
ref

----------------------------------------------------------------

-- | What kind of transport is used for this connection?
data Transport = TCP -- ^ Plain channel: TCP
               | TLS {
                   Transport -> Int
tlsMajorVersion :: Int
                 , Transport -> Int
tlsMinorVersion :: Int
                 , Transport -> Maybe ByteString
tlsNegotiatedProtocol :: Maybe ByteString -- ^ The result of Application Layer Protocol Negociation in RFC 7301
                 , Transport -> Word16
tlsChiperID :: Word16
#ifdef MIN_VERSION_crypton_x509
                 , Transport -> Maybe CertificateChain
tlsClientCertificate :: Maybe CertificateChain
#endif
                 }  -- ^ Encrypted channel: TLS or SSL
               | QUIC {
                   Transport -> Maybe ByteString
quicNegotiatedProtocol :: Maybe ByteString
                 , Transport -> Word16
quicChiperID :: Word16
#ifdef MIN_VERSION_crypton_x509
                 , Transport -> Maybe CertificateChain
quicClientCertificate :: Maybe CertificateChain
#endif
                 }

isTransportSecure :: Transport -> Bool
isTransportSecure :: Transport -> Bool
isTransportSecure Transport
TCP = Bool
False
isTransportSecure Transport
_   = Bool
True

isTransportQUIC :: Transport -> Bool
isTransportQUIC :: Transport -> Bool
isTransportQUIC QUIC{} = Bool
True
isTransportQUIC Transport
_      = Bool
False

#ifdef MIN_VERSION_crypton_x509
getTransportClientCertificate :: Transport -> Maybe CertificateChain
getTransportClientCertificate :: Transport -> Maybe CertificateChain
getTransportClientCertificate Transport
TCP              = forall a. Maybe a
Nothing
getTransportClientCertificate (TLS Int
_ Int
_ Maybe ByteString
_ Word16
_ Maybe CertificateChain
cc) = Maybe CertificateChain
cc
getTransportClientCertificate (QUIC Maybe ByteString
_ Word16
_ Maybe CertificateChain
cc)    = Maybe CertificateChain
cc
#endif