{-# LANGUAGE OverloadedStrings #-}

-- | JWE encrypted token support.
--
-- To create a JWE, you need to select two algorithms. One is an AES algorithm
-- used to encrypt the content of your token (for example, @A128GCM@), for which
-- a single-use key is generated internally. The second is used to encrypt
-- this content-encryption key and can be either an RSA or AES-keywrap algorithm.
-- You need to generate a suitable key to use with this, or load one from storage.
--
-- AES is much faster and creates shorter tokens, but both the encoder and decoder
-- of the token need to have a copy of the key, which they must keep secret. With
-- RSA anyone can send you a JWE if they have a copy of your public key.
--
-- In the example below, we show encoding and decoding using a 2048 bit RSA key pair
-- (256 bytes). If using RSA, use one of the @RSA_OAEP@ algorithms. @RSA1_5@ is
-- deprecated due to <https://robotattack.org/ known vulnerabilities>.
--
-- >>> import Jose.Jwe
-- >>> import Jose.Jwa
-- >>> import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId)
-- >>> (kPub, kPr) <- generateRsaKeyPair 256 (KeyId "My RSA Key") Enc Nothing
-- >>> Right (Jwt jwt) <- jwkEncode RSA_OAEP A128GCM kPub (Claims "secret claims")
-- >>> Right (Jwe (hdr, claims)) <- jwkDecode kPr jwt
-- >>> claims
-- "secret claims"
--
-- Using 128-bit AES keywrap is very similar, the main difference is that
-- we generate a 128-bit symmetric key (16 bytes):
--
-- >>> aesKey <- generateSymmetricKey 16 (KeyId "My Keywrap Key") Enc Nothing
-- >>> Right (Jwt jwt) <- jwkEncode A128KW A128GCM aesKey (Claims "more secret claims")
-- >>> Right (Jwe (hdr, claims)) <- jwkDecode aesKey jwt
-- >>> claims
-- "more secret claims"

module Jose.Jwe
    ( jwkEncode
    , jwkDecode
    , rsaEncode
    , rsaDecode
    )
where

import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except
import Crypto.Cipher.Types (AuthTag(..))
import Crypto.PubKey.RSA (PrivateKey(..), PublicKey(..), generateBlinder, private_pub)
import Crypto.Random (MonadRandom)
import qualified Data.Aeson as A
import Data.ByteArray (ByteArray, ScrubbedBytes)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Maybe (isNothing)
import Jose.Types
import qualified Jose.Internal.Base64 as B64
import Jose.Internal.Crypto
import Jose.Jwa
import Jose.Jwk
import qualified Jose.Internal.Parser as P

-- | Create a JWE using a JWK.
-- The key and algorithms must be consistent or an error
-- will be returned.
jwkEncode :: MonadRandom m
    => JweAlg                          -- ^ Algorithm to use for key encryption
    -> Enc                             -- ^ Content encryption algorithm
    -> Jwk                             -- ^ The key to use to encrypt the content key
    -> Payload                         -- ^ The token content (claims or nested JWT)
    -> m (Either JwtError Jwt)         -- ^ The encoded JWE if successful
jwkEncode :: JweAlg -> Enc -> Jwk -> Payload -> m (Either JwtError Jwt)
jwkEncode JweAlg
a Enc
e Jwk
jwk Payload
payload = ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m Jwt -> m (Either JwtError Jwt))
-> ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall a b. (a -> b) -> a -> b
$ case Jwk
jwk of
    RsaPublicJwk PublicKey
kPub Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ByteString
-> ExceptT JwtError m Jwt
forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> ByteString
hdr Maybe KeyId
kid) Enc
e (PublicKey -> ScrubbedBytes -> ExceptT JwtError m ByteString
forall (m :: * -> *) msg a.
(MonadRandom m, ByteArray msg, ByteArray a) =>
PublicKey -> msg -> ExceptT JwtError m a
doRsa PublicKey
kPub) ByteString
bytes
    RsaPrivateJwk PrivateKey
