module Snap.CORS
(
wrapCORS
, wrapCORSWithOptions
, applyCORS
, CORSOptions(..)
, defaultOptions
, OriginList(..)
, OriginSet, mkOriginSet, origins
, HashableURI(..)
) where
import Control.Applicative
import Control.Monad (guard, mzero, void, when)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Maybe (MaybeT(..))
import Data.Hashable (Hashable(..))
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Network.URI (URI (..), URIAuth (..), parseURI)
import qualified Data.HashSet as HashSet
import qualified Data.Text as Text
import qualified Snap
newtype OriginSet = OriginSet { origins :: HashSet.HashSet HashableURI }
data OriginList
= Everywhere
| Nowhere
| Origins OriginSet
data CORSOptions m = CORSOptions
{ corsAllowOrigin :: m OriginList
, corsAllowCredentials :: m Bool
}
defaultOptions :: Monad m => CORSOptions m
defaultOptions = CORSOptions
{ corsAllowOrigin = return Everywhere
, corsAllowCredentials = return True
}
wrapCORS :: Snap.Initializer b v ()
wrapCORS = wrapCORSWithOptions defaultOptions
wrapCORSWithOptions :: CORSOptions (Snap.Handler b v) -> Snap.Initializer b v ()
wrapCORSWithOptions options = Snap.wrapSite (applyCORS options >>)
applyCORS :: Snap.MonadSnap m => CORSOptions m -> m ()
applyCORS options = void $ runMaybeT $ do
origin <- MaybeT $ Snap.getsRequest (Snap.getHeader "Origin")
originUri <- MaybeT $ pure $
fmap simplifyURI $ parseURI $ Text.unpack $ decodeUtf8 origin
originList <- lift $ corsAllowOrigin options
case originList of
Everywhere -> return ()
Nowhere -> mzero
(Origins (OriginSet xs)) ->
guard (HashableURI originUri `HashSet.member` xs)
lift $ do
addHeader "Access-Control-Allow-Origin"
(encodeUtf8 $ Text.pack $ show originUri)
allowCredentials <- corsAllowCredentials options
when (allowCredentials) $
addHeader "Access-Control-Allow-Credentials" "true"
where
addHeader k v = Snap.modifyResponse (Snap.addHeader k v)
mkOriginSet :: [URI] -> OriginSet
mkOriginSet = OriginSet . HashSet.fromList . map (HashableURI . simplifyURI)
simplifyURI :: URI -> URI
simplifyURI uri = uri { uriAuthority = fmap simplifyURIAuth (uriAuthority uri)
, uriPath = ""
, uriQuery = ""
, uriFragment = ""
}
where simplifyURIAuth auth = auth { uriUserInfo = "" }
newtype HashableURI = HashableURI URI
deriving (Eq, Show)
instance Hashable HashableURI where
hashWithSalt s (HashableURI (URI scheme authority path query fragment)) =
s `hashWithSalt`
scheme `hashWithSalt`
fmap hashAuthority authority `hashWithSalt`
path `hashWithSalt`
query `hashWithSalt`
fragment
where
hashAuthority (URIAuth userInfo regName port) =
s `hashWithSalt`
userInfo `hashWithSalt`
regName `hashWithSalt`
port