{-# LANGUAGE ScopedTypeVariables #-}
module SAML2.Core.Signature
( signSAMLProtocol
, verifySAMLProtocol
, verifySAMLProtocol'
) where
import Control.Exception
import Control.Lens ((^.), (.~))
import Control.Monad (unless)
import qualified Data.ByteString.Lazy as BSL
import Data.List.NonEmpty (NonEmpty((:|)))
import Network.URI (URI(uriFragment), nullURI)
import Text.XML.HXT.DOM.TypeDefs
import SAML2.XML
import qualified SAML2.XML.Canonical as C14N
import qualified SAML2.XML.Signature as DS
import qualified SAML2.Core.Protocols as SAMLP
signSAMLProtocol :: SAMLP.SAMLProtocol a => DS.SigningKey -> a -> IO a
signSAMLProtocol :: SigningKey -> a -> IO a
signSAMLProtocol SigningKey
sk a
m = do
Reference
r <- Reference -> XmlTree -> IO Reference
DS.generateReference Reference :: Maybe ID
-> Maybe AnyURI
-> Maybe AnyURI
-> Maybe Transforms
-> DigestMethod
-> Base64Binary
-> Reference
DS.Reference
{ referenceId :: Maybe ID
DS.referenceId = Maybe ID
forall a. Maybe a
Nothing
, referenceURI :: Maybe AnyURI
DS.referenceURI = AnyURI -> Maybe AnyURI
forall a. a -> Maybe a
Just AnyURI
nullURI{ uriFragment :: ID
uriFragment = Char
'#'Char -> ID -> ID
forall a. a -> [a] -> [a]
:ProtocolType -> ID
SAMLP.protocolID ProtocolType
p }
, referenceType :: Maybe AnyURI
DS.referenceType = Maybe AnyURI
forall a. Maybe a
Nothing
, referenceTransforms :: Maybe Transforms
DS.referenceTransforms = Transforms -> Maybe Transforms
forall a. a -> Maybe a
Just (Transforms -> Maybe Transforms) -> Transforms -> Maybe Transforms
forall a b. (a -> b) -> a -> b
$ List1 Transform -> Transforms
DS.Transforms
(List1 Transform -> Transforms) -> List1 Transform -> Transforms
forall a b. (a -> b) -> a -> b
$ TransformAlgorithm -> Transform
DS.simpleTransform TransformAlgorithm
DS.TransformEnvelopedSignature
Transform -> [Transform] -> List1 Transform
forall a. a -> [a] -> NonEmpty a
:| TransformAlgorithm -> Transform
DS.simpleTransform (CanonicalizationAlgorithm -> TransformAlgorithm
DS.TransformCanonicalization (CanonicalizationAlgorithm -> TransformAlgorithm)
-> CanonicalizationAlgorithm -> TransformAlgorithm
forall a b. (a -> b) -> a -> b
$ Bool -> CanonicalizationAlgorithm
C14N.CanonicalXMLExcl10 Bool
False)
Transform -> [Transform] -> [Transform]
forall a. a -> [a] -> [a]
: []
, referenceDigestMethod :: DigestMethod
DS.referenceDigestMethod = DigestAlgorithm -> DigestMethod
DS.simpleDigest DigestAlgorithm
DS.DigestSHA1
, referenceDigestValue :: Base64Binary
DS.referenceDigestValue = ID -> Base64Binary
forall a. HasCallStack => ID -> a
error ID
"signSAMLProtocol: referenceDigestValue"
} (XmlTree -> IO Reference) -> XmlTree -> IO Reference
forall a b. (a -> b) -> a -> b
$ a -> XmlTree
forall a. XmlPickler a => a -> XmlTree
samlToDoc a
m
Signature
s' <- SigningKey -> SignedInfo -> IO Signature
DS.generateSignature SigningKey
sk (SignedInfo -> IO Signature) -> SignedInfo -> IO Signature
forall a b. (a -> b) -> a -> b
$ SignedInfo
-> (Signature -> SignedInfo) -> Maybe Signature -> SignedInfo
forall b a. b -> (a -> b) -> Maybe a -> b
maybe SignedInfo :: Maybe ID
-> CanonicalizationMethod
-> SignatureMethod
-> List1 Reference
-> SignedInfo
DS.SignedInfo
{ signedInfoId :: Maybe ID
DS.signedInfoId = Maybe ID
forall a. Maybe a
Nothing
, signedInfoCanonicalizationMethod :: CanonicalizationMethod
DS.signedInfoCanonicalizationMethod = CanonicalizationAlgorithm -> CanonicalizationMethod
DS.simpleCanonicalization (CanonicalizationAlgorithm -> CanonicalizationMethod)
-> CanonicalizationAlgorithm -> CanonicalizationMethod
forall a b. (a -> b) -> a -> b
$ Bool -> CanonicalizationAlgorithm
C14N.CanonicalXMLExcl10 Bool
False
, signedInfoSignatureMethod :: SignatureMethod
DS.signedInfoSignatureMethod = SignatureMethod :: IdentifiedURI SignatureAlgorithm
-> Maybe Int -> Nodes -> SignatureMethod
DS.SignatureMethod
{ signatureMethodAlgorithm :: IdentifiedURI SignatureAlgorithm
DS.signatureMethodAlgorithm = SignatureAlgorithm -> IdentifiedURI SignatureAlgorithm
forall b a. a -> Identified b a
Identified (SignatureAlgorithm -> IdentifiedURI SignatureAlgorithm)
-> SignatureAlgorithm -> IdentifiedURI SignatureAlgorithm
forall a b. (a -> b) -> a -> b
$ SigningKey -> SignatureAlgorithm
DS.signingKeySignatureAlgorithm SigningKey
sk
, signatureMethodHMACOutputLength :: Maybe Int
DS.signatureMethodHMACOutputLength = Maybe Int
forall a. Maybe a
Nothing
, signatureMethod :: Nodes
DS.signatureMethod = []
}
, signedInfoReference :: List1 Reference
DS.signedInfoReference = Reference
r Reference -> [Reference] -> List1 Reference
forall a. a -> [a] -> NonEmpty a
:| []
} Signature -> SignedInfo
DS.signatureSignedInfo (Maybe Signature -> SignedInfo) -> Maybe Signature -> SignedInfo
forall a b. (a -> b) -> a -> b
$ ProtocolType -> Maybe Signature
SAMLP.protocolSignature ProtocolType
p
a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> IO a) -> a -> IO a
forall a b. (a -> b) -> a -> b
$ (Maybe Signature -> Identity (Maybe Signature)) -> a -> Identity a
forall a. Signable a => Lens' a (Maybe Signature)
DS.signature' ((Maybe Signature -> Identity (Maybe Signature))
-> a -> Identity a)
-> Maybe Signature -> a -> a
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Signature -> Maybe Signature
forall a. a -> Maybe a
Just Signature
s' (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
m
where
p :: ProtocolType
p = a
m a -> Getting ProtocolType a ProtocolType -> ProtocolType
forall s a. s -> Getting a s a -> a
^. Getting ProtocolType a ProtocolType
forall a. SAMLProtocol a => Lens' a ProtocolType
SAMLP.samlProtocol'
verifySAMLProtocol :: SAMLP.SAMLProtocol a => BSL.ByteString -> IO a
verifySAMLProtocol :: ByteString -> IO a
verifySAMLProtocol ByteString
b = do
XmlTree
x <- IO XmlTree
-> (XmlTree -> IO XmlTree) -> Maybe XmlTree -> IO XmlTree
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (ID -> IO XmlTree
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail ID
"invalid XML") XmlTree -> IO XmlTree
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe XmlTree -> IO XmlTree) -> Maybe XmlTree -> IO XmlTree
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe XmlTree
xmlToDoc ByteString
b
a
m <- (ID -> IO a) -> (a -> IO a) -> Either ID a -> IO a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either ID a -> IO a) -> Either ID a -> IO a
forall a b. (a -> b) -> a -> b
$ XmlTree -> Either ID a
forall a. XmlPickler a => XmlTree -> Either ID a
docToSAML XmlTree
x
Maybe Bool
v <- PublicKeys -> ID -> XmlTree -> IO (Maybe Bool)
DS.verifySignature PublicKeys
forall a. Monoid a => a
mempty (a -> ID
forall a. Signable a => a -> ID
DS.signedID a
m) XmlTree
x
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Maybe Bool -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or Maybe Bool
v) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ID -> IO ()
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail ID
"verifySAMLProtocol: invalid or missing signature"
a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
m
verifySAMLProtocol' :: SAMLP.SAMLProtocol a => DS.PublicKeys -> XmlTree -> IO a
verifySAMLProtocol' :: PublicKeys -> XmlTree -> IO a
verifySAMLProtocol' PublicKeys
pubkeys XmlTree
x = do
a
m <- (ID -> IO a) -> (a -> IO a) -> Either ID a -> IO a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either ID a -> IO a) -> Either ID a -> IO a
forall a b. (a -> b) -> a -> b
$ XmlTree -> Either ID a
forall a. XmlPickler a => XmlTree -> Either ID a
docToSAML XmlTree
x
Either SomeException (Maybe Bool)
v :: Either SomeException (Maybe Bool) <- IO (Maybe Bool) -> IO (Either SomeException (Maybe Bool))
forall e a. Exception e => IO a -> IO (Either e a)
try (IO (Maybe Bool) -> IO (Either SomeException (Maybe Bool)))
-> IO (Maybe Bool) -> IO (Either SomeException (Maybe Bool))
forall a b. (a -> b) -> a -> b
$ PublicKeys -> ID -> XmlTree -> IO (Maybe Bool)
DS.verifySignature PublicKeys
pubkeys (a -> ID
forall a. Signable a => a -> ID
DS.signedID a
m) XmlTree
x
case Either SomeException (Maybe Bool)
v of
Left SomeException
e -> ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail (ID -> IO a) -> ID -> IO a
forall a b. (a -> b) -> a -> b
$ ID
"signature verification failed: " ID -> ID -> ID
forall a. [a] -> [a] -> [a]
++ SomeException -> ID
forall a. Show a => a -> ID
show SomeException
e
Right Maybe Bool
Nothing -> ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail ID
"signature verification failed: no matching key/alg pair."
Right (Just Bool
False) -> ID -> IO a
forall (m :: * -> *) a. MonadFail m => ID -> m a
fail ID
"signature verification failed: verification failed."
Right (Just Bool
True) -> a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
m