kPr Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ByteString
-> ExceptT JwtError m Jwt
forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> ByteString
hdr Maybe KeyId
kid) Enc
e (PublicKey -> ScrubbedBytes -> ExceptT JwtError m ByteString
forall (m :: * -> *) msg a.
(MonadRandom m, ByteArray msg, ByteArray a) =>
PublicKey -> msg -> ExceptT JwtError m a
doRsa (PrivateKey -> PublicKey
private_pub PrivateKey
kPr)) ByteString
bytes
    SymmetricJwk  ByteString
kek Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ByteString
-> ExceptT JwtError m Jwt
forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode (Maybe KeyId -> ByteString
hdr Maybe KeyId
kid) Enc
e (m (Either JwtError ByteString) -> ExceptT JwtError m ByteString
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either JwtError ByteString) -> ExceptT JwtError m ByteString)
-> (ScrubbedBytes -> m (Either JwtError ByteString))
-> ScrubbedBytes
-> ExceptT JwtError m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  Either JwtError ByteString -> m (Either JwtError ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError ByteString -> m (Either JwtError ByteString))
-> (ScrubbedBytes -> Either JwtError ByteString)
-> ScrubbedBytes
-> m (Either JwtError ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JweAlg
-> ScrubbedBytes -> ScrubbedBytes -> Either JwtError ByteString
forall ba.
ByteArray ba =>
JweAlg -> ScrubbedBytes -> ScrubbedBytes -> Either JwtError ba
keyWrap JweAlg
a (ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
kek)) ByteString
bytes
    Jwk
_                         -> JwtError -> ExceptT JwtError m Jwt
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (JwtError -> ExceptT JwtError m Jwt)
-> JwtError -> ExceptT JwtError m Jwt
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"JWK cannot encode a JWE"
  where
    doRsa :: PublicKey -> msg -> ExceptT JwtError m a
doRsa PublicKey
kPub = m (Either JwtError a) -> ExceptT JwtError m a
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either JwtError a) -> ExceptT JwtError m a)
-> (msg -> m (Either JwtError a)) -> msg -> ExceptT JwtError m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PublicKey -> JweAlg -> msg -> m (Either JwtError a)
forall (m :: * -> *) msg out.
(MonadRandom m, ByteArray msg, ByteArray out) =>
PublicKey -> JweAlg -> msg -> m (Either JwtError out)
rsaEncrypt PublicKey
kPub JweAlg
a
    hdr :: Maybe KeyId -> B.ByteString
    hdr :: Maybe KeyId -> ByteString
hdr Maybe KeyId
kid = ByteString -> ByteString
BL.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
        [ByteString] -> ByteString
BL.concat
            [ ByteString
"{\"alg\":"
            , JweAlg -> ByteString
forall a. ToJSON a => a -> ByteString
A.encode JweAlg
a
            , ByteString
",\"enc\":"
            , Enc -> ByteString
forall a. ToJSON a => a -> ByteString
A.encode Enc
e
            , ByteString
-> (ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
"" (\ByteString
c -> [ByteString] -> ByteString
BL.concat [ByteString
",\"cty\":\"", ByteString
c, ByteString
"\"" ]) Maybe ByteString
contentType
            , if Maybe KeyId -> Bool
forall a. Maybe a -> Bool
isNothing Maybe KeyId
kid then ByteString
"" else [ByteString] -> ByteString
BL.concat [ByteString
",\"kid\":", Maybe KeyId -> ByteString
forall a. ToJSON a => a -> ByteString
A.encode Maybe KeyId
kid ]
            , ByteString
"}"
            ]

    (Maybe ByteString
contentType, ByteString
bytes) = case Payload
payload of
        Claims ByteString
c       -> (Maybe ByteString
forall a. Maybe a
Nothing, ByteString
c)
        Nested (Jwt ByteString
b) -> (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"JWT", ByteString
b)


-- | Try to decode a JWE using a JWK.
-- If the key type does not match the content encoding algorithm,
-- an error will be returned.
jwkDecode :: MonadRandom m
    => Jwk
    -> ByteString
    -> m (Either JwtError JwtContent)
jwkDecode :: Jwk -> ByteString -> m (Either JwtError JwtContent)
jwkDecode Jwk
jwk ByteString
jwt = ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent))
-> ExceptT JwtError m JwtContent -> m (Either JwtError JwtContent)
forall a b. (a -> b) -> a -> b
$ case Jwk
jwk of
    RsaPrivateJwk PrivateKey
