module Network.Wai.SAML2.Validation (
    validateResponse,
    decodeResponse,
    validateSAMLResponse,
    ansiX923
) where
import Control.Exception
import Control.Monad.Except
import Crypto.Error
import Crypto.Hash
import qualified Crypto.PubKey.RSA.OAEP as OAEP
import Crypto.PubKey.RSA.PKCS15 as PKCS15
import Crypto.PubKey.RSA.Types (PrivateKey)
import Crypto.Cipher.AES
import Crypto.Cipher.Types
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as BS
import qualified Data.ByteString.Lazy as LBS
import Data.Default.Class
import Data.Time
import Network.Wai.SAML2.XML.Encrypted
import Network.Wai.SAML2.Config
import Network.Wai.SAML2.Error
import Network.Wai.SAML2.XML
import Network.Wai.SAML2.C14N
import Network.Wai.SAML2.Response
import Network.Wai.SAML2.Assertion
import qualified Text.XML as XML
import qualified Text.XML.Cursor as XML
validateResponse :: SAML2Config
                 -> BS.ByteString
                 -> IO (Either SAML2Error (Assertion, Response))
validateResponse :: SAML2Config
-> ByteString -> IO (Either SAML2Error (Assertion, Response))
validateResponse SAML2Config
cfg ByteString
responseData = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
    
    UTCTime
now <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    (Document
responseXmlDoc, Response
samlResponse) <- ByteString -> ExceptT SAML2Error IO (Document, Response)
decodeResponse ByteString
responseData
    Assertion
assertion <- SAML2Config
-> Document
-> Response
-> UTCTime
-> ExceptT SAML2Error IO Assertion
validateSAMLResponse SAML2Config
cfg Document
responseXmlDoc Response
samlResponse UTCTime
now
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Assertion
assertion, Response
samlResponse)
decodeResponse :: BS.ByteString -> ExceptT SAML2Error IO (XML.Document, Response)
decodeResponse :: ByteString -> ExceptT SAML2Error IO (Document, Response)
decodeResponse ByteString
responseData = do
    
    let resXmlDocData :: ByteString
resXmlDocData = ByteString -> ByteString
BS.decodeLenient ByteString
responseData
    
    
    Document
responseXmlDoc <- case ParseSettings -> ByteString -> Either SomeException Document
XML.parseLBS forall a. Default a => a
def (ByteString -> ByteString
LBS.fromStrict ByteString
resXmlDocData) of
        Left SomeException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ SomeException -> SAML2Error
InvalidResponseXml SomeException
err
        Right Document
responseXmlDoc -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Document
responseXmlDoc
    
    Either IOException Response
resParseResult <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$
        forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML (Document -> Cursor
XML.fromDocument Document
responseXmlDoc)
    case Either IOException Response
resParseResult of
        Left IOException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ IOException -> SAML2Error
InvalidResponse IOException
err
        Right Response
samlResponse -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Document
responseXmlDoc, Response
samlResponse)
validateSAMLResponse :: SAML2Config
                     -> XML.Document
                     -> Response
                     -> UTCTime
                     -> ExceptT SAML2Error IO Assertion
validateSAMLResponse :: SAML2Config
-> Document
-> Response
-> UTCTime
-> ExceptT SAML2Error IO Assertion
validateSAMLResponse SAML2Config
cfg Document
responseXmlDoc Response
samlResponse UTCTime
now = do
    
    case StatusCode -> StatusCodeValue
statusCodeValue forall a b. (a -> b) -> a -> b
$ Response -> StatusCode
responseStatusCode Response
samlResponse of
        StatusCodeValue
Success -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        StatusCodeValue
_status -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ StatusCode -> SAML2Error
Unsuccessful forall a b. (a -> b) -> a -> b
$ Response -> StatusCode
responseStatusCode Response
samlResponse
    
    
    let destination :: Text
destination = Response -> Text
responseDestination Response
samlResponse
    case SAML2Config -> Maybe Text
saml2ExpectedDestination SAML2Config
cfg of
        Just Text
