-- | TODO: Can we use <https://hackage.haskell.org/package/wai-cors> instead?
module Network.Wai.Middleware.Cors
  ( corsMiddleware
  ) where

import Freckle.App.Prelude

import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.CaseInsensitive as CI
import Network.HTTP.Types (ResponseHeaders)
import Network.HTTP.Types.Status (status200)
import Network.Wai
import Network.Wai.Middleware.AddHeaders

corsMiddleware
  :: (ByteString -> Bool)
  -- ^ Predicate that returns 'True' for valid @Origin@ values
  -> [ByteString]
  -- ^ Extra headers to add to @Expose-Headers@
  -> Middleware
corsMiddleware :: (ByteString -> Bool) -> [ByteString] -> Middleware
corsMiddleware ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders =
  (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders

handleOptions :: (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions :: (ByteString -> Bool) -> [ByteString] -> Middleware
handleOptions ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders Application
app Request
req Response -> IO ResponseReceived
sendResponse =
  case (Request -> ByteString
requestMethod Request
req, forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Origin" (Request -> RequestHeaders
requestHeaders Request
req)) of
    (ByteString
"OPTIONS", Just ByteString
origin) ->
      Response -> IO ResponseReceived
sendResponse forall a b. (a -> b) -> a -> b
$
        Status -> RequestHeaders -> ByteString -> Response
responseLBS
          Status
status200
          ( [(ByteString, ByteString)] -> RequestHeaders
toHeaders forall a b. (a -> b) -> a -> b
$ (ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin
          )
          forall a. Monoid a => a
mempty
    (ByteString, Maybe ByteString)
_ -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
 where
  toHeaders :: [(ByteString, ByteString)] -> ResponseHeaders
  toHeaders :: [(ByteString, ByteString)] -> RequestHeaders
toHeaders = forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall s. FoldCase s => s -> CI s
CI.mk)

addCORSHeaders :: (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders :: (ByteString -> Bool) -> [ByteString] -> Middleware
addCORSHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders Application
app Request
req Response -> IO ResponseReceived
sendResponse =
  case forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Origin" (Request -> RequestHeaders
requestHeaders Request
req) of
    Maybe ByteString
Nothing -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
    Just ByteString
origin ->
      [(ByteString, ByteString)] -> Middleware
addHeaders
        ((ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin)
        Application
app
        Request
req
        Response -> IO ResponseReceived
sendResponse

corsResponseHeaders
  :: (ByteString -> Bool)
  -> [ByteString]
  -> ByteString
  -> [(ByteString, ByteString)]
corsResponseHeaders :: (ByteString -> Bool)
-> [ByteString] -> ByteString -> [(ByteString, ByteString)]
corsResponseHeaders ByteString -> Bool
validateOrigin [ByteString]
extraExposedHeaders ByteString
origin =
  [ (ByteString
"Access-Control-Allow-Origin", ByteString
validatedOrigin)
  , (ByteString
"Access-Control-Allow-Methods", ByteString
"POST, GET, OPTIONS, PUT, DELETE, PATCH")
  , (ByteString
"Access-Control-Allow-Credentials", ByteString
"true")
  , (ByteString
"Access-Control-Allow-Headers", ByteString
"Content-Type, *")
  , (ByteString
"Access-Control-Expose-Headers", ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
", " [ByteString]
exposedHeaders)
  ]
 where
  validatedOrigin :: ByteString
validatedOrigin = if ByteString -> Bool
validateOrigin ByteString
origin then ByteString
origin else ByteString
"BADORIGIN"

  exposedHeaders :: [ByteString]
exposedHeaders =
    [ByteString
"Set-Cookie", ByteString
"Content-Disposition", ByteString
"Link"] forall a. Semigroup a => a -> a -> a
<> [ByteString]
extraExposedHeaders