-- | PASETO token claim validation.
module Crypto.Paseto.Token.Validation
  ( -- * Errors
    ValidationError (..)
  , renderValidationError
  , renderValidationErrors

    -- * Rules
  , ValidationRule (..)
  , ClaimMustExist (..)
    -- ** Simple rules
  , forAudience
  , identifiedBy
  , issuedBy
  , notExpired
  , subject
  , validAt
  , customClaimEq
    -- ** Recommended default rules
  , getDefaultValidationRules

    -- * Validation
  , validate
  , validateDefault
  ) where

import Crypto.Paseto.Token.Claim
  ( Audience (..)
  , ClaimKey (..)
  , Expiration (..)
  , IssuedAt (..)
  , Issuer (..)
  , NotBefore (..)
  , Subject (..)
  , TokenIdentifier (..)
  , UnregisteredClaimKey
  , renderClaimKey
  , renderExpiration
  , renderIssuedAt
  , renderNotBefore
  )
import Crypto.Paseto.Token.Claims
  ( Claims
  , lookupAudience
  , lookupCustom
  , lookupExpiration
  , lookupIssuedAt
  , lookupIssuer
  , lookupNotBefore
  , lookupSubject
  , lookupTokenIdentifier
  )
import qualified Data.Aeson as Aeson
import qualified Data.ByteString as BS
import Data.Either ( lefts )
import qualified Data.List as L
import Data.List.NonEmpty ( NonEmpty )
import qualified Data.List.NonEmpty as NE
import Data.Text ( Text )
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Data.Time.Clock ( UTCTime, getCurrentTime )
import Prelude hiding ( exp, lookup )

-- | Validation error.
data ValidationError
  = -- | Expected claim does not exist.
    ValidationClaimNotFoundError
      -- | Claim key which could not be found.
      !ClaimKey
  | -- | Token claim is invalid.
    ValidationInvalidClaimError
      -- | Claim key.
      !ClaimKey
      -- | Expected claim value (rendered as 'Text').
      !Text
      -- | Actual claim value (rendered as 'Text').
      !Text
  | -- | Token is expired.
    ValidationExpirationError !Expiration
  | -- | Token's 'IssuedAt' time is in the future.
    ValidationIssuedAtError !IssuedAt
  | -- | Token is not yet valid as its 'NotBefore' time is in the future.
    ValidationNotBeforeError !NotBefore
  | -- | Custom validation error.
    ValidationCustomError !Text
  deriving stock (Int -> ValidationError -> ShowS
[ValidationError] -> ShowS
ValidationError -> String
(Int -> ValidationError -> ShowS)
-> (ValidationError -> String)
-> ([ValidationError] -> ShowS)
-> Show ValidationError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ValidationError -> ShowS
showsPrec :: Int -> ValidationError -> ShowS
$cshow :: ValidationError -> String
show :: ValidationError -> String
$cshowList :: [ValidationError] -> ShowS
showList :: [ValidationError] -> ShowS
Show, ValidationError -> ValidationError -> Bool
(ValidationError -> ValidationError -> Bool)
-> (ValidationError -> ValidationError -> Bool)
-> Eq ValidationError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ValidationError -> ValidationError -> Bool
== :: ValidationError -> ValidationError -> Bool
$c/= :: ValidationError -> ValidationError -> Bool
/= :: ValidationError -> ValidationError -> Bool
Eq)

-- | Render a 'ValidationError' as 'Text'.
renderValidationError :: ValidationError -> Text
renderValidationError :: ValidationError -> Text
renderValidationError ValidationError
err =
  case ValidationError
err of
    ValidationClaimNotFoundError ClaimKey
k ->
      Text
"\"" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ClaimKey -> Text
renderClaimKey ClaimKey
k Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\" claim does not exist"
    ValidationInvalidClaimError ClaimKey
k Text
expected Text
actual ->
      Text
"expected value \""
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
expected
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\" for \""
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ClaimKey -> Text
renderClaimKey ClaimKey
k
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\" claim but encountered \""
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
actual
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\""
    ValidationExpirationError Expiration
exp ->
      Text
"token expired at " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Expiration -> Text
renderExpiration Expiration
exp
    ValidationIssuedAtError IssuedAt
iat ->
      Text
"token is not issued until " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> IssuedAt -> Text
renderIssuedAt IssuedAt
iat
    ValidationNotBeforeError NotBefore
nbf ->
      Text
"token is not valid before " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> NotBefore -> Text
renderNotBefore NotBefore
nbf
    ValidationCustomError Text
e -> Text
e

-- | Render a non-empty list of 'ValidationError's as 'Text'.
renderValidationErrors :: NonEmpty ValidationError -> Text
renderValidationErrors :: NonEmpty ValidationError -> Text
renderValidationErrors NonEmpty ValidationError
errs =
  Text -> [Text] -> Text
T.intercalate Text
", " ((ValidationError -> Text) -> [ValidationError] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map ValidationError -> Text
renderValidationError (NonEmpty ValidationError -> [ValidationError]
forall a. NonEmpty a -> [a]
NE.toList NonEmpty ValidationError
errs))

-- | Token claim validation rule.
newtype ValidationRule = ValidationRule
  { ValidationRule -> Claims -> Either ValidationError ()
unValidationRule :: Claims -> Either ValidationError () }

-- | Whether a claim must exist.
newtype ClaimMustExist = ClaimMustExist Bool

-- | Build a simple validation rule which checks whether a value extracted
-- from the 'Claims' is equal to a given expected value.
mkEqValidationRule
  :: Eq a
  => (Claims -> Maybe a)
  -- ^ Extract a value from the claims (i.e. the actual value).
  -> ClaimKey
  -- ^ Claim key which corresponds to the extracted value (this is used in
  -- constructing errors).
  -> (a -> Text)
  -- ^ Render the expected value as 'Text' (this is used in constructing
  -- errors).
  -> a
  -- ^ Expected value.
  -> ValidationRule
mkEqValidationRule :: forall a.
Eq a =>
(Claims -> Maybe a)
-> ClaimKey -> (a -> Text) -> a -> ValidationRule
mkEqValidationRule Claims -> Maybe a
lookup ClaimKey
claimKey a -> Text
render a
x = (Claims -> Either ValidationError ()) -> ValidationRule
ValidationRule ((Claims -> Either ValidationError ()) -> ValidationRule)
-> (Claims -> Either ValidationError ()) -> ValidationRule
forall a b. (a -> b) -> a -> b
$ \Claims
cs ->
  case Claims -> Maybe a
lookup Claims
cs of
    Just a
y
      | a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y -> () -> Either ValidationError ()
forall a b. b -> Either a b
Right ()
      | Bool
otherwise -> ValidationError -> Either ValidationError ()
forall a b. a -> Either a b
Left (ValidationError -> Either ValidationError ())
-> ValidationError -> Either ValidationError ()
forall a b. (a -> b) -> a -> b
$ ClaimKey -> Text -> Text -> ValidationError
ValidationInvalidClaimError ClaimKey
claimKey (a -> Text
render a
x) (a -> Text
render a
y)
    Maybe a
Nothing -> ValidationError -> Either ValidationError ()
forall a b. a -> Either a b
Left (ClaimKey -> ValidationError
ValidationClaimNotFoundError ClaimKey
claimKey)

-- | Validate that a token is intended for a given audience.
forAudience :: Audience -> ValidationRule
forAudience :: Audience -> ValidationRule
forAudience = (Claims -> Maybe Audience)
-> ClaimKey -> (Audience -> Text) -> Audience -> ValidationRule
forall a.
Eq a =>
(Claims -> Maybe a)
-> ClaimKey -> (a -> Text) -> a -> ValidationRule
mkEqValidationRule Claims -> Maybe Audience
lookupAudience ClaimKey
AudienceClaimKey Audience -> Text
unAudience

-- | Validate a token's identifier.
identifiedBy :: TokenIdentifier -> ValidationRule
identifiedBy :: TokenIdentifier -> ValidationRule
identifiedBy = (Claims -> Maybe TokenIdentifier)
-> ClaimKey
-> (TokenIdentifier -> Text)
-> TokenIdentifier
-> ValidationRule
forall a.
Eq a =>
(Claims -> Maybe a)
-> ClaimKey -> (a -> Text) -> a -> ValidationRule
mkEqValidationRule Claims -> Maybe TokenIdentifier
lookupTokenIdentifier ClaimKey
TokenIdentifierClaimKey TokenIdentifier -> Text
unTokenIdentifier

-- | Validate a token's issuer.
issuedBy :: Issuer -> ValidationRule
issuedBy :: Issuer -> ValidationRule
issuedBy = (Claims -> Maybe Issuer)
-> ClaimKey -> (Issuer -> Text) -> Issuer -> ValidationRule
forall a.
Eq a =>
(Claims -> Maybe a)
-> ClaimKey -> (a -> Text) -> a -> ValidationRule
mkEqValidationRule Claims -> Maybe Issuer
lookupIssuer ClaimKey
IssuerClaimKey Issuer -> Text
unIssuer

-- | Validate that a token is not expired at the given time.
--
-- That is, if the 'Crypto.Paseto.Token.Claim.ExpirationClaim' is present,
-- check that it isn't in the past (relative to the given time).
notExpired :: UTCTime -> ValidationRule
notExpired :: UTCTime -> ValidationRule
notExpired UTCTime
x = (Claims -> Either ValidationError ()) -> ValidationRule
ValidationRule ((Claims -> Either ValidationError ()) -> ValidationRule)
-> (Claims -> Either ValidationError ()) -> ValidationRule
forall a b. (a -> b) -> a -> b
$ \Claims
cs ->
  case Claims -> Maybe Expiration
lookupExpiration Claims
cs of
    Just exp :: Expiration
exp@(Expiration UTCTime
y)
      | UTCTime
x UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
<= UTCTime
y -> () -> Either ValidationError ()
forall a b. b -> Either a b
Right ()
      | Bool
otherwise -> ValidationError -> Either ValidationError ()
forall a b. a -> Either a b
Left (Expiration -> ValidationError
ValidationExpirationError Expiration
exp)
    Maybe Expiration
Nothing -> () -> Either ValidationError ()
forall a b. b -> Either a b
Right ()

-- | Validate the subject of a token.
subject :: Subject -> ValidationRule
subject :: Subject -> ValidationRule
subject = (Claims -> Maybe Subject)
-> ClaimKey -> (Subject -> Text) -> Subject -> ValidationRule
forall a.
Eq a =>
(Claims -> Maybe a)
-> ClaimKey -> (a -> Text) -> a -> ValidationRule
mkEqValidationRule Claims -> Maybe Subject
lookupSubject ClaimKey
SubjectClaimKey Subject -> Text
unSubject

-- | Validate that a token is valid at the given time.
--
-- This involves the following checks (relative to the given time):
--
-- * If the 'Crypto.Paseto.Token.Claim.ExpirationClaim' is present, check that
-- it isn't in the past.
--
-- * If the 'Crypto.Paseto.Token.Claim.IssuedAtClaim' is present, check that it
-- isn't in the future.
--
-- * If the 'Crypto.Paseto.Token.Claim.NotBeforeClaim' is present, check that
-- it isn't in the future.
validAt :: UTCTime -> ValidationRule
validAt :: UTCTime -> ValidationRule
validAt UTCTime
x = (Claims -> Either ValidationError ()) -> ValidationRule
ValidationRule ((Claims -> Either ValidationError ()) -> ValidationRule)
-> (Claims -> Either ValidationError ()) -> ValidationRule
forall a b. (a -> b) -> a -> b
$ \Claims
cs -> do
  ValidationRule -> Claims -> Either ValidationError ()
unValidationRule (UTCTime -> ValidationRule
notExpired UTCTime
x) Claims
cs

  case Claims -> Maybe IssuedAt
lookupIssuedAt Claims
cs of
    Maybe IssuedAt
Nothing -> () -> Either ValidationError ()
forall a b. b -> Either a b
Right ()
    Just iat :: IssuedAt
iat@(IssuedAt UTCTime
y)
      | UTCTime
x UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
>= UTCTime
y -> () -> Either ValidationError ()
forall a b. b -> Either a b
Right ()
      | Bool
otherwise -> ValidationError -> Either ValidationError ()
forall a b. a -> Either a b
Left (IssuedAt -> ValidationError
ValidationIssuedAtError IssuedAt
iat)

  case Claims -> Maybe NotBefore
lookupNotBefore Claims
cs of
    Maybe NotBefore
Nothing -> () -> Either ValidationError ()
forall a b. b -> Either a b
Right ()
    Just nbf :: NotBefore
nbf@(NotBefore UTCTime
y)
      | UTCTime
x UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
>= UTCTime
y -> () -> Either ValidationError ()
forall a b. b -> Either a b
Right ()
      | Bool
otherwise -> ValidationError -> Either ValidationError ()
forall a b. a -> Either a b
Left (NotBefore -> ValidationError
ValidationNotBeforeError NotBefore
nbf)

-- | Validate that a custom claim is equal to the given value.
customClaimEq
  :: ClaimMustExist
  -- ^ Whether the custom claim must exist.
  -> UnregisteredClaimKey
  -- ^ Custom claim key to lookup.
  -> Aeson.Value
  -- ^ Custom claim value to validate (i.e. the expected value).
  -> ValidationRule
customClaimEq :: ClaimMustExist -> UnregisteredClaimKey -> Value -> ValidationRule
customClaimEq ClaimMustExist
mustExist UnregisteredClaimKey
k Value
expected = (Claims -> Either ValidationError ()) -> ValidationRule
ValidationRule ((Claims -> Either ValidationError ()) -> ValidationRule)
-> (Claims -> Either ValidationError ()) -> ValidationRule
forall a b. (a -> b) -> a -> b
$ \Claims
cs ->
  case (ClaimMustExist
mustExist, UnregisteredClaimKey -> Claims -> Maybe Value
lookupCustom UnregisteredClaimKey
k Claims
cs) of
    (ClaimMustExist Bool
True, Maybe Value
Nothing) -> ValidationError -> Either ValidationError ()
forall a b. a -> Either a b
Left (ClaimKey -> ValidationError
ValidationClaimNotFoundError (ClaimKey -> ValidationError) -> ClaimKey -> ValidationError
forall a b. (a -> b) -> a -> b
$ UnregisteredClaimKey -> ClaimKey
CustomClaimKey UnregisteredClaimKey
k)
    (ClaimMustExist Bool
False, Maybe Value
Nothing) -> () -> Either ValidationError ()
forall a b. b -> Either a b
Right ()
    (ClaimMustExist
_, Just Value
actual)
      | Value
expected Value -> Value -> Bool
forall a. Eq a => a -> a -> Bool
== Value
actual -> () -> Either ValidationError ()
forall a b. b -> Either a b
Right ()
      | Bool
otherwise ->
          ValidationError -> Either ValidationError ()
forall a b. a -> Either a b
Left (ValidationError -> Either ValidationError ())
-> ValidationError -> Either ValidationError ()
forall a b. (a -> b) -> a -> b
$
            ClaimKey -> Text -> Text -> ValidationError
ValidationInvalidClaimError
              (UnregisteredClaimKey -> ClaimKey
CustomClaimKey UnregisteredClaimKey
k)
              (ByteString -> Text
TE.decodeUtf8 (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BS.toStrict (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ Value -> ByteString
forall a. ToJSON a => a -> ByteString
Aeson.encode Value
expected)
              (ByteString -> Text
TE.decodeUtf8 (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BS.toStrict (ByteString -> Text) -> ByteString -> Text
forall a b. (a -> b) -> a -> b
$ Value -> ByteString
forall a. ToJSON a => a -> ByteString
Aeson.encode Value
actual)

-- | Get a list of
-- [recommended default validation rules](https://github.com/paseto-standard/paseto-spec/blob/af79f25908227555404e7462ccdd8ce106049469/docs/02-Implementation-Guide/05-API-UX.md#secure-defaults).
--
-- At the moment, the only default rule is checking 'validAt' for the current
-- system time ('getCurrentTime').
getDefaultValidationRules :: IO [ValidationRule]
getDefaultValidationRules :: IO [ValidationRule]
getDefaultValidationRules = ValidationRule -> [ValidationRule]
forall a. a -> [a]
L.singleton (ValidationRule -> [ValidationRule])
-> (UTCTime -> ValidationRule) -> UTCTime -> [ValidationRule]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UTCTime -> ValidationRule
validAt (UTCTime -> [ValidationRule]) -> IO UTCTime -> IO [ValidationRule]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UTCTime
getCurrentTime

-- | Validate a list of rules against a collection of claims.
--
-- This function will run through all of the provided validation rules and
-- collect all of the errors encountered, if any. If there are no validation
-- errors, @Right ()@ is returned.
validate :: [ValidationRule] -> Claims -> Either (NonEmpty ValidationError) ()
validate :: [ValidationRule] -> Claims -> Either (NonEmpty ValidationError) ()
validate [ValidationRule]
rs Claims
cs =
  case [ValidationError] -> Maybe (NonEmpty ValidationError)
forall a. [a] -> Maybe (NonEmpty a)
NE.nonEmpty ([ValidationError] -> Maybe (NonEmpty ValidationError))
-> [ValidationError] -> Maybe (NonEmpty ValidationError)
forall a b. (a -> b) -> a -> b
$ [Either ValidationError ()] -> [ValidationError]
forall a b. [Either a b] -> [a]
lefts ((ValidationRule -> Either ValidationError ())
-> [ValidationRule] -> [Either ValidationError ()]
forall a b. (a -> b) -> [a] -> [b]
map ValidationRule -> Either ValidationError ()
v [ValidationRule]
rs) of
    Just NonEmpty ValidationError
errs -> NonEmpty ValidationError -> Either (NonEmpty ValidationError) ()
forall a b. a -> Either a b
Left NonEmpty ValidationError
errs
    Maybe (NonEmpty ValidationError)
Nothing -> () -> Either (NonEmpty ValidationError) ()
forall a b. b -> Either a b
Right ()
  where
    v :: ValidationRule -> Either ValidationError ()
v (ValidationRule Claims -> Either ValidationError ()
f) = Claims -> Either ValidationError ()
f Claims
cs

-- | Validate a collection of claims against the default validation rules
-- ('getDefaultValidationRules').
validateDefault :: Claims -> IO (Either (NonEmpty ValidationError) ())
validateDefault :: Claims -> IO (Either (NonEmpty ValidationError) ())
validateDefault Claims
cs = ([ValidationRule]
 -> Claims -> Either (NonEmpty ValidationError) ())
-> Claims
-> [ValidationRule]
-> Either (NonEmpty ValidationError) ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip [ValidationRule] -> Claims -> Either (NonEmpty ValidationError) ()
validate Claims
cs ([ValidationRule] -> Either (NonEmpty ValidationError) ())
-> IO [ValidationRule] -> IO (Either (NonEmpty ValidationError) ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO [ValidationRule]
getDefaultValidationRules