{-# LANGUAGE CPP #-}
-- |
--
-- @since 3.0.14
module Network.Wai.Middleware.ForceDomain where

import Data.ByteString (ByteString)
#if __GLASGOW_HASKELL__ < 804
import Data.Monoid ((<>))
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (mempty)
#endif
#endif
import Network.HTTP.Types (hLocation, methodGet, status301, status307)
import Network.Wai (Middleware, Request (..), responseBuilder)

import Network.Wai.Request (appearsSecure)

-- | Force a domain by redirecting.
-- The `checkDomain` function takes the current domain and checks whether it is correct.
-- It should return `Nothing` if the domain is correct, or `Just "domain.com"` if it is incorrect.
--
-- @since 3.0.14
forceDomain :: (ByteString -> Maybe ByteString) -> Middleware
forceDomain :: (ByteString -> Maybe ByteString) -> Middleware
forceDomain ByteString -> Maybe ByteString
checkDomain Application
app Request
req Response -> IO ResponseReceived
sendResponse =
    case Request -> Maybe ByteString
requestHeaderHost Request
req Maybe ByteString
-> (ByteString -> Maybe ByteString) -> Maybe ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Maybe ByteString
checkDomain of
        Maybe ByteString
Nothing ->
            Application
app Request
req Response -> IO ResponseReceived
sendResponse
        Just ByteString
domain ->
            Response -> IO ResponseReceived
sendResponse (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ ByteString -> Response
redirectResponse ByteString
domain

    where
        -- From: Network.Wai.Middleware.ForceSSL
        redirectResponse :: ByteString -> Response
redirectResponse ByteString
domain =
            Status -> ResponseHeaders -> Builder -> Response
responseBuilder Status
status [(HeaderName
hLocation, ByteString -> ByteString
location ByteString
domain)] Builder
forall a. Monoid a => a
mempty

        location :: ByteString -> ByteString
location ByteString
h =
            let p :: ByteString
p = if Request -> Bool
appearsSecure Request
req then ByteString
"https://" else ByteString
"http://" in
            ByteString
p ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
h ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req

        status :: Status
status
            | Request -> ByteString
requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
methodGet = Status
status301
            | Bool
otherwise = Status
status307