kPr Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> do
        Blinder
blinder <- m Blinder -> ExceptT JwtError m Blinder
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Blinder -> ExceptT JwtError m Blinder)
-> m Blinder -> ExceptT JwtError m Blinder
forall a b. (a -> b) -> a -> b
$ Integer -> m Blinder
forall (m :: * -> *). MonadRandom m => Integer -> m Blinder
generateBlinder (PublicKey -> Integer
public_n (PublicKey -> Integer) -> PublicKey -> Integer
forall a b. (a -> b) -> a -> b
$ PrivateKey -> PublicKey
private_pub PrivateKey
kPr)
        Jwe
e <- (JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (Maybe Blinder
-> PrivateKey
-> JweAlg
-> ByteString
-> Either JwtError ScrubbedBytes
forall ct.
ByteArray ct =>
Maybe Blinder
-> PrivateKey -> JweAlg -> ct -> Either JwtError ScrubbedBytes
rsaDecrypt (Blinder -> Maybe Blinder
forall a. a -> Maybe a
Just Blinder
blinder) PrivateKey
kPr) ByteString
jwt
        JwtContent -> ExceptT JwtError m JwtContent
forall (m :: * -> *) a. Monad m => a -> m a
return (Jwe -> JwtContent
Jwe Jwe
e)
    SymmetricJwk ByteString
kb   Maybe KeyId
_ Maybe KeyUse
_ Maybe Alg
_ -> (Jwe -> JwtContent)
-> ExceptT JwtError m Jwe -> ExceptT JwtError m JwtContent
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Jwe -> JwtContent
Jwe ((JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (ScrubbedBytes
-> JweAlg -> ByteString -> Either JwtError ScrubbedBytes
forall ba.
ByteArray ba =>
ScrubbedBytes -> JweAlg -> ba -> Either JwtError ScrubbedBytes
keyUnwrap (ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
kb)) ByteString
jwt)
    UnsupportedJwk Object
_ -> JwtError -> ExceptT JwtError m JwtContent
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
KeyError Text
"Unsupported JWK cannot be used to decode JWE")
    Jwk
_ -> JwtError -> ExceptT JwtError m JwtContent
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (JwtError -> ExceptT JwtError m JwtContent)
-> JwtError -> ExceptT JwtError m JwtContent
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
KeyError Text
"This JWK cannot decode a JWE"


doDecode :: MonadRandom m
    => (JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
    -> ByteString
    -> ExceptT JwtError m Jwe
doDecode :: (JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode JweAlg -> ByteString -> Either JwtError ScrubbedBytes
decodeCek ByteString
jwt = do
    DecodableJwt
encodedJwt <- m (Either JwtError DecodableJwt) -> ExceptT JwtError m DecodableJwt
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (Either JwtError DecodableJwt -> m (Either JwtError DecodableJwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Either JwtError DecodableJwt
P.parseJwt ByteString
jwt))
    case DecodableJwt
encodedJwt of
        P.DecodableJwe JweHeader
hdr (P.EncryptedCEK ByteString
ek) IV
iv (P.Payload ByteString
payload) Tag
tag (P.AAD ByteString
aad) -> do
            let alg :: JweAlg
alg = JweHeader -> JweAlg
jweAlg JweHeader
hdr
                enc :: Enc
enc = JweHeader -> Enc
jweEnc JweHeader
hdr
            (ScrubbedBytes
dummyCek, ScrubbedBytes
_) <- m (ScrubbedBytes, ScrubbedBytes)
-> ExceptT JwtError m (ScrubbedBytes, ScrubbedBytes)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (ScrubbedBytes, ScrubbedBytes)
 -> ExceptT JwtError m (ScrubbedBytes, ScrubbedBytes))
-> m (ScrubbedBytes, ScrubbedBytes)
-> ExceptT JwtError m (ScrubbedBytes, ScrubbedBytes)
forall a b. (a -> b) -> a -> b
$ Enc -> m (ScrubbedBytes, ScrubbedBytes)
forall (m :: * -> *).
MonadRandom m =>
Enc -> m (ScrubbedBytes, ScrubbedBytes)
generateCmkAndIV Enc
enc
            let decryptedCek :: ScrubbedBytes
decryptedCek = (JwtError -> ScrubbedBytes)
-> (ScrubbedBytes -> ScrubbedBytes)
-> Either JwtError ScrubbedBytes
-> ScrubbedBytes
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (ScrubbedBytes -> JwtError -> ScrubbedBytes
forall a b. a -> b -> a
const ScrubbedBytes
dummyCek) ScrubbedBytes -> ScrubbedBytes
forall a. a -> a
id (Either JwtError ScrubbedBytes -> ScrubbedBytes)
-> Either JwtError ScrubbedBytes -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ JweAlg -> ByteString -> Either JwtError ScrubbedBytes
decodeCek JweAlg
alg ByteString
ek
                cek :: ScrubbedBytes
cek = if ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
decryptedCek Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
dummyCek
                        then ScrubbedBytes
decryptedCek
                        else ScrubbedBytes
dummyCek
            ByteString
claims <- ExceptT JwtError m ByteString
-> (ByteString -> ExceptT JwtError m ByteString)
-> Maybe ByteString
-> ExceptT JwtError m ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (JwtError -> ExceptT JwtError m ByteString
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE JwtError
BadCrypto) ByteString -> ExceptT JwtError m ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> ExceptT JwtError m ByteString)
-> Maybe ByteString -> ExceptT JwtError m ByteString
forall a b. (a -> b) -> a -> b
$ Enc
-> ScrubbedBytes
-> IV
-> ByteString
-> Tag
-> ByteString
-> Maybe ByteString
forall ba.
ByteArray ba =>
Enc -> ScrubbedBytes -> IV -> ba -> Tag -> ba -> Maybe ba
decryptPayload Enc
enc ScrubbedBytes
cek IV
iv ByteString
aad Tag
tag ByteString
payload
            Jwe -> ExceptT JwtError m Jwe
forall (m :: * -> *) a. Monad m => a -> m a
return (JweHeader
hdr, ByteString
claims)

        DecodableJwt
_ -> JwtError -> ExceptT JwtError m Jwe
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (Text -> JwtError
BadHeader Text
"Content is not a JWE")


doEncode :: (MonadRandom m, ByteArray ba)
    => ByteString
    -> Enc
    -> (ScrubbedBytes -> ExceptT JwtError m ByteString)
    -> ba
    -> ExceptT JwtError m Jwt
doEncode :: ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode ByteString
hdr Enc
e ScrubbedBytes -> ExceptT JwtError m ByteString
encryptKey ba
claims = do
    (ScrubbedBytes
cmk, ScrubbedBytes
iv) <- m (ScrubbedBytes, ScrubbedBytes)
-> ExceptT JwtError m (ScrubbedBytes, ScrubbedBytes)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Enc -> m (ScrubbedBytes, ScrubbedBytes)
forall (m :: * -> *).
MonadRandom m =>
Enc -> m (ScrubbedBytes, ScrubbedBytes)
generateCmkAndIV Enc
e)
    let Just (AuthTag Bytes
sig, ba
ct) = Enc
-> ScrubbedBytes
-> ScrubbedBytes
-> ba
-> ba
-> Maybe (AuthTag, ba)
forall ba iv.
(ByteArray ba, ByteArray iv) =>
Enc -> ScrubbedBytes -> iv -> ba -> ba -> Maybe (AuthTag, ba)
encryptPayload Enc
e ScrubbedBytes
cmk ScrubbedBytes
iv ba
aad ba
claims
    ByteString
