{-# LANGUAGE OverloadedStrings #-}
-- | Add <http://www.w3.org/TR/cors/ CORS> (cross-origin resource sharing)
-- headers to a Snap application. CORS headers can be added either conditionally
-- or unconditionally to the entire site, or you can apply CORS headers to a
-- single route.
--
-- To use in a snaplet, simply use 'wrapSite':
--
-- @
-- wrapSite $ applyCORS defaultOptions
-- @
module Snap.Util.CORS
  ( -- * Applying CORS to a specific response
    applyCORS

    -- * Option Specification
  , CORSOptions(..)
  , defaultOptions

    -- ** Origin lists
  , OriginList(..)
  , OriginSet, mkOriginSet, origins

    -- * Internals
  , HashableURI(..), HashableMethod (..)
  ) where

import Control.Applicative
import Control.Monad (join, when)
import Data.CaseInsensitive (CI)
import Data.Hashable (Hashable(..))
import Data.Maybe (fromMaybe)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Network.URI (URI (..), URIAuth (..),  parseURI)

import qualified Data.Attoparsec.ByteString.Char8 as Attoparsec
import qualified Data.ByteString.Char8 as S
import qualified Data.CaseInsensitive as CI
import qualified Data.HashSet as HashSet
import qualified Data.Text as Text

import qualified Snap.Core as Snap
import Snap.Internal.Parsing (pTokens)

-- | A set of origins. RFC 6454 specifies that origins are a scheme, host and
-- port, so the 'OriginSet' wrapper around a 'HashSet.HashSet' ensures that each
-- 'URI' constists of nothing more than this.
newtype OriginSet = OriginSet { OriginSet -> HashSet HashableURI
origins :: HashSet.HashSet HashableURI }

-- | Used to specify the contents of the @Access-Control-Allow-Origin@ header.
data OriginList
  = Everywhere
  -- ^ Allow any origin to access this resource. Corresponds to
  -- @Access-Control-Allow-Origin: *@
  | Nowhere
  -- ^ Do not allow cross-origin requests
  | Origins OriginSet
  -- ^ Allow cross-origin requests from these origins.

-- | Specify the options to use when building CORS headers for a response. Most
-- of these options are 'Snap.Handler' actions to allow you to conditionally
-- determine the setting of each header.
data CORSOptions m = CORSOptions
  { forall (m :: * -> *). CORSOptions m -> m OriginList
corsAllowOrigin :: m OriginList
  -- ^ Which origins are allowed to make cross-origin requests.

  , forall (m :: * -> *). CORSOptions m -> m Bool
corsAllowCredentials :: m Bool
  -- ^ Whether or not to allow exposing the response when the omit credentials
  -- flag is unset.

  , forall (m :: * -> *). CORSOptions m -> m (HashSet (CI ByteString))
corsExposeHeaders :: m (HashSet.HashSet (CI S.ByteString))
  -- ^ A list of headers that are exposed to clients. This allows clients to
  -- read the values of these headers, if the response includes them.

  , forall (m :: * -> *). CORSOptions m -> m (HashSet HashableMethod)
corsAllowedMethods :: m (HashSet.HashSet HashableMethod)
  -- ^ A list of request methods that are allowed.

  , forall (m :: * -> *).
CORSOptions m -> HashSet ByteString -> m (HashSet ByteString)
corsAllowedHeaders :: HashSet.HashSet S.ByteString -> m (HashSet.HashSet S.ByteString)
  -- ^ An action to determine which of the request headers are allowed.
  -- This action is supplied the parsed contents of
  -- @Access-Control-Request-Headers@.
  }

-- | Liberal default options. Specifies that:
--
-- * All origins may make cross-origin requests
-- * @allow-credentials@ is true.
-- * No extra headers beyond simple headers are exposed.
-- * @GET@, @POST@, @PUT@, @DELETE@ and @HEAD@ are all allowed.
-- * All request headers are allowed.
--
-- All options are determined unconditionally.
defaultOptions :: Monad m => CORSOptions m
defaultOptions :: forall (m :: * -> *). Monad m => CORSOptions m
defaultOptions = CORSOptions
  { corsAllowOrigin :: m OriginList
corsAllowOrigin = forall (m :: * -> *) a. Monad m => a -> m a
return OriginList
Everywhere
  , corsAllowCredentials :: m Bool
corsAllowCredentials = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
  , corsExposeHeaders :: m (HashSet (CI ByteString))
corsExposeHeaders = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. HashSet a
HashSet.empty
  , corsAllowedMethods :: m (HashSet HashableMethod)
corsAllowedMethods = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! HashSet HashableMethod
defaultAllowedMethods
  , corsAllowedHeaders :: HashSet ByteString -> m (HashSet ByteString)
corsAllowedHeaders = forall (m :: * -> *) a. Monad m => a -> m a
return
  }

defaultAllowedMethods :: HashSet.HashSet HashableMethod
defaultAllowedMethods :: HashSet HashableMethod
defaultAllowedMethods = forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Method -> HashableMethod
HashableMethod
        [ Method
Snap.GET, Method
Snap.POST, Method
Snap.PUT, Method
Snap.DELETE, Method
Snap.HEAD ]


-- | Apply CORS headers to a specific request. This is useful if you only have
-- a single action that needs CORS headers, and you don't want to pay for
-- conditional checks on every request.
--
-- You should note that 'applyCORS' needs to be used before you add any
-- 'Snap.method' combinators. For example, the following won't do what you want:
--
-- > method POST $ applyCORS defaultOptions $ myHandler
--
-- This fails to work as CORS requires an @OPTIONS@ request in the preflighting
-- stage, but this would get filtered out. Instead, use
--
-- > applyCORS defaultOptions $ method POST $ myHandler
applyCORS :: Snap.MonadSnap m => CORSOptions m -> m () -> m ()
applyCORS :: forall (m :: * -> *). MonadSnap m => CORSOptions m -> m () -> m ()
applyCORS CORSOptions m
options m ()
m =
  (forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> Maybe URI
decodeOrigin forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CI ByteString -> m (Maybe ByteString)
getHeader CI ByteString
"Origin") forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall b a. b -> (a -> b) -> Maybe a -> b
maybe m ()
m URI -> m ()
corsRequestFrom

 where
  corsRequestFrom :: URI -> m ()
corsRequestFrom URI
origin = do
    OriginList
originList <- forall (m :: * -> *). CORSOptions m -> m OriginList
corsAllowOrigin CORSOptions m
options
    if URI
origin URI -> OriginList -> Bool
`inOriginList` OriginList
originList
       then forall (m :: * -> *) a. MonadSnap m => Method -> m a -> m a
Snap.method Method
Snap.OPTIONS (forall {a}. Show a => a -> m ()
preflightRequestFrom URI
origin)
              forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall {a}. Show a => a -> m ()
handleRequestFrom URI
origin
       else m ()
m

  preflightRequestFrom :: a -> m ()
preflightRequestFrom a
origin = do
    Maybe HashableMethod
maybeMethod <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (String -> HashableMethod
parseMethod forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
S.unpack) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                     CI ByteString -> m (Maybe ByteString)
getHeader CI ByteString
"Access-Control-Request-Method"

    case Maybe HashableMethod
maybeMethod of
      Maybe HashableMethod
Nothing -> m ()
m

      Just HashableMethod
method -> do
        HashSet HashableMethod
allowedMethods <- forall (m :: * -> *). CORSOptions m -> m (HashSet HashableMethod)
corsAllowedMethods CORSOptions m
options

        if HashableMethod
method forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`HashSet.member` HashSet HashableMethod
allowedMethods
          then do
            Maybe (HashSet ByteString)
maybeHeaders <-
              forall a. a -> Maybe a -> a
fromMaybe (forall a. a -> Maybe a
Just forall a. HashSet a
HashSet.empty) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> Maybe (HashSet ByteString)
splitHeaders
                forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CI ByteString -> m (Maybe ByteString)
getHeader CI ByteString
"Access-Control-Request-Headers"

            case Maybe (HashSet ByteString)
maybeHeaders of
              Maybe (HashSet ByteString)
Nothing -> m ()
m
              Just HashSet ByteString
headers -> do
                HashSet ByteString
allowedHeaders <- forall (m :: * -> *).
CORSOptions m -> HashSet ByteString -> m (HashSet ByteString)
corsAllowedHeaders CORSOptions m
options HashSet ByteString
headers

                if Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall a. HashSet a -> Bool
HashSet.null forall a b. (a -> b) -> a -> b
$
                     HashSet ByteString
headers forall a. (Eq a, Hashable a) => HashSet a -> HashSet a -> HashSet a
`HashSet.difference` HashSet ByteString
allowedHeaders
                   then m ()
m
                   else do
                     forall {m :: * -> *} {a}. (MonadSnap m, Show a) => a -> m ()
addAccessControlAllowOrigin a
origin
                     m ()
addAccessControlAllowCredentials

                     forall {m :: * -> *} {a}.
MonadSnap m =>
CI ByteString -> (a -> ByteString) -> [a] -> m ()
commaSepHeader
                       CI ByteString
"Access-Control-Allow-Headers"
                       forall a. a -> a
id (forall a. HashSet a -> [a]
HashSet.toList HashSet ByteString
allowedHeaders)

                     forall {m :: * -> *} {a}.
MonadSnap m =>
CI ByteString -> (a -> ByteString) -> [a] -> m ()
commaSepHeader
                       CI ByteString
"Access-Control-Allow-Methods"
                       (String -> ByteString
S.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show) (forall a. HashSet a -> [a]
HashSet.toList HashSet HashableMethod
allowedMethods)

          else m ()
m

  handleRequestFrom :: a -> m ()
handleRequestFrom a
origin = do
    forall {m :: * -> *} {a}. (MonadSnap m, Show a) => a -> m ()
addAccessControlAllowOrigin a
origin
    m ()
addAccessControlAllowCredentials

    HashSet (CI ByteString)
exposeHeaders <- forall (m :: * -> *). CORSOptions m -> m (HashSet (CI ByteString))
corsExposeHeaders CORSOptions m
options
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall a. HashSet a -> Bool
HashSet.null HashSet (CI ByteString)
exposeHeaders) forall a b. (a -> b) -> a -> b
$
      forall {m :: * -> *} {a}.
