{-# LANGUAGE OverloadedStrings #-}

-- | JWE encrypted token support.
--
-- To create a JWE, you need to select two algorithms. One is an AES algorithm
-- used to encrypt the content of your token (for example, @A128GCM@), for which
-- a single-use key is generated internally. The second is used to encrypt
-- this content-encryption key and can be either an RSA or AES-keywrap algorithm.
-- You need to generate a suitable key to use with this, or load one from storage.
--
-- AES is much faster and creates shorter tokens, but both the encoder and decoder
-- of the token need to have a copy of the key, which they must keep secret. With
-- RSA anyone can send you a JWE if they have a copy of your public key.
--
-- In the example below, we show encoding and decoding using a 2048 bit RSA key pair
-- (256 bytes). If using RSA, use one of the @RSA_OAEP@ algorithms. @RSA1_5@ is
-- deprecated due to <https://robotattack.org/ known vulnerabilities>.
--
-- >>> import Jose.Jwe
-- >>> import Jose.Jwa
-- >>> import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId)
-- >>> (kPub, kPr) <- generateRsaKeyPair 256 (KeyId "My RSA Key") Enc Nothing
-- >>> Right (Jwt jwt) <- jwkEncode RSA_OAEP A128GCM kPub (Claims "secret claims")
-- >>> Right (Jwe (hdr, claims)) <- jwkDecode kPr jwt
-- >>> claims
-- "secret claims"
--
-- Using 128-bit AES keywrap is very similar, the main difference is that
-- we generate a 128-bit symmetric key (16 bytes):
--
-- >>> aesKey <- generateSymmetricKey 16 (KeyId "My Keywrap Key") Enc Nothing
-- >>> Right (Jwt jwt) <- jwkEncode A128KW A128GCM aesKey (Claims "more secret claims")
-- >>> Right (Jwe (hdr, claims)) <- jwkDecode aesKey jwt
-- >>> claims
-- "more secret claims"

module Jose.Jwe
    ( jwkEncode
    , jwkDecode
    , rsaEncode
    , rsaDecode
    )
where

import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except
import Crypto.Cipher.Types (AuthTag(..))
import Crypto.PubKey.RSA (PrivateKey(..), PublicKey(..), generateBlinder, private_pub)
import Crypto.Random (MonadRandom)
import qualified Data.Aeson as A
import Data.ByteArray (ByteArray, ScrubbedBytes)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Maybe (isNothing)
import Jose.Types
import qualified Jose.Internal.Base64 as B64
import Jose.Internal.Crypto
import Jose.Jwa
import Jose.Jwk
import qualified Jose.Internal.Parser as P

-- | Create a JWE using a JWK.
-- The key and algorithms must be consistent or an error
-- will be returned.
jwkEncode :: MonadRandom m
    => JweAlg                          -- ^ Algorithm to use for key encryption
    -> Enc                             -- ^ Content encryption algorithm
    -> Jwk                             -- ^ The key to use to encrypt the content key
    -> Payload                         -- ^ The token content (claims or nested JWT)
    -> m (Either JwtError Jwt)         -- ^ The encoded JWE if successful
jwkEncode :: forall (m :: * -> *).
MonadRandom m =>
JweAlg -> Enc -> Jwk -> Payload -> m (Either JwtError Jwt)
jwkEncode JweAlg
a Enc
e Jwk
jwk Payload
payload = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ case Jwk
jwk of
    RsaPublicJwk PublicKey
kPub Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> ByteString
hdr Maybe KeyId
kid) Enc
e (forall {m :: * -> *} {a} {a}.
(MonadRandom m, ByteArray a, ByteArray a) =>
PublicKey -> a -> ExceptT JwtError m a
doRsa PublicKey
kPub) ByteString
bytes
    RsaPrivateJwk PrivateKey
kPr Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> ByteString
hdr Maybe KeyId
kid) Enc
e (forall {m :: * -> *} {a} {a}.
(MonadRandom m, ByteArray a, ByteArray a) =>
PublicKey -> a -> ExceptT JwtError m a
doRsa (PrivateKey -> PublicKey
private_pub PrivateKey
kPr)) ByteString
bytes
    SymmetricJwk  ByteString
kek Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> ByteString
hdr Maybe KeyId
kid) Enc
e (forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
.  forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ba.
ByteArray ba =>
JweAlg -> ScrubbedBytes -> ScrubbedBytes -> Either JwtError ba
keyWrap JweAlg
a (forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
kek)) ByteString
bytes
    Jwk
_                         -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"JWK cannot encode a JWE"
  where
    doRsa :: PublicKey -> a -> ExceptT JwtError m a
doRsa PublicKey
kPub = forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) msg out.
(MonadRandom m, ByteArray msg, ByteArray out) =>
PublicKey -> JweAlg -> msg -> m (Either JwtError out)
rsaEncrypt PublicKey
kPub JweAlg
a
    hdr :: Maybe KeyId -> B.ByteString
    hdr :: Maybe KeyId -> ByteString
hdr Maybe KeyId
kid = ByteString -> ByteString
BL.toStrict forall a b. (a -> b) -> a -> b
$
        [ByteString] -> ByteString
BL.concat
            [ ByteString
"{\"alg\":"
            , forall a. ToJSON a => a -> ByteString
A.encode JweAlg
a
            , ByteString
",\"enc\":"
            , forall a. ToJSON a => a -> ByteString
A.encode Enc
e
            , forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
"" (\ByteString
c -> [ByteString] -> ByteString
BL.concat [ByteString
",\"cty\":\"", ByteString
c, ByteString
"\"" ]) Maybe ByteString
contentType
            , if forall a. Maybe a -> Bool
isNothing Maybe KeyId
kid then ByteString
"" else [ByteString] -> ByteString
BL.concat [ByteString
",\"kid\":", forall a. ToJSON a => a -> ByteString
A.encode Maybe KeyId
kid ]
            , ByteString
"}"
            ]

    (Maybe ByteString
contentType, ByteString
bytes) = case Payload
payload of
        Claims ByteString
c       -> (forall a. Maybe a
Nothing, ByteString
c)
        Nested (Jwt ByteString
b) -> (forall a. a -> Maybe a
Just ByteString
"JWT", ByteString
b)


-- | Try to decode a JWE using a JWK.
-- If the key type does not match the content encoding algorithm,
-- an error will be returned.
jwkDecode :: MonadRandom m
    => Jwk
    -> ByteString
    -> m (Either JwtError JwtContent)
jwkDecode :: forall (m :: * -> *).
MonadRandom m =>
Jwk -> ByteString -> m (Either JwtError JwtContent)
jwkDecode Jwk
jwk ByteString
jwt = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ case Jwk
jwk of
    RsaPrivateJwk PrivateKey
kPr Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> do
        Blinder
blinder <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadRandom m => Integer -> m Blinder
generateBlinder (PublicKey -> Integer
public_n forall a b. (a -> b) -> a -> b
$ PrivateKey -> PublicKey
private_pub PrivateKey
kPr)
        Jwe
e <- forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (forall ct.
ByteArray ct =>
Maybe Blinder
-> PrivateKey -> JweAlg -> ct -> Either JwtError ScrubbedBytes
rsaDecrypt (forall a. a -> Maybe a
Just Blinder
blinder) PrivateKey
kPr) ByteString
jwt
        forall (m :: * -> *) a. Monad m => a -> m a
return (Jwe -> JwtContent
Jwe Jwe
e)
    SymmetricJwk ByteString
kb   Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Jwe -> JwtContent
Jwe (forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (forall ba.
ByteArray ba =>
ScrubbedBytes -> JweAlg -> ba -> Either JwtError ScrubbedBytes
keyUnwrap (forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
kb)) ByteString
jwt)
    UnsupportedJwk Object
_ -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
KeyError Text
"Unsupported JWK cannot be used to decode JWE")
    Jwk
_ -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"This JWK cannot decode a JWE"


doDecode :: MonadRandom m
    => (JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
    -> ByteString
    -> ExceptT JwtError m Jwe
doDecode :: forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode JweAlg -> ByteString -> Either JwtError ScrubbedBytes
decodeCek ByteString
jwt = do
    DecodableJwt
encodedJwt <- forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Either JwtError DecodableJwt
P.parseJwt ByteString
jwt))
    case DecodableJwt
encodedJwt of
        P.DecodableJwe JweHeader
hdr (P.EncryptedCEK ByteString
ek) IV
iv (P.Payload ByteString
payload) Tag
tag (P.AAD ByteString
aad) -> do
            let alg :: JweAlg
alg = JweHeader -> JweAlg
jweAlg JweHeader
hdr
                enc :: Enc
enc = JweHeader -> Enc
jweEnc JweHeader
hdr
            (ScrubbedBytes
dummyCek, ScrubbedBytes
_) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadRandom m =>
Enc -> m (ScrubbedBytes, ScrubbedBytes)
generateCmkAndIV Enc
enc
            let decryptedCek :: ScrubbedBytes
decryptedCek = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const ScrubbedBytes
dummyCek) forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ JweAlg -> ByteString -> Either JwtError ScrubbedBytes
decodeCek JweAlg
alg ByteString
ek
                cek :: ScrubbedBytes
cek = if forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
decryptedCek forall a. Eq a => a -> a -> Bool
== forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
dummyCek
                        then ScrubbedBytes
decryptedCek
                        else ScrubbedBytes
dummyCek
            ByteString
claims <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE JwtError
BadCrypto) forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall ba.
ByteArray ba =>
Enc -> ScrubbedBytes -> IV -> ba -> Tag -> ba -> Maybe ba
decryptPayload Enc
enc ScrubbedBytes
cek IV
iv ByteString
aad Tag
tag ByteString
payload
            forall (m :: * -> *) a. Monad m => a -> m a
return (JweHeader
hdr, ByteString
claims)

        DecodableJwt
_ -> forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadHeader Text
"Content is not a JWE")


doEncode :: (MonadRandom m, ByteArray ba)
    => ByteString
    -> Enc
    -> (ScrubbedBytes -> ExceptT JwtError m ByteString)
    -> ba
    -> ExceptT JwtError m Jwt
doEncode :: forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode ByteString
hdr Enc
e ScrubbedBytes -> ExceptT JwtError m ByteString
encryptKey ba
claims = do
    (ScrubbedBytes
cmk, ScrubbedBytes
iv) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *).
MonadRandom m =>
Enc -> m (ScrubbedBytes, ScrubbedBytes)
generateCmkAndIV Enc
e)
    let Just (AuthTag Bytes
sig, ba
ct) = forall ba iv.
(ByteArray ba, ByteArray iv) =>
Enc -> ScrubbedBytes -> iv -> ba -> ba -> Maybe (AuthTag, ba)
encryptPayload Enc
e ScrubbedBytes
cmk ScrubbedBytes
iv ba
aad ba
claims
    ByteString
jweKey <- ScrubbedBytes -> ExceptT JwtError m ByteString
encryptKey ScrubbedBytes
cmk
    let jwe :: ByteString
jwe = ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"." forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B64.encode [ByteString
hdr, ByteString
jweKey, forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ScrubbedBytes
iv, forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ba
ct, forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert Bytes
sig]
    forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Jwt
Jwt ByteString
jwe)
  where
    aad :: ba
aad = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B64.encode ByteString
hdr

-- | Creates a JWE with the content key encoded using RSA.
rsaEncode :: MonadRandom m
    => JweAlg          -- ^ RSA algorithm to use (@RSA_OAEP@ or @RSA1_5@)
    -> Enc             -- ^ Content encryption algorithm
    -> PublicKey       -- ^ RSA key to encrypt with
    -> ByteString      -- ^ The JWT claims (content)
    -> m (Either JwtError Jwt) -- ^ The encoded JWE
rsaEncode :: forall (m :: * -> *).
MonadRandom m =>
JweAlg -> Enc -> PublicKey -> ByteString -> m (Either JwtError Jwt)
rsaEncode JweAlg
a Enc
e PublicKey
kPub ByteString
claims = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode ByteString
hdr Enc
e (forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) msg out.
(MonadRandom m, ByteArray msg, ByteArray out) =>
PublicKey -> JweAlg -> msg -> m (Either JwtError out)
rsaEncrypt PublicKey
kPub JweAlg
a) ByteString
claims
  where
    hdr :: ByteString
hdr = ByteString -> ByteString
BL.toStrict forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat [ByteString
"{\"alg\":", forall a. ToJSON a => a -> ByteString
A.encode JweAlg
a, ByteString
",", ByteString
"\"enc\":", forall a. ToJSON a => a -> ByteString
A.encode Enc
e, ByteString
"}"]


-- | Decrypts a JWE.
rsaDecode :: MonadRandom m
    => PrivateKey               -- ^ Decryption key
    -> ByteString               -- ^ The encoded JWE
    -> m (Either JwtError Jwe)  -- ^ The decoded JWT, unless an error occurs
rsaDecode :: forall (m :: * -> *).
MonadRandom m =>
PrivateKey -> ByteString -> m (Either JwtError Jwe)
rsaDecode PrivateKey
pk ByteString
jwt = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
    Blinder
blinder <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadRandom m => Integer -> m Blinder
generateBlinder (PublicKey -> Integer
public_n forall a b. (a -> b) -> a -> b
$ PrivateKey -> PublicKey
private_pub PrivateKey
pk)
    forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (forall ct.
ByteArray ct =>
Maybe Blinder
-> PrivateKey -> JweAlg -> ct -> Either JwtError ScrubbedBytes
rsaDecrypt (forall a. a -> Maybe a
Just Blinder
blinder) PrivateKey
pk) ByteString
jwt