--------------------------------------------------------------------------------
-- SAML2 Middleware for WAI                                                   --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

{-# LANGUAGE LambdaCase #-}

-- | This module provides a data type for IDP metadata containing certificate,
-- SSO URLs etc.
--
-- @since 0.4
module Network.Wai.SAML2.EntityDescriptor (
    IDPSSODescriptor(..),
    Binding(..)
) where

--------------------------------------------------------------------------------

import qualified Data.ByteString.Base64 as Base64
import qualified Data.X509 as X509
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T

import Network.Wai.SAML2.XML

import Text.XML.Cursor

--------------------------------------------------------------------------------

-- | Describes metadata of an identity provider.
-- See also section 2.4.3 of [Metadata for the OASIS Security Assertion Markup Language (SAML) V2.0](https://docs.oasis-open.org/security/saml/v2.0/saml-metadata-2.0-os.pdf).
data IDPSSODescriptor
    = IDPSSODescriptor {
        -- | IdP Entity ID. 'Network.Wai.SAML2.Config.saml2ExpectedIssuer' should be compared against this identifier
        IDPSSODescriptor -> Text
entityID :: Text
        -- | The X.509 certificate for signed assertions
    ,   IDPSSODescriptor -> SignedExact Certificate
x509Certificate :: X509.SignedExact X509.Certificate
        -- | Supported NameID formats
    ,   IDPSSODescriptor -> [Text]
nameIDFormats :: [Text]
        -- | List of SSO urls corresponding to 'Binding's
    ,   IDPSSODescriptor -> [(Binding, Text)]
singleSignOnServices :: [(Binding, Text)]
    } deriving Int -> IDPSSODescriptor -> ShowS
[IDPSSODescriptor] -> ShowS
IDPSSODescriptor -> String
(Int -> IDPSSODescriptor -> ShowS)
-> (IDPSSODescriptor -> String)
-> ([IDPSSODescriptor] -> ShowS)
-> Show IDPSSODescriptor
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> IDPSSODescriptor -> ShowS
showsPrec :: Int -> IDPSSODescriptor -> ShowS
$cshow :: IDPSSODescriptor -> String
show :: IDPSSODescriptor -> String
$cshowList :: [IDPSSODescriptor] -> ShowS
showList :: [IDPSSODescriptor] -> ShowS
Show

-- | urn:oasis:names:tc:SAML:2.0:bindings
-- https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf
data Binding
    -- | SAML protocol messages are transmitted within the base64-encoded content of an HTML form control
    = HTTPPost
    -- | SAML protocol messages are transmitted within URL parameters
    | HTTPRedirect
    -- | The request and/or response are transmitted by reference using a small stand-in called an artifact
    | HTTPArtifact
    -- | Reverse HTTP Binding for SOAP specification
    | PAOS
    -- | SOAP is a lightweight protocol intended for exchanging structured information in a decentralized, distributed environment
    | SOAP
    -- | SAML protocol messages are encoded into a URL via the DEFLATE compression method
    | URLEncodingDEFLATE
    deriving (Int -> Binding -> ShowS
[Binding] -> ShowS
Binding -> String
(Int -> Binding -> ShowS)
-> (Binding -> String) -> ([Binding] -> ShowS) -> Show Binding
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Binding -> ShowS
showsPrec :: Int -> Binding -> ShowS
$cshow :: Binding -> String
show :: Binding -> String
$cshowList :: [Binding] -> ShowS
showList :: [Binding] -> ShowS
Show, Binding -> Binding -> Bool
(Binding -> Binding -> Bool)
-> (Binding -> Binding -> Bool) -> Eq Binding
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Binding -> Binding -> Bool
== :: Binding -> Binding -> Bool
$c/= :: Binding -> Binding -> Bool
/= :: Binding -> Binding -> Bool
Eq)

instance FromXML IDPSSODescriptor where
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m IDPSSODescriptor
parseXML Cursor
cursor = do
        let entityID :: Text
entityID = [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"entityID" Cursor
cursor
        Cursor
descriptor <- String -> [Cursor] -> m Cursor
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"IDPSSODescriptor is required"
            ([Cursor] -> m Cursor) -> [Cursor] -> m Cursor
forall a b. (a -> b) -> a -> b
$ Cursor
cursor Cursor -> (Cursor -> [Cursor]) -> [Cursor]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Cursor -> [Cursor]
element (Text -> Name
mdName Text
"IDPSSODescriptor")
        Text
rawCertificate <- String -> [Text] -> m Text
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"X509Certificate is required" ([Text] -> m Text) -> [Text] -> m Text
forall a b. (a -> b) -> a -> b
$ Cursor
descriptor
            Cursor -> (Cursor -> [Text]) -> [Text]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Cursor -> [Cursor]
element (Text -> Name
mdName Text
"KeyDescriptor")
            (Cursor -> [Cursor]) -> (Cursor -> [Text]) -> Cursor -> [Text]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Cursor -> [Cursor]
element (Text -> Name
dsName Text
"KeyInfo")
            (Cursor -> [Cursor]) -> (Cursor -> [Text]) -> Cursor -> [Text]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Cursor -> [Cursor]
element (Text -> Name
dsName Text
"X509Data")
            (Cursor -> [Cursor]) -> (Cursor -> [Text]) -> Cursor -> [Text]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Cursor -> [Cursor]
element (Text -> Name
dsName Text
"X509Certificate")
            (Cursor -> [Cursor]) -> (Cursor -> [Text]) -> Cursor -> [Text]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content
        SignedExact Certificate
x509Certificate <- (String -> m (SignedExact Certificate))
-> (SignedExact Certificate -> m (SignedExact Certificate))
-> Either String (SignedExact Certificate)
-> m (SignedExact Certificate)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> m (SignedExact Certificate)
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail SignedExact Certificate -> m (SignedExact Certificate)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            (Either String (SignedExact Certificate)
 -> m (SignedExact Certificate))
-> Either String (SignedExact Certificate)
-> m (SignedExact Certificate)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String (SignedExact Certificate)
forall a.
(Show a, Eq a, ASN1Object a) =>
ByteString -> Either String (SignedExact a)
X509.decodeSignedObject
            (ByteString -> Either String (SignedExact Certificate))
-> ByteString -> Either String (SignedExact Certificate)
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
Base64.decodeLenient
            (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf8 Text
rawCertificate
        let nameIDFormats :: [Text]
nameIDFormats = Cursor
descriptor
                Cursor -> (Cursor -> [Text]) -> [Text]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Cursor -> [Cursor]
element (Text -> Name
mdName Text
"NameIDFormat")
                (Cursor -> [Cursor]) -> (Cursor -> [Text]) -> Cursor -> [Text]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content
        [(Binding, Text)]
singleSignOnServices <- (Cursor -> m (Binding, Text)) -> [Cursor] -> m [(Binding, Text)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Cursor -> m (Binding, Text)
forall (m :: * -> *). MonadFail m => Cursor -> m (Binding, Text)
parseService
            ([Cursor] -> m [(Binding, Text)])
-> [Cursor] -> m [(Binding, Text)]
forall a b. (a -> b) -> a -> b
$ Cursor
descriptor Cursor -> (Cursor -> [Cursor]) -> [Cursor]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Cursor -> [Cursor]
element (Text -> Name
mdName Text
"SingleSignOnService")
        IDPSSODescriptor -> m IDPSSODescriptor
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure IDPSSODescriptor{[(Binding, Text)]
[Text]
Text
SignedExact Certificate
entityID :: Text
x509Certificate :: SignedExact Certificate
nameIDFormats :: [Text]
singleSignOnServices :: [(Binding, Text)]
entityID :: Text
x509Certificate :: SignedExact Certificate
nameIDFormats :: [Text]
singleSignOnServices :: [(Binding, Text)]
..}

-- | `parseService` @cursor@ attempts to parse a pair of a `Binding` value
-- and a location given as a `Text` value from the XML @cursor@.
parseService :: MonadFail m => Cursor -> m (Binding, Text)
parseService :: forall (m :: * -> *). MonadFail m => Cursor -> m (Binding, Text)
parseService Cursor
cursor = do
    Binding
binding <- String -> [Text] -> m Text
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"Binding is required" (Name -> Cursor -> [Text]
attribute Name
"Binding" Cursor
cursor)
        m Text -> (Text -> m Binding) -> m Binding
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Text -> m Binding
forall (m :: * -> *). MonadFail m => Text -> m Binding
parseBinding
    Text
location <- String -> [Text] -> m Text
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"Location is required" ([Text] -> m Text) -> [Text] -> m Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Location" Cursor
cursor
    (Binding, Text) -> m (Binding, Text)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Binding
binding, Text
location)

-- | `parseBinding` @uri@ attempts to parse a `Binding` value from @uri@.
parseBinding :: MonadFail m => Text -> m Binding
parseBinding :: forall (m :: * -> *). MonadFail m => Text -> m Binding
parseBinding = \case
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Artifact" -> Binding -> m Binding
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
HTTPArtifact
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" -> Binding -> m Binding
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
HTTPPost
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" -> Binding -> m Binding
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
HTTPRedirect
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:PAOS" -> Binding -> m Binding
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
PAOS
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:SOAP" -> Binding -> m Binding
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
SOAP
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:URL-Encoding:DEFLATE"
        -> Binding -> m Binding
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
URLEncodingDEFLATE
    Text
other -> String -> m Binding
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m Binding) -> String -> m Binding
forall a b. (a -> b) -> a -> b
$ String
"Unknown Binding: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Text -> String
T.unpack Text
other

--------------------------------------------------------------------------------