expectedDestination
            | Text
destination forall a. Eq a => a -> a -> Bool
/= Text
expectedDestination ->
                forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text -> SAML2Error
UnexpectedDestination Text
destination
        Maybe Text
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    
    
    let issuer :: Text
issuer = Response -> Text
responseIssuer Response
samlResponse
    case SAML2Config -> Maybe Text
saml2ExpectedIssuer SAML2Config
cfg of
        Just Text
expectedIssuer
            | Text
issuer forall a. Eq a => a -> a -> Bool
/= Text
expectedIssuer -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text -> SAML2Error
InvalidIssuer Text
issuer
        Maybe Text
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    
    
    
    
    
    
    Element
signedInfo <- forall (m :: * -> *). MonadFail m => Cursor -> m Element
extractSignedInfo (Document -> Cursor
XML.fromDocument Document
responseXmlDoc)
    
    
    let doc :: Document
doc = Prologue -> Element -> [Miscellaneous] -> Document
XML.Document ([Miscellaneous] -> Maybe Doctype -> [Miscellaneous] -> Prologue
XML.Prologue [] forall a. Maybe a
Nothing []) Element
signedInfo []
    let signedInfoXml :: ByteString
signedInfoXml = RenderSettings -> Document -> ByteString
XML.renderLBS forall a. Default a => a
def Document
doc
    
    Either IOException ByteString
signedInfoCanonResult <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$
        ByteString -> IO ByteString
canonicalise (ByteString -> ByteString
LBS.toStrict ByteString
signedInfoXml)
    ByteString
normalisedSignedInfo <- case Either IOException ByteString
signedInfoCanonResult of
        Left IOException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ IOException -> SAML2Error
CanonicalisationFailure IOException
err
        Right ByteString
result -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
result
    
    
    
    
    
    
    let documentId :: Text
documentId = Response -> Text
responseId Response
samlResponse
    let referenceId :: Text
referenceId = Reference -> Text
referenceURI
                    forall a b. (a -> b) -> a -> b
$ SignedInfo -> Reference
signedInfoReference
                    forall a b. (a -> b) -> a -> b
$ Signature -> SignedInfo
signatureInfo
                    forall a b. (a -> b) -> a -> b
$ Response -> Signature
responseSignature Response
samlResponse
    if Text
documentId forall a. Eq a => a -> a -> Bool
/= Text
referenceId
    then forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text -> SAML2Error
UnexpectedReference Text
referenceId
    else forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    
    
    
    
    let docMinusSignature :: Document
docMinusSignature = Document -> Document
removeSignature Document
responseXmlDoc
    
    let renderedXml :: ByteString
renderedXml = RenderSettings -> Document -> ByteString
XML.renderLBS forall a. Default a => a
def Document
docMinusSignature
    Either IOException ByteString
refCanonResult <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$ ByteString -> IO ByteString
canonicalise (ByteString -> ByteString
LBS.toStrict ByteString
renderedXml)
    ByteString
normalised <- case Either IOException ByteString
refCanonResult of
        Left IOException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ IOException -> SAML2Error
CanonicalisationFailure IOException
err
        Right ByteString
result -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
result
    
    
    
    
    
    let documentHash :: Digest SHA256
documentHash = forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith SHA256
SHA256 ByteString
normalised
    let referenceHash :: Maybe (Digest SHA256)
referenceHash = forall a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
ba -> Maybe (Digest a)
digestFromByteString
                      forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BS.decodeLenient
                      forall a b. (a -> b) -> a -> b
$ Reference -> ByteString
referenceDigestValue
                      forall a b. (a -> b) -> a -> b
$ SignedInfo -> Reference
signedInfoReference
                      forall a b. (a -> b) -> a -> b
$ Signature -> SignedInfo
signatureInfo
                      forall a b. (a -> b) -> a -> b
$ Response -> Signature
responseSignature Response
samlResponse
    if forall a. a -> Maybe a
Just Digest SHA256
documentHash forall a. Eq a => a -> a -> Bool
/= Maybe (Digest SHA256)
referenceHash
    then forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
InvalidDigest
    else forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    
    
    
    
    let sig :: ByteString
sig = ByteString -> ByteString
BS.decodeLenient forall a b. (a -> b) -> a -> b
$ Signature -> ByteString
signatureValue forall a b. (a -> b) -> a -> b
$ Response -> Signature
responseSignature Response
samlResponse
    
    
    let pubKey :: PublicKey
pubKey = SAML2Config -> PublicKey
saml2PublicKey SAML2Config
cfg
    if forall hashAlg.
HashAlgorithmASN1 hashAlg =>
Maybe hashAlg -> PublicKey -> ByteString -> ByteString -> Bool
PKCS15.verify (forall a. a -> Maybe a
Just SHA256
SHA256) PublicKey
pubKey ByteString
normalisedSignedInfo ByteString
sig
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    else forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
InvalidSignature
    Assertion
assertion <- case Response -> Maybe EncryptedAssertion
responseEncryptedAssertion Response
samlResponse of
        Just EncryptedAssertion
encrypted -> case SAML2Config -> Maybe PrivateKey
saml2PrivateKey SAML2Config
cfg of
            Just PrivateKey
pk -> PrivateKey -> EncryptedAssertion -> ExceptT SAML2Error IO Assertion
decryptAssertion PrivateKey
pk EncryptedAssertion
encrypted
            Maybe PrivateKey
Nothing -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
EncryptedAssertionNotSupported
        Maybe EncryptedAssertion
Nothing
            | SAML2Config -> Bool
saml2RequireEncryptedAssertion SAML2Config
cfg -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
EncryptedAssertionRequired
            | Bool
otherwise -> case Response -> Maybe Assertion
responseAssertion Response
samlResponse of
                Just Assertion
plain -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Assertion
plain
                Maybe Assertion
Nothing -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ IOException -> SAML2Error
InvalidResponse forall a b. (a -> b) -> a -> b
$ String -> IOException
userError String
"Assertion or EncryptedAssertion is required"
    
    let Conditions{[AudienceRestriction]
UTCTime
conditionsAudienceRestrictions :: Conditions -> [AudienceRestriction]
conditionsNotOnOrAfter :: Conditions -> UTCTime
conditionsNotBefore :: Conditions -> UTCTime
conditionsAudienceRestrictions :: [AudienceRestriction]
conditionsNotOnOrAfter :: UTCTime
conditionsNotBefore :: UTCTime
..} = Assertion -> Conditions
assertionConditions Assertion
assertion
    
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((UTCTime
now forall a. Ord a => a -> a -> Bool
< UTCTime
conditionsNotBefore Bool -> Bool -> Bool
|| UTCTime
now forall a. Ord a => a -> a -> Bool
>= UTCTime
conditionsNotOnOrAfter) Bool -> Bool -> Bool
&&
           Bool -> Bool
not (SAML2Config -> Bool
saml2DisableTimeValidation SAML2Config
cfg))
          forall a b. (a -> b) -> a -> b
$ forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
NotValid
    
    
    case SAML2Config -> [Text]
saml2Audiences SAML2Config
cfg of
        
        [] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        [Text]
ourAudiences ->
           forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [AudienceRestriction]
conditionsAudienceRestrictions forall a b. (a -> b) -> a -> b
$
              \(AudienceRestriction [Text]
audiences) ->
                 forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
ourAudiences) [Text]
audiences)
                   forall a b. (a -> b) -> a -> b
$ forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Text] -> SAML2Error
AudienceMismatch [Text]
audiences)
    
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Assertion
assertion
decryptAssertion :: PrivateKey -> EncryptedAssertion -> ExceptT SAML2Error IO Assertion
decryptAssertion :: PrivateKey -> EncryptedAssertion -> ExceptT SAML2Error IO Assertion
decryptAssertion PrivateKey
pk EncryptedAssertion
encryptedAssertion = do
    Either Error ByteString
oaepResult <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall hash (m :: * -> *).
(HashAlgorithm hash, MonadRandom m) =>
OAEPParams hash ByteString ByteString
-> PrivateKey -> ByteString -> m (Either Error ByteString)
OAEP.decryptSafer (forall seed output hash.
(ByteArrayAccess seed, ByteArray output, HashAlgorithm hash) =>
hash -> OAEPParams hash seed output
OAEP.defaultOAEPParams SHA1
SHA1) PrivateKey
pk
        forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BS.decodeLenient
        forall a b. (a -> b) -> a -> b
$ CipherData -> ByteString
cipherValue
        forall a b. (a -> b) -> a -> b
$ EncryptedKey -> CipherData
encryptedKeyCipher
        forall a b. (a -> b) -> a -> b
$ EncryptedAssertion -> EncryptedKey
encryptedAssertionKey
        forall a b. (a -> b) -> a -> b
$ EncryptedAssertion
encryptedAssertion
    ByteString
aesKey <- case Either Error ByteString
oaepResult of
        Left Error
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Error -> SAML2Error
DecryptionFailure Error
err
        Right ByteString
cipherData -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
cipherData
    
    
    ByteString
xmlData <- case forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit ByteString
aesKey of
        CryptoFailed CryptoError
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ CryptoError -> SAML2Error
CryptoError CryptoError
err
        CryptoPassed AES128
aes128 -> do
            
            let cipherText :: ByteString
cipherText = ByteString -> ByteString
BS.decodeLenient
                           forall a b. (a -> b) -> a -> b
$ CipherData -> ByteString
cipherValue
                           forall a b. (a -> b) -> a -> b
$ EncryptedAssertion -> CipherData
encryptedAssertionCipher
                           forall a b. (a -> b) -> a -> b
$ EncryptedAssertion
encryptedAssertion
            
            
            let (ByteString
ivBytes, ByteString
xmlBytes) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
16 ByteString
cipherText
            
            case forall b c. (ByteArrayAccess b, BlockCipher c) => b -> Maybe (IV c)
makeIV ByteString
ivBytes of
                Maybe (IV AES128)
Nothing -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
InvalidIV
                Just IV AES128
iv -> do
                    
                    let plaintext :: ByteString
plaintext = forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
cbcDecrypt (AES128
aes128 :: AES128) IV AES128
iv ByteString
xmlBytes
                    
                    case ByteString -> Maybe ByteString
ansiX923 ByteString
plaintext of
                        Maybe ByteString
Nothing -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError SAML2Error
InvalidPadding
                        Just ByteString
xmlData -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
xmlData
    
    case ParseSettings -> ByteString -> Either SomeException Document
XML.parseLBS forall a. Default a => a
def (ByteString -> ByteString
LBS.fromStrict ByteString
xmlData) of
        Left SomeException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ SomeException -> SAML2Error
InvalidAssertionXml SomeException
err
        Right Document
assertDoc -> do
            
            
            Either IOException Assertion
assertParseResult <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$
                forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML (Document -> Cursor
XML.fromDocument Document
assertDoc)
            case Either IOException Assertion
assertParseResult of
                Left IOException
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ IOException -> SAML2Error
InvalidAssertion IOException
err
                Right Assertion
assertion -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Assertion
assertion
ansiX923 :: BS.ByteString -> Maybe BS.ByteString
ansiX923 :: ByteString -> Maybe ByteString
ansiX923 ByteString
d
    | Int
len forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. Maybe a
Nothing
    | Int
padLen forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
padLen forall a. Ord a => a -> a -> Bool
> Int
len = forall a. Maybe a
Nothing
    | Bool
otherwise = forall a. a -> Maybe a
Just ByteString
content
    where len :: Int
len = ByteString -> Int
BS.length ByteString
d
          padBytes :: Word8
padBytes = HasCallStack => ByteString -> Int -> Word8
BS.index ByteString
d (Int
lenforall a. Num a => a -> a -> a
-Int
1)
          padLen :: Int
padLen = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
padBytes
          (ByteString
content,ByteString
_) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (Int
len forall a. Num a => a -> a -> a
- Int
padLen) ByteString
d