-- This was written for one specific use case and then generalized.

-- The specific use case was a JSON API with a consumer that would choke on the
-- "Set-Cookie" response header. The solution was to test for the API's
-- `pathInfo` in the Request and if it matched, filter the response headers.

-- When using this, care should be taken not to strip out headers that are
-- required for correct operation of the client (eg Content-Type).

module Network.Wai.Middleware.StripHeaders
    ( stripHeader
    , stripHeaders
    , stripHeaderIf
    , stripHeadersIf
    ) where

import Data.ByteString (ByteString)
import qualified Data.CaseInsensitive as CI
import Network.Wai (Middleware, Request, modifyResponse, mapResponseHeaders, ifRequest)
import Network.Wai.Internal (Response)


stripHeader :: ByteString -> (Response -> Response)
stripHeader :: ByteString -> Response -> Response
stripHeader ByteString
h = (ResponseHeaders -> ResponseHeaders) -> Response -> Response
mapResponseHeaders (((CI ByteString, ByteString) -> Bool)
-> ResponseHeaders -> ResponseHeaders
forall a. (a -> Bool) -> [a] -> [a]
filter (\ (CI ByteString, ByteString)
hdr -> (CI ByteString, ByteString) -> CI ByteString
forall a b. (a, b) -> a
fst (CI ByteString, ByteString)
hdr CI ByteString -> CI ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
CI.mk ByteString
h))

stripHeaders :: [ByteString] -> (Response -> Response)
stripHeaders :: [ByteString] -> Response -> Response
stripHeaders [ByteString]
hs =
  let hnames :: [CI ByteString]
hnames = (ByteString -> CI ByteString) -> [ByteString] -> [CI ByteString]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
CI.mk [ByteString]
hs
  in (ResponseHeaders -> ResponseHeaders) -> Response -> Response
mapResponseHeaders (((CI ByteString, ByteString) -> Bool)
-> ResponseHeaders -> ResponseHeaders
forall a. (a -> Bool) -> [a] -> [a]
filter (\ (CI ByteString, ByteString)
hdr -> (CI ByteString, ByteString) -> CI ByteString
forall a b. (a, b) -> a
fst (CI ByteString, ByteString)
hdr CI ByteString -> [CI ByteString] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [CI ByteString]
hnames))

-- | If the request satisifes the provided predicate, strip headers matching
-- the provided header name.
--
-- Since 3.0.8
stripHeaderIf :: ByteString -> (Request -> Bool) -> Middleware
stripHeaderIf :: ByteString -> (Request -> Bool) -> Middleware
stripHeaderIf ByteString
h Request -> Bool
rpred =
  (Request -> Bool) -> Middleware -> Middleware
ifRequest Request -> Bool
rpred ((Response -> Response) -> Middleware
modifyResponse ((Response -> Response) -> Middleware)
-> (Response -> Response) -> Middleware
forall a b. (a -> b) -> a -> b
$ ByteString -> Response -> Response
stripHeader ByteString
h)

-- | If the request satisifes the provided predicate, strip all headers whose
-- header name is in the list of provided header names.
--
-- Since 3.0.8
stripHeadersIf :: [ByteString] -> (Request -> Bool) -> Middleware
stripHeadersIf :: [ByteString] -> (Request -> Bool) -> Middleware
stripHeadersIf [ByteString]
hs Request -> Bool
rpred
  = (Request -> Bool) -> Middleware -> Middleware
ifRequest Request -> Bool
rpred ((Response -> Response) -> Middleware
modifyResponse ((Response -> Response) -> Middleware)
-> (Response -> Response) -> Middleware
forall a b. (a -> b) -> a -> b
$ [ByteString] -> Response -> Response
stripHeaders [ByteString]
hs)