module Amazon.SNS.Verify.Validate
( validateSnsMessage
, handleSubscription
, SNSNotificationValidationError (..)
, ValidSNSMessage (..)
) where
import Amazon.SNS.Verify.Prelude
import Amazon.SNS.Verify.Payload
import Amazon.SNS.Verify.ValidURI (validRegPattern, validScheme)
import Control.Error (ExceptT, catMaybes, headMay, runExceptT, throwE)
import Control.Monad (when)
import Data.ByteArray.Encoding (Base (Base64), convertFromBase)
import Data.PEM (pemContent, pemParseLBS)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Data.X509
( HashALG (..)
, PubKeyALG (..)
, SignatureALG (..)
, SignedCertificate
, certPubKey
, decodeSignedCertificate
, getCertificate
)
import Data.X509.Validation
( SignatureFailure
, SignatureVerification (..)
, verifySignature
)
import Network.HTTP.Simple
( getResponseBody
, getResponseStatusCode
, httpLbs
, parseRequest_
)
import Network.URI (parseURI, uriAuthority, uriRegName, uriScheme)
import Text.Regex.TDFA ((=~))
data ValidSNSMessage
= SNSMessage Text
| SNSSubscribe SNSSubscription
| SNSUnsubscribe SNSSubscription
deriving stock (Int -> ValidSNSMessage -> ShowS
[ValidSNSMessage] -> ShowS
ValidSNSMessage -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ValidSNSMessage] -> ShowS
$cshowList :: [ValidSNSMessage] -> ShowS
show :: ValidSNSMessage -> String
$cshow :: ValidSNSMessage -> String
showsPrec :: Int -> ValidSNSMessage -> ShowS
$cshowsPrec :: Int -> ValidSNSMessage -> ShowS
Show, ValidSNSMessage -> ValidSNSMessage -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ValidSNSMessage -> ValidSNSMessage -> Bool
$c/= :: ValidSNSMessage -> ValidSNSMessage -> Bool
== :: ValidSNSMessage -> ValidSNSMessage -> Bool
$c== :: ValidSNSMessage -> ValidSNSMessage -> Bool
Eq)
validateSnsMessage
:: MonadIO m
=> SNSPayload
-> m (Either SNSNotificationValidationError ValidSNSMessage)
validateSnsMessage :: forall (m :: * -> *).
MonadIO m =>
SNSPayload
-> m (Either SNSNotificationValidationError ValidSNSMessage)
validateSnsMessage payload :: SNSPayload
payload@SNSPayload {Text
SNSType
snsTypePayload :: SNSPayload -> SNSType
snsSigningCertURL :: SNSPayload -> Text
snsSignature :: SNSPayload -> Text
snsSignatureVersion :: SNSPayload -> Text
snsType :: SNSPayload -> Text
snsTopicArn :: SNSPayload -> Text
snsTimestamp :: SNSPayload -> Text
snsMessageId :: SNSPayload -> Text
snsMessage :: SNSPayload -> Text
snsTypePayload :: SNSType
snsSigningCertURL :: Text
snsSignature :: Text
snsSignatureVersion :: Text
snsType :: Text
snsTopicArn :: Text
snsTimestamp :: Text
snsMessageId :: Text
snsMessage :: Text
..} = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
ByteString
signature <-
forall (m :: * -> *) a e b.
Monad m =>
(a -> e) -> Either a b -> ExceptT e m b
unTryE String -> SNSNotificationValidationError
BadSignature
forall a b. (a -> b) -> a -> b
$ forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
convertFromBase Base
Base64
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
encodeUtf8
Text
snsSignature
SignedCertificate
signedCert <- forall (m :: * -> *).
MonadIO m =>
SNSPayload
-> ExceptT SNSNotificationValidationError m SignedCertificate
retrieveCertificate SNSPayload
payload
let valid :: SignatureVerification
valid =
SignatureALG
-> PubKey -> ByteString -> ByteString -> SignatureVerification
verifySignature
(HashALG -> PubKeyALG -> SignatureALG
SignatureALG HashALG
HashSHA1 PubKeyALG
PubKeyALG_RSA)
(Certificate -> PubKey
certPubKey forall a b. (a -> b) -> a -> b
$ SignedCertificate -> Certificate
getCertificate SignedCertificate
signedCert)
(SNSPayload -> ByteString
unsignedSignature SNSPayload
payload)
ByteString
signature
case SignatureVerification
valid of
SignatureVerification
SignaturePass -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case SNSType
snsTypePayload of
Notification {} -> Text -> ValidSNSMessage
SNSMessage Text
snsMessage
SubscriptionConfirmation SNSSubscription
x -> SNSSubscription -> ValidSNSMessage
SNSSubscribe SNSSubscription
x
UnsubscribeConfirmation SNSSubscription
x -> SNSSubscription -> ValidSNSMessage
SNSUnsubscribe SNSSubscription
x
SignatureFailed SignatureFailure
e -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE forall a b. (a -> b) -> a -> b
$ SignatureFailure -> SNSNotificationValidationError
InvalidPayload SignatureFailure
e
retrieveCertificate
:: MonadIO m
=> SNSPayload
-> ExceptT SNSNotificationValidationError m SignedCertificate
retrieveCertificate :: forall (m :: * -> *).
MonadIO m =>
SNSPayload
-> ExceptT SNSNotificationValidationError m SignedCertificate
retrieveCertificate SNSPayload {Text
SNSType
snsTypePayload :: SNSType
snsSigningCertURL :: Text
snsSignature :: Text
snsSignatureVersion :: Text
snsType :: Text
snsTopicArn :: Text
snsTimestamp :: Text
snsMessageId :: Text
snsMessage :: Text
snsTypePayload :: SNSPayload -> SNSType
snsSigningCertURL :: SNSPayload -> Text
snsSignature :: SNSPayload -> Text
snsSignatureVersion :: SNSPayload -> Text
snsType :: SNSPayload -> Text
snsTopicArn :: SNSPayload -> Text
snsTimestamp :: SNSPayload -> Text
snsMessageId :: SNSPayload -> Text
snsMessage :: SNSPayload -> Text
..} = do
String
certUrlStr <- forall (m :: * -> *) a e b.
Monad m =>
(a -> e) -> Either a b -> ExceptT e m b
unTryE forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ Text -> Either SNSNotificationValidationError String
validateCertUrl Text
snsSigningCertURL
Response ByteString
response <- forall (m :: * -> *).
MonadIO m =>
Request -> m (Response ByteString)
httpLbs forall a b. (a -> b) -> a -> b
$ String -> Request
parseRequest_ String
certUrlStr
[PEM]
pems <- forall (m :: * -> *) a e b.
Monad m =>
(a -> e) -> Either a b -> ExceptT e m b
unTryE String -> SNSNotificationValidationError
BadPem forall a b. (a -> b) -> a -> b
$ ByteString -> Either String [PEM]
pemParseLBS forall a b. (a -> b) -> a -> b
$ forall a. Response a -> a
getResponseBody Response ByteString
response
ByteString
cert <-
forall (m :: * -> *) a. Monad m => m a -> Maybe a -> m a
fromMaybeM (forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE forall a b. (a -> b) -> a -> b
$ String -> SNSNotificationValidationError
BadPem String
"Empty List") forall a b. (a -> b) -> a -> b
$ PEM -> ByteString
pemContent forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. [a] -> Maybe a
headMay [PEM]
pems
forall (m :: * -> *) a e b.
Monad m =>
(a -> e) -> Either a b -> ExceptT e m b
unTryE String -> SNSNotificationValidationError
BadCert forall a b. (a -> b) -> a -> b
$ ByteString -> Either String SignedCertificate
decodeSignedCertificate ByteString
cert
validateCertUrl :: Text -> Either SNSNotificationValidationError String
validateCertUrl :: Text -> Either SNSNotificationValidationError String
validateCertUrl Text
certUrl = do
URI
uri <- forall (m :: * -> *) a. Monad m => m a -> Maybe a -> m a
fromMaybeM (forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ String -> SNSNotificationValidationError
BadUri String
certUrlStr) forall a b. (a -> b) -> a -> b
$ String -> Maybe URI
parseURI String
certUrlStr
if URI -> String
uriScheme URI
uri
forall a. Eq a => a -> a -> Bool
== String
validScheme
Bool -> Bool -> Bool
&& forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" URIAuth -> String
uriRegName (URI -> Maybe URIAuth
uriAuthority URI
uri)
forall source source1 target.
(RegexMaker Regex CompOption ExecOption source,
RegexContext Regex source1 target) =>
source1 -> source -> target
=~ String
validRegPattern
then forall a b. b -> Either a b
Right String
certUrlStr
else forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ String -> SNSNotificationValidationError
BadUri String
certUrlStr
where
certUrlStr :: String
certUrlStr = Text -> String
T.unpack Text
certUrl
unsignedSignature :: SNSPayload -> ByteString
unsignedSignature :: SNSPayload -> ByteString
unsignedSignature SNSPayload {Text
SNSType
snsTypePayload :: SNSType
snsSigningCertURL :: Text
snsSignature :: Text
snsSignatureVersion :: Text
snsType :: Text
snsTopicArn :: Text
snsTimestamp :: Text
snsMessageId :: Text
snsMessage :: Text
snsTypePayload :: SNSPayload -> SNSType
snsSigningCertURL :: SNSPayload -> Text
snsSignature :: SNSPayload -> Text
snsSignatureVersion :: SNSPayload -> Text
snsType :: SNSPayload -> Text
snsTopicArn :: SNSPayload -> Text
snsTimestamp :: SNSPayload -> Text
snsMessageId :: SNSPayload -> Text
snsMessage :: SNSPayload -> Text
..} =
Text -> ByteString
encodeUtf8
forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat
forall a b. (a -> b) -> a -> b
$ (forall a. Semigroup a => a -> a -> a
<> Text
"\n")
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. [Maybe a] -> [a]
catMaybes
[ forall a. a -> Maybe a
Just Text
"Message"
, forall a. a -> Maybe a
Just Text
snsMessage
, forall a. a -> Maybe a
Just Text
"MessageId"
, forall a. a -> Maybe a
Just Text
snsMessageId
, Text
"SubscribeURL" forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Maybe Text
mSubscribeUrl
, Maybe Text
mSubscribeUrl
, Text
"Subject" forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Maybe Text
mSubject
, Maybe Text
mSubject
, forall a. a -> Maybe a
Just Text
"Timestamp"
, forall a. a -> Maybe a
Just Text
snsTimestamp
, Text
"Token" forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Maybe Text
mToken
, Maybe Text
mToken
, forall a. a -> Maybe a
Just Text
"TopicArn"
, forall a. a -> Maybe a
Just Text
snsTopicArn
, forall a. a -> Maybe a
Just Text
"Type"
, forall a. a -> Maybe a
Just Text
snsType
]
where
(Maybe Text
mSubject, Maybe Text
mToken, Maybe Text
mSubscribeUrl) = case SNSType
snsTypePayload of
Notification SNSNotification
x -> (SNSNotification -> Maybe Text
snsSubject SNSNotification
x, forall a. Maybe a
Nothing, forall a. Maybe a
Nothing)
SubscriptionConfirmation SNSSubscription
x ->
(forall a. Maybe a
Nothing, forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsToken SNSSubscription
x, forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsSubscribeURL SNSSubscription
x)
UnsubscribeConfirmation SNSSubscription
x ->
(forall a. Maybe a
Nothing, forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsToken SNSSubscription
x, forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsSubscribeURL SNSSubscription
x)
handleSubscription
:: MonadIO m
=> ValidSNSMessage
-> m (Either SNSNotificationValidationError Text)
handleSubscription :: forall (m :: * -> *).
MonadIO m =>
ValidSNSMessage -> m (Either SNSNotificationValidationError Text)
handleSubscription =
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
SNSMessage Text
t -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
t
SNSSubscribe SNSSubscription {Text
snsSubscribeURL :: Text
snsToken :: Text
snsSubscribeURL :: SNSSubscription -> Text
snsToken :: SNSSubscription -> Text
..} -> do
Response ByteString
response <- forall (m :: * -> *).
MonadIO m =>
Request -> m (Response ByteString)
httpLbs forall a b. (a -> b) -> a -> b
$ String -> Request
parseRequest_ forall a b. (a -> b) -> a -> b
$ Text -> String
T.unpack Text
snsSubscribeURL
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Response a -> Int
getResponseStatusCode Response ByteString
response forall a. Ord a => a -> a -> Bool
>= Int
300) forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE SNSNotificationValidationError
BadSubscription
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE SNSNotificationValidationError
SubscribeMessageResponded
SNSUnsubscribe {} -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE SNSNotificationValidationError
UnsubscribeMessage
data SNSNotificationValidationError
= BadPem String
| BadUri String
| BadSignature String
| BadCert String
| BadJSONParse String
| BadSubscription
| InvalidPayload SignatureFailure
|
| UnsubscribeMessage
| SubscribeMessageResponded
deriving stock (Int -> SNSNotificationValidationError -> ShowS
[SNSNotificationValidationError] -> ShowS
SNSNotificationValidationError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SNSNotificationValidationError] -> ShowS
$cshowList :: [SNSNotificationValidationError] -> ShowS
show :: SNSNotificationValidationError -> String
$cshow :: SNSNotificationValidationError -> String
showsPrec :: Int -> SNSNotificationValidationError -> ShowS
$cshowsPrec :: Int -> SNSNotificationValidationError -> ShowS
Show, SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
$c/= :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
== :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
$c== :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
Eq)
deriving anyclass (Show SNSNotificationValidationError
Typeable SNSNotificationValidationError
SomeException -> Maybe SNSNotificationValidationError
SNSNotificationValidationError -> String
SNSNotificationValidationError -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> String)
-> Exception e
displayException :: SNSNotificationValidationError -> String
$cdisplayException :: SNSNotificationValidationError -> String
fromException :: SomeException -> Maybe SNSNotificationValidationError
$cfromException :: SomeException -> Maybe SNSNotificationValidationError
toException :: SNSNotificationValidationError -> SomeException
$ctoException :: SNSNotificationValidationError -> SomeException
Exception)