-- | Integration of "Freckle.App" tooling with "Network.Wai"
module Freckle.App.Wai
  ( noCacheMiddleware
  , corsMiddleware
  , denyFrameEmbeddingMiddleware
  ) where

import Freckle.App.Prelude

import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.CaseInsensitive as CI
import Network.HTTP.Types (ResponseHeaders)
import Network.HTTP.Types.Status (status200)
import Network.Wai
import Network.Wai.Middleware.AddHeaders (addHeaders)

noCacheMiddleware :: Middleware
noCacheMiddleware :: Middleware
noCacheMiddleware = [(ByteString, ByteString)] -> Middleware
addHeaders [forall {a} {b}. (IsString a, IsString b) => (a, b)
cacheControlHeader]
 where
  cacheControlHeader :: (a, b)
cacheControlHeader =
    (a
"Cache-Control", b
"no-cache, no-store, max-age=0, private")

corsMiddleware
  :: (ByteString -> Bool)
  -- ^ Predicate that returns 'True' for valid @Origin@ values
  -> [ByteString]
  -- ^ Extra headers to add to @Expose-Headers@
  -> Middleware
corsMiddleware :: (ByteString -> Bool) -> [ByteString] -> Middleware
corsMiddleware ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders =
  (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders

-- | Middleware that adds header to deny all frame embedding
denyFrameEmbeddingMiddleware :: Middleware
denyFrameEmbeddingMiddleware :: Middleware
denyFrameEmbeddingMiddleware = [(ByteString, ByteString)] -> Middleware
addHeaders [(ByteString
"X-Frame-Options", ByteString
"DENY")]

handleOptions :: (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions :: (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders Application
app Request
req Response -> IO ResponseReceived
sendResponse =
  case (Request -> ByteString
requestMethod Request
req, forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Origin" (Request -> RequestHeaders
requestHeaders Request
req)) of
    (ByteString
"OPTIONS", Just ByteString
origin) -> Response -> IO ResponseReceived
sendResponse forall a b. (a -> b) -> a -> b
$ Status -> RequestHeaders -> ByteString -> Response
responseLBS
      Status
status200
      ([(ByteString, ByteString)] -> RequestHeaders
toHeaders forall a b. (a -> b) -> a -> b
$ (ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin
      )
      forall a. Monoid a => a
mempty
    (ByteString, Maybe ByteString)
_ -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
 where
  toHeaders :: [(ByteString, ByteString)] -> ResponseHeaders
  toHeaders :: [(ByteString, ByteString)] -> RequestHeaders
toHeaders = forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall s. FoldCase s => s -> CI s
CI.mk)

addCORSHeaders :: (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders :: (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders Application
app Request
req Response -> IO ResponseReceived
sendResponse =
  case forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Origin" (Request -> RequestHeaders
requestHeaders Request
req) of
    Maybe ByteString
Nothing -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
    Just ByteString
origin -> [(ByteString, ByteString)] -> Middleware
addHeaders
      ((ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin)
      Application
app
      Request
req
      Response -> IO ResponseReceived
sendResponse

corsResponseHeaders
  :: (ByteString -> Bool)
  -> [ByteString]
  -> ByteString
  -> [(ByteString, ByteString)]
corsResponseHeaders :: (ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin =
  [ (ByteString
"Access-Control-Allow-Origin", ByteString
validatedOrigin)
  , (ByteString
"Access-Control-Allow-Methods", ByteString
"POST, GET, OPTIONS, PUT, DELETE, PATCH")
  , (ByteString
"Access-Control-Allow-Credentials", ByteString
"true")
  , (ByteString
"Access-Control-Allow-Headers", ByteString
"Content-Type, *")
  , (ByteString
"Access-Control-Expose-Headers", ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
", " [ByteString]
exposedHeaders)
  ]
 where
  validatedOrigin :: ByteString
validatedOrigin = if ByteString -> Bool
validateOrigin ByteString
origin then ByteString
origin else ByteString
"BADORIGIN"

  exposedHeaders :: [ByteString]
exposedHeaders =
    [ByteString
"Set-Cookie", ByteString
"Content-Disposition", ByteString
"Link"] forall a. Semigroup a => a -> a -> a
<> [ByteString]
extraExposedHeaders