jweKey <- ScrubbedBytes -> ExceptT JwtError m ByteString
encryptKey ScrubbedBytes
cmk
    let jwe :: ByteString
jwe = ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"." ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B64.encode [ByteString
hdr, ByteString
jweKey, ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ScrubbedBytes
iv, ba -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ba
ct, Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert Bytes
sig]
    Jwt -> ExceptT JwtError m Jwt
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Jwt
Jwt ByteString
jwe)
  where
    aad :: ba
aad = ByteString -> ba
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B64.encode ByteString
hdr

-- | Creates a JWE with the content key encoded using RSA.
rsaEncode :: MonadRandom m
    => JweAlg          -- ^ RSA algorithm to use (@RSA_OAEP@ or @RSA1_5@)
    -> Enc             -- ^ Content encryption algorithm
    -> PublicKey       -- ^ RSA key to encrypt with
    -> ByteString      -- ^ The JWT claims (content)
    -> m (Either JwtError Jwt) -- ^ The encoded JWE
rsaEncode :: JweAlg -> Enc -> PublicKey -> ByteString -> m (Either JwtError Jwt)
rsaEncode JweAlg
a Enc
e PublicKey
kPub ByteString
claims = ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m Jwt -> m (Either JwtError Jwt))
-> ExceptT JwtError m Jwt -> m (Either JwtError Jwt)
forall a b. (a -> b) -> a -> b
$ ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ByteString
-> ExceptT JwtError m Jwt
forall (m :: * -> *) ba.
(MonadRandom m, ByteArray ba) =>
ByteString
-> Enc
-> (ScrubbedBytes -> ExceptT JwtError m ByteString)
-> ba
-> ExceptT JwtError m Jwt
doEncode ByteString
hdr Enc
e (m (Either JwtError ByteString) -> ExceptT JwtError m ByteString
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either JwtError ByteString) -> ExceptT JwtError m ByteString)
-> (ScrubbedBytes -> m (Either JwtError ByteString))
-> ScrubbedBytes
-> ExceptT JwtError m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PublicKey
-> JweAlg -> ScrubbedBytes -> m (Either JwtError ByteString)
forall (m :: * -> *) msg out.
(MonadRandom m, ByteArray msg, ByteArray out) =>
PublicKey -> JweAlg -> msg -> m (Either JwtError out)
rsaEncrypt PublicKey
kPub JweAlg
a) ByteString
claims
  where
    hdr :: ByteString
