{-# LANGUAGE OverloadedStrings #-}

-- | JWS HMAC and RSA signed token support.
--
-- Example usage with HMAC:
--
-- >>> import Jose.Jws
-- >>> import Jose.Jwa
-- >>> let Right (Jwt jwt) = hmacEncode HS256 "secretmackey" "public claims"
-- >>> jwt
-- "eyJhbGciOiJIUzI1NiJ9.cHVibGljIGNsYWltcw.GDV7RdBrCYfCtFCZZGPy_sWry4GwfX3ckMywXUyxBsc"
-- >>> hmacDecode "wrongkey" jwt
-- Left BadSignature
-- >>> hmacDecode "secretmackey" jwt
-- Right (JwsHeader {jwsAlg = HS256, jwsTyp = Nothing, jwsCty = Nothing, jwsKid = Nothing},"public claims")

module Jose.Jws
    ( jwkEncode
    , hmacEncode
    , hmacDecode
    , rsaEncode
    , rsaDecode
    , ecDecode
    , ed25519Encode
    , ed25519Decode
    , ed448Encode
    , ed448Decode
    )
where

import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import qualified Crypto.PubKey.Ed25519 as Ed25519
import qualified Crypto.PubKey.Ed448 as Ed448
import Crypto.PubKey.RSA (PrivateKey(..), PublicKey(..), generateBlinder)
import Crypto.Random (MonadRandom)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B

import Jose.Types
import qualified Jose.Internal.Base64 as B64
import Jose.Internal.Crypto
import qualified Jose.Internal.Parser as P
import Jose.Jwa
import Jose.Jwk (Jwk (..))

-- | Create a JWS signed with a JWK.
-- The key and algorithm must be consistent or an error
-- will be returned.
jwkEncode :: MonadRandom m
          => JwsAlg                          -- ^ The algorithm to use
          -> Jwk                             -- ^ The key to sign with
          -> Payload                         -- ^ The public JWT claims
          -> m (Either JwtError Jwt)         -- ^ The encoded token, if successful
jwkEncode :: JwsAlg -> Jwk -> Payload -> m (Either JwtError Jwt)
jwkEncode JwsAlg
a Jwk
key Payload
payload = case Jwk
key of
    RsaPrivateJwk PrivateKey
kPr Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> JwsAlg -> PrivateKey -> ByteString -> m (Either JwtError Jwt)
forall (m :: * -> *).
MonadRandom m =>
JwsAlg -> PrivateKey -> ByteString -> m (Either JwtError Jwt)
rsaEncodeInternal JwsAlg
a PrivateKey
kPr (JwsAlg -> Maybe KeyId -> Payload -> ByteString
sigTarget JwsAlg
a Maybe KeyId
kid Payload
payload)
    SymmetricJwk  ByteString
k   Maybe KeyId
kid Maybe KeyUse
_ Maybe Alg
_ -> Either JwtError Jwt -> m (Either JwtError Jwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError Jwt -> m (Either JwtError Jwt))
-> Either JwtError Jwt -> m (Either JwtError Jwt)
forall a b. (a -> b) -> a -> b
$ JwsAlg -> ByteString -> ByteString -> Either JwtError Jwt
hmacEncodeInternal JwsAlg
a ByteString
k (JwsAlg -> Maybe KeyId -> Payload -> ByteString
sigTarget JwsAlg
a Maybe KeyId
kid Payload
payload)
    Ed25519PrivateJwk SecretKey
kPr PublicKey
kPub Maybe KeyId
kid -> Either JwtError Jwt -> m (Either JwtError Jwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError Jwt -> m (Either JwtError Jwt))
-> Either JwtError Jwt -> m (Either JwtError Jwt)
forall a b. (a -> b) -> a -> b
$
        case JwsAlg
a of
            JwsAlg
EdDSA -> Jwt -> Either JwtError Jwt
forall a b. b -> Either a b
Right (Jwt -> Either JwtError Jwt) -> Jwt -> Either JwtError Jwt
forall a b. (a -> b) -> a -> b
$ SecretKey -> PublicKey -> ByteString -> Jwt
ed25519EncodeInternal SecretKey
kPr PublicKey
kPub (JwsAlg -> Maybe KeyId -> Payload -> ByteString
sigTarget JwsAlg
EdDSA Maybe KeyId
kid Payload
payload)
            JwsAlg
_ -> JwtError -> Either JwtError Jwt
forall a b. a -> Either a b
Left (Text -> JwtError
KeyError Text
"Algorithm cannot be used with an Ed25519 key")
    Ed448PrivateJwk SecretKey
kPr PublicKey
kPub Maybe KeyId
kid -> Either JwtError Jwt -> m (Either JwtError Jwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError Jwt -> m (Either JwtError Jwt))
-> Either JwtError Jwt -> m (Either JwtError Jwt)
forall a b. (a -> b) -> a -> b
$
        case JwsAlg
a of
            JwsAlg
EdDSA -> Jwt -> Either JwtError Jwt
forall a b. b -> Either a b
Right (Jwt -> Either JwtError Jwt) -> Jwt -> Either JwtError Jwt
forall a b. (a -> b) -> a -> b
$ SecretKey -> PublicKey -> ByteString -> Jwt
ed448EncodeInternal SecretKey
kPr PublicKey
kPub (JwsAlg -> Maybe KeyId -> Payload -> ByteString
sigTarget JwsAlg
EdDSA Maybe KeyId
kid Payload
payload)
            JwsAlg
_ -> JwtError -> Either JwtError Jwt
forall a b. a -> Either a b
Left (Text -> JwtError
KeyError Text
"Algorithm cannot be used with an Ed448 key")
    Jwk
_                         -> Either JwtError Jwt -> m (Either JwtError Jwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError Jwt -> m (Either JwtError Jwt))
-> Either JwtError Jwt -> m (Either JwtError Jwt)
forall a b. (a -> b) -> a -> b
$ JwtError -> Either JwtError Jwt
forall a b. a -> Either a b
Left (JwtError -> Either JwtError Jwt)
-> JwtError -> Either JwtError Jwt
forall a b. (a -> b) -> a -> b
$ Text -> JwtError
BadAlgorithm Text
"EC signing is not supported"

-- | Create a JWS with an HMAC for validation.
hmacEncode :: JwsAlg       -- ^ The MAC algorithm to use
           -> ByteString   -- ^ The MAC key
           -> ByteString   -- ^ The public JWT claims (token content)
           -> Either JwtError Jwt -- ^ The encoded JWS token
hmacEncode :: JwsAlg -> ByteString -> ByteString -> Either JwtError Jwt
hmacEncode JwsAlg
a ByteString
key ByteString
payload = JwsAlg -> ByteString -> ByteString -> Either JwtError Jwt
hmacEncodeInternal JwsAlg
a ByteString
key (JwsAlg -> Maybe KeyId -> Payload -> ByteString
sigTarget JwsAlg
a Maybe KeyId
forall a. Maybe a
Nothing (ByteString -> Payload
Claims ByteString
payload))

hmacEncodeInternal :: JwsAlg
                   -> ByteString
                   -> ByteString
                   -> Either JwtError Jwt
hmacEncodeInternal :: JwsAlg -> ByteString -> ByteString -> Either JwtError Jwt
hmacEncodeInternal JwsAlg
a ByteString
key ByteString
st = ByteString -> Jwt
Jwt (ByteString -> Jwt)
-> (ByteString -> ByteString) -> ByteString -> Jwt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\ByteString
mac -> [ByteString] -> ByteString
B.concat [ByteString
st, ByteString
".", ByteString -> ByteString
forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode ByteString
mac]) (ByteString -> Jwt)
-> Either JwtError ByteString -> Either JwtError Jwt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JwsAlg -> ByteString -> ByteString -> Either JwtError ByteString
hmacSign JwsAlg
a ByteString
key ByteString
st

