{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module Network.Wai.Handler.Warp.Types where

import Control.Lens
import qualified Data.ByteString as S
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.Typeable (Typeable)
import qualified Control.Exception as E
#ifdef MIN_VERSION_crypton_x509
import Data.X509
#endif
import Network.Socket (SockAddr)
import Network.Socket.BufferPool
import System.Posix.Types (Fd)
import qualified System.TimeManager as T

import qualified Network.Wai.Handler.Warp.Date as D
import qualified Network.Wai.Handler.Warp.FdCache as F
import qualified Network.Wai.Handler.Warp.FileInfoCache as I
import Network.Wai.Handler.Warp.Imports

----------------------------------------------------------------

-- | TCP port number.
type Port = Int

----------------------------------------------------------------

-- | The type for header value used with 'HeaderName'.
type HeaderValue = ByteString

----------------------------------------------------------------

-- | Error types for bad 'Request'.
data InvalidRequest
    = NotEnoughLines [String]
    | BadFirstLine String
    | NonHttp
    | IncompleteHeaders
    | ConnectionClosedByPeer
    | OverLargeHeader
    | BadProxyHeader String
    | -- | Since 3.3.22
      PayloadTooLarge
    | -- | Since 3.3.22
      RequestHeaderFieldsTooLarge
    deriving (Eq, Typeable)

class HasInvalidRequest a where
    invalidRequest :: Lens' a InvalidRequest

instance HasInvalidRequest InvalidRequest where
    invalidRequest = id

class AsInvalidRequest a where
    _InvalidRequest :: Prism' a InvalidRequest
    _NotEnoughLines :: Prism' a [String]
    _NotEnoughLines = _InvalidRequest . _NotEnoughLines
    _BadFirstLine :: Prism' a String
    _BadFirstLine = _InvalidRequest . _BadFirstLine
    _NonHttp :: Prism' a ()
    _NonHttp = _InvalidRequest . _NonHttp
    _IncompleteHeaders :: Prism' a ()
    _IncompleteHeaders = _InvalidRequest . _IncompleteHeaders
    _ConnectionClosedByPeer :: Prism' a ()
    _ConnectionClosedByPeer = _InvalidRequest . _ConnectionClosedByPeer
    _OverLargeHeader :: Prism' a ()
    _OverLargeHeader = _InvalidRequest . _OverLargeHeader
    _BadProxyHeader :: Prism' a String
    _BadProxyHeader = _InvalidRequest . _BadProxyHeader
    _PayloadTooLarge :: Prism' a ()
    _PayloadTooLarge = _InvalidRequest . _PayloadTooLarge
    _RequestHeaderFieldsTooLarge :: Prism' a ()
    _RequestHeaderFieldsTooLarge = _InvalidRequest . _RequestHeaderFieldsTooLarge

instance AsInvalidRequest InvalidRequest where
    _InvalidRequest = id
    _NotEnoughLines =
        prism'
            NotEnoughLines
            (\case
                NotEnoughLines x -> Just x
                _ -> Nothing)
    _BadFirstLine =
        prism'
            BadFirstLine
            (\case
                BadFirstLine x -> Just x
                _ -> Nothing)
    _NonHttp =
        prism'
            (\() -> NonHttp)
            (\case
                NonHttp -> Just ()
                _ -> Nothing)
    _IncompleteHeaders =
        prism'
            (\() -> IncompleteHeaders)
            (\case
                IncompleteHeaders -> Just ()
                _ -> Nothing)
    _ConnectionClosedByPeer =
        prism'
            (\() -> ConnectionClosedByPeer)
            (\case
                ConnectionClosedByPeer -> Just ()
                _ -> Nothing)
    _OverLargeHeader =
        prism'
            (\() -> OverLargeHeader)
            (\case
                OverLargeHeader -> Just ()
                _ -> Nothing)
    _BadProxyHeader =
        prism'
            BadProxyHeader
            (\case
                BadProxyHeader x -> Just x
                _ -> Nothing)
    _PayloadTooLarge =
        prism'
            (\() -> PayloadTooLarge)
            (\case
                PayloadTooLarge -> Just ()
                _ -> Nothing)
    _RequestHeaderFieldsTooLarge =
        prism'
            (\() -> RequestHeaderFieldsTooLarge)
            (\case
                RequestHeaderFieldsTooLarge -> Just ()
                _ -> Nothing)

instance Show InvalidRequest where
    show (NotEnoughLines xs) = "Warp: Incomplete request headers, received: " ++ show xs
    show (BadFirstLine s) = "Warp: Invalid first line of request: " ++ show s
    show NonHttp = "Warp: Request line specified a non-HTTP request"
    show IncompleteHeaders = "Warp: Request headers did not finish transmission"
    show ConnectionClosedByPeer = "Warp: Client closed connection prematurely"
    show OverLargeHeader =
        "Warp: Request headers too large, possible memory attack detected. Closing connection."
    show (BadProxyHeader s) = "Warp: Invalid PROXY protocol header: " ++ show s
    show RequestHeaderFieldsTooLarge = "Request header fields too large"
    show PayloadTooLarge = "Payload too large"

instance E.Exception InvalidRequest

----------------------------------------------------------------

-- | Exception thrown if something goes wrong while in the midst of
-- sending a response, since the status code can't be altered at that
-- point.
--
-- Used to determine whether keeping the HTTP1.1 connection / HTTP2 stream alive is safe
-- or irrecoverable.
newtype ExceptionInsideResponseBody = ExceptionInsideResponseBody E.SomeException
    deriving (Show, Typeable)

instance (ExceptionInsideResponseBody ~ a) =>
  Rewrapped ExceptionInsideResponseBody a

instance Wrapped ExceptionInsideResponseBody where
  type Unwrapped ExceptionInsideResponseBody =
    E.SomeException
  _Wrapped' =
    iso (\(ExceptionInsideResponseBody x) -> x) ExceptionInsideResponseBody

instance E.Exception ExceptionInsideResponseBody

----------------------------------------------------------------

-- | Data type to abstract file identifiers.
--   On Unix, a file descriptor would be specified to make use of
--   the file descriptor cache.
--
-- Since: 3.1.0
data FileId = FileId
    { fileIdPath :: FilePath
    , fileIdFd :: Maybe Fd
    }

class HasFileId a where
    fileId :: Lens' a FileId
    fileIdPathL :: Lens' a FilePath
    fileIdPathL = fileId . fileIdPathL
    fileIdFdL :: Lens' a (Maybe Fd)
    fileIdFdL = fileId . fileIdFdL

instance HasFileId FileId where
    fileId = id
    fileIdPathL f (FileId p d) =
        fmap (\p' -> FileId p' d) (f p)
    fileIdFdL f (FileId p d) =
        fmap (FileId p) (f d)

class AsFileId a where
    _FileId :: Prism' a FileId

instance AsFileId FileId where
    _FileId = id

-- |  fileid, offset, length, hook action, HTTP headers
--
-- Since: 3.1.0
type SendFile = FileId -> Integer -> Integer -> IO () -> [ByteString] -> IO ()

-- | A write buffer of a specified size
-- containing bytes and a way to free the buffer.
data WriteBuffer = WriteBuffer
    { bufBuffer :: Buffer
    , bufSize :: !BufSize
    -- ^ The size of the write buffer.
    , bufFree :: IO ()
    -- ^ Free the allocated buffer. Warp guarantees it will only be
    -- called once, and no other functions will be called after it.
    }

class HasWriteBuffer a where
    writeBuffer :: Lens' a WriteBuffer
    bufBufferL :: Lens' a Buffer
    bufBufferL = writeBuffer . bufBufferL
    bufSizeL :: Lens' a BufSize
    bufSizeL = writeBuffer . bufSizeL
    bufFreeL :: Lens' a (IO ())
    bufFreeL = writeBuffer . bufFreeL

instance HasWriteBuffer WriteBuffer where
    writeBuffer = id
    bufBufferL f (WriteBuffer b z i) =
        fmap (\b' -> WriteBuffer b' z i) (f b)
    bufSizeL f (WriteBuffer b z i) =
        fmap (\z' -> WriteBuffer b z' i) (f z)
    bufFreeL f (WriteBuffer b z i) =
        fmap (WriteBuffer b z) (f i)

class AsWriteBuffer a where
    _WriteBuffer :: Prism' a WriteBuffer

instance AsWriteBuffer WriteBuffer where
    _WriteBuffer = id

type RecvBuf = Buffer -> BufSize -> IO Bool

-- | Data type to manipulate IO actions for connections.
--   This is used to abstract IO actions for plain HTTP and HTTP over TLS.
data Connection = Connection
    { connSendMany :: [ByteString] -> IO ()
    -- ^ This is not used at this moment.
    , connSendAll :: ByteString -> IO ()
    -- ^ The sending function.
    , connSendFile :: SendFile
    -- ^ The sending function for files in HTTP/1.1.
    , connClose :: IO ()
    -- ^ The connection closing function. Warp guarantees it will only be
    -- called once. Other functions (like 'connRecv') may be called after
    -- 'connClose' is called.
    , connRecv :: Recv
    -- ^ The connection receiving function. This returns "" for EOF or exceptions.
    , connRecvBuf :: RecvBuf
    -- ^ Obsoleted.
    , connWriteBuffer :: IORef WriteBuffer
    -- ^ Reference to a write buffer. When during sending of a 'Builder'
    -- response it's detected the current 'WriteBuffer' is too small it will be
    -- freed and a new bigger buffer will be created and written to this
    -- reference.
    , connHTTP2 :: IORef Bool
    -- ^ Is this connection HTTP/2?
    , connMySockAddr :: SockAddr
    }

class HasConnection a where
    connection :: Lens' a Connection
    connSendManyL :: Lens' a ([ByteString] -> IO ())
    connSendAllL :: Lens' a (ByteString -> IO ())
    connSendFileL :: Lens' a SendFile
    connCloseL :: Lens' a (IO ())
    connRecvL :: Lens' a Recv
    connRecvBufL :: Lens' a RecvBuf
    connWriteBufferL :: Lens' a (IORef WriteBuffer)
    connHTTP2L :: Lens' a (IORef Bool)
    connMySockAddrL :: Lens' a SockAddr

instance HasConnection Connection where
    connection = id
    connSendManyL f (Connection sm sa sf cl rc rb wb ht sd) =
        fmap (\sm' -> Connection sm' sa sf cl rc rb wb ht sd) (f sm)
    connSendAllL f (Connection sm sa sf cl rc rb wb ht sd) =
        fmap (\sa' -> Connection sm sa' sf cl rc rb wb ht sd) (f sa)
    connSendFileL f (Connection sm sa sf cl rc rb wb ht sd) =
        fmap (\sf' -> Connection sm sa sf' cl rc rb wb ht sd) (f sf)
    connCloseL f (Connection sm sa sf cl rc rb wb ht sd) =
        fmap (\cl' -> Connection sm sa sf cl' rc rb wb ht sd) (f cl)
    connRecvL f (Connection sm sa sf cl rc rb wb ht sd) =
        fmap (\rc' -> Connection sm sa sf cl rc' rb wb ht sd) (f rc)
    connRecvBufL f (Connection sm sa sf cl rc rb wb ht sd) =
        fmap (\rb' -> Connection sm sa sf cl rc rb' wb ht sd) (f rb)
    connWriteBufferL f (Connection sm sa sf cl rc rb wb ht sd) =
        fmap (\wb' -> Connection sm sa sf cl rc rb wb' ht sd) (f wb)
    connHTTP2L f (Connection sm sa sf cl rc rb wb ht sd) =
        fmap (\ht' -> Connection sm sa sf cl rc rb wb ht' sd) (f ht)
    connMySockAddrL f (Connection sm sa sf cl rc rb wb ht sd) =
        fmap (Connection sm sa sf cl rc rb wb ht) (f sd)

class AsConnection a where
    _Connection :: Prism' a Connection

instance AsConnection Connection where
    _Connection = id

getConnHTTP2 :: Connection -> IO Bool
getConnHTTP2 = readIORef . connHTTP2

setConnHTTP2 :: Connection -> Bool -> IO ()
setConnHTTP2 = writeIORef . connHTTP2

----------------------------------------------------------------

data InternalInfo = InternalInfo
    { timeoutManager :: T.Manager
    , getDate :: IO D.GMTDate
    , getFd :: FilePath -> IO (Maybe F.Fd, F.Refresh)
    , getFileInfo :: FilePath -> IO I.FileInfo
    }

----------------------------------------------------------------

-- | Type for input streaming.
data Source = Source !(IORef ByteString) !(IO ByteString)

mkSource :: IO ByteString -> IO Source
mkSource func = do
    ref <- newIORef S.empty
    return $! Source ref func

readSource :: Source -> IO ByteString
readSource (Source ref func) = do
    bs <- readIORef ref
    if S.null bs
        then func
        else do
            writeIORef ref S.empty
            return bs

-- | Read from a Source, ignoring any leftovers.
readSource' :: Source -> IO ByteString
readSource' (Source _ func) = func

leftoverSource :: Source -> ByteString -> IO ()
leftoverSource (Source ref _) = writeIORef ref

readLeftoverSource :: Source -> IO ByteString
readLeftoverSource (Source ref _) = readIORef ref

----------------------------------------------------------------

-- | What kind of transport is used for this connection?
data Transport
    = -- | Plain channel: TCP
      TCP
    | TLS
        { tlsMajorVersion :: Int
        , tlsMinorVersion :: Int
        , tlsNegotiatedProtocol :: Maybe ByteString
        -- ^ The result of Application Layer Protocol Negociation in RFC 7301
        , tlsChiperID :: Word16
        -- ^ Encrypted channel: TLS or SSL
#ifdef MIN_VERSION_crypton_x509
        , tlsClientCertificate :: Maybe CertificateChain
#endif
        }
    | QUIC
        { quicNegotiatedProtocol :: Maybe ByteString
        , quicChiperID :: Word16
#ifdef MIN_VERSION_crypton_x509
        , quicClientCertificate :: Maybe CertificateChain
#endif
        }

isTransportSecure :: Transport -> Bool
isTransportSecure TCP = False
isTransportSecure _ = True

isTransportQUIC :: Transport -> Bool
isTransportQUIC QUIC{} = True
isTransportQUIC _ = False

#ifdef MIN_VERSION_crypton_x509
getTransportClientCertificate :: Transport -> Maybe CertificateChain
getTransportClientCertificate TCP              = Nothing
getTransportClientCertificate (TLS _ _ _ _ cc) = cc
getTransportClientCertificate (QUIC _ _ cc)    = cc
#endif
