module Snap.CORS
(
wrapCORS
, wrapCORSWithOptions
, applyCORS
, CORSOptions(..)
, defaultOptions
, OriginList(..)
, OriginSet, mkOriginSet, origins
, HashableURI(..), HashableMethod (..)
) where
import Control.Applicative
import Control.Monad (join, when)
import Data.CaseInsensitive (CI)
import Data.Hashable (Hashable(..))
import Data.Maybe (fromMaybe)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Network.URI (URI (..), URIAuth (..), parseURI)
import qualified Data.Attoparsec.Combinator as Attoparsec
import qualified Data.Attoparsec.ByteString.Char8 as Attoparsec
import qualified Data.ByteString.Char8 as Char8
import qualified Data.CaseInsensitive as CI
import qualified Data.HashSet as HashSet
import qualified Data.Text as Text
import qualified Snap
newtype OriginSet = OriginSet { origins :: HashSet.HashSet HashableURI }
data OriginList
= Everywhere
| Nowhere
| Origins OriginSet
data CORSOptions m = CORSOptions
{ corsAllowOrigin :: m OriginList
, corsAllowCredentials :: m Bool
, corsExposeHeaders :: m (HashSet.HashSet (CI Char8.ByteString))
, corsAllowedMethods :: m (HashSet.HashSet HashableMethod)
, corsAllowedHeaders :: HashSet.HashSet String -> m (HashSet.HashSet String)
}
defaultOptions :: Monad m => CORSOptions m
defaultOptions = CORSOptions
{ corsAllowOrigin = return Everywhere
, corsAllowCredentials = return True
, corsExposeHeaders = return HashSet.empty
, corsAllowedMethods =
return $ HashSet.fromList $ map HashableMethod
[ Snap.GET, Snap.POST, Snap.PUT, Snap.DELETE, Snap.HEAD ]
, corsAllowedHeaders = return
}
wrapCORS :: Snap.Initializer b v ()
wrapCORS = wrapCORSWithOptions defaultOptions
wrapCORSWithOptions :: CORSOptions (Snap.Handler b v) -> Snap.Initializer b v ()
wrapCORSWithOptions options = Snap.wrapSite (applyCORS options)
applyCORS :: Snap.MonadSnap m => CORSOptions m -> m () -> m ()
applyCORS options m =
(join . fmap decodeOrigin <$> getHeader "Origin") >>= maybe m corsRequestFrom
where
corsRequestFrom origin = do
originList <- corsAllowOrigin options
if origin `inOriginList` originList
then Snap.method Snap.OPTIONS (preflightRequestFrom origin)
<|> handleRequestFrom origin
else m
preflightRequestFrom origin = do
maybeMethod <- fmap (parseMethod . Char8.unpack) <$>
getHeader "Access-Control-Request-Method"
case maybeMethod of
Nothing -> m
Just method -> do
allowedMethods <- corsAllowedMethods options
if method `HashSet.member` allowedMethods
then do
maybeHeaders <-
fromMaybe (Just HashSet.empty) . fmap splitHeaders
<$> getHeader "Access-Control-Request-Headers"
case maybeHeaders of
Nothing -> m
Just headers -> do
allowedHeaders <- corsAllowedHeaders options headers
if not $ HashSet.null $ headers `HashSet.difference` allowedHeaders
then m
else do
addAccessControlAllowOrigin origin
addAccessControlAllowCredentials
commaSepHeader
"Access-Control-Allow-Headers"
Char8.pack (HashSet.toList allowedHeaders)
commaSepHeader
"Access-Control-Allow-Methods"
(Char8.pack . show) (HashSet.toList allowedMethods)
else m
handleRequestFrom origin = do
addAccessControlAllowOrigin origin
addAccessControlAllowCredentials
exposeHeaders <- corsExposeHeaders options
when (not $ HashSet.null exposeHeaders) $
commaSepHeader
"Access-Control-Expose-Headers"
CI.original (HashSet.toList exposeHeaders)
m
addAccessControlAllowOrigin origin =
addHeader "Access-Control-Allow-Origin"
(encodeUtf8 $ Text.pack $ show origin)
addAccessControlAllowCredentials = do
allowCredentials <- corsAllowCredentials options
when (allowCredentials) $
addHeader "Access-Control-Allow-Credentials" "true"
decodeOrigin = fmap simplifyURI . parseURI . Text.unpack . decodeUtf8
addHeader k v = Snap.modifyResponse (Snap.addHeader k v)
commaSepHeader k f vs =
case vs of
[] -> return ()
_ -> addHeader k $ Char8.intercalate ", " (map f vs)
getHeader = Snap.getsRequest . Snap.getHeader
splitHeaders =
let spaces = Attoparsec.many' Attoparsec.space
headerC = Attoparsec.satisfy (not . (`elem`( " ," :: String)))
headerName = Attoparsec.many' headerC
header = spaces *> headerName <* spaces
parser = HashSet.fromList <$> header `Attoparsec.sepBy` (Attoparsec.char ',')
in either (const Nothing) Just . Attoparsec.parseOnly parser
mkOriginSet :: [URI] -> OriginSet
mkOriginSet = OriginSet . HashSet.fromList . map (HashableURI . simplifyURI)
simplifyURI :: URI -> URI
simplifyURI uri = uri { uriAuthority = fmap simplifyURIAuth (uriAuthority uri)
, uriPath = ""
, uriQuery = ""
, uriFragment = ""
}
where simplifyURIAuth auth = auth { uriUserInfo = "" }
parseMethod :: String -> HashableMethod
parseMethod "GET" = HashableMethod Snap.GET
parseMethod "POST" = HashableMethod Snap.POST
parseMethod "HEAD" = HashableMethod Snap.HEAD
parseMethod "PUT" = HashableMethod Snap.PUT
parseMethod "DELETE" = HashableMethod Snap.DELETE
parseMethod "TRACE" = HashableMethod Snap.TRACE
parseMethod "OPTIONS" = HashableMethod Snap.OPTIONS
parseMethod "CONNECT" = HashableMethod Snap.CONNECT
parseMethod "PATCH" = HashableMethod Snap.PATCH
parseMethod s = HashableMethod $ Snap.Method (Char8.pack s)
newtype HashableURI = HashableURI URI
deriving (Eq)
instance Show HashableURI where
show (HashableURI u) = show u
instance Hashable HashableURI where
hashWithSalt s (HashableURI (URI scheme authority path query fragment)) =
s `hashWithSalt`
scheme `hashWithSalt`
fmap hashAuthority authority `hashWithSalt`
path `hashWithSalt`
query `hashWithSalt`
fragment
where
hashAuthority (URIAuth userInfo regName port) =
s `hashWithSalt`
userInfo `hashWithSalt`
regName `hashWithSalt`
port
inOriginList :: URI -> OriginList -> Bool
_ `inOriginList` Nowhere = False
_ `inOriginList` Everywhere = True
origin `inOriginList` (Origins (OriginSet xs)) =
HashableURI origin `HashSet.member` xs
newtype HashableMethod = HashableMethod Snap.Method
deriving (Eq)
instance Hashable HashableMethod where
hashWithSalt s (HashableMethod Snap.GET) = s `hashWithSalt` (0 :: Int)
hashWithSalt s (HashableMethod Snap.HEAD) = s `hashWithSalt` (1 :: Int)
hashWithSalt s (HashableMethod Snap.POST) = s `hashWithSalt` (2 :: Int)
hashWithSalt s (HashableMethod Snap.PUT) = s `hashWithSalt` (3 :: Int)
hashWithSalt s (HashableMethod Snap.DELETE) = s `hashWithSalt` (4 :: Int)
hashWithSalt s (HashableMethod Snap.TRACE) = s `hashWithSalt` (5 :: Int)
hashWithSalt s (HashableMethod Snap.OPTIONS) = s `hashWithSalt` (6 :: Int)
hashWithSalt s (HashableMethod Snap.CONNECT) = s `hashWithSalt` (7 :: Int)
hashWithSalt s (HashableMethod Snap.PATCH) = s `hashWithSalt` (8 :: Int)
hashWithSalt s (HashableMethod (Snap.Method m)) =
s `hashWithSalt` (9 :: Int) `hashWithSalt` m
instance Show HashableMethod where
show (HashableMethod m) = show m