module Freckle.App.Wai
( makeRequestMetricsMiddleware
, noCacheMiddleware
, corsMiddleware
, denyFrameEmbeddingMiddleware
) where
import Freckle.App.Prelude hiding (decodeUtf8)
import Control.Monad.Reader (runReaderT)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.CaseInsensitive as CI
import Data.Text.Encoding (decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Freckle.App.Datadog (HasDogStatsClient, HasDogStatsTags)
import qualified Freckle.App.Datadog as Datadog
import Network.HTTP.Types (ResponseHeaders)
import Network.HTTP.Types.Status (status200, statusCode)
import Network.Wai
import Network.Wai.Middleware.AddHeaders (addHeaders)
makeRequestMetricsMiddleware
:: (HasDogStatsClient env, HasDogStatsTags env)
=> env
-> (Request -> [(Text, Text)])
-> Middleware
makeRequestMetricsMiddleware :: env -> (Request -> [(Text, Text)]) -> Middleware
makeRequestMetricsMiddleware env
env Request -> [(Text, Text)]
getTags Application
app Request
req Response -> IO ResponseReceived
sendResponse' = do
UTCTime
start <- IO UTCTime
getCurrentTime
Application
app Request
req ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Response
res -> do
(ReaderT env IO () -> env -> IO ())
-> env -> ReaderT env IO () -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT env IO () -> env -> IO ()
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT env
env (ReaderT env IO () -> IO ()) -> ReaderT env IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Text -> [(Text, Text)] -> ReaderT env IO ()
forall (m :: * -> *) env.
(MonadUnliftIO m, MonadReader env m, HasDogStatsClient env,
HasDogStatsTags env) =>
Text -> [(Text, Text)] -> m ()
Datadog.increment Text
"requests" ([(Text, Text)] -> ReaderT env IO ())
-> [(Text, Text)] -> ReaderT env IO ()
forall a b. (a -> b) -> a -> b
$ Response -> [(Text, Text)]
tags Response
res
Text -> [(Text, Text)] -> UTCTime -> ReaderT env IO ()
forall (m :: * -> *) env.
(MonadUnliftIO m, MonadReader env m, HasDogStatsClient env,
HasDogStatsTags env) =>
Text -> [(Text, Text)] -> UTCTime -> m ()
Datadog.histogramSinceMs Text
"response_time_ms" (Response -> [(Text, Text)]
tags Response
res) UTCTime
start
Response -> IO ResponseReceived
sendResponse' Response
res
where
tags :: Response -> [(Text, Text)]
tags Response
res =
Request -> [(Text, Text)]
getTags Request
req
[(Text, Text)] -> [(Text, Text)] -> [(Text, Text)]
forall a. Semigroup a => a -> a -> a
<> [ (Text
"method", ByteString -> Text
decodeUtf8 (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
requestMethod Request
req)
, (Text
"status", String -> Text
pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ Status -> Int
statusCode (Status -> Int) -> Status -> Int
forall a b. (a -> b) -> a -> b
$ Response -> Status
responseStatus Response
res)
]
noCacheMiddleware :: Middleware
noCacheMiddleware :: Middleware
noCacheMiddleware = [(ByteString, ByteString)] -> Middleware
addHeaders [(ByteString, ByteString)
forall a b. (IsString a, IsString b) => (a, b)
cacheControlHeader]
where
cacheControlHeader :: (a, b)
cacheControlHeader =
(a
"Cache-Control", b
"no-cache, no-store, max-age=0, private")
corsMiddleware
:: (ByteString -> Bool)
-> [ByteString]
-> Middleware
corsMiddleware :: (ByteString -> Bool) -> [ByteString] -> Middleware
corsMiddleware ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders =
(ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders
Middleware -> Middleware -> Middleware
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders
denyFrameEmbeddingMiddleware :: Middleware
denyFrameEmbeddingMiddleware :: Middleware
denyFrameEmbeddingMiddleware = [(ByteString, ByteString)] -> Middleware
addHeaders [(ByteString
"X-Frame-Options", ByteString
"DENY")]
handleOptions :: (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions :: (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders Application
app Request
req Response -> IO ResponseReceived
sendResponse =
case (Request -> ByteString
requestMethod Request
req, HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Origin" (Request -> [(HeaderName, ByteString)]
requestHeaders Request
req)) of
(ByteString
"OPTIONS", Just ByteString
origin) -> Response -> IO ResponseReceived
sendResponse (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status -> [(HeaderName, ByteString)] -> ByteString -> Response
responseLBS
Status
status200
([(ByteString, ByteString)] -> [(HeaderName, ByteString)]
toHeaders ([(ByteString, ByteString)] -> [(HeaderName, ByteString)])
-> [(ByteString, ByteString)] -> [(HeaderName, ByteString)]
forall a b. (a -> b) -> a -> b
$ (ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin
)
ByteString
forall a. Monoid a => a
mempty
(ByteString, Maybe ByteString)
_ -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
where
toHeaders :: [(ByteString, ByteString)] -> ResponseHeaders
toHeaders :: [(ByteString, ByteString)] -> [(HeaderName, ByteString)]
toHeaders = ((ByteString, ByteString) -> (HeaderName, ByteString))
-> [(ByteString, ByteString)] -> [(HeaderName, ByteString)]
forall a b. (a -> b) -> [a] -> [b]
map ((ByteString -> HeaderName)
-> (ByteString, ByteString) -> (HeaderName, ByteString)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CI.mk)
addCORSHeaders :: (ByteString -> Bool) -> [ByteString] -> Middleware
ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders Application
app Request
req Response -> IO ResponseReceived
sendResponse =
case HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Origin" (Request -> [(HeaderName, ByteString)]
requestHeaders Request
req) of
Maybe ByteString
Nothing -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
Just ByteString
origin -> [(ByteString, ByteString)] -> Middleware
addHeaders
((ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin)
Application
app
Request
req
Response -> IO ResponseReceived
sendResponse
corsResponseHeaders
:: (ByteString -> Bool)
-> [ByteString]
-> ByteString
-> [(ByteString, ByteString)]
ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin =
[ (ByteString
"Access-Control-Allow-Origin", ByteString
validatedOrigin)
, (ByteString
"Access-Control-Allow-Methods", ByteString
"POST, GET, OPTIONS, PUT, DELETE, PATCH")
, (ByteString
"Access-Control-Allow-Credentials", ByteString
"true")
, (ByteString
"Access-Control-Allow-Headers", ByteString
"Content-Type, *")
, (ByteString
"Access-Control-Expose-Headers", ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
", " [ByteString]
exposedHeaders)
]
where
validatedOrigin :: ByteString
validatedOrigin = if ByteString -> Bool
validateOrigin ByteString
origin then ByteString
origin else ByteString
"BADORIGIN"
exposedHeaders :: [ByteString]
exposedHeaders =
[ByteString
"Set-Cookie", ByteString
"Content-Disposition", ByteString
"Link"] [ByteString] -> [ByteString] -> [ByteString]
forall a. Semigroup a => a -> a -> a
<> [ByteString]
extraExposedHeaders
decodeUtf8 :: ByteString -> Text
decodeUtf8 :: ByteString -> Text
decodeUtf8 = OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode