--------------------------------------------------------------------------------
-- SAML2 Middleware for WAI                                                   --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

-- | Implements WAI 'Middleware' for SAML2 service providers. Two different
-- interfaces are supported (with equivalent functionality): one which simply
-- stores the outcome of the validation process in the request vault and one
-- which passes the outcome to a callback.
module Network.Wai.SAML2 (
    -- * Callback-based middleware
    --
    -- $callbackBasedMiddleware
    Result(..),
    saml2Callback,

    -- * Vault-based middleware
    --
    -- $vaultBasedMiddleware
    assertionKey,
    errorKey,
    saml2Vault,
    relayStateKey,

    -- * Re-exports
    module Network.Wai.SAML2.Config,
    module Network.Wai.SAML2.Error,
    module Network.Wai.SAML2.Assertion
) where

--------------------------------------------------------------------------------

import qualified Data.ByteString as BS
import Data.Functor ((<&>))
import Data.Maybe (fromMaybe)
import qualified Data.Vault.Lazy as V

import Network.Wai
import Network.Wai.Parse
import Network.Wai.SAML2.Config
import Network.Wai.SAML2.Validation
import Network.Wai.SAML2.Assertion
import Network.Wai.SAML2.Error
import qualified Network.Wai.SAML2.Response as SAML2

import System.IO.Unsafe (unsafePerformIO)

--------------------------------------------------------------------------------

-- | Checks whether the request method of @request@ is @"POST"@.
isPOST :: Request -> Bool
isPOST :: Request -> Bool
isPOST = (Method -> Method -> Bool
forall a. Eq a => a -> a -> Bool
==Method
"POST") (Method -> Bool) -> (Request -> Method) -> Request -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Method
requestMethod

--------------------------------------------------------------------------------

-- $callbackBasedMiddleware
--
-- This 'Middleware' provides a SAML2 service provider (SP) implementation
-- that can be wrapped around an existing WAI 'Application'. The middleware is
-- parameterised over the SAML2 configuration and a callback. If the middleware
-- intercepts a request made to the endpoint given by the SAML2 configuration,
-- the result of validating the SAML2 response contained in the request body
-- will be passed to the callback.
--
-- > saml2Callback cfg callback mainApp
-- >  where callback (Left err) app req sendResponse = do
-- >            -- a POST request was made to the assertion endpoint, but
-- >            -- something went wrong, details of which are provided by
-- >            -- the error: this should probably be logged as it may
-- >            -- indicate that an attack was attempted against the
-- >            -- endpoint, but you *must* not show the error
-- >            -- to the client as it would severely compromise
-- >            -- system security
-- >            --
-- >            -- you may also want to return e.g. a HTTP 400 or 401 status
-- >
-- >        callback (Right result) app req sendResponse = do
-- >            -- a POST request was made to the assertion endpoint and the
-- >            -- SAML2 response was successfully validated:
-- >            -- you *must* check that you have not encountered the
-- >            -- assertion ID before; we assume that there is a
-- >            -- computation tryRetrieveAssertion which looks up
-- >            -- assertions by ID in e.g. a database
-- >            result <- tryRetrieveAssertion (assertionId (assertion result))
-- >
-- >            case result of
-- >                Just something -> -- a replay attack has occurred
-- >                Nothing -> do
-- >                    -- store the assertion id somewhere
-- >                    storeAssertion (assertionId (assertion result))
-- >
-- >                    -- the assertion is valid and you can now e.g.
-- >                    -- retrieve user data from your database
-- >                    -- before proceeding with the request by e.g.
-- >                    -- redirecting them to the main view

-- | 'saml2Callback' @config callback@ produces SAML2 'Middleware' for
-- the given @config@. If the middleware intercepts a request to the
-- endpoint given by @config@, the result will be passed to @callback@.
saml2Callback :: SAML2Config
              -> (Either SAML2Error Result -> Middleware)
              -> Middleware
saml2Callback :: SAML2Config
-> (Either SAML2Error Result -> Middleware) -> Middleware
saml2Callback SAML2Config
cfg Either SAML2Error Result -> Middleware
callback Application
app Request
req Response -> IO ResponseReceived
sendResponse = do
    let path :: Method
path = Request -> Method
rawPathInfo Request
req

    -- check if we need to handle this request
    if Method
path Method -> Method -> Bool
forall a. Eq a => a -> a -> Bool
== SAML2Config -> Method
saml2AssertionPath SAML2Config
cfg Bool -> Bool -> Bool
&& Request -> Bool
isPOST Request
req then do
            -- default request parse options, but do not allow files;
            -- we are not expecting any
            let bodyOpts :: ParseRequestBodyOptions
