module Network.Wai.SAML2 (
Result(..),
saml2Callback,
assertionKey,
errorKey,
saml2Vault,
relayStateKey,
module Network.Wai.SAML2.Config,
module Network.Wai.SAML2.Error,
module Network.Wai.SAML2.Assertion
) where
import qualified Data.ByteString as BS
import Data.Functor ((<&>))
import Data.Maybe (fromMaybe)
import qualified Data.Vault.Lazy as V
import Network.Wai
import Network.Wai.Parse
import Network.Wai.SAML2.Config
import Network.Wai.SAML2.Validation
import Network.Wai.SAML2.Assertion
import Network.Wai.SAML2.Error
import qualified Network.Wai.SAML2.Response as SAML2
import System.IO.Unsafe (unsafePerformIO)
isPOST :: Request -> Bool
isPOST :: Request -> Bool
isPOST = (Method -> Method -> Bool
forall a. Eq a => a -> a -> Bool
==Method
"POST") (Method -> Bool) -> (Request -> Method) -> Request -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Method
requestMethod
saml2Callback :: SAML2Config
-> (Either SAML2Error Result -> Middleware)
-> Middleware
saml2Callback :: SAML2Config
-> (Either SAML2Error Result -> Middleware) -> Middleware
saml2Callback SAML2Config
cfg Either SAML2Error Result -> Middleware
callback Application
app Request
req Response -> IO ResponseReceived
sendResponse = do
let path :: Method
path = Request -> Method
rawPathInfo Request
req
if Method
path Method -> Method -> Bool
forall a. Eq a => a -> a -> Bool
== SAML2Config -> Method
saml2AssertionPath SAML2Config
cfg Bool -> Bool -> Bool
&& Request -> Bool
isPOST Request
req then do
let bodyOpts :: ParseRequestBodyOptions
bodyOpts = Int -> ParseRequestBodyOptions -> ParseRequestBodyOptions
setMaxRequestNumFiles Int
0
(ParseRequestBodyOptions -> ParseRequestBodyOptions)
-> ParseRequestBodyOptions -> ParseRequestBodyOptions
forall a b. (a -> b) -> a -> b
$ Int64 -> ParseRequestBodyOptions -> ParseRequestBodyOptions
setMaxRequestFileSize Int64
0
(ParseRequestBodyOptions -> ParseRequestBodyOptions)
-> ParseRequestBodyOptions -> ParseRequestBodyOptions
forall a b. (a -> b) -> a -> b
$ ParseRequestBodyOptions
defaultParseRequestBodyOptions
([Param]
body, [File ByteString]
_) <- ParseRequestBodyOptions
-> BackEnd ByteString -> Request -> IO ([Param], [File ByteString])
forall y.
ParseRequestBodyOptions
-> BackEnd y -> Request -> IO ([Param], [File y])
parseRequestBodyEx ParseRequestBodyOptions
bodyOpts BackEnd ByteString
forall (m :: * -> *) ignored1 ignored2.
Monad m =>
ignored1 -> ignored2 -> m Method -> m ByteString
lbsBackEnd Request
req
case Method -> [Param] -> Maybe Method
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Method
"SAMLResponse" [Param]
body of
Just Method
val -> do
let rs :: Maybe Method
rs = Method -> [Param] -> Maybe Method
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Method
"RelayState" [Param]
body
Either SAML2Error Result
result <- SAML2Config
-> Method -> IO (Either SAML2Error (Assertion, Response))
validateResponse SAML2Config
cfg Method
val IO (Either SAML2Error (Assertion, Response))
-> (Either SAML2Error (Assertion, Response)
-> Either SAML2Error Result)
-> IO (Either SAML2Error Result)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&>
((Assertion, Response) -> Result)
-> Either SAML2Error (Assertion, Response)
-> Either SAML2Error Result
forall a b. (a -> b) -> Either SAML2Error a -> Either SAML2Error b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Assertion
assertion, Response
response) ->
Result{
assertion :: Assertion
assertion = Assertion
assertion,
relayState :: Maybe Method
relayState = Maybe Method
rs,
response :: Response
response = Response
response
})
Either SAML2Error Result -> Middleware
callback Either SAML2Error Result
result Application
app Request
req Response -> IO ResponseReceived
sendResponse
Maybe Method
Nothing -> Either SAML2Error Result -> Middleware
callback (SAML2Error -> Either SAML2Error Result
forall a b. a -> Either a b
Left SAML2Error
InvalidRequest) Application
app Request
req Response -> IO ResponseReceived
sendResponse
else Application
app Request
req Response -> IO ResponseReceived
sendResponse
assertionKey :: V.Key Assertion
assertionKey :: Key Assertion
assertionKey = IO (Key Assertion) -> Key Assertion
forall a. IO a -> a
unsafePerformIO IO (Key Assertion)
forall a. IO (Key a)
V.newKey
relayStateKey :: V.Key BS.ByteString
relayStateKey :: Key Method
relayStateKey = IO (Key Method) -> Key Method
forall a. IO a -> a
unsafePerformIO IO (Key Method)
forall a. IO (Key a)
V.newKey
errorKey :: V.Key SAML2Error
errorKey :: Key SAML2Error
errorKey = IO (Key SAML2Error) -> Key SAML2Error
forall a. IO a -> a
unsafePerformIO IO (Key SAML2Error)
forall a. IO (Key a)
V.newKey
saml2Vault :: SAML2Config -> Middleware
saml2Vault :: SAML2Config -> Middleware
saml2Vault SAML2Config
cfg = SAML2Config
-> (Either SAML2Error Result -> Middleware) -> Middleware
saml2Callback SAML2Config
cfg Either SAML2Error Result -> Middleware
forall {t} {t}.
Either SAML2Error Result
-> (Request -> t -> t) -> Request -> t -> t
callback
where callback :: Either SAML2Error Result
-> (Request -> t -> t) -> Request -> t -> t
callback (Left SAML2Error
err) Request -> t -> t
app Request
req t
sendResponse = do
Request -> t -> t
app Request
req{
vault = V.insert errorKey err (vault req)
} t
sendResponse
callback (Right Result
result) Request -> t -> t
app Request
req t
sendResponse = do
let mRelayState :: Maybe Method
mRelayState = Result -> Maybe Method
relayState Result
result
let vlt :: Vault
vlt = Request -> Vault
vault Request
req
Request -> t -> t
app Request
req{
vault = V.insert assertionKey (assertion result)
$ fromMaybe vlt $ mRelayState >>= \Method
rs ->
Vault -> Maybe Vault
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vault -> Maybe Vault) -> Vault -> Maybe Vault
forall a b. (a -> b) -> a -> b
$ Key Method -> Method -> Vault -> Vault
forall a. Key a -> a -> Vault -> Vault
V.insert Key Method
relayStateKey Method
rs Vault
vlt
} t
sendResponse
data Result = Result {
Result -> Maybe Method
relayState :: !(Maybe BS.ByteString),
Result -> Assertion
assertion :: !Assertion,
Result -> Response
response :: !SAML2.Response
} deriving (Result -> Result -> Bool
(Result -> Result -> Bool)
-> (Result -> Result -> Bool) -> Eq Result
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Result -> Result -> Bool
== :: Result -> Result -> Bool
$c/= :: Result -> Result -> Bool
/= :: Result -> Result -> Bool
Eq, Int -> Result -> ShowS
[Result] -> ShowS
Result -> String
(Int -> Result -> ShowS)
-> (Result -> String) -> ([Result] -> ShowS) -> Show Result
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Result -> ShowS
showsPrec :: Int -> Result -> ShowS
$cshow :: Result -> String
show :: Result -> String
$cshowList :: [Result] -> ShowS
showList :: [Result] -> ShowS
Show)