--------------------------------------------------------------------------------
-- SAML2 Middleware for WAI                                                   --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

-- | Types to reprsent SAML2 responses.
module Network.Wai.SAML2.Response (
    -- * SAML2 responses
    Response(..),
    removeSignature,
    extractSignedInfo,

    -- * Re-exports
    module Network.Wai.SAML2.StatusCode,
    module Network.Wai.SAML2.Signature
) where 

--------------------------------------------------------------------------------

import qualified Data.Text as T
import Data.Time

import Text.XML
import Text.XML.Cursor

import Network.Wai.SAML2.XML
import Network.Wai.SAML2.XML.Encrypted
import Network.Wai.SAML2.StatusCode
import Network.Wai.SAML2.Signature

--------------------------------------------------------------------------------

-- | Represents SAML2 responses.
data Response = Response {
    -- | The intended destination of this response.
    Response -> Text
responseDestination :: !T.Text,
    -- | The unique ID of the response.
    Response -> Text
responseId :: !T.Text,
    -- | The timestamp when the response was issued.
    Response -> UTCTime
responseIssueInstant :: !UTCTime,
    -- | The SAML version.
    Response -> Text
responseVersion :: !T.Text,
    -- | The name of the issuer.
    Response -> Text
responseIssuer :: !T.Text,
    -- | The status of the response.
    Response -> StatusCode
responseStatusCode :: !StatusCode,
    -- | The response signature.
    Response -> Signature
responseSignature :: !Signature,
    -- | The (encrypted) assertion.
    Response -> EncryptedAssertion
responseEncryptedAssertion :: !EncryptedAssertion
} deriving (Response -> Response -> Bool
(Response -> Response -> Bool)
-> (Response -> Response -> Bool) -> Eq Response
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Response -> Response -> Bool
$c/= :: Response -> Response -> Bool
== :: Response -> Response -> Bool
$c== :: Response -> Response -> Bool
Eq, Int -> Response -> ShowS
[Response] -> ShowS
Response -> String
(Int -> Response -> ShowS)
-> (Response -> String) -> ([Response] -> ShowS) -> Show Response
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Response] -> ShowS
$cshowList :: [Response] -> ShowS
show :: Response -> String
$cshow :: Response -> String
showsPrec :: Int -> Response -> ShowS
$cshowsPrec :: Int -> Response -> ShowS
Show)

instance FromXML Response where 
    parseXML :: Cursor -> m Response
parseXML Cursor
cursor = do
        UTCTime
issueInstant <- Text -> m UTCTime
forall (m :: * -> *). MonadFail m => Text -> m UTCTime
parseUTCTime 
                      (Text -> m UTCTime) -> Text -> m UTCTime
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.concat 
                      ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"IssueInstant" Cursor
cursor

        StatusCode
statusCode <- case Cursor -> Maybe StatusCode
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML Cursor
cursor of
            Maybe StatusCode
Nothing -> String -> m StatusCode
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Invalid status code"
            Just StatusCode
sc -> StatusCode -> m StatusCode
forall (f :: * -> *) a. Applicative f => a -> f a
pure StatusCode
sc

        EncryptedAssertion
encAssertion <- String -> [EncryptedAssertion] -> m EncryptedAssertion
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"EncryptedAssertion is required" 
                    ([EncryptedAssertion] -> m EncryptedAssertion)
-> [EncryptedAssertion] -> m EncryptedAssertion
forall a b. (a -> b) -> a -> b
$   Cursor
cursor
                    Cursor -> (Cursor -> [EncryptedAssertion]) -> [EncryptedAssertion]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/  Name -> Axis
element (Text -> Name
saml2Name Text
"EncryptedAssertion")
                    Axis
-> (Cursor -> [EncryptedAssertion])
-> Cursor
-> [EncryptedAssertion]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/  Name -> Axis
element (Text -> Name
xencName Text
"EncryptedData")
                    Axis
