{-# LANGUAGE CPP #-}
module Network.Wai.Middleware.ForceSSL (
    forceSSL,
) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>))
import Data.Monoid (mempty)
#endif
#if __GLASGOW_HASKELL__ < 804
import Data.Monoid ((<>))
#endif
import Network.HTTP.Types (hLocation, methodGet, status301, status307)
import Network.Wai (Middleware, Request (..), Response, responseBuilder)
import Network.Wai.Request (appearsSecure)
forceSSL :: Middleware
forceSSL :: Middleware
forceSSL Application
app Request
req Response -> IO ResponseReceived
sendResponse =
    case (Request -> Bool
appearsSecure Request
req, Request -> Maybe Response
redirectResponse Request
req) of
        (Bool
False, Just Response
resp) -> Response -> IO ResponseReceived
sendResponse Response
resp
        (Bool, Maybe Response)
_ -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
redirectResponse :: Request -> Maybe Response
redirectResponse :: Request -> Maybe Response
redirectResponse Request
req = do
    ByteString
host <- Request -> Maybe ByteString
requestHeaderHost Request
req
    Response -> Maybe Response
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (Response -> Maybe Response) -> Response -> Maybe Response
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> Builder -> Response
responseBuilder Status
status [(HeaderName
hLocation, ByteString -> ByteString
location ByteString
host)] Builder
forall a. Monoid a => a
mempty
  where
    location :: ByteString -> ByteString
location ByteString
h = ByteString
"https://" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
h ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req
    status :: Status
status
        | Request -> ByteString
requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
methodGet = Status
status301
        | Bool
otherwise = Status
status307