{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Network.Wai.Middleware.EnforceHTTPS
(
EnforceHTTPSConfig(..)
, defaultConfig
, def
, withResolver
, withConfig
, HTTPSResolver
, xForwardedProto
, azure
, forwarded
, customProtoHeader
) where
import Data.ByteString (ByteString)
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
import Network.HTTP.Types (Method, Status)
import Network.Wai (Application, Middleware, Request)
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (mappend, mempty)
#endif
import qualified Data.ByteString as ByteString
import qualified Data.CaseInsensitive as CaseInsensitive
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import qualified Network.HTTP.Forwarded as Forwarded
import qualified Network.HTTP.Types as HTTP
import qualified Network.Wai as Wai
data EnforceHTTPSConfig = EnforceHTTPSConfig
{ EnforceHTTPSConfig -> HTTPSResolver
httpsIsSecure :: !HTTPSResolver
, EnforceHTTPSConfig -> ByteString -> ByteString
httpsHostRewrite :: !(ByteString -> ByteString)
, EnforceHTTPSConfig -> Int
httpsPort :: !Int
, EnforceHTTPSConfig -> Bool
httpsIgnoreURL :: !Bool
, EnforceHTTPSConfig -> Bool
httpsTemporary :: !Bool
, EnforceHTTPSConfig -> Bool
httpsSkipDefaultPort :: !Bool
, EnforceHTTPSConfig -> [ByteString]
httpsRedirectMethods :: ![Method]
, EnforceHTTPSConfig -> Status
httpsDisallowStatus :: !Status
}
defaultConfig :: EnforceHTTPSConfig
defaultConfig :: EnforceHTTPSConfig
defaultConfig = EnforceHTTPSConfig :: HTTPSResolver
-> (ByteString -> ByteString)
-> Int
-> Bool
-> Bool
-> Bool
-> [ByteString]
-> Status
-> EnforceHTTPSConfig
EnforceHTTPSConfig
{ httpsIsSecure :: HTTPSResolver
httpsIsSecure = HTTPSResolver
Wai.isSecure
, httpsHostRewrite :: ByteString -> ByteString
httpsHostRewrite = ByteString -> ByteString
forall a. a -> a
id
, httpsPort :: Int
httpsPort = Int
443
, httpsIgnoreURL :: Bool
httpsIgnoreURL = Bool
False
, httpsTemporary :: Bool
httpsTemporary = Bool
False
, httpsSkipDefaultPort :: Bool
httpsSkipDefaultPort = Bool
True
, httpsRedirectMethods :: [ByteString]
httpsRedirectMethods = [ ByteString
"GET", ByteString
"HEAD" ]
, httpsDisallowStatus :: Status
httpsDisallowStatus = Status
HTTP.methodNotAllowed405
}
{-# INLINE defaultConfig #-}
withConfig :: EnforceHTTPSConfig -> Middleware
withConfig :: EnforceHTTPSConfig -> Middleware
withConfig conf :: EnforceHTTPSConfig
conf@EnforceHTTPSConfig { Bool
Int
[ByteString]
Status
ByteString -> ByteString
HTTPSResolver
httpsDisallowStatus :: Status
httpsRedirectMethods :: [ByteString]
httpsSkipDefaultPort :: Bool
httpsTemporary :: Bool
httpsIgnoreURL :: Bool
httpsPort :: Int
httpsHostRewrite :: ByteString -> ByteString
httpsIsSecure :: HTTPSResolver
httpsDisallowStatus :: EnforceHTTPSConfig -> Status
httpsRedirectMethods :: EnforceHTTPSConfig -> [ByteString]
httpsSkipDefaultPort :: EnforceHTTPSConfig -> Bool
httpsTemporary :: EnforceHTTPSConfig -> Bool
httpsIgnoreURL :: EnforceHTTPSConfig -> Bool
httpsPort :: EnforceHTTPSConfig -> Int
httpsHostRewrite :: EnforceHTTPSConfig -> ByteString -> ByteString
httpsIsSecure :: EnforceHTTPSConfig -> HTTPSResolver
.. } Application
app Request
req
| HTTPSResolver
httpsIsSecure Request
req = Application
app Request
req
| Bool
otherwise = EnforceHTTPSConfig -> Application
redirect EnforceHTTPSConfig
conf Request
req
{-# INLINE withConfig #-}
redirect :: EnforceHTTPSConfig -> Application
redirect :: EnforceHTTPSConfig -> Application
redirect EnforceHTTPSConfig { Bool
Int
[ByteString]
Status
ByteString -> ByteString
HTTPSResolver
httpsDisallowStatus :: Status
httpsRedirectMethods :: [ByteString]
httpsSkipDefaultPort :: Bool
httpsTemporary :: Bool
httpsIgnoreURL :: Bool
httpsPort :: Int
httpsHostRewrite :: ByteString -> ByteString
httpsIsSecure :: HTTPSResolver
httpsDisallowStatus :: EnforceHTTPSConfig -> Status
httpsRedirectMethods :: EnforceHTTPSConfig -> [ByteString]
httpsSkipDefaultPort :: EnforceHTTPSConfig -> Bool
httpsTemporary :: EnforceHTTPSConfig -> Bool
httpsIgnoreURL :: EnforceHTTPSConfig -> Bool
httpsPort :: EnforceHTTPSConfig -> Int
httpsHostRewrite :: EnforceHTTPSConfig -> ByteString -> ByteString
httpsIsSecure :: EnforceHTTPSConfig -> HTTPSResolver
.. } Request
req Response -> IO ResponseReceived
respond = Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
case Request -> Maybe ByteString
Wai.requestHeaderHost Request
req of
Maybe ByteString
Nothing -> Status -> ResponseHeaders -> Builder -> Response
Wai.responseBuilder Status
HTTP.status400 [] Builder
forall a. Monoid a => a
mempty
Just ByteString
h -> Status -> ResponseHeaders -> Builder -> Response
Wai.responseBuilder Status
status (ByteString -> ResponseHeaders
headers (ByteString -> ResponseHeaders) -> ByteString -> ResponseHeaders
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
stripPort ByteString
h) Builder
forall a. Monoid a => a
mempty
where
( Status
status, ByteString -> ResponseHeaders
headers ) =
if ByteString
reqMethod ByteString -> [ByteString] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
httpsRedirectMethods then
( if Bool
httpsTemporary then
Status
HTTP.status307
else
Status
HTTP.status301
, (HeaderName, ByteString) -> ResponseHeaders
forall (m :: * -> *) a. Monad m => a -> m a
return ((HeaderName, ByteString) -> ResponseHeaders)
-> (ByteString -> (HeaderName, ByteString))
-> ByteString
-> ResponseHeaders
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> (HeaderName, ByteString)
redirectURL
)
else
( Status
httpsDisallowStatus
, ResponseHeaders -> ByteString -> ResponseHeaders
forall a b. a -> b -> a
const (ResponseHeaders -> ByteString -> ResponseHeaders)
-> ResponseHeaders -> ByteString -> ResponseHeaders
forall a b. (a -> b) -> a -> b
$
if Status
httpsDisallowStatus Status -> Status -> Bool
forall a. Eq a => a -> a -> Bool
== Status
HTTP.methodNotAllowed405 then
[ (HeaderName
"Allow", ByteString -> [ByteString] -> ByteString
ByteString.intercalate ByteString
", " [ByteString]
httpsRedirectMethods) ]
else
[]
)
redirectURL :: ByteString -> (HeaderName, ByteString)
redirectURL ByteString
h =
( HeaderName
HTTP.hLocation, ByteString
"https://" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
fullHost ByteString
h ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
path)
path :: ByteString
path =
if Bool
httpsIgnoreURL then
ByteString
forall a. Monoid a => a
mempty
else
Request -> ByteString
Wai.rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
Wai.rawQueryString Request
req
port :: ByteString
port =
if Int
httpsPort Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
443 Bool -> Bool -> Bool
&& Bool
httpsSkipDefaultPort then
ByteString
""
else
Text -> ByteString
Text.encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ (Text -> Text -> Text
forall a. Monoid a => a -> a -> a
mappend Text
":") (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Int -> String
forall a. Show a => a -> String
show Int
httpsPort
stripPort :: ByteString -> ByteString
stripPort ByteString
h =
(ByteString, ByteString) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, ByteString) -> ByteString)
-> (ByteString, ByteString) -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
ByteString.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
58) ByteString
h
fullHost :: ByteString -> ByteString
fullHost ByteString
h = ByteString -> ByteString
httpsHostRewrite ByteString
h ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
port
reqMethod :: ByteString
reqMethod = Request -> ByteString
Wai.requestMethod Request
req
def :: Middleware
def :: Middleware
def =
EnforceHTTPSConfig -> Middleware
withConfig EnforceHTTPSConfig
defaultConfig
{-# INLINE def #-}
withResolver :: HTTPSResolver -> Middleware
withResolver :: HTTPSResolver -> Middleware
withResolver HTTPSResolver
resolver =
EnforceHTTPSConfig -> Middleware
withConfig (EnforceHTTPSConfig -> Middleware)
-> EnforceHTTPSConfig -> Middleware
forall a b. (a -> b) -> a -> b
$ EnforceHTTPSConfig
defaultConfig { httpsIsSecure :: HTTPSResolver
httpsIsSecure = HTTPSResolver
resolver }
{-# INLINE withResolver #-}
type HTTPSResolver =
Request -> Bool
xForwardedProto :: HTTPSResolver
xForwardedProto :: HTTPSResolver
xForwardedProto Request
req =
Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"https") Maybe ByteString
maybeHederVal
where
maybeHederVal :: Maybe ByteString
maybeHederVal =
HeaderName
"x-forwarded-proto" HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` Request -> ResponseHeaders
Wai.requestHeaders Request
req
{-# INLINE xForwardedProto #-}
azure :: HTTPSResolver
azure :: HTTPSResolver
azure Request
req =
Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Bool -> ByteString -> Bool
forall a b. a -> b -> a
const Bool
True) Maybe ByteString
maybeHeader
where
maybeHeader :: Maybe ByteString
maybeHeader =
HeaderName
"x-arr-ssl" HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` Request -> ResponseHeaders
Wai.requestHeaders Request
req
{-# INLINE azure #-}
customProtoHeader :: ByteString -> HTTPSResolver
ByteString
header Request
req =
Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"https") Maybe ByteString
maybeHederVal
where
maybeHederVal :: Maybe ByteString
maybeHederVal =
ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CaseInsensitive.mk ByteString
header HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` Request -> ResponseHeaders
Wai.requestHeaders Request
req
{-# INLINE customProtoHeader #-}
forwarded :: HTTPSResolver
forwarded :: HTTPSResolver
forwarded Request
req =
Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False ByteString -> Bool
check Maybe ByteString
maybeHeader
where
check :: ByteString -> Bool
check ByteString
val =
Bool -> (HeaderName -> Bool) -> Maybe HeaderName -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
"https") (Maybe HeaderName -> Bool) -> Maybe HeaderName -> Bool
forall a b. (a -> b) -> a -> b
$
Forwarded -> Maybe HeaderName
Forwarded.forwardedProto (Forwarded -> Maybe HeaderName) -> Forwarded -> Maybe HeaderName
forall a b. (a -> b) -> a -> b
$ ByteString -> Forwarded
Forwarded.parseForwarded ByteString
val
maybeHeader :: Maybe ByteString
maybeHeader =
HeaderName
"forwarded" HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` Request -> ResponseHeaders
Wai.requestHeaders Request
req
{-# INLINE forwarded #-}