--------------------------------------------------------------------------------
-- SAML2 Middleware for WAI                                                   --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

-- | Types to reprsent SAML2 responses.
module Network.Wai.SAML2.Response (
    -- * SAML2 responses
    Response(..),
    removeSignature,
    extractSignedInfo,
    extractPrefixList,

    -- * Re-exports
    module Network.Wai.SAML2.StatusCode,
    module Network.Wai.SAML2.Signature
) where

--------------------------------------------------------------------------------

import Data.Maybe (listToMaybe)
import qualified Data.Text as T
import Data.Time

import Text.XML
import Text.XML.Cursor

import Network.Wai.SAML2.Assertion
import Network.Wai.SAML2.XML
import Network.Wai.SAML2.XML.Encrypted
import Network.Wai.SAML2.StatusCode
import Network.Wai.SAML2.Signature

--------------------------------------------------------------------------------

-- | Represents SAML2 responses.

-- Reference [StatusResponseType]
data Response = Response {
    -- | The intended destination of this response.
    Response -> Text
responseDestination :: !T.Text,
    -- | The ID of the request this responds corresponds to, if any.
    --
    -- @since 0.4
    Response -> Maybe Text
responseInResponseTo :: !(Maybe T.Text),
    -- | The unique ID of the response.
    Response -> Text
responseId :: !T.Text,
    -- | The timestamp when the response was issued.
    Response -> UTCTime
responseIssueInstant :: !UTCTime,
    -- | The SAML version.
    Response -> Text
responseVersion :: !T.Text,
    -- | The name of the issuer.
    Response -> Text
responseIssuer :: !T.Text,
    -- | The status of the response.
    Response -> StatusCode
responseStatusCode :: !StatusCode,
    -- | The response signature.
    Response -> Signature
responseSignature :: !Signature,
    -- | The unencrypted assertion.
    --
    -- @since 0.4
    Response -> Maybe Assertion
responseAssertion :: !(Maybe Assertion),
    -- | The encrypted assertion.
    --
    -- @since 0.4
    Response -> Maybe EncryptedAssertion
responseEncryptedAssertion :: !(Maybe EncryptedAssertion)
} deriving (Response -> Response -> Bool
(Response -> Response -> Bool)
-> (Response -> Response -> Bool) -> Eq Response
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Response -> Response -> Bool
== :: Response -> Response -> Bool
$c/= :: Response -> Response -> Bool
/= :: Response -> Response -> Bool
Eq, Int -> Response -> ShowS
[Response] -> ShowS
Response -> String
(Int -> Response -> ShowS)
-> (Response -> String) -> ([Response] -> ShowS) -> Show Response
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Response -> ShowS
showsPrec :: Int -> Response -> ShowS
$cshow :: Response -> String
show :: Response -> String
$cshowList :: [Response] -> ShowS
showList :: [Response] -> ShowS
Show)

instance FromXML Response where
    -- Reference [StatusResponseType]
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m Response
parseXML Cursor
cursor = do
        UTCTime
issueInstant <- Text -> m UTCTime
forall (m :: * -> *). MonadFail m => Text -> m UTCTime
parseUTCTime
                      (Text -> m UTCTime) -> Text -> m UTCTime
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.concat
                      ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"IssueInstant" Cursor
cursor

        StatusCode
statusCode <- case Cursor -> Maybe StatusCode
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
forall (m :: * -> *). MonadFail m => Cursor -> m StatusCode
parseXML Cursor
cursor of
            Maybe StatusCode
Nothing -> String -> m StatusCode
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Invalid status code"
            Just StatusCode
sc -> StatusCode -> m StatusCode
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StatusCode
sc

        let assertion :: Maybe Assertion
assertion = [Assertion] -> Maybe Assertion
forall a. [a] -> Maybe a
listToMaybe
                    ([Assertion] -> Maybe Assertion) -> [Assertion] -> Maybe Assertion
forall a b. (a -> b) -> a -> b
$ ( Cursor
cursor
                    Cursor -> (Cursor -> [Cursor]) -> [Cursor]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/  Name -> Cursor -> [Cursor]
element (Text -> Name
saml2Name Text
"Assertion")
                    ) [Cursor] -> (Cursor -> [Assertion]) -> [Assertion]
forall a b. [a] -> (a -> [b]) -> [b]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Cursor -> [Assertion]
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
forall (m :: * -> *). MonadFail m => Cursor -> m Assertion
parseXML

        let encAssertion :: Maybe EncryptedAssertion
encAssertion = [EncryptedAssertion] -> Maybe EncryptedAssertion
forall a. [a] -> Maybe a
listToMaybe
                    ([EncryptedAssertion] -> Maybe EncryptedAssertion)
-> [EncryptedAssertion] -> Maybe EncryptedAssertion
forall a b. (a -> b) -> a -> b
$ ( Cursor
cursor
                    Cursor -> (Cursor -> [Cursor]) -> [Cursor]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/  Name -> Cursor -> [Cursor]
element (Text -> Name
saml2Name Text
"EncryptedAssertion")
                    ) [Cursor]
-> (Cursor -> [EncryptedAssertion]) -> [EncryptedAssertion]
forall a b. [a] -> (a -> [b]) -> [b]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Cursor -> [EncryptedAssertion]
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
forall (m :: * -> *). MonadFail m => Cursor -> m EncryptedAssertion
parseXML

        Signature
signature <- String -> [Cursor] -> m Cursor
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"Signature is required" (
            Cursor
cursor Cursor -> (Cursor -> [Cursor]) -> [Cursor]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Cursor -> [Cursor]
element (Text -> Name
dsName Text
"Signature") ) m Cursor -> (Cursor -> m Signature) -> m Signature
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Cursor -> m Signature
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
forall (m :: * -> *). MonadFail m => Cursor -> m Signature
parseXML

        Response -> m Response
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Response{
            responseDestination :: Text
responseDestination = [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Destination" Cursor
cursor,
            responseId :: Text
responseId = [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"ID" Cursor
cursor,
            responseInResponseTo :: Maybe Text
responseInResponseTo = [Text] -> Maybe Text
forall a. [a] -> Maybe a
listToMaybe ([Text] -> Maybe Text) -> [Text] -> Maybe Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"InResponseTo" Cursor
cursor,
            responseIssueInstant :: UTCTime
responseIssueInstant = UTCTime
issueInstant,
            responseVersion :: Text
responseVersion = [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Version" Cursor
cursor,
            responseIssuer :: Text
responseIssuer = [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$
                Cursor
cursor Cursor -> (Cursor -> [Text]) -> [Text]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Cursor -> [Cursor]
element (Text -> Name
saml2Name Text
"Issuer") (Cursor -> [Cursor]) -> (Cursor -> [Text]) -> Cursor -> [Text]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content,
            responseStatusCode :: StatusCode
responseStatusCode = StatusCode
statusCode,
            responseSignature :: Signature
responseSignature = Signature
signature,
            responseAssertion :: Maybe Assertion
responseAssertion = Maybe Assertion
assertion,
            responseEncryptedAssertion :: Maybe EncryptedAssertion
responseEncryptedAssertion = Maybe EncryptedAssertion
encAssertion
        }

--------------------------------------------------------------------------------

-- | Returns 'True' if the argument is not a @<Signature>@ element.
isNotSignature :: Node -> Bool
isNotSignature :: Node -> Bool
isNotSignature (NodeElement Element
e) = Element -> Name
elementName Element
e Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/= Text -> Name
dsName Text
"Signature"
isNotSignature Node
_ = Bool
True

-- | 'removeSignature' @document@ removes all @<Signature>@ elements from
-- @document@ and returns the resulting document.
removeSignature :: Document -> Document
removeSignature :: Document -> Document
removeSignature (Document Prologue
prologue Element
root [Miscellaneous]
misc) =
    let Element Name
n Map Name Text
attr [Node]
ns = Element
root
    in Prologue -> Element -> [Miscellaneous] -> Document
Document Prologue
prologue (Name -> Map Name Text -> [Node] -> Element
Element Name
n Map Name Text
attr ((Node -> Bool) -> [Node] -> [Node]
forall a. (a -> Bool) -> [a] -> [a]
filter Node -> Bool
isNotSignature [Node]
ns)) [Miscellaneous]
misc

-- | Returns all nodes at @cursor@.
nodes :: MonadFail m => Cursor -> m Node
nodes :: forall (m :: * -> *). MonadFail m => Cursor -> m Node
nodes = Node -> m Node
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Node -> m Node) -> (Cursor -> Node) -> Cursor -> m Node
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cursor -> Node
forall node. Cursor node -> node
node

-- | 'extractSignedInfo' @cursor@ extracts the SignedInfo element from the
-- document reprsented by @cursor@.
extractSignedInfo :: MonadFail m => Cursor -> m Element
extractSignedInfo :: forall (m :: * -> *). MonadFail m => Cursor -> m Element
extractSignedInfo Cursor
cursor = do
    NodeElement Element
signedInfo <- String -> [Cursor] -> m Cursor
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"SignedInfo is required"
                            ( Cursor
cursor
                           Cursor -> (Cursor -> [Cursor]) -> [Cursor]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Cursor -> [Cursor]
element (Text -> Name
dsName Text
"Signature")
                           (Cursor -> [Cursor]) -> (Cursor -> [Cursor]) -> Cursor -> [Cursor]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Cursor -> [Cursor]
element (Text -> Name
dsName Text
"SignedInfo")
                          ) m Cursor -> (Cursor -> m Node) -> m Node
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Cursor -> m Node
forall (m :: * -> *). MonadFail m => Cursor -> m Node
nodes
    Element -> m Element
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Element
signedInfo

-- | Obtain a list of InclusiveNamespaces entries used for exclusive XML canonicalisation.
--
-- @since 0.5
extractPrefixList :: Cursor -> [T.Text]
extractPrefixList :: Cursor -> [Text]
extractPrefixList Cursor
cursor = (Text -> [Text]) -> [Text] -> [Text]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Text -> [Text]
T.words
    ([Text] -> [Text]) -> [Text] -> [Text]
forall a b. (a -> b) -> a -> b
$ (Cursor -> [Text]) -> [Cursor] -> [Text]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Name -> Cursor -> [Text]
attribute Name
"PrefixList")
    ([Cursor] -> [Text]) -> [Cursor] -> [Text]
forall a b. (a -> b) -> a -> b
$ Cursor
cursor
    Cursor -> (Cursor -> [Cursor]) -> [Cursor]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Cursor -> [Cursor]
element (Text -> Name
dsName Text
"Reference")
    (Cursor -> [Cursor]) -> (Cursor -> [Cursor]) -> Cursor -> [Cursor]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Cursor -> [Cursor]
element (Text -> Name
dsName Text
"Transforms")
    (Cursor -> [Cursor]) -> (Cursor -> [Cursor]) -> Cursor -> [Cursor]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Cursor -> [Cursor]
element (Text -> Name
dsName Text
"Transform")
    (Cursor -> [Cursor]) -> (Cursor -> [Cursor]) -> Cursor -> [Cursor]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Cursor -> [Cursor]
element (Text -> Name
ecName Text
"InclusiveNamespaces")

--------------------------------------------------------------------------------

-- Reference [StatusResponseType]
--   Source: https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38
--   Section: 3.2.2 Complex Type StatusResponseType