-- |
-- Copyright        : (c) Raghu Kaippully, 2020
-- License          : MPL-2.0
-- Maintainer       : rkaippully@gmail.com
--
-- Middlewares related to HTTP headers.
module WebGear.Middlewares.Header
  ( requestContentType
  ) where

import Control.Arrow (Kleisli (..))
import Control.Monad ((>=>))
import Data.ByteString.Lazy (ByteString)
import Data.HashMap.Strict (fromList)
import Data.String (fromString)
import GHC.TypeLits (KnownSymbol)
import Network.HTTP.Types (badRequest400)
import Text.Printf (printf)

import WebGear.Route (MonadRouter (..))
import WebGear.Trait (linkplus)
import WebGear.Trait.Header (HeaderMatch, HeaderMismatch (..))
import WebGear.Types (RequestMiddleware, Response (..))


-- | A middleware to check that the Content-Type header in the request
-- has a specific value. It will fail the handler if the header did
-- not match.
--
-- Typical usage:
--
-- > requestContentType @"application/json" handler
--
requestContentType :: forall c m req res a. (KnownSymbol c, MonadRouter m)
                   => RequestMiddleware m req (HeaderMatch "Content-Type" c:req) res a
requestContentType :: RequestMiddleware m req (HeaderMatch "Content-Type" c : req) res a
requestContentType handler :: Handler m (HeaderMatch "Content-Type" c : req) res a
handler = (Linked req Request -> m (Linked res (Response a)))
-> Kleisli m (Linked req Request) (Linked res (Response a))
forall (m :: * -> *) a b. (a -> m b) -> Kleisli m a b
Kleisli ((Linked req Request -> m (Linked res (Response a)))
 -> Kleisli m (Linked req Request) (Linked res (Response a)))
-> (Linked req Request -> m (Linked res (Response a)))
-> Kleisli m (Linked req Request) (Linked res (Response a))
forall a b. (a -> b) -> a -> b
$
  forall t a (m :: * -> *) (ts :: [*]).
Trait t a m =>
Linked ts a -> m (Either (Fail t a) (Linked (t : ts) a))
forall a (m :: * -> *) (ts :: [*]).
Trait (HeaderMatch "Content-Type" c) a m =>
Linked ts a
-> m (Either
        (Fail (HeaderMatch "Content-Type" c) a)
        (Linked (HeaderMatch "Content-Type" c : ts) a))
linkplus @(HeaderMatch "Content-Type" c) (Linked req Request
 -> m (Either
         HeaderMismatch
         (Linked (HeaderMatch "Content-Type" c : req) Request)))
-> (Either
      HeaderMismatch
      (Linked (HeaderMatch "Content-Type" c : req) Request)
    -> m (Linked res (Response a)))
-> Linked req Request
-> m (Linked res (Response a))
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (HeaderMismatch -> m (Linked res (Response a)))
-> (Linked (HeaderMatch "Content-Type" c : req) Request
    -> m (Linked res (Response a)))
-> Either
     HeaderMismatch
     (Linked (HeaderMatch "Content-Type" c : req) Request)
-> m (Linked res (Response a))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Response ByteString -> m (Linked res (Response a))
forall (m :: * -> *) a. MonadRouter m => Response ByteString -> m a
failHandler (Response ByteString -> m (Linked res (Response a)))
-> (HeaderMismatch -> Response ByteString)
-> HeaderMismatch
-> m (Linked res (Response a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderMismatch -> Response ByteString
mkError) (Handler m (HeaderMatch "Content-Type" c : req) res a
-> Linked (HeaderMatch "Content-Type" c : req) Request
-> m (Linked res (Response a))
forall (m :: * -> *) a b. Kleisli m a b -> a -> m b
runKleisli Handler m (HeaderMatch "Content-Type" c : req) res a
handler)
  where
    mkError :: HeaderMismatch -> Response ByteString
    mkError :: HeaderMismatch -> Response ByteString
mkError err :: HeaderMismatch
err = Response :: forall a.
Status -> HashMap HeaderName ByteString -> Maybe a -> Response a
Response
                  { respStatus :: Status
respStatus  = Status
badRequest400
                  , respHeaders :: HashMap HeaderName ByteString
respHeaders = [(HeaderName, ByteString)] -> HashMap HeaderName ByteString
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
fromList []
                  , respBody :: Maybe ByteString
respBody    = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ String -> ByteString
forall a. IsString a => String -> a
fromString (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$
                    case (HeaderMismatch -> ByteString
expectedHeader HeaderMismatch
err, HeaderMismatch -> Maybe ByteString
actualHeader HeaderMismatch
err) of
                      (ex :: ByteString
ex, Nothing) -> String -> String -> String
forall r. PrintfType r => String -> r
printf "Expected Content-Type header %s but not found" (ByteString -> String
forall a. Show a => a -> String
show ByteString
ex)
                      (ex :: ByteString
ex, Just h :: ByteString
h)  -> String -> String -> String -> String
forall r. PrintfType r => String -> r
printf "Expected Content-Type header %s but found %s" (ByteString -> String
forall a. Show a => a -> String
show ByteString
ex) (ByteString -> String
forall a. Show a => a -> String
show ByteString
h)
                  }