-- | Decodes and validates an HMAC signed JWS.
hmacDecode :: ByteString          -- ^ The HMAC key
           -> ByteString          -- ^ The JWS token to decode
           -> Either JwtError Jws -- ^ The decoded token if successful
hmacDecode :: ByteString -> ByteString -> Either JwtError Jws
hmacDecode ByteString
key = JwsVerifier -> ByteString -> Either JwtError Jws
decode (JwsAlg -> ByteString -> ByteString -> ByteString -> Bool
`hmacVerify` ByteString
key)

-- | Creates a JWS with an RSA signature.
rsaEncode :: MonadRandom m
          => JwsAlg                           -- ^ The RSA algorithm to use
          -> PrivateKey                       -- ^ The key to sign with
          -> ByteString                       -- ^ The public JWT claims (token content)
          -> m (Either JwtError Jwt)          -- ^ The encoded JWS token
rsaEncode :: JwsAlg -> PrivateKey -> ByteString -> m (Either JwtError Jwt)
rsaEncode JwsAlg
a PrivateKey
pk ByteString
payload = JwsAlg -> PrivateKey -> ByteString -> m (Either JwtError Jwt)
forall (m :: * -> *).
MonadRandom m =>
JwsAlg -> PrivateKey -> ByteString -> m (Either JwtError Jwt)
rsaEncodeInternal JwsAlg
a PrivateKey
pk (JwsAlg -> Maybe KeyId -> Payload -> ByteString
sigTarget JwsAlg
a Maybe KeyId
forall a. Maybe a
Nothing (ByteString -> Payload
Claims ByteString
payload))