MonadSnap m =>
CI ByteString -> (a -> ByteString) -> [a] -> m ()
commaSepHeader
        CI ByteString
"Access-Control-Expose-Headers"
        forall s. CI s -> s
CI.original (forall a. HashSet a -> [a]
HashSet.toList HashSet (CI ByteString)
exposeHeaders)

    m ()
m

  addAccessControlAllowOrigin :: a -> m ()
addAccessControlAllowOrigin a
origin =
    forall {m :: * -> *}.
MonadSnap m =>
CI ByteString -> ByteString -> m ()
addHeader CI ByteString
"Access-Control-Allow-Origin"
              (Text -> ByteString
encodeUtf8 forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show a
origin)

  addAccessControlAllowCredentials :: m ()
addAccessControlAllowCredentials = do
    Bool
allowCredentials <- forall (m :: * -> *). CORSOptions m -> m Bool
corsAllowCredentials CORSOptions m
options
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
allowCredentials) forall a b. (a -> b) -> a -> b
$
      forall {m :: * -> *}.
MonadSnap m =>
CI ByteString -> ByteString -> m ()
addHeader CI ByteString
"Access-Control-Allow-Credentials" ByteString
"true"

  decodeOrigin :: S.ByteString -> Maybe URI
  decodeOrigin :: ByteString -> Maybe URI