-> (Cursor -> [EncryptedAssertion])
-> Cursor
-> [EncryptedAssertion]
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Cursor -> [EncryptedAssertion]
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        Signature
signature <- String -> [Signature] -> m Signature
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"Signature is required" ([Signature] -> m Signature) -> [Signature] -> m Signature
forall a b. (a -> b) -> a -> b
$
            Cursor
cursor Cursor -> (Cursor -> [Signature]) -> [Signature]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"Signature") Axis -> (Cursor -> [Signature]) -> Cursor -> [Signature]
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Cursor -> [Signature]
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        Response -> m Response
forall (f :: * -> *) a. Applicative f => a -> f a
pure Response :: Text
-> Text
-> UTCTime
-> Text
-> Text
-> StatusCode
-> Signature
-> EncryptedAssertion
-> Response
Response{
            responseDestination :: Text
responseDestination = [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Destination" Cursor
cursor,
            responseId :: Text
responseId = [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"ID" Cursor
cursor,
            responseIssueInstant :: UTCTime
responseIssueInstant = UTCTime
issueInstant,
            responseVersion :: Text
responseVersion = [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Version" Cursor
cursor,
            responseIssuer :: Text
responseIssuer = [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ 
                Cursor
cursor Cursor -> (Cursor -> [Text]) -> [Text]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
saml2Name Text
"Issuer") Axis -> (Cursor -> [Text]) -> Cursor -> [Text]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content,
            responseStatusCode :: StatusCode
responseStatusCode = StatusCode
statusCode,
            responseSignature :: Signature
responseSignature = Signature
signature,
            responseEncryptedAssertion :: EncryptedAssertion
responseEncryptedAssertion = EncryptedAssertion
encAssertion
        }
    
--------------------------------------------------------------------------------

-- | Returns 'True' if the argument is not a @<Signature>@ element.
isNotSignature :: Node -> Bool 
isNotSignature :: Node -> Bool
isNotSignature (NodeElement Element
e) = Element -> Name
elementName Element
e Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/= Text -> Name
dsName Text
"Signature" 
isNotSignature Node
_ = Bool
True

-- | 'removeSignature' @document@ removes all @<Signature>@ elements from
-- @document@ and returns the resulting document.
removeSignature :: Document -> Document
removeSignature :: Document -> Document
removeSignature (Document Prologue
prologue Element
root [Miscellaneous]
misc) = 
    let Element Name
n Map Name Text
attr [Node]
ns = Element
root
    in Prologue -> Element -> [Miscellaneous] -> Document
Document Prologue
prologue (Name -> Map Name Text -> [Node] -> Element
Element Name
n Map Name Text
attr ((Node -> Bool) -> [Node] -> [Node]
forall a. (a -> Bool) -> [a] -> [a]
filter Node -> Bool
isNotSignature [Node]
ns)) [Miscellaneous]
misc
          
-- | Returns all nodes at @cursor@.
nodes :: Cursor -> [Node]
nodes :: Cursor -> [Node]
nodes = Node -> [Node]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Node -> [Node]) -> (Cursor -> Node) -> Cursor -> [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cursor -> Node
forall node. Cursor node -> node
node 

-- | 'extractSignedInfo' @cursor@ extracts the SignedInfo element from the 
-- document reprsented by @cursor@.
extractSignedInfo :: MonadFail m => Cursor -> m Element
extractSignedInfo :: Cursor -> m Element
extractSignedInfo Cursor
cursor = do 
    NodeElement Element
signedInfo <- String -> [Node] -> m Node
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"SignedInfo is required" 
                            ([Node] -> m Node) -> [Node] -> m Node
forall a b. (a -> b) -> a -> b
$ Cursor
cursor 
                           Cursor -> (Cursor -> [Node]) -> [Node]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"Signature") 
                           Axis -> (Cursor -> [Node]) -> Cursor -> [Node]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Axis
element (Text -> Name
dsName Text
"SignedInfo") 
                          Axis -> (Cursor -> [Node]) -> Cursor -> [Node]
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Cursor -> [Node]
nodes
    Element -> m Element
forall (f :: * -> *) a. Applicative f => a -> f a
pure Element
signedInfo

--------------------------------------------------------------------------------