bodyOpts = Int -> ParseRequestBodyOptions -> ParseRequestBodyOptions
setMaxRequestNumFiles Int
0
                         (ParseRequestBodyOptions -> ParseRequestBodyOptions)
-> ParseRequestBodyOptions -> ParseRequestBodyOptions
forall a b. (a -> b) -> a -> b
$ Int64 -> ParseRequestBodyOptions -> ParseRequestBodyOptions
setMaxRequestFileSize Int64
0
                         (ParseRequestBodyOptions -> ParseRequestBodyOptions)
-> ParseRequestBodyOptions -> ParseRequestBodyOptions
forall a b. (a -> b) -> a -> b
$ ParseRequestBodyOptions
defaultParseRequestBodyOptions

            -- parse the request
            ([Param]
body, [File ByteString]
_) <- ParseRequestBodyOptions
-> BackEnd ByteString -> Request -> IO ([Param], [File ByteString])
forall y.
ParseRequestBodyOptions
-> BackEnd y -> Request -> IO ([Param], [File y])
parseRequestBodyEx ParseRequestBodyOptions
bodyOpts BackEnd ByteString
forall (m :: * -> *) ignored1 ignored2.
Monad m =>
ignored1 -> ignored2 -> m Method -> m ByteString
lbsBackEnd Request
req

            case Method -> [Param] -> Maybe Method
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Method
"SAMLResponse" [Param]
body of
                Just Method
val -> do
                    let rs :: Maybe Method
rs = Method -> [Param] -> Maybe Method
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Method
"RelayState" [Param]
body
                    Either SAML2Error Result
result <- SAML2Config
-> Method -> IO (Either SAML2Error (Assertion, Response))
validateResponse SAML2Config
cfg Method
val IO (Either SAML2Error (Assertion, Response))
-> (Either SAML2Error (Assertion, Response)
    -> Either SAML2Error Result)
-> IO (Either SAML2Error Result)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&>
                                  ((Assertion, Response) -> Result)
-> Either SAML2Error (Assertion, Response)
-> Either SAML2Error Result
forall a b. (a -> b) -> Either SAML2Error a -> Either SAML2Error b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Assertion
assertion, Response
response) ->
                                            Result{
                                                assertion :: Assertion
assertion = Assertion
assertion,
                                                relayState :: Maybe Method
relayState = Maybe Method
rs,
                                                response :: Response
response = Response
response
                                            })

                    -- call the callback
                    Either SAML2Error Result -> Middleware
callback Either SAML2Error Result
result Application
app Request
req Response -> IO ResponseReceived
sendResponse
                -- the request does not contain the expected payload
                Maybe Method
Nothing -> Either SAML2Error Result -> Middleware
callback (SAML2Error -> Either SAML2Error Result
forall a b. a -> Either a b
Left SAML2Error
InvalidRequest) Application
app Request
req Response -> IO ResponseReceived
sendResponse

    -- not one of the paths we need to handle, pass the request on to the
    -- inner application
    else Application
app Request
req Response -> IO ResponseReceived
sendResponse

--------------------------------------------------------------------------------

-- $vaultBasedMiddleware
--
-- This is a simpler-to-use 'Middleware' which stores the outcome of a request
-- made to the assertation endpoint in the request vault. The inner WAI
-- application can then check of the presence of an assertion or an error with
-- 'V.lookup' and 'assertionKey' or 'errorKey' respectively. At most one of
-- the two locations will be populated for a given request, i.e. it is not
-- possible for an assertion to be validated and an error to occur.
--
-- > saml2Vault cfg $ \app req sendResponse -> do
-- >    case V.lookup errorKey (vault req) of
-- >        Just err ->
-- >            -- log the error, but you *must* not show the error
-- >            -- to the client as it would severely compromise
-- >            -- system security
-- >        Nothing -> pure () -- carry on
-- >
-- >    case V.lookup assertionKey (vault req) of
-- >        Nothing -> pure () -- carry on
-- >        Just assertion -> do
-- >            -- a valid assertion was processed by the middleware,
-- >            -- you *must* check that you have not encountered the
-- >            -- assertion ID before; we assume that there is a
-- >            -- computation tryRetrieveAssertion which looks up
-- >            -- assertions by ID in e.g. a database
-- >            result <- tryRetrieveAssertion (assertionId assertion)
-- >
-- >            case result of
-- >                Just something -> -- a replay attack has occurred
-- >                Nothing -> do
-- >                    -- store the assertion id somewhere
-- >                    storeAssertion (assertionId assertion)
-- >
-- >                    -- the assertion is valid

