--------------------------------------------------------------------------------
-- SAML2 Middleware for WAI                                                   --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

-- | SAML2 signatures.
module Network.Wai.SAML2.Signature (
    CanonicalisationMethod(..),
    SignatureMethod(..),
    DigestMethod(..),
    SignedInfo(..),
    Reference(..),
    Signature(..)
) where 

--------------------------------------------------------------------------------

import qualified Data.ByteString as BS
import qualified Data.Text as T 
import Data.Text.Encoding

import Text.XML.Cursor

import Network.Wai.SAML2.XML

--------------------------------------------------------------------------------

-- | Enumerates XML canonicalisation methods.
data CanonicalisationMethod 
    -- | Original C14N 1.0 specification.
    = C14N_1_0 
    -- | Exclusive C14N 1.0 specification.
    | C14N_EXC_1_0
    -- | C14N 1.1 specification.
    | C14N_1_1
    deriving (CanonicalisationMethod -> CanonicalisationMethod -> Bool
(CanonicalisationMethod -> CanonicalisationMethod -> Bool)
-> (CanonicalisationMethod -> CanonicalisationMethod -> Bool)
-> Eq CanonicalisationMethod
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CanonicalisationMethod -> CanonicalisationMethod -> Bool
$c/= :: CanonicalisationMethod -> CanonicalisationMethod -> Bool
== :: CanonicalisationMethod -> CanonicalisationMethod -> Bool
$c== :: CanonicalisationMethod -> CanonicalisationMethod -> Bool
Eq, Int -> CanonicalisationMethod -> ShowS
[CanonicalisationMethod] -> ShowS
CanonicalisationMethod -> String
(Int -> CanonicalisationMethod -> ShowS)
-> (CanonicalisationMethod -> String)
-> ([CanonicalisationMethod] -> ShowS)
-> Show CanonicalisationMethod
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CanonicalisationMethod] -> ShowS
$cshowList :: [CanonicalisationMethod] -> ShowS
show :: CanonicalisationMethod -> String
$cshow :: CanonicalisationMethod -> String
showsPrec :: Int -> CanonicalisationMethod -> ShowS
$cshowsPrec :: Int -> CanonicalisationMethod -> ShowS
Show)

instance FromXML CanonicalisationMethod where 
    parseXML :: Cursor -> m CanonicalisationMethod
parseXML Cursor
cursor = 
        case [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Algorithm" Cursor
cursor of
            Text
"http://www.w3.org/2001/10/xml-exc-c14n#" -> CanonicalisationMethod -> m CanonicalisationMethod
forall (f :: * -> *) a. Applicative f => a -> f a
pure CanonicalisationMethod
C14N_EXC_1_0
            Text
_ -> String -> m CanonicalisationMethod
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Not a valid CanonicalisationMethod"

-- | Enumerates signature methods.
data SignatureMethod 
    -- | RSA with SHA256 digest
    = RSA_SHA256
    deriving (SignatureMethod -> SignatureMethod -> Bool
(SignatureMethod -> SignatureMethod -> Bool)
-> (SignatureMethod -> SignatureMethod -> Bool)
-> Eq SignatureMethod
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SignatureMethod -> SignatureMethod -> Bool
$c/= :: SignatureMethod -> SignatureMethod -> Bool
== :: SignatureMethod -> SignatureMethod -> Bool
$c== :: SignatureMethod -> SignatureMethod -> Bool
Eq, Int -> SignatureMethod -> ShowS
[SignatureMethod] -> ShowS
SignatureMethod -> String
(Int -> SignatureMethod -> ShowS)
-> (SignatureMethod -> String)
-> ([SignatureMethod] -> ShowS)
-> Show SignatureMethod
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SignatureMethod] -> ShowS
$cshowList :: [SignatureMethod] -> ShowS
show :: SignatureMethod -> String
$cshow :: SignatureMethod -> String
showsPrec :: Int -> SignatureMethod -> ShowS
$cshowsPrec :: Int -> SignatureMethod -> ShowS
Show)

instance FromXML SignatureMethod where 
    parseXML :: Cursor -> m SignatureMethod
parseXML Cursor
cursor = case [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Algorithm" Cursor
cursor of
        Text
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" -> SignatureMethod -> m SignatureMethod
forall (f :: * -> *) a. Applicative f => a -> f a
pure SignatureMethod
RSA_SHA256
        Text
_ -> String -> m SignatureMethod
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Not a valid SignatureMethod"

--------------------------------------------------------------------------------

-- | Enumerates digest methods.
data DigestMethod
    -- | SHA256
    = DigestSHA256
    deriving (DigestMethod -> DigestMethod -> Bool
(DigestMethod -> DigestMethod -> Bool)
-> (DigestMethod -> DigestMethod -> Bool) -> Eq DigestMethod
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DigestMethod -> DigestMethod -> Bool
$c/= :: DigestMethod -> DigestMethod -> Bool
== :: DigestMethod -> DigestMethod -> Bool
$c== :: DigestMethod -> DigestMethod -> Bool
Eq, Int -> DigestMethod -> ShowS
[DigestMethod] -> ShowS
DigestMethod -> String
(Int -> DigestMethod -> ShowS)
-> (DigestMethod -> String)
-> ([DigestMethod] -> ShowS)
-> Show DigestMethod
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DigestMethod] -> ShowS
$cshowList :: [DigestMethod] -> ShowS
show :: DigestMethod -> String
$cshow :: DigestMethod -> String
showsPrec :: Int -> DigestMethod -> ShowS
$cshowsPrec :: Int -> DigestMethod -> ShowS
Show)

