-- | TODO: Can we use <https://hackage.haskell.org/package/wai-cors> instead?
module Network.Wai.Middleware.Cors
  ( corsMiddleware
  ) where

import Freckle.App.Prelude

import Data.ByteString (ByteString)
import Data.ByteString qualified as BS
import Data.CaseInsensitive qualified as CI
import Network.HTTP.Types (ResponseHeaders)
import Network.HTTP.Types.Status (status200)
import Network.Wai
import Network.Wai.Middleware.AddHeaders

corsMiddleware
  :: (ByteString -> Bool)
  -- ^ Predicate that returns 'True' for valid @Origin@ values
  -> [ByteString]
  -- ^ Extra headers to add to @Expose-Headers@
  -> Middleware
corsMiddleware :: (ByteString -> Bool) -> [ByteString] -> Middleware
corsMiddleware ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders =
  (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders
    Middleware -> Middleware -> Middleware
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders

handleOptions :: (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions :: (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders Application
app Request
req Response -> IO ResponseReceived
sendResponse =
  case (Request -> ByteString
requestMethod Request
req, HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Origin" (Request -> [(HeaderName, ByteString)]
requestHeaders Request
req)) of
    (ByteString
"OPTIONS", Just ByteString
origin) ->
      Response -> IO ResponseReceived
sendResponse (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
        Status -> [(HeaderName, ByteString)] -> ByteString -> Response
responseLBS
          Status
status200
          ( [(ByteString, ByteString)] -> [(HeaderName, ByteString)]
toHeaders ([(ByteString, ByteString)] -> [(HeaderName, ByteString)])
-> [(ByteString, ByteString)] -> [(HeaderName, ByteString)]
forall a b. (a -> b) -> a -> b
$ (ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin
          )
          ByteString
forall a. Monoid a => a
mempty
    (ByteString, Maybe ByteString)
_ -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
 where
  toHeaders :: [(ByteString, ByteString)] -> ResponseHeaders
  toHeaders :: [(ByteString, ByteString)] -> [(HeaderName, ByteString)]
toHeaders = ((ByteString, ByteString) -> (HeaderName, ByteString))
-> [(ByteString, ByteString)] -> [(HeaderName, ByteString)]
forall a b. (a -> b) -> [a] -> [b]
map ((ByteString -> HeaderName)
-> (ByteString, ByteString) -> (HeaderName, ByteString)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CI.mk)

addCORSHeaders :: (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders :: (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders Application
app Request
req Response -> IO ResponseReceived
sendResponse =
  case HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Origin" (Request -> [(HeaderName, ByteString)]
requestHeaders Request
req) of
    Maybe ByteString
Nothing -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
    Just ByteString
origin ->
      [(ByteString, ByteString)] -> Middleware
addHeaders
        ((ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin)
        Application
app
        Request
req
        Response -> IO ResponseReceived
sendResponse

corsResponseHeaders
  :: (ByteString -> Bool)
  -> [ByteString]
  -> ByteString
  -> [(ByteString, ByteString)]
corsResponseHeaders :: (ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin =
  [ (ByteString
"Access-Control-Allow-Origin", ByteString
validatedOrigin)
  , (ByteString
"Access-Control-Allow-Methods", ByteString
"POST, GET, OPTIONS, PUT, DELETE, PATCH")
  , (ByteString
"Access-Control-Allow-Credentials", ByteString
"true")
  , (ByteString
"Access-Control-Allow-Headers", ByteString
"Content-Type, *")
  , (ByteString
"Access-Control-Expose-Headers", ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
", " [ByteString]
exposedHeaders)
  , (ByteString
"Vary", ByteString
"Origin")
  ]
 where
  validatedOrigin :: ByteString
validatedOrigin = if ByteString -> Bool
validateOrigin ByteString
origin then ByteString
origin else ByteString
"BADORIGIN"

  exposedHeaders :: [ByteString]
exposedHeaders =
    [ByteString
"Set-Cookie", ByteString
"Content-Disposition", ByteString
"Link"] [ByteString] -> [ByteString] -> [ByteString]
forall a. Semigroup a => a -> a -> a
<> [ByteString]
extraExposedHeaders