decodeOrigin = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap URI -> URI
simplifyURI forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Maybe URI
parseURI forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
Text.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
decodeUtf8

  addHeader :: CI ByteString -> ByteString -> m ()
addHeader CI ByteString
k ByteString
v = forall (m :: * -> *). MonadSnap m => (Response -> Response) -> m ()
Snap.modifyResponse (forall a. HasHeaders a => CI ByteString -> ByteString -> a -> a
Snap.addHeader CI ByteString
k ByteString
v)

  commaSepHeader :: CI ByteString -> (a -> ByteString) -> [a] -> m ()
commaSepHeader CI ByteString
k a -> ByteString
f [a]
vs =
    case [a]
vs of
      [] -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
      [a]
_  -> forall {m :: * -> *}.
MonadSnap m =>
CI ByteString -> ByteString -> m ()
addHeader CI ByteString
k forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
S.intercalate ByteString
", " (forall a b. (a -> b) -> [a] -> [b]
map a -> ByteString
f [a]
vs)

  getHeader :: CI ByteString -> m (Maybe ByteString)
getHeader = forall (m :: * -> *) a. MonadSnap m => (Request -> a) -> m a
Snap.getsRequest forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasHeaders a => CI ByteString -> a -> Maybe ByteString
Snap.getHeader

  splitHeaders :: ByteString -> Maybe (HashSet ByteString)
splitHeaders = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const forall a. Maybe a
Nothing) (forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList) forall b c a. (b -> c) -> (a -> b) -> a -> c
.
    forall a. Parser a -> ByteString -> Either String a
Attoparsec.parseOnly Parser [ByteString]
pTokens

mkOriginSet :: [URI] -> OriginSet
mkOriginSet :: [URI] -> OriginSet
mkOriginSet = HashSet HashableURI -> OriginSet
OriginSet forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
.
              forall a b. (a -> b) -> [a] -> [b]
map (URI -> HashableURI
HashableURI forall b c a. (b -> c) -> (a -> b) -> a -> c
. URI -> URI
simplifyURI)

simplifyURI :: URI -> URI
simplifyURI :: URI -> URI
simplifyURI URI
uri = URI
uri { uriAuthority :: Maybe URIAuth
uriAuthority =
                          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap URIAuth -> URIAuth
simplifyURIAuth (URI -> Maybe URIAuth
uriAuthority URI
uri)
                       , uriPath :: String
uriPath = String
""
                       , uriQuery :: String
uriQuery = String
""
                       , uriFragment :: String
uriFragment = String
""
                       }
 where simplifyURIAuth :: URIAuth -> URIAuth
simplifyURIAuth URIAuth
auth = URIAuth
auth { uriUserInfo :: String
uriUserInfo = String
"" }