instance FromXML DigestMethod where 
    parseXML :: Cursor -> m DigestMethod
parseXML Cursor
cursor =  case [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Algorithm" Cursor
cursor of
        Text
"http://www.w3.org/2001/04/xmlenc#sha256" -> DigestMethod -> m DigestMethod
forall (f :: * -> *) a. Applicative f => a -> f a
pure DigestMethod
DigestSHA256
        Text
_ -> String -> m DigestMethod
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Not a valid DigestMethod"

-- | Represents a reference to some entity along with a digest of it.
data Reference = Reference {
    -- | The URI of the entity that is referenced.
    Reference -> Text
referenceURI :: !T.Text,
    -- | The method that was used to calculate the digest for the 
    -- entity that is referenced.
    Reference -> DigestMethod
referenceDigestMethod :: !DigestMethod,
    -- | The digest of the entity that was calculated by the IdP.
    Reference -> ByteString
referenceDigestValue :: !BS.ByteString
} deriving (Reference -> Reference -> Bool
(Reference -> Reference -> Bool)
-> (Reference -> Reference -> Bool) -> Eq Reference
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Reference -> Reference -> Bool
$c/= :: Reference -> Reference -> Bool
== :: Reference -> Reference -> Bool
$c== :: Reference -> Reference -> Bool
Eq, Int -> Reference -> ShowS
[Reference] -> ShowS
Reference -> String
(Int -> Reference -> ShowS)
-> (Reference -> String)
-> ([Reference] -> ShowS)
-> Show Reference
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Reference] -> ShowS
$cshowList :: [Reference] -> ShowS
show :: Reference -> String
$cshow :: Reference -> String
showsPrec :: Int -> Reference -> ShowS
$cshowsPrec :: Int -> Reference -> ShowS
Show)

instance FromXML Reference where
    parseXML :: Cursor -> m Reference
parseXML Cursor
cursor = do 
        -- the reference starts with a #, drop it
        let uri :: Text
uri = Int -> Text -> Text
T.drop Int
1 (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"URI" Cursor
cursor

        DigestMethod
digestMethod <- String -> [DigestMethod] -> m DigestMethod
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"DigestMethod is required" ([DigestMethod] -> m DigestMethod)
-> [DigestMethod] -> m DigestMethod
forall a b. (a -> b) -> a -> b
$
            Cursor
cursor Cursor -> (Cursor -> [DigestMethod]) -> [DigestMethod]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"DigestMethod") 
            Axis -> (Cursor -> [DigestMethod]) -> Cursor -> [DigestMethod]
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Cursor -> [DigestMethod]
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        let digestValue :: ByteString
digestValue = Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$
                Cursor
cursor Cursor -> (Cursor -> [Text]) -> [Text]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"DigestValue") Axis -> (Cursor -> [Text]) -> Cursor -> [Text]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content

        Reference -> m Reference
forall (f :: * -> *) a. Applicative f => a -> f a
pure Reference :: Text -> DigestMethod -> ByteString -> Reference
Reference{
            referenceURI :: Text
referenceURI = Text
uri,
            referenceDigestMethod :: DigestMethod
referenceDigestMethod = DigestMethod
digestMethod,
            referenceDigestValue :: ByteString
referenceDigestValue = ByteString
digestValue
        }

--------------------------------------------------------------------------------

-- | Represents references to some entities for which the IdP has calculated
-- digests. The 'SignedInfo' component is then signed by the IdP.
data SignedInfo = SignedInfo {
    -- | The XML canonicalisation method used.
    SignedInfo -> CanonicalisationMethod
signedInfoCanonicalisationMethod :: !CanonicalisationMethod,
    -- | The method used to compute the signature for the referenced entity.
    SignedInfo -> SignatureMethod
signedInfoSignatureMethod :: !SignatureMethod,
    -- | The reference to some entity, along with a digest.
    SignedInfo -> Reference
signedInfoReference :: !Reference
} deriving (SignedInfo -> SignedInfo -> Bool
(SignedInfo -> SignedInfo -> Bool)
-> (SignedInfo -> SignedInfo -> Bool) -> Eq SignedInfo
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SignedInfo -> SignedInfo -> Bool
$c/= :: SignedInfo -> SignedInfo -> Bool
== :: SignedInfo -> SignedInfo -> Bool
$c== :: SignedInfo -> SignedInfo -> Bool
Eq, Int -> SignedInfo -> ShowS
[SignedInfo] -> ShowS
SignedInfo -> String
(Int -> SignedInfo -> ShowS)
-> (SignedInfo -> String)
-> ([SignedInfo] -> ShowS)
-> Show SignedInfo
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SignedInfo] -> ShowS
$cshowList :: [SignedInfo] -> ShowS
show :: SignedInfo -> String
$cshow :: SignedInfo -> String
showsPrec :: Int -> SignedInfo -> ShowS
$cshowsPrec :: Int -> SignedInfo -> ShowS
Show)

