module HTTPMethodInvalid (canonicalizeHTTPMethods, limitHTTPMethods) where

import qualified Network.Wai as Wai
import qualified Network.HTTP.Types as HTTP
import qualified Network.HTTP.Types.Method as Method
import qualified Data.ByteString.Char8 as ByteString
import qualified Data.ByteString.Lazy.Char8 as LBS

-- | Checks that the HTTP method is one of the StdMethods.
-- StdMethod: HTTP standard method (as defined by RFC 2616, and PATCH which is defined by RFC 5789).
-- Otherwise sets the HTTP method to INVALID.
-- post: HTTP method is canonicalized
canonicalizeHTTPMethods :: Wai.Middleware
canonicalizeHTTPMethods :: Middleware
canonicalizeHTTPMethods Application
app Request
request Response -> IO ResponseReceived
respond = do
  let method :: Method
method = Request -> Method
Wai.requestMethod Request
request
      parsedMethod :: Method
parsedMethod = (Method -> Method)
-> (StdMethod -> Method) -> Either Method StdMethod -> Method
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (\Method
_ -> Method
invalid) (StdMethod -> Method
Method.renderStdMethod) (Method -> Either Method StdMethod
Method.parseMethod Method
method)
      request' :: Request
request' = Request
request { requestMethod :: Method
Wai.requestMethod = Method
parsedMethod }
  Application
app Request
request' Response -> IO ResponseReceived
respond

-- | Early exit for INVALID HTTP methods.
-- pre: HTTP method is canonicalized.
limitHTTPMethods :: Wai.Middleware
limitHTTPMethods :: Middleware
limitHTTPMethods Application
app Request
request Response -> IO ResponseReceived
respond =
  if Request -> Method
Wai.requestMethod Request
request Method -> Method -> Bool
forall a. Eq a => a -> a -> Bool
== Method
invalid
    then Response -> IO ResponseReceived
respond (Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
HTTP.badRequest400 [(HeaderName
HTTP.hContentType, (String -> Method
ByteString.pack String
"application/json"))] (String -> ByteString
LBS.pack String
"{}"))
    else Application
app Request
request Response -> IO ResponseReceived
respond

invalid :: Method.Method
invalid :: Method
invalid = String -> Method
ByteString.pack String
"INVALID"