module Network.Wai.SAML2 (
Result(..),
saml2Callback,
assertionKey,
errorKey,
saml2Vault,
relayStateKey,
module Network.Wai.SAML2.Config,
module Network.Wai.SAML2.Error,
module Network.Wai.SAML2.Assertion
) where
import qualified Data.ByteString as BS
import Data.Maybe (fromMaybe)
import qualified Data.Vault.Lazy as V
import Network.Wai
import Network.Wai.Parse
import Network.Wai.SAML2.Config
import Network.Wai.SAML2.Validation
import Network.Wai.SAML2.Assertion
import Network.Wai.SAML2.Error
import System.IO.Unsafe (unsafePerformIO)
isPOST :: Request -> Bool
isPOST :: Request -> Bool
isPOST = (Method -> Method -> Bool
forall a. Eq a => a -> a -> Bool
==Method
"POST") (Method -> Bool) -> (Request -> Method) -> Request -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Method
requestMethod
saml2Callback :: SAML2Config
-> (Either SAML2Error Result -> Middleware)
-> Middleware
saml2Callback :: SAML2Config
-> (Either SAML2Error Result -> Middleware) -> Middleware
saml2Callback SAML2Config
cfg Either SAML2Error Result -> Middleware
callback Application
app Request
req Response -> IO ResponseReceived
sendResponse = do
let path :: Method
path = Request -> Method
rawPathInfo Request
req
if | Method
path Method -> Method -> Bool
forall a. Eq a => a -> a -> Bool
== SAML2Config -> Method
saml2AssertionPath SAML2Config
cfg Bool -> Bool -> Bool
&& Request -> Bool
isPOST Request
req -> do
let bodyOpts :: ParseRequestBodyOptions
bodyOpts = Int -> ParseRequestBodyOptions -> ParseRequestBodyOptions
setMaxRequestNumFiles Int
0
(ParseRequestBodyOptions -> ParseRequestBodyOptions)
-> ParseRequestBodyOptions -> ParseRequestBodyOptions
forall a b. (a -> b) -> a -> b
$ Int64 -> ParseRequestBodyOptions -> ParseRequestBodyOptions
setMaxRequestFileSize Int64
0
(ParseRequestBodyOptions -> ParseRequestBodyOptions)
-> ParseRequestBodyOptions -> ParseRequestBodyOptions
forall a b. (a -> b) -> a -> b
$ ParseRequestBodyOptions
defaultParseRequestBodyOptions
([Param]
body, [File ByteString]
_) <- ParseRequestBodyOptions
-> BackEnd ByteString -> Request -> IO ([Param], [File ByteString])
forall y.
ParseRequestBodyOptions
-> BackEnd y -> Request -> IO ([Param], [File y])
parseRequestBodyEx ParseRequestBodyOptions
bodyOpts BackEnd ByteString
forall (m :: * -> *) ignored1 ignored2.
Monad m =>
ignored1 -> ignored2 -> m Method -> m ByteString
lbsBackEnd Request
req
case Method -> [Param] -> Maybe Method
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Method
"SAMLResponse" [Param]
body of
Just Method
val -> do
Either SAML2Error Assertion
result <- SAML2Config -> Method -> IO (Either SAML2Error Assertion)
validateResponse SAML2Config
cfg Method
val
let rs :: Maybe Method
rs = Method -> [Param] -> Maybe Method
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Method
"RelayState" [Param]
body
Either SAML2Error Result -> Middleware
callback (Maybe Method -> Assertion -> Result
Result Maybe Method
rs (Assertion -> Result)
-> Either SAML2Error Assertion -> Either SAML2Error Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either SAML2Error Assertion
result) Application
app Request
req Response -> IO ResponseReceived
sendResponse
Maybe Method
Nothing -> Either SAML2Error Result -> Middleware
callback (SAML2Error -> Either SAML2Error Result
forall a b. a -> Either a b
Left SAML2Error
InvalidRequest) Application
app Request
req Response -> IO ResponseReceived
sendResponse
| Bool
otherwise -> Application
app Request
req Response -> IO ResponseReceived
sendResponse
assertionKey :: V.Key Assertion
assertionKey :: Key Assertion
assertionKey = IO (Key Assertion) -> Key Assertion
forall a. IO a -> a
unsafePerformIO IO (Key Assertion)
forall a. IO (Key a)
V.newKey
relayStateKey :: V.Key BS.ByteString
relayStateKey :: Key Method
relayStateKey = IO (Key Method) -> Key Method
forall a. IO a -> a
unsafePerformIO IO (Key Method)
forall a. IO (Key a)
V.newKey
errorKey :: V.Key SAML2Error
errorKey :: Key SAML2Error
errorKey = IO (Key SAML2Error) -> Key SAML2Error
forall a. IO a -> a
unsafePerformIO IO (Key SAML2Error)
forall a. IO (Key a)
V.newKey
saml2Vault :: SAML2Config -> Middleware
saml2Vault :: SAML2Config -> Middleware
saml2Vault SAML2Config
cfg = SAML2Config
-> (Either SAML2Error Result -> Middleware) -> Middleware
saml2Callback SAML2Config
cfg Either SAML2Error Result -> Middleware
forall t p.
Either SAML2Error Result
-> (Request -> t -> p) -> Request -> t -> p
callback
where callback :: Either SAML2Error Result
-> (Request -> t -> p) -> Request -> t -> p
callback (Left SAML2Error
err) Request -> t -> p
app Request
req t
sendResponse = do
Request -> t -> p
app Request
req{
vault :: Vault
vault = Key SAML2Error -> SAML2Error -> Vault -> Vault
forall a. Key a -> a -> Vault -> Vault
V.insert Key SAML2Error
errorKey SAML2Error
err (Request -> Vault
vault Request
req)
} t
sendResponse
callback (Right Result
result) Request -> t -> p
app Request
req t
sendResponse = do
let mRelayState :: Maybe Method
mRelayState = Result -> Maybe Method
relayState Result
result
let vlt :: Vault
vlt = Request -> Vault
vault Request
req
Request -> t -> p
app Request
req{
vault :: Vault
vault = Key Assertion -> Assertion -> Vault -> Vault
forall a. Key a -> a -> Vault -> Vault
V.insert Key Assertion
assertionKey (Result -> Assertion
assertion Result
result)
(Vault -> Vault) -> Vault -> Vault
forall a b. (a -> b) -> a -> b
$ Vault -> Maybe Vault -> Vault
forall a. a -> Maybe a -> a
fromMaybe Vault
vlt (Maybe Vault -> Vault) -> Maybe Vault -> Vault
forall a b. (a -> b) -> a -> b
$ Maybe Method
mRelayState Maybe Method -> (Method -> Maybe Vault) -> Maybe Vault
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Method
rs ->
Vault -> Maybe Vault
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vault -> Maybe Vault) -> Vault -> Maybe Vault
forall a b. (a -> b) -> a -> b
$ Key Method -> Method -> Vault -> Vault
forall a. Key a -> a -> Vault -> Vault
V.insert Key Method
relayStateKey Method
rs Vault
vlt
} t
sendResponse
data Result = Result {
Result -> Maybe Method
relayState :: !(Maybe BS.ByteString),
Result -> Assertion
assertion :: !Assertion
} deriving (Result -> Result -> Bool
(Result -> Result -> Bool)
-> (Result -> Result -> Bool) -> Eq Result
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Result -> Result -> Bool
$c/= :: Result -> Result -> Bool
== :: Result -> Result -> Bool
$c== :: Result -> Result -> Bool
Eq, Int -> Result -> ShowS
[Result] -> ShowS
Result -> String
(Int -> Result -> ShowS)
-> (Result -> String) -> ([Result] -> ShowS) -> Show Result
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Result] -> ShowS
$cshowList :: [Result] -> ShowS
show :: Result -> String
$cshow :: Result -> String
showsPrec :: Int -> Result -> ShowS
$cshowsPrec :: Int -> Result -> ShowS
Show)