{-# LANGUAGE OverloadedStrings #-} {-# OPTIONS_HADDOCK prune #-} -- | Internal functions for encrypting and signing / decrypting -- and verifying JWT content. module Jose.Internal.Crypto ( hmacSign , hmacVerify , rsaSign , rsaVerify , rsaEncrypt , rsaDecrypt , ecVerify , encryptPayload , decryptPayload , generateCmkAndIV , pad , unpad ) where import Control.Applicative import Crypto.Cipher.Types (AuthTag(..)) import Crypto.Number.Serialize (os2ip) import qualified Crypto.PubKey.ECC.ECDSA as ECDSA import qualified Crypto.PubKey.RSA as RSA import qualified Crypto.PubKey.RSA.PKCS15 as PKCS15 import qualified Crypto.PubKey.RSA.OAEP as OAEP import Crypto.Random (CPRG, cprgGenerate) import qualified Crypto.Cipher.AES as AES import Crypto.PubKey.HashDescr import Crypto.MAC.HMAC (hmac) import Data.Byteable (constEqBytes) import Data.ByteString (ByteString) import qualified Data.ByteString as B import qualified Data.Serialize as Serialize import qualified Data.Text as T import Data.Word (Word64, Word8) import Jose.Jwa import Jose.Types (JwtError(..)) -- | Sign a message with an HMAC key. hmacSign :: JwsAlg -- ^ HMAC algorithm to use -> ByteString -- ^ Key -> ByteString -- ^ The message/content -> Either JwtError ByteString -- ^ HMAC output hmacSign a k m = do hash <- maybe (Left $ BadAlgorithm $ T.pack $ "Not an HMAC algorithm: " ++ show a) return $ lookup a hmacHashes return $ hmac (hashFunction hash) 64 k m -- | Verify the HMAC for a given message. -- Returns false if the MAC is incorrect or the 'Alg' is not an HMAC. hmacVerify :: JwsAlg -- ^ HMAC Algorithm to use -> ByteString -- ^ Key -> ByteString -- ^ The message/content -> ByteString -- ^ The signature to check -> Bool -- ^ Whether the signature is correct hmacVerify a key msg sig = either (const False) (`constEqBytes` sig) $ hmacSign a key msg -- | Sign a message using an RSA private key. -- -- The failure condition should only occur if the algorithm is not an RSA -- algorithm, or the RSA key is too small, causing the padding of the -- signature to fail. With real-world RSA keys this shouldn't happen in practice. rsaSign :: Maybe RSA.Blinder -- ^ RSA blinder -> JwsAlg -- ^ Algorithm to use. Must be one of @RSA256@, @RSA384@ or @RSA512@ -> RSA.PrivateKey -- ^ Private key to sign with -> ByteString -- ^ Message to sign -> Either JwtError ByteString -- ^ The signature rsaSign blinder a key msg = do hash <- lookupRSAHash a either (const $ Left BadCrypto) Right $ PKCS15.sign blinder hash key msg where -- | Verify the signature for a message using an RSA public key. -- -- Returns false if the check fails or if the 'Alg' value is not -- an RSA signature algorithm. rsaVerify :: JwsAlg -- ^ The signature algorithm. Used to obtain the hash function. -> RSA.PublicKey -- ^ The key to check the signature with -> ByteString -- ^ The message/content -> ByteString -- ^ The signature to check -> Bool -- ^ Whether the signature is correct rsaVerify a key msg sig = case lookupRSAHash a of Right hash -> PKCS15.verify hash key msg sig _ -> False -- | Verify the signature for a message using an EC public key. -- -- Returns false if the check fails or if the 'Alg' value is not -- an EC signature algorithm. ecVerify :: JwsAlg -- ^ The signature algorithm. Used to obtain the hash function. -> ECDSA.PublicKey -- ^ The key to check the signature with -> ByteString -- ^ The message/content -> ByteString -- ^ The signature to check -> Bool -- ^ Whether the signature is correct ecVerify a key msg sig = case lookupECHash a of Just hash -> let (r, s) = B.splitAt (B.length sig `div` 2) sig in ECDSA.verify hash key (ECDSA.Signature (os2ip r) (os2ip s)) msg Nothing -> False hmacHashes :: [(JwsAlg, HashDescr)] hmacHashes = [(HS256, hashDescrSHA256), (HS384, hashDescrSHA384), (HS512, hashDescrSHA512)] lookupECHash :: JwsAlg -> Maybe HashFunction lookupECHash alg = hashFunction <$> case alg of ES256 -> Just hashDescrSHA256 ES384 -> Just hashDescrSHA384 ES512 -> Just hashDescrSHA512 _ -> Nothing lookupRSAHash :: JwsAlg -> Either JwtError HashDescr lookupRSAHash alg = case alg of RS256 -> Right hashDescrSHA256 RS384 -> Right hashDescrSHA384 RS512 -> Right hashDescrSHA512 _ -> Left . BadAlgorithm . T.pack $ "Not an RSA algorithm: " ++ show alg -- | Generates the symmetric key (content management key) and IV -- -- Used to encrypt a message. generateCmkAndIV :: CPRG g => g -- ^ The random number generator -> Enc -- ^ The encryption algorithm to be used -> (B.ByteString, B.ByteString, g) -- ^ The key, IV and generator generateCmkAndIV g e = (cmk, iv, g'') where (cmk, g') = cprgGenerate (keySize e) g (iv, g'') = cprgGenerate (ivSize e) g' -- iv for aes gcm or cbc keySize :: Enc -> Int keySize A128GCM = 16 keySize A256GCM = 32 keySize A128CBC_HS256 = 32 keySize A256CBC_HS512 = 64 ivSize :: Enc -> Int ivSize A128GCM = 12 ivSize A256GCM = 12 ivSize _ = 16 -- | Encrypts a message (typically a symmetric key) using RSA. rsaEncrypt :: CPRG g => g -- ^ Random number generator -> JweAlg -- ^ The algorithm (either @RSA1_5@ or @RSA_OAEP@) -> RSA.PublicKey -- ^ The encryption key -> B.ByteString -- ^ The message to encrypt -> (B.ByteString, g) -- ^ The encrypted messaged and new generator rsaEncrypt gen a pubKey content = (ct, g') where encrypt = case a of RSA1_5 -> PKCS15.encrypt gen RSA_OAEP -> OAEP.encrypt gen oaepParams -- TODO: Check that we can't cause any errors here with our RSA public key (Right ct, g') = encrypt pubKey content -- | Decrypts an RSA encrypted message. rsaDecrypt :: Maybe RSA.Blinder -> JweAlg -- ^ The RSA algorithm to use -> RSA.PrivateKey -- ^ The decryption key -> B.ByteString -- ^ The encrypted content -> Either JwtError B.ByteString -- ^ The decrypted key rsaDecrypt blinder a rsaKey jweKey = either (const $ Left BadCrypto) Right $ decrypt rsaKey jweKey where decrypt = case a of RSA1_5 -> PKCS15.decrypt blinder RSA_OAEP -> OAEP.decrypt blinder oaepParams oaepParams :: OAEP.OAEPParams oaepParams = OAEP.defaultOAEPParams (hashFunction hashDescrSHA1) -- TODO: Need to check key length and IV are is valid for enc. -- | Decrypt an AES encrypted message. decryptPayload :: Enc -- ^ Encryption algorithm -> ByteString -- ^ Content management key -> ByteString -- ^ IV -> ByteString -- ^ Additional authentication data -> ByteString -- ^ The integrity protection value to be checked -> ByteString -- ^ The encrypted JWT payload -> Either JwtError ByteString decryptPayload e cek iv aad sig ct = do (plaintext, tag) <- case e of A128GCM -> decryptedGCM A256GCM -> decryptedGCM A128CBC_HS256 -> decryptedCBC 16 hashDescrSHA256 A256CBC_HS512 -> decryptedCBC 32 hashDescrSHA512 if tag == AuthTag sig then return plaintext else Left BadSignature where decryptedGCM = Right $ AES.decryptGCM (AES.initAES cek) iv aad ct decryptedCBC l h = do let (macKey, encKey) = B.splitAt (B.length cek `div` 2) cek let al = fromIntegral (B.length aad) * 8 :: Word64 plaintext <- unpad $ AES.decryptCBC (AES.initAES encKey) iv ct let mac = authTag l h macKey $ B.concat [aad, iv, ct, Serialize.encode al] return (plaintext, mac) -- | Encrypt a message using AES. encryptPayload :: Enc -- ^ Encryption algorithm -> ByteString -- ^ Content management key -> ByteString -- ^ IV -> ByteString -- ^ Additional authenticated data -> ByteString -- ^ The message/JWT claims -> (ByteString, AuthTag) -- ^ Ciphertext claims and signature tag encryptPayload e cek iv aad msg = case e of A128GCM -> aesgcm A256GCM -> aesgcm A128CBC_HS256 -> (aescbc, sig 16 hashDescrSHA256) A256CBC_HS512 -> (aescbc, sig 32 hashDescrSHA512) where aesgcm = AES.encryptGCM (AES.initAES cek) iv aad msg (macKey, encKey) = B.splitAt (B.length cek `div` 2) cek aescbc = AES.encryptCBC (AES.initAES encKey) iv (pad msg) al = fromIntegral (B.length aad) * 8 :: Word64 sig l h = authTag l h macKey $ B.concat [aad, iv, aescbc, Serialize.encode al] authTag :: Int -> HashDescr -> ByteString -> ByteString -> AuthTag authTag l h k m = AuthTag $ B.take l $ hmac (hashFunction h) 64 k m unpad :: ByteString -> Either JwtError ByteString unpad bs | padLen > 16 || padLen /= B.length padding = Left BadCrypto | B.any (/= padByte) padding = Left BadCrypto | otherwise = Right pt where len = B.length bs padByte = B.last bs padLen = fromIntegral padByte (pt, padding) = B.splitAt (len - padLen) bs pad :: ByteString -> ByteString pad bs = B.append bs padding where lastBlockSize = B.length bs `mod` 16 padByte = fromIntegral $ 16 - lastBlockSize :: Word8 padding = B.replicate (fromIntegral padByte) padByte