module Network.Wai.SAML2.Validation (
validateResponse,
decodeResponse,
validateSAMLResponse,
ansiX923
) where
import Control.Exception
import Control.Monad (forM_, when, unless)
import Control.Monad.Except
import Control.Monad.IO.Class (liftIO)
import Crypto.Error
import Crypto.Hash
import qualified Crypto.PubKey.RSA.OAEP as OAEP
import Crypto.PubKey.RSA.PKCS15 as PKCS15
import Crypto.PubKey.RSA.Types (PrivateKey)
import Crypto.Cipher.AES
import Crypto.Cipher.Types
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as BS
import qualified Data.ByteString.Lazy as LBS
import Data.Default.Class
import Data.Time
import Network.Wai.SAML2.XML.Encrypted
import Network.Wai.SAML2.Config
import Network.Wai.SAML2.Error
import Network.Wai.SAML2.XML
import Network.Wai.SAML2.C14N
import Network.Wai.SAML2.Response
import Network.Wai.SAML2.Assertion
import qualified Text.XML as XML
import qualified Text.XML.Cursor as XML
validateResponse :: SAML2Config
-> BS.ByteString
-> IO (Either SAML2Error (Assertion, Response))
validateResponse :: SAML2Config
-> ByteString -> IO (Either SAML2Error (Assertion, Response))
validateResponse SAML2Config
cfg ByteString
responseData = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
UTCTime
now <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
(Document
responseXmlDoc, Response
samlResponse) <- ByteString -> ExceptT SAML2Error IO (Document, Response)
decodeResponse ByteString
responseData
Assertion
assertion <- SAML2Config
-> Document
-> Response
-> UTCTime
-> ExceptT SAML2Error IO Assertion
validateSAMLResponse SAML2Config
cfg Document
responseXmlDoc Response
samlResponse UTCTime
now
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Assertion
assertion, Response
samlResponse)
decodeResponse :: BS.ByteString -> ExceptT SAML2Error IO (XML.Document, Response)
decodeResponse :: ByteString -> ExceptT SAML2Error IO (Document, Response)
decodeResponse ByteString
responseData = do
let resXmlDocData :: ByteString
resXmlDocData = ByteString -> ByteString
BS.decodeLenient ByteString
responseData
Document
responseXmlDoc <- case ParseSettings -> ByteString -> Either SomeException Document
XML.parseLBS ParseSettings
parseSettings (ByteString -> ByteString
LBS.fromStrict ByteString
resXmlDocData) of
Left SomeException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ SomeException -> SAML2Error
InvalidResponseXml SomeException
err
Right Document
responseXmlDoc -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Document
responseXmlDoc
Either IOException Response
resParseResult <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML (Document -> Cursor
XML.fromDocument Document
responseXmlDoc)
case Either IOException Response
resParseResult of
Left IOException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ IOException -> SAML2Error
InvalidResponse IOException
err
Right Response
samlResponse -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Document
responseXmlDoc, Response
samlResponse)
validateSAMLResponse :: SAML2Config
-> XML.Document
-> Response
-> UTCTime
-> ExceptT SAML2Error IO Assertion
validateSAMLResponse :: SAML2Config
-> Document
-> Response
-> UTCTime
-> ExceptT SAML2Error IO Assertion
validateSAMLResponse SAML2Config
cfg Document
responseXmlDoc Response
samlResponse UTCTime
now = do
case StatusCode -> StatusCodeValue
statusCodeValue forall a b. (a -> b) -> a -> b
$ Response -> StatusCode
responseStatusCode Response
samlResponse of
StatusCodeValue
Success -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
StatusCodeValue
_status -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ StatusCode -> SAML2Error
Unsuccessful forall a b. (a -> b) -> a -> b
$ Response -> StatusCode
responseStatusCode Response
samlResponse
let destination :: Text
destination = Response -> Text
responseDestination Response
samlResponse
case SAML2Config -> Maybe Text
saml2ExpectedDestination SAML2Config
cfg of
Just Text
expectedDestination
| Text
destination forall a. Eq a => a -> a -> Bool
/= Text
expectedDestination ->
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text -> SAML2Error
UnexpectedDestination Text
destination
Maybe Text
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
let issuer :: Text
issuer = Response -> Text
responseIssuer Response
samlResponse
case SAML2Config -> Maybe Text
saml2ExpectedIssuer SAML2Config
cfg of
Just Text
expectedIssuer
| Text
issuer forall a. Eq a => a -> a -> Bool
/= Text
expectedIssuer -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text -> SAML2Error
InvalidIssuer Text
issuer
Maybe Text
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Element
signedInfo <- forall (m :: * -> *). MonadFail m => Cursor -> m Element
extractSignedInfo (Document -> Cursor
XML.fromDocument Document
responseXmlDoc)
let doc :: Document
doc = Prologue -> Element -> [Miscellaneous] -> Document
XML.Document ([Miscellaneous] -> Maybe Doctype -> [Miscellaneous] -> Prologue
XML.Prologue [] forall a. Maybe a
Nothing []) Element
signedInfo []
let signedInfoXml :: ByteString
signedInfoXml = RenderSettings -> Document -> ByteString
XML.renderLBS forall a. Default a => a
def Document
doc
let prefixList :: [Text]
prefixList = Cursor -> [Text]
extractPrefixList (Document -> Cursor
XML.fromDocument Document
doc)
Either IOException ByteString
signedInfoCanonResult <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$
[Text] -> ByteString -> IO ByteString
canonicalise [Text]
prefixList (ByteString -> ByteString
LBS.toStrict ByteString
signedInfoXml)
ByteString
normalisedSignedInfo <- case Either IOException ByteString
signedInfoCanonResult of
Left IOException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ IOException -> SAML2Error
CanonicalisationFailure IOException
err
Right ByteString
result -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
result
let documentId :: Text
documentId = Response -> Text
responseId Response
samlResponse
let referenceId :: Text
referenceId = Reference -> Text
referenceURI
forall a b. (a -> b) -> a -> b
$ SignedInfo -> Reference
signedInfoReference
forall a b. (a -> b) -> a -> b
$ Signature -> SignedInfo
signatureInfo
forall a b. (a -> b) -> a -> b
$ Response -> Signature
responseSignature Response
samlResponse
if Text
documentId forall a. Eq a => a -> a -> Bool
/= Text
referenceId
then forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text -> SAML2Error
UnexpectedReference Text
referenceId
else forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
let docMinusSignature :: Document
docMinusSignature = Document -> Document
removeSignature Document
responseXmlDoc
let renderedXml :: ByteString
renderedXml = RenderSettings -> Document -> ByteString
XML.renderLBS forall a. Default a => a
def Document
docMinusSignature
Either IOException ByteString
refCanonResult <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$ [Text] -> ByteString -> IO ByteString
canonicalise [Text]
prefixList (ByteString -> ByteString
LBS.toStrict ByteString
renderedXml)
ByteString
normalised <- case Either IOException ByteString
refCanonResult of
Left IOException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ IOException -> SAML2Error
CanonicalisationFailure IOException
err
Right ByteString
result -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
result
let documentHash :: Digest SHA256
documentHash = forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith SHA256
SHA256 ByteString
normalised
let referenceHash :: Maybe (Digest SHA256)
referenceHash = forall a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
ba -> Maybe (Digest a)
digestFromByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BS.decodeLenient
forall a b. (a -> b) -> a -> b
$ Reference -> ByteString
referenceDigestValue
forall a b. (a -> b) -> a -> b
$ SignedInfo -> Reference
signedInfoReference
forall a b. (a -> b) -> a -> b
$ Signature -> SignedInfo
signatureInfo
forall a b. (a -> b) -> a -> b
$ Response -> Signature
responseSignature Response
samlResponse
if forall a. a -> Maybe a
Just Digest SHA256
documentHash forall a. Eq a => a -> a -> Bool
/= Maybe (Digest SHA256)
referenceHash
then forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
InvalidDigest
else forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
let sig :: ByteString
sig = ByteString -> ByteString
BS.decodeLenient forall a b. (a -> b) -> a -> b
$ Signature -> ByteString
signatureValue forall a b. (a -> b) -> a -> b
$ Response -> Signature
responseSignature Response
samlResponse
let pubKey :: PublicKey
pubKey = SAML2Config -> PublicKey
saml2PublicKey SAML2Config
cfg
if forall hashAlg.
HashAlgorithmASN1 hashAlg =>
Maybe hashAlg -> PublicKey -> ByteString -> ByteString -> Bool
PKCS15.verify (forall a. a -> Maybe a
Just SHA256
SHA256) PublicKey
pubKey ByteString
normalisedSignedInfo ByteString
sig
then forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
else forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
InvalidSignature
Assertion
assertion <- case Response -> Maybe EncryptedAssertion
responseEncryptedAssertion Response
samlResponse of
Just EncryptedAssertion
encrypted -> case SAML2Config -> Maybe PrivateKey
saml2PrivateKey SAML2Config
cfg of
Just PrivateKey
pk -> PrivateKey -> EncryptedAssertion -> ExceptT SAML2Error IO Assertion
decryptAssertion PrivateKey
pk EncryptedAssertion
encrypted
Maybe PrivateKey
Nothing -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
EncryptedAssertionNotSupported
Maybe EncryptedAssertion
Nothing
| SAML2Config -> Bool
saml2RequireEncryptedAssertion SAML2Config
cfg -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
EncryptedAssertionRequired
| Bool
otherwise -> case Response -> Maybe Assertion
responseAssertion Response
samlResponse of
Just Assertion
plain -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Assertion
plain
Maybe Assertion
Nothing -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ IOException -> SAML2Error
InvalidResponse forall a b. (a -> b) -> a -> b
$ String -> IOException
userError String
"Assertion or EncryptedAssertion is required"
let Conditions{[AudienceRestriction]
UTCTime
conditionsAudienceRestrictions :: Conditions -> [AudienceRestriction]
conditionsNotOnOrAfter :: Conditions -> UTCTime
conditionsNotBefore :: Conditions -> UTCTime
conditionsAudienceRestrictions :: [AudienceRestriction]
conditionsNotOnOrAfter :: UTCTime
conditionsNotBefore :: UTCTime
..} = Assertion -> Conditions
assertionConditions Assertion
assertion
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((UTCTime
now forall a. Ord a => a -> a -> Bool
< UTCTime
conditionsNotBefore Bool -> Bool -> Bool
|| UTCTime
now forall a. Ord a => a -> a -> Bool
>= UTCTime
conditionsNotOnOrAfter) Bool -> Bool -> Bool
&&
Bool -> Bool
not (SAML2Config -> Bool
saml2DisableTimeValidation SAML2Config
cfg))
forall a b. (a -> b) -> a -> b
$ forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
NotValid
case SAML2Config -> [Text]
saml2Audiences SAML2Config
cfg of
[] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
[Text]
ourAudiences ->
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [AudienceRestriction]
conditionsAudienceRestrictions forall a b. (a -> b) -> a -> b
$
\(AudienceRestriction [Text]
audiences) ->
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
ourAudiences) [Text]
audiences)
forall a b. (a -> b) -> a -> b
$ forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Text] -> SAML2Error
AudienceMismatch [Text]
audiences)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Assertion
assertion
decryptAssertion :: PrivateKey -> EncryptedAssertion -> ExceptT SAML2Error IO Assertion
decryptAssertion :: PrivateKey -> EncryptedAssertion -> ExceptT SAML2Error IO Assertion
decryptAssertion PrivateKey
pk EncryptedAssertion
encryptedAssertion = do
Either Error ByteString
oaepResult <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall hash (m :: * -> *).
(HashAlgorithm hash, MonadRandom m) =>
OAEPParams hash ByteString ByteString
-> PrivateKey -> ByteString -> m (Either Error ByteString)
OAEP.decryptSafer (forall seed output hash.
(ByteArrayAccess seed, ByteArray output, HashAlgorithm hash) =>
hash -> OAEPParams hash seed output
OAEP.defaultOAEPParams SHA1
SHA1) PrivateKey
pk
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BS.decodeLenient
forall a b. (a -> b) -> a -> b
$ CipherData -> ByteString
cipherValue
forall a b. (a -> b) -> a -> b
$ EncryptedKey -> CipherData
encryptedKeyCipher
forall a b. (a -> b) -> a -> b
$ EncryptedAssertion -> EncryptedKey
encryptedAssertionKey
forall a b. (a -> b) -> a -> b
$ EncryptedAssertion
encryptedAssertion
ByteString
aesKey <- case Either Error ByteString
oaepResult of
Left Error
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Error -> SAML2Error
DecryptionFailure Error
err
Right ByteString
cipherData -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
cipherData
ByteString
xmlData <- case forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit ByteString
aesKey of
CryptoFailed CryptoError
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ CryptoError -> SAML2Error
CryptoError CryptoError
err
CryptoPassed AES128
aes128 -> do
let cipherText :: ByteString
cipherText = ByteString -> ByteString
BS.decodeLenient
forall a b. (a -> b) -> a -> b
$ CipherData -> ByteString
cipherValue
forall a b. (a -> b) -> a -> b
$ EncryptedAssertion -> CipherData
encryptedAssertionCipher
forall a b. (a -> b) -> a -> b
$ EncryptedAssertion
encryptedAssertion
let (ByteString
ivBytes, ByteString
xmlBytes) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
16 ByteString
cipherText
case forall b c. (ByteArrayAccess b, BlockCipher c) => b -> Maybe (IV c)
makeIV ByteString
ivBytes of
Maybe (IV AES128)
Nothing -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
InvalidIV
Just IV AES128
iv -> do
let plaintext :: ByteString
plaintext = forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
cbcDecrypt (AES128
aes128 :: AES128) IV AES128
iv ByteString
xmlBytes
case ByteString -> Maybe ByteString
ansiX923 ByteString
plaintext of
Maybe ByteString
Nothing -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
InvalidPadding
Just ByteString
xmlData -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
xmlData
case ParseSettings -> ByteString -> Either SomeException Document
XML.parseLBS forall a. Default a => a
def (ByteString -> ByteString
LBS.fromStrict ByteString
xmlData) of
Left SomeException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ SomeException -> SAML2Error
InvalidAssertionXml SomeException
err
Right Document
assertDoc -> do
Either IOException Assertion
assertParseResult <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$
forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML (Document -> Cursor
XML.fromDocument Document
assertDoc)
case Either IOException Assertion
assertParseResult of
Left IOException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ IOException -> SAML2Error
InvalidAssertion IOException
err
Right Assertion
assertion -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Assertion
assertion
ansiX923 :: BS.ByteString -> Maybe BS.ByteString
ansiX923 :: ByteString -> Maybe ByteString
ansiX923 ByteString
d
| Int
len forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. Maybe a
Nothing
| Int
padLen forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
padLen forall a. Ord a => a -> a -> Bool
> Int
len = forall a. Maybe a
Nothing
| Bool
otherwise = forall a. a -> Maybe a
Just ByteString
content
where len :: Int
len = ByteString -> Int
BS.length ByteString
d
padBytes :: Word8
padBytes = HasCallStack => ByteString -> Int -> Word8
BS.index ByteString
d (Int
lenforall a. Num a => a -> a -> a
-Int
1)
padLen :: Int
padLen = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
padBytes
(ByteString
content,ByteString
_) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (Int
len forall a. Num a => a -> a -> a
- Int
padLen) ByteString
d