-- | Integration of "Freckle.App" tooling with "Network.Wai"
module Freckle.App.Wai
  ( makeRequestMetricsMiddleware
  , noCacheMiddleware
  , corsMiddleware
  , denyFrameEmbeddingMiddleware
  ) where

import Freckle.App.Prelude hiding (decodeUtf8)

import Control.Monad.Reader (runReaderT)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.CaseInsensitive as CI
import Data.Text.Encoding (decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Freckle.App.Datadog (HasDogStatsClient, HasDogStatsTags)
import qualified Freckle.App.Datadog as Datadog
import Network.HTTP.Types (ResponseHeaders)
import Network.HTTP.Types.Status (status200, statusCode)
import Network.Wai
import Network.Wai.Middleware.AddHeaders (addHeaders)

makeRequestMetricsMiddleware
  :: (HasDogStatsClient env, HasDogStatsTags env)
  => env
  -> (Request -> [(Text, Text)])
  -> Middleware
makeRequestMetricsMiddleware :: env -> (Request -> [(Text, Text)]) -> Middleware
makeRequestMetricsMiddleware env
env Request -> [(Text, Text)]
getTags Application
app Request
req Response -> IO ResponseReceived
sendResponse' = do
  UTCTime
start <- IO UTCTime
getCurrentTime
  Application
app Request
req ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Response
res -> do
    (ReaderT env IO () -> env -> IO ())
-> env -> ReaderT env IO () -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT env IO () -> env -> IO ()
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT env
env (ReaderT env IO () -> IO ()) -> ReaderT env IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Text -> [(Text, Text)] -> ReaderT env IO ()
forall (m :: * -> *) env.
(MonadUnliftIO m, MonadReader env m, HasDogStatsClient env,
 HasDogStatsTags env) =>
Text -> [(Text, Text)] -> m ()
Datadog.increment Text
"requests" ([(Text, Text)] -> ReaderT env IO ())
-> [(Text, Text)] -> ReaderT env IO ()
forall a b. (a -> b) -> a -> b
$ Response -> [(Text, Text)]
tags Response
res
      Text -> [(Text, Text)] -> UTCTime -> ReaderT env IO ()
forall (m :: * -> *) env.
(MonadUnliftIO m, MonadReader env m, HasDogStatsClient env,
 HasDogStatsTags env) =>
Text -> [(Text, Text)] -> UTCTime -> m ()
Datadog.histogramSinceMs Text
"response_time_ms" (Response -> [(Text, Text)]
tags Response
res) UTCTime
start
    Response -> IO ResponseReceived
sendResponse' Response
res
 where
  tags :: Response -> [(Text, Text)]
tags Response
res =
    Request -> [(Text, Text)]
getTags Request
req
      [(Text, Text)] -> [(Text, Text)] -> [(Text, Text)]
forall a. Semigroup a => a -> a -> a
<> [ (Text
"method", ByteString -> Text
decodeUtf8 (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
requestMethod Request
req)
         , (Text
"status", String -> Text
pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ Status -> Int
statusCode (Status -> Int) -> Status -> Int
forall a b. (a -> b) -> a -> b
$ Response -> Status
responseStatus Response
res)
         ]

noCacheMiddleware :: Middleware
noCacheMiddleware :: Middleware
noCacheMiddleware = [(ByteString, ByteString)] -> Middleware
addHeaders [(ByteString, ByteString)
forall a b. (IsString a, IsString b) => (a, b)
cacheControlHeader]
 where
  cacheControlHeader :: (a, b)
cacheControlHeader =
    (a
"Cache-Control", b
"no-cache, no-store, max-age=0, private")

corsMiddleware
  :: (ByteString -> Bool)
  -- ^ Predicate that returns 'True' for valid @Origin@ values
  -> [ByteString]
  -- ^ Extra headers to add to @Expose-Headers@
  -> Middleware
corsMiddleware :: (ByteString -> Bool) -> [ByteString] -> Middleware
corsMiddleware ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders =
  (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders
    Middleware -> Middleware -> Middleware
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders

-- | Middleware that adds header to deny all frame embedding
denyFrameEmbeddingMiddleware :: Middleware
denyFrameEmbeddingMiddleware :: Middleware
denyFrameEmbeddingMiddleware = [(ByteString, ByteString)] -> Middleware
addHeaders [(ByteString
"X-Frame-Options", ByteString
"DENY")]

handleOptions :: (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions :: (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders Application
app Request
req Response -> IO ResponseReceived
sendResponse =
  case (Request -> ByteString
requestMethod Request
req, HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Origin" (Request -> [(HeaderName, ByteString)]
requestHeaders Request
req)) of
    (ByteString
"OPTIONS", Just ByteString
origin) -> Response -> IO ResponseReceived
sendResponse (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status -> [(HeaderName, ByteString)] -> ByteString -> Response
responseLBS
      Status
status200
      ([(ByteString, ByteString)] -> [(HeaderName, ByteString)]
toHeaders ([(ByteString, ByteString)] -> [(HeaderName, ByteString)])
-> [(ByteString, ByteString)] -> [(HeaderName, ByteString)]
forall a b. (a -> b) -> a -> b
$ (ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin
      )
      ByteString
forall a. Monoid a => a
mempty
    (ByteString, Maybe ByteString)
_ -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
 where
  toHeaders :: [(ByteString, ByteString)] -> ResponseHeaders
  toHeaders :: [(ByteString, ByteString)] -> [(HeaderName, ByteString)]
toHeaders = ((ByteString, ByteString) -> (HeaderName, ByteString))
-> [(ByteString, ByteString)] -> [(HeaderName, ByteString)]
forall a b. (a -> b) -> [a] -> [b]
map ((ByteString -> HeaderName)
-> (ByteString, ByteString) -> (HeaderName, ByteString)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CI.mk)

addCORSHeaders :: (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders :: (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders Application
app Request
req Response -> IO ResponseReceived
sendResponse =
  case HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Origin" (Request -> [(HeaderName, ByteString)]
requestHeaders Request
req) of
    Maybe ByteString
Nothing -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
    Just ByteString
origin -> [(ByteString, ByteString)] -> Middleware
addHeaders
      ((ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin)
      Application
app
      Request
req
      Response -> IO ResponseReceived
sendResponse

corsResponseHeaders
  :: (ByteString -> Bool)
  -> [ByteString]
  -> ByteString
  -> [(ByteString, ByteString)]
corsResponseHeaders :: (ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin =
  [ (ByteString
"Access-Control-Allow-Origin", ByteString
validatedOrigin)
  , (ByteString
"Access-Control-Allow-Methods", ByteString
"POST, GET, OPTIONS, PUT, DELETE, PATCH")
  , (ByteString
"Access-Control-Allow-Credentials", ByteString
"true")
  , (ByteString
"Access-Control-Allow-Headers", ByteString
"Content-Type, *")
  , (ByteString
"Access-Control-Expose-Headers", ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
", " [ByteString]
exposedHeaders)
  ]
 where
  validatedOrigin :: ByteString
validatedOrigin = if ByteString -> Bool
validateOrigin ByteString
origin then ByteString
origin else ByteString
"BADORIGIN"

  exposedHeaders :: [ByteString]
exposedHeaders =
    [ByteString
"Set-Cookie", ByteString
"Content-Disposition", ByteString
"Link"] [ByteString] -> [ByteString] -> [ByteString]
forall a. Semigroup a => a -> a -> a
<> [ByteString]
extraExposedHeaders

decodeUtf8 :: ByteString -> Text
decodeUtf8 :: ByteString -> Text
decodeUtf8 = OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode