{-# LANGUAGE OverloadedStrings, FlexibleContexts #-}
{-# OPTIONS_HADDOCK prune #-}

-- | High-level JWT encoding and decoding.
--
-- See the Jose.Jws and Jose.Jwe modules for specific JWS and JWE examples.
--
-- Example usage with a key stored as a JWK:
--
-- >>> import Jose.Jwe
-- >>> import Jose.Jwa
-- >>> import Jose.Jwk
-- >>> import Data.ByteString
-- >>> import Data.Aeson (decodeStrict)
-- >>> let jsonJwk = "{\"kty\":\"RSA\", \"kid\":\"mykey\", \"n\":\"ofgWCuLjybRlzo0tZWJjNiuSfb4p4fAkd_wWJcyQoTbji9k0l8W26mPddxHmfHQp-Vaw-4qPCJrcS2mJPMEzP1Pt0Bm4d4QlL-yRT-SFd2lZS-pCgNMsD1W_YpRPEwOWvG6b32690r2jZ47soMZo9wGzjb_7OMg0LOL-bSf63kpaSHSXndS5z5rexMdbBYUsLA9e-KXBdQOS-UTo7WTBEMa2R2CapHg665xsmtdVMTBQY4uDZlxvb3qCo5ZwKh9kG4LT6_I5IhlJH7aGhyxXFvUK-DWNmoudF8NAco9_h9iaGNj8q2ethFkMLs91kzk2PAcDTW9gb54h4FRWyuXpoQ\", \"e\":\"AQAB\", \"d\":\"Eq5xpGnNCivDflJsRQBXHx1hdR1k6Ulwe2JZD50LpXyWPEAeP88vLNO97IjlA7_GQ5sLKMgvfTeXZx9SE-7YwVol2NXOoAJe46sui395IW_GO-pWJ1O0BkTGoVEn2bKVRUCgu-GjBVaYLU6f3l9kJfFNS3E0QbVdxzubSu3Mkqzjkn439X0M_V51gfpRLI9JYanrC4D4qAdGcopV_0ZHHzQlBjudU2QvXt4ehNYTCBr6XCLQUShb1juUO1ZdiYoFaFQT5Tw8bGUl_x_jTj3ccPDVZFD9pIuhLhBOneufuBiB4cS98l2SR_RQyGWSeWjnczT0QU91p1DhOVRuOopznQ\"}" :: ByteString
-- >>> let Just jwk = decodeStrict jsonJwk :: Maybe Jwk
-- >>> Right (Jwt jwtEncoded) <- encode [jwk] (JwsEncoding RS256) (Claims "public claims")
-- >>> Right jwtDecoded <- Jose.Jwt.decode [jwk] (Just (JwsEncoding RS256)) jwtEncoded
-- >>> jwtDecoded
-- Jws (JwsHeader {jwsAlg = RS256, jwsTyp = Nothing, jwsCty = Nothing, jwsKid = Just (KeyId "mykey")},"public claims")

module Jose.Jwt
    ( module Jose.Types
    , encode
    , decode
    , decodeClaims
    )
where

import Control.Monad (msum, when, unless)
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import Crypto.PubKey.RSA (PrivateKey(..))
import Crypto.Random (MonadRandom)
import Data.Aeson (decodeStrict',FromJSON)
import Data.ByteString (ByteString)
import Data.Maybe (isNothing)
import qualified Data.ByteString.Char8 as BC

import qualified Jose.Internal.Base64 as B64
import qualified Jose.Internal.Parser as P
import Jose.Types
import Jose.Jwk
import Jose.Jwa

import qualified Jose.Jws as Jws
import qualified Jose.Jwe as Jwe


-- | Use the supplied JWKs to create a JWT.
-- The list of keys will be searched to locate one which is
-- consistent with the chosen encoding algorithms.
--
encode :: MonadRandom m
    => [Jwk]                     -- ^ The key or keys. At least one must be consistent with the chosen algorithm
    -> JwtEncoding               -- ^ The encoding algorithm(s) used to encode the payload
    -> Payload                   -- ^ The payload (claims)
    -> m (Either JwtError Jwt)   -- ^ The encoded JWT, if successful
encode :: forall (m :: * -> *).
MonadRandom m =>
[Jwk] -> JwtEncoding -> Payload -> m (Either JwtError Jwt)
encode [Jwk]
jwks JwtEncoding
encoding Payload
msg = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ case JwtEncoding
encoding of
    JwsEncoding JwsAlg
None -> case Payload
msg of
        Claims ByteString
p -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString -> Jwt
Jwt forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
BC.intercalate ByteString
"." [ByteString
unsecuredHdr, forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode ByteString
p]
        Nested Jwt
_ -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE JwtError
BadClaims
    JwsEncoding JwsAlg
a    -> case forall a. (a -> Bool) -> [a] -> [a]
filter (JwsAlg -> Jwk -> Bool
canEncodeJws JwsAlg
a) [Jwk]
jwks of
        []    -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
KeyError Text
"No matching key found for JWS algorithm")
        (Jwk
k:[Jwk]
_) -> forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *).
MonadRandom m =>
JwsAlg -> Jwk -> Payload -> m (Either JwtError Jwt)
Jws.jwkEncode JwsAlg
a Jwk
k Payload
msg)
    JweEncoding JweAlg
a Enc
e -> case forall a. (a -> Bool) -> [a] -> [a]
filter (JweAlg -> Jwk -> Bool
canEncodeJwe JweAlg
a) [Jwk]
jwks of
        []    -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
KeyError Text
"No matching key found for JWE algorithm")
        (Jwk
k:[Jwk]
_) -> forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *).
MonadRandom m =>
JweAlg -> Enc -> Jwk -> Payload -> m (Either JwtError Jwt)
Jwe.jwkEncode JweAlg
a Enc
e Jwk
k Payload
msg)
  where
    unsecuredHdr :: ByteString
unsecuredHdr = forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode (String -> ByteString
BC.pack String
"{\"alg\":\"none\"}")


-- | Uses the supplied keys to decode a JWT.
-- Locates a matching key by header @kid@ value where possible
-- or by suitable key type for the encoding algorithm.
--
-- The algorithm(s) used can optionally be supplied for validation
-- by setting the @JwtEncoding@ parameter, in which case an error will
-- be returned if they don't match. If you expect the tokens to use
-- a particular algorithm, then you should set this parameter.
--
-- For unsecured tokens (with algorithm "none"), the expected algorithm
-- must be set to @Just (JwsEncoding None)@ or an error will be returned.
decode :: MonadRandom m
    => [Jwk]                           -- ^ The keys to use for decoding
    -> Maybe JwtEncoding               -- ^ The expected encoding information
    -> ByteString                      -- ^ The encoded JWT
    -> m (Either JwtError JwtContent)  -- ^ The decoded JWT payload, if successful
decode :: forall (m :: * -> *).
MonadRandom m =>
[Jwk]
-> Maybe JwtEncoding
-> ByteString
-> m (Either JwtError JwtContent)
decode [Jwk]
keySet Maybe JwtEncoding
encoding ByteString
jwt = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
    DecodableJwt
decodableJwt <- forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Either JwtError DecodableJwt
P.parseJwt ByteString
jwt))

    [Maybe JwtContent]
decodings <- case (DecodableJwt
decodableJwt, Maybe JwtEncoding
encoding) of
        (P.Unsecured ByteString
p, Just (JwsEncoding JwsAlg
None)) -> forall (m :: * -> *) a. Monad m => a -> m a
return [forall a. a -> Maybe a
Just (ByteString -> JwtContent
Unsecured ByteString
p)]
        (P.Unsecured ByteString
_, Maybe JwtEncoding
_) -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"JWT is unsecured but expected 'alg' was not 'none'")
        (P.DecodableJws JwsHeader
hdr Payload
_ Sig
_ SigTarget
_, Maybe JwtEncoding
e) -> do
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a. Maybe a -> Bool
isNothing Maybe JwtEncoding
e Bool -> Bool -> Bool
|| Maybe JwtEncoding
e forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just (JwsAlg -> JwtEncoding
JwsEncoding (JwsHeader -> JwsAlg
jwsAlg JwsHeader
hdr))) forall a b. (a -> b) -> a -> b
$
                forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"Expected 'alg' doesn't match JWS header")
            [Jwk]
ks <- forall {m :: * -> *} {a}. Monad m => [a] -> ExceptT JwtError m [a]
checkKeys forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (JwsHeader -> Jwk -> Bool
canDecodeJws JwsHeader
hdr) [Jwk]
keySet
            forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJws [Jwk]
ks
        (P.DecodableJwe JweHeader
hdr EncryptedCEK
_ IV
_ Payload
_ Tag
_ AAD
_, Maybe JwtEncoding
e) -> do
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a. Maybe a -> Bool
isNothing Maybe JwtEncoding
e Bool -> Bool -> Bool
|| Maybe JwtEncoding
e forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just (JweAlg -> Enc -> JwtEncoding
JweEncoding (JweHeader -> JweAlg
jweAlg JweHeader
hdr) (JweHeader -> Enc
jweEnc JweHeader
hdr))) forall a b. (a -> b) -> a -> b
$
                forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadAlgorithm Text
"Expected encoding doesn't match JWE header")
            [Jwk]
ks <- forall {m :: * -> *} {a}. Monad m => [a] -> ExceptT JwtError m [a]
checkKeys forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (JweHeader -> Jwk -> Bool
canDecodeJwe JweHeader
hdr) [Jwk]
keySet
            forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJwe [Jwk]
ks
    case forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [Maybe JwtContent]
decodings of
        Maybe JwtContent
Nothing  -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"None of the keys was able to decode the JWT"
        Just JwtContent
jwtContent -> forall (m :: * -> *) a. Monad m => a -> m a
return JwtContent
jwtContent
  where
    decodeWithJws :: MonadRandom m => Jwk -> ExceptT JwtError m (Maybe JwtContent)
    decodeWithJws :: forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJws Jwk
k = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing) (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. Jws -> JwtContent
Jws) forall a b. (a -> b) -> a -> b
$ case Jwk
k of
        Ed25519PublicJwk PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed25519Decode PublicKey
kPub ByteString
jwt
        Ed25519PrivateJwk SecretKey
_ PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed25519Decode PublicKey
kPub ByteString
jwt
        Ed448PublicJwk PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed448Decode PublicKey
kPub ByteString
jwt
        Ed448PrivateJwk SecretKey
_ PublicKey
kPub Maybe KeyId
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ed448Decode PublicKey
kPub ByteString
jwt
        RsaPublicJwk  PublicKey
kPub Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.rsaDecode PublicKey
kPub ByteString
jwt
        RsaPrivateJwk PrivateKey
kPr  Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.rsaDecode (PrivateKey -> PublicKey
private_pub PrivateKey
kPr) ByteString
jwt
        EcPublicJwk   PublicKey
kPub Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ EcCurve
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ecDecode PublicKey
kPub ByteString
jwt
        EcPrivateJwk  KeyPair
kPr  Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ EcCurve
_ -> PublicKey -> ByteString -> Either JwtError Jws
Jws.ecDecode (KeyPair -> PublicKey
ECDSA.toPublicKey KeyPair
kPr) ByteString
jwt
        SymmetricJwk  ByteString
kb   Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> ByteString -> ByteString -> Either JwtError Jws
Jws.hmacDecode ByteString
kb ByteString
jwt
        UnsupportedJwk Object
_ -> forall a b. a -> Either a b
Left (Text -> JwtError
KeyError Text
"Unsupported JWKs cannot be used")

    decodeWithJwe :: MonadRandom m => Jwk -> ExceptT JwtError m (Maybe JwtContent)
    decodeWithJwe :: forall (m :: * -> *).
MonadRandom m =>
Jwk -> ExceptT JwtError m (Maybe JwtContent)
decodeWithJwe Jwk
k = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const forall a. Maybe a
Nothing) forall a. a -> Maybe a
Just) (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *).
MonadRandom m =>
Jwk -> ByteString -> m (Either JwtError JwtContent)
Jwe.jwkDecode Jwk
k ByteString
jwt))

    checkKeys :: [a] -> ExceptT JwtError m [a]
checkKeys [] = forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"No suitable key was found to decode the JWT"
    checkKeys [a]
ks = forall (m :: * -> *) a. Monad m => a -> m a
return [a]
ks


-- | Convenience function to return the claims contained in a JWS.
-- This is needed in situations such as client assertion authentication,
-- <https://tools.ietf.org/html/rfc7523>, where the contents of the JWT,
-- such as the @sub@ claim, may be required in order to work out
-- which key should be used to verify the token.
--
-- Obviously this should not be used by itself to decode a token since
-- no integrity checking is done and the contents may be forged.
decodeClaims :: (FromJSON a)
    => ByteString
    -> Either JwtError (JwtHeader, a)
decodeClaims :: forall a.
FromJSON a =>
ByteString -> Either JwtError (JwtHeader, a)
decodeClaims ByteString
jwt = do
    let components :: [ByteString]
components = Char -> ByteString -> [ByteString]
BC.split Char
'.' ByteString
jwt
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
components forall a. Eq a => a -> a -> Bool
/= Int
3) forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Int -> JwtError
BadDots Int
2
    JwtHeader
hdr    <- forall input output (m :: * -> *).
(ByteArrayAccess input, ByteArray output, MonadError JwtError m) =>
input -> m output
B64.decode (forall a. [a] -> a
head [ByteString]
components) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Either JwtError JwtHeader
parseHeader
    a
claims <- forall input output (m :: * -> *).
(ByteArrayAccess input, ByteArray output, MonadError JwtError m) =>
input -> m output
B64.decode ((forall a. [a] -> a
head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
tail) [ByteString]
components) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {b}. FromJSON b => ByteString -> Either JwtError b
parseClaims
    forall (m :: * -> *) a. Monad m => a -> m a
return (JwtHeader
hdr, a
claims)
  where
    parseClaims :: ByteString -> Either JwtError b
parseClaims ByteString
bs = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a b. a -> Either a b
Left JwtError
BadClaims) forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall a. FromJSON a => ByteString -> Maybe a
decodeStrict' ByteString
bs