instance FromXML SignedInfo where 
    parseXML :: Cursor -> m SignedInfo
parseXML Cursor
cursor = do 
        CanonicalisationMethod
canonicalisationMethod <- 
                String -> [CanonicalisationMethod] -> m CanonicalisationMethod
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"CanonicalizationMethod is required"
              ([CanonicalisationMethod] -> m CanonicalisationMethod)
-> [CanonicalisationMethod] -> m CanonicalisationMethod
forall a b. (a -> b) -> a -> b
$ Cursor
cursor 
             Cursor
-> (Cursor -> [CanonicalisationMethod]) -> [CanonicalisationMethod]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"CanonicalizationMethod") 
            Axis
-> (Cursor -> [CanonicalisationMethod])
-> Cursor
-> [CanonicalisationMethod]
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Cursor -> [CanonicalisationMethod]
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        SignatureMethod
signatureMethod <- 
                String -> [SignatureMethod] -> m SignatureMethod
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"SignatureMethod is required" 
              ([SignatureMethod] -> m SignatureMethod)
-> [SignatureMethod] -> m SignatureMethod
forall a b. (a -> b) -> a -> b
$ Cursor
cursor 
             Cursor -> (Cursor -> [SignatureMethod]) -> [SignatureMethod]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"SignatureMethod")
            Axis
-> (Cursor -> [SignatureMethod]) -> Cursor -> [SignatureMethod]
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Cursor -> [SignatureMethod]
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        Reference
reference <- 
                String -> [Reference] -> m Reference
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"Reference is required" 
              ([Reference] -> m Reference) -> [Reference] -> m Reference
forall a b. (a -> b) -> a -> b
$ Cursor
cursor 
             Cursor -> (Cursor -> [Reference]) -> [Reference]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"Reference")
            Axis -> (Cursor -> [Reference]) -> Cursor -> [Reference]
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Cursor -> [Reference]
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        SignedInfo -> m SignedInfo
forall (f :: * -> *) a. Applicative f => a -> f a
pure SignedInfo :: CanonicalisationMethod
-> SignatureMethod -> Reference -> SignedInfo
SignedInfo{
            signedInfoCanonicalisationMethod :: CanonicalisationMethod
signedInfoCanonicalisationMethod = CanonicalisationMethod
canonicalisationMethod,
            signedInfoSignatureMethod :: SignatureMethod
signedInfoSignatureMethod = SignatureMethod
signatureMethod,
            signedInfoReference :: Reference
signedInfoReference = Reference
reference
        }

-- | Represents response signatures.
data Signature = Signature {
    -- | Information about the data for which the IdP has computed digests.
    Signature -> SignedInfo
signatureInfo :: !SignedInfo,
    -- | The signature of the 'SignedInfo' value.
    Signature -> ByteString
signatureValue :: !BS.ByteString
} deriving (Signature -> Signature -> Bool
(Signature -> Signature -> Bool)
-> (Signature -> Signature -> Bool) -> Eq Signature
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Signature -> Signature -> Bool
$c/= :: Signature -> Signature -> Bool
== :: Signature -> Signature -> Bool
$c== :: Signature -> Signature -> Bool
Eq, Int -> Signature -> ShowS
[Signature] -> ShowS
Signature -> String
(Int -> Signature -> ShowS)
-> (Signature -> String)
-> ([Signature] -> ShowS)
-> Show Signature
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Signature] -> ShowS
$cshowList :: [Signature] -> ShowS
show :: Signature -> String
$cshow :: Signature -> String
showsPrec :: Int -> Signature -> ShowS
$cshowsPrec :: Int -> Signature -> ShowS
Show)

instance FromXML Signature where 
    parseXML :: Cursor -> m Signature
parseXML Cursor
cursor = do 
        SignedInfo
info <- String -> [SignedInfo] -> m SignedInfo
forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"SignedInfo is required" ([SignedInfo] -> m SignedInfo) -> [SignedInfo] -> m SignedInfo
forall a b. (a -> b) -> a -> b
$ 
            Cursor
cursor Cursor -> (Cursor -> [SignedInfo]) -> [SignedInfo]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"SignedInfo") Axis -> (Cursor -> [SignedInfo]) -> Cursor -> [SignedInfo]
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Cursor -> [SignedInfo]
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        let value :: ByteString
value = Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$
                Cursor
cursor Cursor -> (Cursor -> [Text]) -> [Text]
forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"SignatureValue") Axis -> (Cursor -> [Text]) -> Cursor -> [Text]
forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content

        Signature -> m Signature
forall (f :: * -> *) a. Applicative f => a -> f a
pure Signature :: SignedInfo -> ByteString -> Signature
Signature{
            signatureInfo :: SignedInfo
signatureInfo = SignedInfo
info,
            signatureValue :: ByteString
signatureValue = ByteString
value
        }

--------------------------------------------------------------------------------