--------------------------------------------------------------------------------
parseMethod :: String -> HashableMethod
parseMethod :: String -> HashableMethod
parseMethod String
"GET"     = Method -> HashableMethod
HashableMethod Method
Snap.GET
parseMethod String
"POST"    = Method -> HashableMethod
HashableMethod Method
Snap.POST
parseMethod String
"HEAD"    = Method -> HashableMethod
HashableMethod Method
Snap.HEAD
parseMethod String
"PUT"     = Method -> HashableMethod
HashableMethod Method
Snap.PUT
parseMethod String
"DELETE"  = Method -> HashableMethod
HashableMethod Method
Snap.DELETE
parseMethod String
"TRACE"   = Method -> HashableMethod
HashableMethod Method
Snap.TRACE
parseMethod String
"OPTIONS" = Method -> HashableMethod
HashableMethod Method
Snap.OPTIONS
parseMethod String
"CONNECT" = Method -> HashableMethod
HashableMethod Method
Snap.CONNECT
parseMethod String
"PATCH"   = Method -> HashableMethod
HashableMethod Method
Snap.PATCH
parseMethod String
s         = Method -> HashableMethod
HashableMethod forall a b. (a -> b) -> a -> b
$ ByteString -> Method
Snap.Method (String -> ByteString
S.pack String
s)

--------------------------------------------------------------------------------
-- | A @newtype@ over 'URI' with a 'Hashable' instance.
newtype HashableURI = HashableURI URI
  deriving (HashableURI -> HashableURI -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HashableURI -> HashableURI -> Bool
$c/= :: HashableURI -> HashableURI -> Bool
== :: HashableURI -> HashableURI -> Bool
$c== :: HashableURI -> HashableURI -> Bool
Eq)

instance Show HashableURI where
  show :: HashableURI -> String
show (HashableURI URI
u) = forall a. Show a => a -> String
show URI
u

instance Hashable HashableURI where
  hashWithSalt :: Int -> HashableURI -> Int
hashWithSalt Int
s (HashableURI (URI String
scheme Maybe URIAuth
authority String
path String
query String
fragment)) =
    Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
    String
scheme forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap URIAuth -> Int
hashAuthority Maybe URIAuth
authority forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
    String
path forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
    String
query forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
    String
fragment

   where
    hashAuthority :: URIAuth -> Int
hashAuthority (URIAuth String
userInfo String
regName String
port) =
          Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
          String
userInfo forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
          String
regName forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
          String
port

inOriginList :: URI -> OriginList -> Bool
URI
_ inOriginList :: URI -> OriginList -> Bool
`inOriginList` OriginList
Nowhere = Bool
False
URI
_ `inOriginList` OriginList
Everywhere = Bool
True
URI
origin `inOriginList` (Origins (OriginSet HashSet HashableURI
xs)) =
  URI -> HashableURI
HashableURI URI
origin forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`HashSet.member` HashSet HashableURI
xs


--------------------------------------------------------------------------------
newtype HashableMethod = HashableMethod Snap.Method
  deriving (HashableMethod -> HashableMethod -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HashableMethod -> HashableMethod -> Bool
$c/= :: HashableMethod -> HashableMethod -> Bool
== :: HashableMethod -> HashableMethod -> Bool
$c== :: HashableMethod -> HashableMethod -> Bool
Eq)

instance Hashable HashableMethod where
  hashWithSalt :: Int -> HashableMethod -> Int
hashWithSalt Int
s (HashableMethod Method
Snap.GET)        = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
0 :: Int)
  hashWithSalt Int
s (HashableMethod Method
Snap.HEAD)       = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
1 :: Int)
  hashWithSalt Int
s (HashableMethod Method
Snap.POST)       = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
2 :: Int)
  hashWithSalt Int
s (HashableMethod Method
Snap.PUT)        = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
3 :: Int)
  hashWithSalt Int
s (HashableMethod Method
Snap.DELETE)     = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
4 :: Int)
  hashWithSalt Int
s (HashableMethod Method
Snap.TRACE)      = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
5 :: Int)
  hashWithSalt Int
s (HashableMethod Method
Snap.OPTIONS)    = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
6 :: Int)
  hashWithSalt Int
s (HashableMethod Method
Snap.CONNECT)    = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
7 :: Int)
  hashWithSalt Int
s (HashableMethod Method
Snap.PATCH)      = Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
8 :: Int)
  hashWithSalt Int
s (HashableMethod (Snap.Method ByteString
m)) =
    Int
s forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
9 :: Int) forall a. Hashable a => Int -> a -> Int
`hashWithSalt` ByteString
m

instance Show HashableMethod where
  show :: HashableMethod -> String
show (HashableMethod Method
m) = forall a. Show a => a -> String
show Method
m