{-# LANGUAGE OverloadedStrings #-} -- | Add (cross-origin resource sharing) -- headers to a Snap application. CORS headers can be added either conditionally -- or unconditionally to the entire site, or you can apply CORS headers to a -- single route. -- -- To use in a snaplet, simply use 'wrapSite': -- -- @ -- wrapSite $ applyCORS defaultOptions -- @ module Snap.Util.CORS ( -- * Applying CORS to a specific response applyCORS -- * Option Specification , CORSOptions(..) , defaultOptions -- ** Origin lists , OriginList(..) , OriginSet, mkOriginSet, origins -- * Internals , HashableURI(..), HashableMethod (..) ) where import Control.Applicative import Control.Monad (join, when) import Data.CaseInsensitive (CI) import Data.Hashable (Hashable(..)) import Data.Maybe (fromMaybe) import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Network.URI (URI (..), URIAuth (..), parseURI) import qualified Data.Attoparsec.ByteString.Char8 as Attoparsec import qualified Data.ByteString.Char8 as S import qualified Data.CaseInsensitive as CI import qualified Data.HashSet as HashSet import qualified Data.Text as Text import qualified Snap.Core as Snap import Snap.Internal.Parsing (pTokens) -- | A set of origins. RFC 6454 specifies that origins are a scheme, host and -- port, so the 'OriginSet' wrapper around a 'HashSet.HashSet' ensures that each -- 'URI' constists of nothing more than this. newtype OriginSet = OriginSet { origins :: HashSet.HashSet HashableURI } -- | Used to specify the contents of the @Access-Control-Allow-Origin@ header. data OriginList = Everywhere -- ^ Allow any origin to access this resource. Corresponds to -- @Access-Control-Allow-Origin: *@ | Nowhere -- ^ Do not allow cross-origin requests | Origins OriginSet -- ^ Allow cross-origin requests from these origins. -- | Specify the options to use when building CORS headers for a response. Most -- of these options are 'Snap.Handler' actions to allow you to conditionally -- determine the setting of each header. data CORSOptions m = CORSOptions { corsAllowOrigin :: m OriginList -- ^ Which origins are allowed to make cross-origin requests. , corsAllowCredentials :: m Bool -- ^ Whether or not to allow exposing the response when the omit credentials -- flag is unset. , corsExposeHeaders :: m (HashSet.HashSet (CI S.ByteString)) -- ^ A list of headers that are exposed to clients. This allows clients to -- read the values of these headers, if the response includes them. , corsAllowedMethods :: m (HashSet.HashSet HashableMethod) -- ^ A list of request methods that are allowed. , corsAllowedHeaders :: HashSet.HashSet S.ByteString -> m (HashSet.HashSet S.ByteString) -- ^ An action to determine which of the request headers are allowed. -- This action is supplied the parsed contents of -- @Access-Control-Request-Headers@. } -- | Liberal default options. Specifies that: -- -- * All origins may make cross-origin requests -- * @allow-credentials@ is true. -- * No extra headers beyond simple headers are exposed. -- * @GET@, @POST@, @PUT@, @DELETE@ and @HEAD@ are all allowed. -- * All request headers are allowed. -- -- All options are determined unconditionally. defaultOptions :: Monad m => CORSOptions m defaultOptions = CORSOptions { corsAllowOrigin = return Everywhere , corsAllowCredentials = return True , corsExposeHeaders = return HashSet.empty , corsAllowedMethods = return $! defaultAllowedMethods , corsAllowedHeaders = return } defaultAllowedMethods :: HashSet.HashSet HashableMethod defaultAllowedMethods = HashSet.fromList $ map HashableMethod [ Snap.GET, Snap.POST, Snap.PUT, Snap.DELETE, Snap.HEAD ] -- | Apply CORS headers to a specific request. This is useful if you only have -- a single action that needs CORS headers, and you don't want to pay for -- conditional checks on every request. -- -- You should note that 'applyCORS' needs to be used before you add any -- 'Snap.method' combinators. For example, the following won't do what you want: -- -- > method POST $ applyCORS defaultOptions $ myHandler -- -- This fails to work as CORS requires an @OPTIONS@ request in the preflighting -- stage, but this would get filtered out. Instead, use -- -- > applyCORS defaultOptions $ method POST $ myHandler applyCORS :: Snap.MonadSnap m => CORSOptions m -> m () -> m () applyCORS options m = (join . fmap decodeOrigin <$> getHeader "Origin") >>= maybe m corsRequestFrom where corsRequestFrom origin = do originList <- corsAllowOrigin options if origin `inOriginList` originList then Snap.method Snap.OPTIONS (preflightRequestFrom origin) <|> handleRequestFrom origin else m preflightRequestFrom origin = do maybeMethod <- fmap (parseMethod . S.unpack) <$> getHeader "Access-Control-Request-Method" case maybeMethod of Nothing -> m Just method -> do allowedMethods <- corsAllowedMethods options if method `HashSet.member` allowedMethods then do maybeHeaders <- fromMaybe (Just HashSet.empty) . fmap splitHeaders <$> getHeader "Access-Control-Request-Headers" case maybeHeaders of Nothing -> m Just headers -> do allowedHeaders <- corsAllowedHeaders options headers if not $ HashSet.null $ headers `HashSet.difference` allowedHeaders then m else do addAccessControlAllowOrigin origin addAccessControlAllowCredentials commaSepHeader "Access-Control-Allow-Headers" id (HashSet.toList allowedHeaders) commaSepHeader "Access-Control-Allow-Methods" (S.pack . show) (HashSet.toList allowedMethods) else m handleRequestFrom origin = do addAccessControlAllowOrigin origin addAccessControlAllowCredentials exposeHeaders <- corsExposeHeaders options when (not $ HashSet.null exposeHeaders) $ commaSepHeader "Access-Control-Expose-Headers" CI.original (HashSet.toList exposeHeaders) m addAccessControlAllowOrigin origin = addHeader "Access-Control-Allow-Origin" (encodeUtf8 $ Text.pack $ show origin) addAccessControlAllowCredentials = do allowCredentials <- corsAllowCredentials options when (allowCredentials) $ addHeader "Access-Control-Allow-Credentials" "true" decodeOrigin :: S.ByteString -> Maybe URI decodeOrigin = fmap simplifyURI . parseURI . Text.unpack . decodeUtf8 addHeader k v = Snap.modifyResponse (Snap.addHeader k v) commaSepHeader k f vs = case vs of [] -> return () _ -> addHeader k $ S.intercalate ", " (map f vs) getHeader = Snap.getsRequest . Snap.getHeader splitHeaders = either (const Nothing) (Just . HashSet.fromList) . Attoparsec.parseOnly pTokens mkOriginSet :: [URI] -> OriginSet mkOriginSet = OriginSet . HashSet.fromList . map (HashableURI . simplifyURI) simplifyURI :: URI -> URI simplifyURI uri = uri { uriAuthority = fmap simplifyURIAuth (uriAuthority uri) , uriPath = "" , uriQuery = "" , uriFragment = "" } where simplifyURIAuth auth = auth { uriUserInfo = "" } -------------------------------------------------------------------------------- parseMethod :: String -> HashableMethod parseMethod "GET" = HashableMethod Snap.GET parseMethod "POST" = HashableMethod Snap.POST parseMethod "HEAD" = HashableMethod Snap.HEAD parseMethod "PUT" = HashableMethod Snap.PUT parseMethod "DELETE" = HashableMethod Snap.DELETE parseMethod "TRACE" = HashableMethod Snap.TRACE parseMethod "OPTIONS" = HashableMethod Snap.OPTIONS parseMethod "CONNECT" = HashableMethod Snap.CONNECT parseMethod "PATCH" = HashableMethod Snap.PATCH parseMethod s = HashableMethod $ Snap.Method (S.pack s) -------------------------------------------------------------------------------- -- | A @newtype@ over 'URI' with a 'Hashable' instance. newtype HashableURI = HashableURI URI deriving (Eq) instance Show HashableURI where show (HashableURI u) = show u instance Hashable HashableURI where hashWithSalt s (HashableURI (URI scheme authority path query fragment)) = s `hashWithSalt` scheme `hashWithSalt` fmap hashAuthority authority `hashWithSalt` path `hashWithSalt` query `hashWithSalt` fragment where hashAuthority (URIAuth userInfo regName port) = s `hashWithSalt` userInfo `hashWithSalt` regName `hashWithSalt` port inOriginList :: URI -> OriginList -> Bool _ `inOriginList` Nowhere = False _ `inOriginList` Everywhere = True origin `inOriginList` (Origins (OriginSet xs)) = HashableURI origin `HashSet.member` xs -------------------------------------------------------------------------------- newtype HashableMethod = HashableMethod Snap.Method deriving (Eq) instance Hashable HashableMethod where hashWithSalt s (HashableMethod Snap.GET) = s `hashWithSalt` (0 :: Int) hashWithSalt s (HashableMethod Snap.HEAD) = s `hashWithSalt` (1 :: Int) hashWithSalt s (HashableMethod Snap.POST) = s `hashWithSalt` (2 :: Int) hashWithSalt s (HashableMethod Snap.PUT) = s `hashWithSalt` (3 :: Int) hashWithSalt s (HashableMethod Snap.DELETE) = s `hashWithSalt` (4 :: Int) hashWithSalt s (HashableMethod Snap.TRACE) = s `hashWithSalt` (5 :: Int) hashWithSalt s (HashableMethod Snap.OPTIONS) = s `hashWithSalt` (6 :: Int) hashWithSalt s (HashableMethod Snap.CONNECT) = s `hashWithSalt` (7 :: Int) hashWithSalt s (HashableMethod Snap.PATCH) = s `hashWithSalt` (8 :: Int) hashWithSalt s (HashableMethod (Snap.Method m)) = s `hashWithSalt` (9 :: Int) `hashWithSalt` m instance Show HashableMethod where show (HashableMethod m) = show m