hdr = ByteString -> ByteString
BL.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.concat [ByteString
"{\"alg\":", JweAlg -> ByteString
forall a. ToJSON a => a -> ByteString
A.encode JweAlg
a, ByteString
",", ByteString
"\"enc\":", Enc -> ByteString
forall a. ToJSON a => a -> ByteString
A.encode Enc
e, ByteString
"}"]


-- | Decrypts a JWE.
rsaDecode :: MonadRandom m
    => PrivateKey               -- ^ Decryption key
    -> ByteString               -- ^ The encoded JWE
    -> m (Either JwtError Jwe)  -- ^ The decoded JWT, unless an error occurs
rsaDecode :: PrivateKey -> ByteString -> m (Either JwtError Jwe)
rsaDecode PrivateKey
pk ByteString
jwt = ExceptT JwtError m Jwe -> m (Either JwtError Jwe)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwtError m Jwe -> m (Either JwtError Jwe))
-> ExceptT JwtError m Jwe -> m (Either JwtError Jwe)
forall a b. (a -> b) -> a -> b
$ do
    Blinder
blinder <- m Blinder -> ExceptT JwtError m Blinder
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Blinder -> ExceptT JwtError m Blinder)
-> m Blinder -> ExceptT JwtError m Blinder
forall a b. (a -> b) -> a -> b
$ Integer -> m Blinder
forall (m :: * -> *). MonadRandom m => Integer -> m Blinder
generateBlinder (PublicKey -> Integer
public_n (PublicKey -> Integer) -> PublicKey -> Integer
forall a b. (a -> b) -> a -> b
$ PrivateKey -> PublicKey
private_pub PrivateKey
pk)
    (JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
forall (m :: * -> *).
MonadRandom m =>
(JweAlg -> ByteString -> Either JwtError ScrubbedBytes)
-> ByteString -> ExceptT JwtError m Jwe
doDecode (Maybe Blinder
-> PrivateKey
-> JweAlg
-> ByteString
-> Either JwtError ScrubbedBytes
forall ct.
ByteArray ct =>
Maybe Blinder
-> PrivateKey -> JweAlg -> ct -> Either JwtError ScrubbedBytes
rsaDecrypt (Blinder -> Maybe Blinder
forall a. a -> Maybe a
Just Blinder
blinder) PrivateKey
pk) ByteString
jwt