{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{- |
Sometimes incoming requests don't stick to the
"no duplicate headers" invariant, for a number
of possible reasons (e.g. proxy servers blindly
adding headers), or your application (or other
middleware) blindly adds headers.

In those cases, you can use this 'Middleware'
to make sure that headers that /can/ be combined
/are/ combined. (e.g. applications might only
check the first \"Accept\" header and fail, while
there might be another one that would match)
    -}
module Network.Wai.Middleware.CombineHeaders
    ( combineHeaders
    , CombineSettings
    , defaultCombineSettings
    , HeaderMap
    , HandleType
    , defaultHeaderMap
    -- * Adjusting the settings
    , setHeader
    , removeHeader
    , setHeaderMap
    , regular
    , keepOnly
    , setRequestHeaders
    , setResponseHeaders
    ) where

import qualified Data.ByteString as B
import qualified Data.List as L (foldl', reverse)
import qualified Data.Map.Strict as M
import Data.Word8 (_comma, _space, _tab)
import Network.HTTP.Types (Header, HeaderName, RequestHeaders)
import qualified Network.HTTP.Types.Header as H
import Network.Wai (Middleware, requestHeaders, mapResponseHeaders)
import Network.Wai.Util (dropWhileEnd)

-- | The mapping of 'HeaderName' to 'HandleType'
type HeaderMap = M.Map HeaderName HandleType

-- | These settings define which headers should be combined,
-- if the combining should happen on incoming (request) headers
-- and if it should happen on outgoing (response) headers.
--
-- Any header you put in the header map *will* be used to
-- combine those headers with commas. There's no check to see
-- if it is a header that allows comma-separated lists, so if
-- you want to combine custom headers, go ahead.
--
-- (You can check the documentation of 'defaultCombineSettings'
-- to see which standard headers are specified to be able to be
-- combined)
--
-- @since 3.1.13.0
data CombineSettings = CombineSettings {
    CombineSettings -> HeaderMap
combineHeaderMap :: HeaderMap,
    -- ^ Which headers should be combined? And how? (cf. 'HandleType')
    CombineSettings -> Bool
combineRequestHeaders :: Bool,
    -- ^ Should request headers be combined?
    CombineSettings -> Bool
combineResponseHeaders :: Bool
    -- ^ Should response headers be combined?
} deriving (CombineSettings -> CombineSettings -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CombineSettings -> CombineSettings -> Bool
$c/= :: CombineSettings -> CombineSettings -> Bool
== :: CombineSettings -> CombineSettings -> Bool
$c== :: CombineSettings -> CombineSettings -> Bool
Eq, Int -> CombineSettings -> ShowS
[CombineSettings] -> ShowS
CombineSettings -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CombineSettings] -> ShowS
$cshowList :: [CombineSettings] -> ShowS
show :: CombineSettings -> String
$cshow :: CombineSettings -> String
showsPrec :: Int -> CombineSettings -> ShowS
$cshowsPrec :: Int -> CombineSettings -> ShowS
Show)

-- | Settings that combine request headers,
-- but don't touch response headers.
--
-- All types of headers that /can/ be combined
-- (as defined in the spec) /will/ be combined.
--
-- To be exact, this is the list:
--
-- * Accept
-- * Accept-CH
-- * Accept-Charset
-- * Accept-Encoding
-- * Accept-Language
-- * Accept-Post
-- * Access-Control-Allow-Headers
-- * Access-Control-Allow-Methods
-- * Access-Control-Expose-Headers
-- * Access-Control-Request-Headers
-- * Allow
-- * Alt-Svc @(KeepOnly \"clear\"")@
-- * Cache-Control
-- * Clear-Site-Data @(KeepOnly \"*\")@
-- * Connection
-- * Content-Encoding
-- * Content-Language
-- * Digest
-- * If-Match
-- * If-None-Match @(KeepOnly \"*\")@
-- * Link
-- * Permissions-Policy
-- * TE
-- * Timing-Allow-Origin @(KeepOnly \"*\")@
-- * Trailer
-- * Transfer-Encoding
-- * Upgrade
-- * Via
-- * Vary @(KeepOnly \"*\")@
-- * Want-Digest
--
-- N.B. Any header name that has \"KeepOnly\" after it
-- will be combined like normal, unless one of the values
-- is the one mentioned (\"*\" most of the time), then
-- that value is used and all others are dropped.
--
-- @since 3.1.13.0
defaultCombineSettings :: CombineSettings
defaultCombineSettings :: CombineSettings
defaultCombineSettings = CombineSettings {
    combineHeaderMap :: HeaderMap
combineHeaderMap = HeaderMap
defaultHeaderMap,
    combineRequestHeaders :: Bool
combineRequestHeaders = Bool
True,
    combineResponseHeaders :: Bool
combineResponseHeaders = Bool
False
}

-- | Override the 'HeaderMap' of the 'CombineSettings'
--  (default: 'defaultHeaderMap')
--
-- @since 3.1.13.0
setHeaderMap :: HeaderMap -> CombineSettings -> CombineSettings
setHeaderMap :: HeaderMap -> CombineSettings -> CombineSettings
setHeaderMap HeaderMap
mp CombineSettings
set = CombineSettings
set{combineHeaderMap :: HeaderMap
combineHeaderMap = HeaderMap
mp}

-- | Set whether the combining of headers should be applied to
-- the incoming request headers. (default: True)
--
-- @since 3.1.13.0
setRequestHeaders :: Bool -> CombineSettings -> CombineSettings
setRequestHeaders :: Bool -> CombineSettings -> CombineSettings
setRequestHeaders Bool
b CombineSettings
set = CombineSettings
set{combineRequestHeaders :: Bool
combineRequestHeaders = Bool
b}

-- | Set whether the combining of headers should be applied to
-- the outgoing response headers. (default: False)
--
-- @since 3.1.13.0
setResponseHeaders :: Bool -> CombineSettings -> CombineSettings
setResponseHeaders :: Bool -> CombineSettings -> CombineSettings
setResponseHeaders Bool
b CombineSettings
set = CombineSettings
set{combineResponseHeaders :: Bool
combineResponseHeaders = Bool
b}

-- | Convenience function to add a header to the header map or,
-- if it is already in the map, to change the 'HandleType'.
--
-- @since 3.1.13.0
setHeader :: HeaderName -> HandleType -> CombineSettings -> CombineSettings
setHeader :: HeaderName -> HandleType -> CombineSettings -> CombineSettings
setHeader HeaderName
name HandleType
typ CombineSettings
settings =
    CombineSettings
settings {
        combineHeaderMap :: HeaderMap
combineHeaderMap = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert HeaderName
name HandleType
typ forall a b. (a -> b) -> a -> b
$ CombineSettings -> HeaderMap
combineHeaderMap CombineSettings
settings
    }

-- | Convenience function to remove a header from the header map.
--
-- @since 3.1.13.0
removeHeader :: HeaderName -> CombineSettings -> CombineSettings
removeHeader :: HeaderName -> CombineSettings -> CombineSettings
removeHeader HeaderName
name CombineSettings
settings =
    CombineSettings
settings {
        combineHeaderMap :: HeaderMap
combineHeaderMap = forall k a. Ord k => k -> Map k a -> Map k a
M.delete HeaderName
name forall a b. (a -> b) -> a -> b
$ CombineSettings -> HeaderMap
combineHeaderMap CombineSettings
settings
    }

-- | This middleware will reorganize the incoming and/or outgoing
-- headers in such a way that it combines any duplicates of
-- headers that, on their own, can normally have more than one
-- value, and any other headers will stay untouched.
--
-- This middleware WILL change the global order of headers
-- (they will be put in alphabetical order), but keep the
-- order of the same type of header. I.e. if there are 3
-- \"Set-Cookie\" headers, the first one will still be first,
-- the second one will still be second, etc. But now they are
-- guaranteed to be next to each other.
--
-- N.B. This 'Middleware' assumes the headers it combines
-- are correctly formatted. If one of the to-be-combined
-- headers is malformed, the new combined header will also
-- (probably) be malformed.
--
-- @since 3.1.13.0
combineHeaders :: CombineSettings -> Middleware
combineHeaders :: CombineSettings -> Middleware
combineHeaders CombineSettings{Bool
HeaderMap
combineResponseHeaders :: Bool
combineRequestHeaders :: Bool
combineHeaderMap :: HeaderMap
combineResponseHeaders :: CombineSettings -> Bool
combineRequestHeaders :: CombineSettings -> Bool
combineHeaderMap :: CombineSettings -> HeaderMap
..} Application
app Request
req Response -> IO ResponseReceived
resFunc =
    Application
app Request
newReq forall a b. (a -> b) -> a -> b
$ Response -> IO ResponseReceived
resFunc forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> Response
adjustRes
  where
    newReq :: Request
newReq
        | Bool
combineRequestHeaders = Request
req { requestHeaders :: RequestHeaders
requestHeaders = RequestHeaders -> RequestHeaders
mkNewHeaders RequestHeaders
oldHeaders }
        | Bool
otherwise = Request
req
    oldHeaders :: RequestHeaders
oldHeaders = Request -> RequestHeaders
requestHeaders Request
req
    adjustRes :: Response -> Response
adjustRes
        | Bool
combineResponseHeaders = (RequestHeaders -> RequestHeaders) -> Response -> Response
mapResponseHeaders RequestHeaders -> RequestHeaders
mkNewHeaders
        | Bool
otherwise = forall a. a -> a
id
    mkNewHeaders :: RequestHeaders -> RequestHeaders
mkNewHeaders =
        forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
M.foldrWithKey' HeaderName -> HeaderHandling -> RequestHeaders -> RequestHeaders
finishHeaders [] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Map HeaderName HeaderHandling
-> (HeaderName, ByteString) -> Map HeaderName HeaderHandling
go forall a. Monoid a => a
mempty
    go :: Map HeaderName HeaderHandling
-> (HeaderName, ByteString) -> Map HeaderName HeaderHandling
go Map HeaderName HeaderHandling
acc hdr :: (HeaderName, ByteString)
hdr@(HeaderName
name, ByteString
_) =
        forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
M.alter ((HeaderName, ByteString)
-> Maybe HeaderHandling -> Maybe HeaderHandling
checkHeader (HeaderName, ByteString)
hdr) HeaderName
name Map HeaderName HeaderHandling
acc
    checkHeader :: Header -> Maybe HeaderHandling -> Maybe HeaderHandling
    checkHeader :: (HeaderName, ByteString)
-> Maybe HeaderHandling -> Maybe HeaderHandling
checkHeader (HeaderName
name, ByteString
newVal) = forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
        Maybe HeaderHandling
Nothing -> (HeaderName
name forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` HeaderMap
combineHeaderMap, [ByteString
newVal])
        -- Yes, this reverses the order of headers, but these
        -- will be reversed again in 'finishHeaders'
        Just (Maybe HandleType
mHandleType, [ByteString]
hdrs) -> (Maybe HandleType
mHandleType, ByteString
newVal forall a. a -> [a] -> [a]
: [ByteString]
hdrs)

-- | Unpack 'HeaderHandling' back into 'Header's again
finishHeaders :: HeaderName -> HeaderHandling -> RequestHeaders -> RequestHeaders
finishHeaders :: HeaderName -> HeaderHandling -> RequestHeaders -> RequestHeaders
finishHeaders HeaderName
name (Maybe HandleType
shouldCombine, [ByteString]
xs) RequestHeaders
hdrs =
    case Maybe HandleType
shouldCombine of
        Just HandleType
typ -> (HeaderName
name, HandleType -> ByteString
combinedHeader HandleType
typ) forall a. a -> [a] -> [a]
: RequestHeaders
hdrs
        Maybe HandleType
Nothing ->
            -- Yes, this reverses the headers, but they
            -- were already reversed by 'checkHeader'
            forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (\RequestHeaders
acc ByteString
el -> (HeaderName
name, ByteString
el) forall a. a -> [a] -> [a]
: RequestHeaders
acc) RequestHeaders
hdrs [ByteString]
xs
  where
    combinedHeader :: HandleType -> ByteString
combinedHeader HandleType
Regular = [ByteString] -> ByteString
combineHdrs [ByteString]
xs
    combinedHeader (KeepOnly ByteString
val)
        | ByteString
val forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
xs = ByteString
val
        | Bool
otherwise = [ByteString] -> ByteString
combineHdrs [ByteString]
xs
    -- headers were reversed, so do 'reverse' before combining
    combineHdrs :: [ByteString] -> ByteString
combineHdrs = ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
", " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
clean forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
L.reverse
    clean :: ByteString -> ByteString
clean = (Word8 -> Bool) -> ByteString -> ByteString
dropWhileEnd forall a b. (a -> b) -> a -> b
$ \Word8
w -> Word8
w forall a. Eq a => a -> a -> Bool
== Word8
_comma Bool -> Bool -> Bool
|| Word8
w forall a. Eq a => a -> a -> Bool
== Word8
_space Bool -> Bool -> Bool
|| Word8
w forall a. Eq a => a -> a -> Bool
== Word8
_tab

type HeaderHandling = (Maybe HandleType, [B.ByteString])

-- | Both will concatenate with @,@ (commas), but 'KeepOnly' will drop all
-- values except the given one if present (e.g. in case of wildcards/special values)
--
-- For example: If there are multiple @"Clear-Site-Data"@ headers, but one of
-- them is the wildcard @\"*\"@ value, using @'KeepOnly' "*"@ will cause all
-- others to be dropped and only the wildcard value to remain.
-- (The @\"*\"@ wildcard in this case means /ALL site data/ should be cleared,
-- so no need to include more)
--
-- @since 3.1.13.0
data HandleType
    = Regular
    | KeepOnly B.ByteString
   deriving (HandleType -> HandleType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HandleType -> HandleType -> Bool
$c/= :: HandleType -> HandleType -> Bool
== :: HandleType -> HandleType -> Bool
$c== :: HandleType -> HandleType -> Bool
Eq, Int -> HandleType -> ShowS
[HandleType] -> ShowS
HandleType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HandleType] -> ShowS
$cshowList :: [HandleType] -> ShowS
show :: HandleType -> String
$cshow :: HandleType -> String
showsPrec :: Int -> HandleType -> ShowS
$cshowsPrec :: Int -> HandleType -> ShowS
Show)

-- | Use the regular strategy when combining headers.
-- (i.e. merge into one header and separate values with commas)
--
-- @since 3.1.13.0
regular :: HandleType
regular :: HandleType
regular = HandleType
Regular

-- | Use the regular strategy when combining headers,
-- but if the exact supplied 'ByteString' is encountered
-- then discard all other values and only keep that value.
--
-- e.g. @keepOnly "*"@ will drop all other encountered values
--
-- @since 3.1.13.0
keepOnly :: B.ByteString -> HandleType
keepOnly :: ByteString -> HandleType
keepOnly = ByteString -> HandleType
KeepOnly

-- | The default collection of HTTP headers that can be combined
-- in case there are multiples in one request or response.
--
-- See the documentation of 'defaultCombineSettings' for the exact list.
--
-- @since 3.1.13.0
defaultHeaderMap :: HeaderMap
defaultHeaderMap :: HeaderMap
defaultHeaderMap = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
    [ (HeaderName
H.hAccept, HandleType
Regular)
    , (HeaderName
"Accept-CH", HandleType
Regular)
    , (HeaderName
H.hAcceptCharset, HandleType
Regular)
    , (HeaderName
H.hAcceptEncoding, HandleType
Regular)
    , (HeaderName
H.hAcceptLanguage, HandleType
Regular)
    , (HeaderName
"Accept-Post", HandleType
Regular)
    , (HeaderName
"Access-Control-Allow-Headers" , HandleType
Regular) -- wildcard? yes, but can just add to list
    , (HeaderName
"Access-Control-Allow-Methods" , HandleType
Regular) -- wildcard? yes, but can just add to list
    , (HeaderName
"Access-Control-Expose-Headers" , HandleType
Regular) -- wildcard? yes, but can just add to list
    , (HeaderName
"Access-Control-Request-Headers", HandleType
Regular)
    , (HeaderName
H.hAllow, HandleType
Regular)
    , (HeaderName
"Alt-Svc", ByteString -> HandleType
KeepOnly ByteString
"clear") -- special "clear" value (if any is "clear", only keep that one)
    , (HeaderName
H.hCacheControl, HandleType
Regular)
    , (HeaderName
"Clear-Site-Data", ByteString -> HandleType
KeepOnly ByteString
"*") -- wildcard (if any is "*", only keep that one)

    -- If "close" and anything else is used together, it's already F-ed,
    -- so just combine them.
    , (HeaderName
H.hConnection, HandleType
Regular)

    , (HeaderName
H.hContentEncoding, HandleType
Regular)
    , (HeaderName
H.hContentLanguage, HandleType
Regular)
    , (HeaderName
"Digest", HandleType
Regular)

    -- We could handle this, but it's experimental AND
    -- will be replaced by "Permissions-Policy"
    -- , "Feature-Policy" -- "semicolon ';' separated"

    , (HeaderName
H.hIfMatch, HandleType
Regular)
    , (HeaderName
H.hIfNoneMatch, ByteString -> HandleType
KeepOnly ByteString
"*") -- wildcard? (if any is "*", only keep that one)
    , (HeaderName
"Link", HandleType
Regular)
    , (HeaderName
"Permissions-Policy", HandleType
Regular)
    , (HeaderName
H.hTE, HandleType
Regular)
    , (HeaderName
"Timing-Allow-Origin", ByteString -> HandleType
KeepOnly ByteString
"*") -- wildcard? (if any is "*", only keep that one)
    , (HeaderName
H.hTrailer, HandleType
Regular)
    , (HeaderName
H.hTransferEncoding, HandleType
Regular)
    , (HeaderName
H.hUpgrade, HandleType
Regular)
    , (HeaderName
H.hVia, HandleType
Regular)
    , (HeaderName
H.hVary, ByteString -> HandleType
KeepOnly ByteString
"*") -- wildcard? (if any is "*", only keep that one)
    , (HeaderName
"Want-Digest", HandleType
Regular)
    ]