rsaEncodeInternal :: MonadRandom m
                  => JwsAlg
                  -> PrivateKey
                  -> ByteString
                  -> m (Either JwtError Jwt)
rsaEncodeInternal :: JwsAlg -> PrivateKey -> ByteString -> m (Either JwtError Jwt)
rsaEncodeInternal JwsAlg
a PrivateKey
pk ByteString
st = do
    Blinder
blinder <- Integer -> m Blinder
forall (m :: * -> *). MonadRandom m => Integer -> m Blinder
generateBlinder (PublicKey -> Integer
public_n (PublicKey -> Integer) -> PublicKey -> Integer
forall a b. (a -> b) -> a -> b
$ PrivateKey -> PublicKey
private_pub PrivateKey
pk)
    Either JwtError Jwt -> m (Either JwtError Jwt)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either JwtError Jwt -> m (Either JwtError Jwt))
-> Either JwtError Jwt -> m (Either JwtError Jwt)
forall a b. (a -> b) -> a -> b
$ Blinder -> Either JwtError Jwt
sign Blinder
blinder
  where
    sign :: Blinder -> Either JwtError Jwt
sign Blinder
b = case Maybe Blinder
-> JwsAlg -> PrivateKey -> ByteString -> Either JwtError ByteString
rsaSign (Blinder -> Maybe Blinder
forall a. a -> Maybe a
Just Blinder
b) JwsAlg
a PrivateKey
pk ByteString
st of
        Right ByteString
sig -> Jwt -> Either JwtError Jwt
forall a b. b -> Either a b
Right (Jwt -> Either JwtError Jwt)
-> (ByteString -> Jwt) -> ByteString -> Either JwtError Jwt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Jwt
Jwt (ByteString -> Either JwtError Jwt)
-> ByteString -> Either JwtError Jwt
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
B.concat [ByteString
st, ByteString
".", ByteString -> ByteString
forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode ByteString
sig]
        Left JwtError
e    -> JwtError -> Either JwtError Jwt
forall a b. a -> Either a b
Left JwtError
e


ed25519Decode :: Ed25519.PublicKey
              -> ByteString
              -> Either JwtError Jws
ed25519Decode :: PublicKey -> ByteString -> Either JwtError Jws
ed25519Decode PublicKey
key = JwsVerifier -> ByteString -> Either JwtError Jws
decode (JwsAlg -> PublicKey -> ByteString -> ByteString -> Bool
`ed25519Verify` PublicKey
key)


ed25519Encode :: Ed25519.SecretKey
              -> Ed25519.PublicKey
              -> ByteString
              -> Jwt
ed25519Encode :: SecretKey -> PublicKey -> ByteString -> Jwt
ed25519Encode SecretKey
kPr PublicKey
kPub ByteString
payload =
    SecretKey -> PublicKey -> ByteString -> Jwt
ed25519EncodeInternal SecretKey
kPr PublicKey
kPub (JwsAlg -> Maybe KeyId -> Payload -> ByteString
sigTarget JwsAlg
EdDSA Maybe KeyId
forall a. Maybe a
Nothing (ByteString -> Payload
Claims ByteString
payload))


ed25519EncodeInternal :: Ed25519.SecretKey
                      -> Ed25519.PublicKey
                      -> ByteString
                      -> Jwt
ed25519EncodeInternal :: SecretKey -> PublicKey -> ByteString -> Jwt
ed25519EncodeInternal SecretKey
kPr PublicKey
kPub ByteString
signMe =
  let
     sig :: Signature
sig = SecretKey -> PublicKey -> ByteString -> Signature
forall ba.
ByteArrayAccess ba =>
SecretKey -> PublicKey -> ba -> Signature
Ed25519.sign SecretKey
kPr PublicKey
kPub ByteString
signMe
  in
     ByteString -> Jwt
Jwt ([ByteString] -> ByteString
B.concat [ByteString
signMe, ByteString
".", Signature -> ByteString
forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode Signature
sig])


ed448Decode :: Ed448.PublicKey
            -> ByteString
            -> Either JwtError Jws
ed448Decode :: PublicKey -> ByteString -> Either JwtError Jws
ed448Decode PublicKey
key = JwsVerifier -> ByteString -> Either JwtError Jws
decode (JwsAlg -> PublicKey -> ByteString -> ByteString -> Bool
`ed448Verify` PublicKey
key)


ed448Encode :: Ed448.SecretKey
            -> Ed448.PublicKey
            -> ByteString
            -> Jwt