-- | 'assertionKey' is a vault key for retrieving assertions from
-- request vaults if the 'saml2Vault' 'Middleware' is used.
assertionKey :: V.Key Assertion
assertionKey :: Key Assertion
assertionKey = IO (Key Assertion) -> Key Assertion
forall a. IO a -> a
unsafePerformIO IO (Key Assertion)
forall a. IO (Key a)
V.newKey

-- | 'relayStateKey' is a vault key for retrieving the relay state
-- from request vaults if the 'saml2Vault' 'Middleware' is used
-- and the assertion is valid.
relayStateKey :: V.Key BS.ByteString
relayStateKey :: Key Method
relayStateKey = IO (Key Method) -> Key Method
forall a. IO a -> a
unsafePerformIO IO (Key Method)
forall a. IO (Key a)
V.newKey

-- | 'errorKey' is a vault key for retrieving SAML2 errors from request vaults
-- if the 'saml2Vault' 'Middleware' is used.
errorKey :: V.Key SAML2Error
errorKey :: Key SAML2Error
errorKey = IO (Key SAML2Error) -> Key SAML2Error
forall a. IO a -> a
unsafePerformIO IO (Key SAML2Error)
forall a. IO (Key a)
V.newKey

-- | 'saml2Vault' @config@ produces SAML2 'Middleware' for the given @config@.
saml2Vault :: SAML2Config -> Middleware
saml2Vault :: SAML2Config -> Middleware
saml2Vault SAML2Config
cfg = SAML2Config
-> (Either SAML2Error Result -> Middleware) -> Middleware
saml2Callback SAML2Config
cfg Either SAML2Error Result -> Middleware
forall {t} {t}.
Either SAML2Error Result
-> (Request -> t -> t) -> Request -> t -> t
callback
    -- if the middleware intercepts a request containing a SAML2 response at
    -- the configured endpoint, the outcome of processing response will be
    -- passed to this callback: we store the result in the corresponding
    -- entry in the request vault
    where callback :: Either SAML2Error Result
-> (Request -> t -> t) -> Request -> t -> t
callback (Left SAML2Error
err) Request -> t -> t
app Request
req t
sendResponse = do
            Request -> t -> t
app Request
req{
                vault = V.insert errorKey err (vault req)
            } t
sendResponse
          callback (Right Result
result) Request -> t -> t
app Request
req t
sendResponse = do
            let mRelayState :: Maybe Method
mRelayState = Result -> Maybe Method
relayState Result
result
            let vlt :: Vault
vlt = Request -> Vault
vault Request
req

            Request -> t -> t
app Request
req{
                vault = V.insert assertionKey (assertion result)
                      $ fromMaybe vlt $ mRelayState >>= \Method
rs ->
                            Vault -> Maybe Vault
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vault -> Maybe Vault) -> Vault -> Maybe Vault
forall a b. (a -> b) -> a -> b
$ Key Method -> Method -> Vault -> Vault
forall a. Key a -> a -> Vault -> Vault
V.insert Key Method
relayStateKey Method
rs Vault
vlt
            } t
sendResponse

--------------------------------------------------------------------------------

-- | Represents the result of validating a SAML2 response.
data Result = Result {
    -- | An optional relay state, as provided in the POST request.
    Result -> Maybe Method
relayState :: !(Maybe BS.ByteString),
    -- | The assertion obtained from the response that has been validated.
    Result -> Assertion
assertion :: !Assertion,
    -- | The full response obtained from the IdP.
    --
    -- @since 0.4
    Result -> Response
response :: !SAML2.Response
} deriving (Result -> Result -> Bool
(Result -> Result -> Bool)
-> (Result -> Result -> Bool) -> Eq Result
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Result -> Result -> Bool
== :: Result -> Result -> Bool
$c/= :: Result -> Result -> Bool
/= :: Result -> Result -> Bool
Eq, Int -> Result -> ShowS
[Result] -> ShowS
Result -> String
(Int -> Result -> ShowS)
-> (Result -> String) -> ([Result] -> ShowS) -> Show Result
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Result -> ShowS
showsPrec :: Int -> Result -> ShowS
$cshow :: Result -> String
show :: Result -> String
$cshowList :: [Result] -> ShowS
showList :: [Result] -> ShowS
Show)

--------------------------------------------------------------------------------