ed448Encode :: SecretKey -> PublicKey -> ByteString -> Jwt
ed448Encode SecretKey
kPr PublicKey
kPub ByteString
payload =
    SecretKey -> PublicKey -> ByteString -> Jwt
ed448EncodeInternal SecretKey
kPr PublicKey
kPub (JwsAlg -> Maybe KeyId -> Payload -> ByteString
sigTarget JwsAlg
EdDSA Maybe KeyId
forall a. Maybe a
Nothing (ByteString -> Payload
Claims ByteString
payload))


ed448EncodeInternal :: Ed448.SecretKey
                    -> Ed448.PublicKey
                    -> ByteString
                    -> Jwt
ed448EncodeInternal :: SecretKey -> PublicKey -> ByteString -> Jwt
ed448EncodeInternal SecretKey
kPr PublicKey
kPub ByteString
signMe =
  let
     sig :: Signature
sig = SecretKey -> PublicKey -> ByteString -> Signature
forall ba.
ByteArrayAccess ba =>
SecretKey -> PublicKey -> ba -> Signature
Ed448.sign SecretKey
kPr PublicKey
kPub ByteString
signMe
  in
     ByteString -> Jwt
Jwt ([ByteString] -> ByteString
B.concat [ByteString
signMe, ByteString
".", Signature -> ByteString
forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode Signature
sig])


-- | Decode and validate an RSA signed JWS.
rsaDecode :: PublicKey            -- ^ The key to check the signature with
          -> ByteString           -- ^ The encoded JWS
          -> Either JwtError Jws  -- ^ The decoded token if successful
rsaDecode :: PublicKey -> ByteString -> Either JwtError Jws
rsaDecode PublicKey
key = JwsVerifier -> ByteString -> Either JwtError Jws
decode (JwsAlg -> PublicKey -> ByteString -> ByteString -> Bool
`rsaVerify` PublicKey
key)


-- | Decode and validate an EC signed JWS
ecDecode :: ECDSA.PublicKey       -- ^ The key to check the signature with
         -> ByteString            -- ^ The encoded JWS
         -> Either JwtError Jws   -- ^ The decoded token if successful
ecDecode :: PublicKey -> ByteString -> Either JwtError Jws
ecDecode PublicKey
key = JwsVerifier -> ByteString -> Either JwtError Jws
decode (JwsAlg -> PublicKey -> ByteString -> ByteString -> Bool
`ecVerify` PublicKey
key)

sigTarget :: JwsAlg -> Maybe KeyId -> Payload -> ByteString
sigTarget :: JwsAlg -> Maybe KeyId -> Payload -> ByteString
sigTarget JwsAlg
a Maybe KeyId
kid Payload
payload = ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"." ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> ByteString
forall input output.
(ByteArrayAccess input, ByteArray output) =>
input -> output
B64.encode [JwsHeader -> ByteString
forall a. ToJSON a => a -> ByteString
encodeHeader JwsHeader
hdr, ByteString
bytes]
  where
    hdr :: JwsHeader
hdr = JwsHeader
defJwsHdr {jwsAlg :: JwsAlg
jwsAlg = JwsAlg
a, jwsKid :: Maybe KeyId
jwsKid = Maybe KeyId
kid, jwsCty :: Maybe Text
jwsCty = Maybe Text
contentType}
    (Maybe Text
contentType, ByteString
bytes) = case Payload
payload of
        Claims ByteString
c       -> (Maybe Text
forall a. Maybe a
Nothing, ByteString
c)
        Nested (Jwt ByteString
b) -> (Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"JWT", ByteString
b)

type JwsVerifier = JwsAlg -> ByteString -> ByteString -> Bool


decode :: JwsVerifier -> ByteString -> Either JwtError Jws
decode :: JwsVerifier -> ByteString -> Either JwtError Jws
decode JwsVerifier
verify ByteString
jwt = do
    DecodableJwt
decodableJwt <- ByteString -> Either JwtError DecodableJwt
P.parseJwt ByteString
jwt
    case DecodableJwt
decodableJwt of
        P.DecodableJws JwsHeader
hdr (P.Payload ByteString
p) (P.Sig ByteString
sig) (P.SigTarget ByteString
signed) ->
          if JwsVerifier
verify (JwsHeader -> JwsAlg
jwsAlg JwsHeader
hdr) ByteString
signed ByteString
sig
              then Jws -> Either JwtError Jws
forall a b. b -> Either a b
Right (JwsHeader
hdr, ByteString
p)
              else JwtError -> Either JwtError Jws
forall a b. a -> Either a b
Left JwtError
BadSignature
        DecodableJwt
_ -> JwtError -> Either JwtError Jws
forall a b. a -> Either a b
Left (Text -> JwtError
BadHeader Text
"JWT is not a JWS")