{-# LANGUAGE OverloadedStrings, DeriveGeneric #-} {-# OPTIONS_HADDOCK prune #-} module Jose.Jwk ( KeyType , KeyUse (..) , KeyId , Jwk (..) , JwkSet (..) , validateForJws , findMatchingJwsKeys , findMatchingJweKeys ) where import Control.Applicative (pure) import Control.Monad (when) import qualified Crypto.PubKey.RSA as RSA import qualified Crypto.PubKey.ECC.ECDSA as ECDSA import qualified Crypto.Types.PubKey.ECC as ECC import Crypto.Number.Serialize import Data.Aeson (genericToJSON, Value(..), FromJSON(..), ToJSON(..), withText) import Data.Aeson.Types (Parser, Options (..), defaultOptions) import Data.ByteString (ByteString) import Data.Text (Text) import qualified Data.Text.Encoding as TE import GHC.Generics (Generic) import qualified Jose.Internal.Base64 as B64 import Jose.Jwa import Jose.Types (JwtError(..), KeyId, JwsHeader(..), JweHeader(..)) data KeyType = Rsa | Ec | Oct deriving (Eq, Show) data EcCurve = P_256 | P_384 | P_521 deriving (Eq,Show) data KeyUse = Sig | Enc deriving (Eq,Show) data Jwk = RsaPublicJwk RSA.PublicKey (Maybe KeyId) (Maybe KeyUse) (Maybe Alg) | RsaPrivateJwk RSA.PrivateKey (Maybe KeyId) (Maybe KeyUse) (Maybe Alg) | EcPublicJwk ECDSA.PublicKey (Maybe KeyId) (Maybe KeyUse) (Maybe Alg) EcCurve | EcPrivateJwk ECDSA.KeyPair (Maybe KeyId) (Maybe KeyUse) (Maybe Alg) EcCurve | SymmetricJwk ByteString (Maybe KeyId) (Maybe KeyUse) (Maybe Alg) deriving (Show, Eq) data JwkSet = JwkSet { keys :: [Jwk] } deriving (Show, Eq, Generic) validateForJws :: JwsAlg -> Jwk -> Either JwtError () validateForJws a jwk = do when (jwkUse jwk == Just Enc) $ Left (KeyError "JWK is for encryption only") either (Left . KeyError) (const $ Right ()) $ case a of HS256 -> mustBeSymmetric HS384 -> mustBeSymmetric HS512 -> mustBeSymmetric RS256 -> mustBeRsa RS384 -> mustBeRsa RS512 -> mustBeRsa ES256 -> mustBeEc ES384 -> mustBeEc ES512 -> mustBeEc None -> Left "JWS with alg 'None' does not require a key" where mustBeRsa = case jwk of RsaPrivateJwk {} -> Right () RsaPublicJwk {} -> Right () _ -> Left "JWK must be an RSA key" mustBeSymmetric = case jwk of SymmetricJwk {} -> Right () _ -> Left "JWK must be symmetric" mustBeEc = case jwk of EcPrivateJwk {} -> Right () EcPublicJwk {} -> Right () _ -> Left "JWK must be an EC key" canDecodeJws :: JwsAlg -> Jwk -> Bool canDecodeJws al jwk = either (const False) (const True) $ validateForJws al jwk canDecodeJwe :: JweAlg -> Jwk -> Bool canDecodeJwe _ jwk = jwkUse jwk /= Just Sig && case jwk of RsaPrivateJwk {} -> True _ -> False curve :: EcCurve -> ECC.Curve curve c = ECC.getCurveByName $ case c of P_256 -> ECC.SEC_p256r1 P_384 -> ECC.SEC_p384r1 P_521 -> ECC.SEC_p521r1 jwkId :: Jwk -> Maybe KeyId jwkId key = case key of RsaPublicJwk _ keyId _ _ -> keyId RsaPrivateJwk _ keyId _ _ -> keyId EcPublicJwk _ keyId _ _ _ -> keyId EcPrivateJwk _ keyId _ _ _ -> keyId SymmetricJwk _ keyId _ _ -> keyId jwkUse :: Jwk -> Maybe KeyUse jwkUse key = case key of RsaPublicJwk _ _ u _ -> u RsaPrivateJwk _ _ u _ -> u EcPublicJwk _ _ u _ _ -> u EcPrivateJwk _ _ u _ _ -> u SymmetricJwk _ _ u _ -> u findKeyById :: [Jwk] -> KeyId -> Maybe Jwk findKeyById [] _ = Nothing findKeyById (key:ks) keyId = case jwkId key of Nothing -> findKeyById ks keyId Just v -> if v == keyId then Just key else findKeyById ks keyId -- TODO filter by key use findMatchingJwsKeys :: [Jwk] -> JwsHeader -> [Jwk] findMatchingJwsKeys jwks hdr = filter (canDecodeJws (jwsAlg hdr)) $ filterById (jwsKid hdr) jwks filterById :: Maybe KeyId -> [Jwk] -> [Jwk] filterById keyId jwks = case keyId of Just i -> maybe jwks (:[]) $ findKeyById jwks i Nothing -> jwks findMatchingJweKeys :: [Jwk] -> JweHeader -> [Jwk] findMatchingJweKeys jwks hdr = filter (canDecodeJwe (jweAlg hdr)) $ filterById (jweKid hdr) jwks newtype JwkBytes = JwkBytes {bytes :: ByteString} deriving (Show) instance FromJSON KeyType where parseJSON = withText "KeyType" $ \t -> case t of "RSA" -> pure Rsa "EC" -> pure Ec "oct" -> pure Oct _ -> fail "unsupported key type" instance ToJSON KeyType where toJSON kt = case kt of Rsa -> String "RSA" Ec -> String "EC" Oct -> String "oct" instance FromJSON KeyUse where parseJSON = withText "KeyUse" $ \t -> case t of "sig" -> pure Sig "enc" -> pure Enc _ -> fail "'use' value must be either 'sig' or 'enc'" instance ToJSON KeyUse where toJSON ku = case ku of Sig -> String "sig" Enc -> String "enc" instance FromJSON EcCurve where parseJSON = withText "EcCurve" $ \t -> case t of "P-256" -> pure P_256 "P-384" -> pure P_384 "P-521" -> pure P_521 _ -> fail "unsupported 'crv' value" instance ToJSON EcCurve where toJSON c = case c of P_256 -> String "P-256" P_384 -> String "P-384" P_521 -> String "P-521" instance FromJSON JwkBytes where parseJSON = withText "JwkBytes" $ \t -> case B64.decode (TE.encodeUtf8 t) of Left _ -> fail "could not base64 decode bytes" Right b -> pure $ JwkBytes b instance ToJSON JwkBytes where toJSON (JwkBytes b) = String . TE.decodeUtf8 $ B64.encode b instance FromJSON Jwk where parseJSON o@(Object _) = do jwkData <- parseJSON o :: Parser JwkData case createJwk jwkData of Left err -> fail err Right jwk -> return jwk parseJSON _ = fail "Jwk must be a JSON object" instance ToJSON Jwk where toJSON jwk = toJSON $ case jwk of RsaPublicJwk pubKey mId mUse mAlg -> createPubData pubKey mId mUse mAlg RsaPrivateJwk privKey mId mUse mAlg -> let pubData = createPubData (RSA.private_pub privKey) mId mUse mAlg in pubData { d = Just . JwkBytes . i2osp $ RSA.private_d privKey , p = i2b $ RSA.private_p privKey , q = i2b $ RSA.private_q privKey , dp = i2b $ RSA.private_dP privKey , dq = i2b $ RSA.private_dQ privKey , qi = i2b $ RSA.private_qinv privKey } SymmetricJwk bs mId mUse mAlg -> defJwk { kty = Oct , k = Just $ JwkBytes bs , kid = mId , use = mUse , alg = mAlg } EcPublicJwk pubKey mId mUse mAlg c -> defJwk { kty = Ec , x = fst (ecPoint pubKey) , y = snd (ecPoint pubKey) , kid = mId , use = mUse , alg = mAlg , crv = Just c } EcPrivateJwk kp mId mUse mAlg c -> defJwk { kty = Ec , x = fst (ecPoint (ECDSA.toPublicKey kp)) , y = snd (ecPoint (ECDSA.toPublicKey kp)) , d = i2b (ECDSA.private_d (ECDSA.toPrivateKey kp)) , kid = mId , use = mUse , alg = mAlg , crv = Just c } where i2b 0 = Nothing i2b i = Just . JwkBytes . i2osp $ i ecPoint pk = case ECDSA.public_q pk of ECC.Point xi yi -> (i2b xi, i2b yi) _ -> (Nothing, Nothing) createPubData pubKey mId mUse mAlg = defJwk { n = i2b (RSA.public_n pubKey) , e = i2b (RSA.public_e pubKey) , kid = mId , use = mUse , alg = mAlg } instance ToJSON JwkSet instance FromJSON JwkSet aesonOptions :: Options aesonOptions = defaultOptions { omitNothingFields = True } data JwkData = J { kty :: KeyType -- There's probably a better way to parse this -- than encoding all the possible key params -- but this will do for now. , n :: Maybe JwkBytes , e :: Maybe JwkBytes , d :: Maybe JwkBytes , p :: Maybe JwkBytes , q :: Maybe JwkBytes , dp :: Maybe JwkBytes , dq :: Maybe JwkBytes , qi :: Maybe JwkBytes , k :: Maybe JwkBytes , crv :: Maybe EcCurve , x :: Maybe JwkBytes , y :: Maybe JwkBytes , use :: Maybe KeyUse , alg :: Maybe Alg , kid :: Maybe Text , x5u :: Maybe Text , x5c :: Maybe [Text] , x5t :: Maybe Text } deriving (Show, Generic) instance FromJSON JwkData instance ToJSON JwkData where toJSON = genericToJSON aesonOptions defJwk :: JwkData defJwk = J { kty = Rsa , n = Nothing , e = Nothing , d = Nothing , p = Nothing , q = Nothing , dp = Nothing , dq = Nothing , qi = Nothing , k = Nothing , crv = Nothing , x = Nothing , y = Nothing , use = Just Sig , alg = Nothing , kid = Nothing , x5u = Nothing , x5c = Nothing , x5t = Nothing } createJwk :: JwkData -> Either String Jwk createJwk kd = case kd of J Rsa (Just nb) (Just eb) Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing u a i _ _ _ -> return $ RsaPublicJwk (rsaPub nb eb) i u a J Rsa (Just nb) (Just eb) (Just db) mp mq mdp mdq mqi Nothing Nothing Nothing Nothing u a i _ _ _ -> return $ RsaPrivateJwk (RSA.PrivateKey (rsaPub nb eb) (os2ip $ bytes db) (os2mip mp) (os2mip mq) (os2mip mdp) (os2mip mdq) (os2mip mqi)) i u a J Oct Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing (Just kb) Nothing Nothing Nothing u a i Nothing Nothing Nothing -> return $ SymmetricJwk (bytes kb) i u a J Ec Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing (Just crv') (Just xb) (Just yb) u a i Nothing Nothing Nothing -> return $ EcPublicJwk (ECDSA.PublicKey (curve crv') (ecPoint xb yb)) i u a crv' J Ec Nothing Nothing (Just db) Nothing Nothing Nothing Nothing Nothing Nothing (Just crv') (Just xb) (Just yb) u a i Nothing Nothing Nothing -> return $ EcPrivateJwk (ECDSA.KeyPair (curve crv') (ecPoint xb yb) (os2ip (bytes db))) i u a crv' _ -> Left $ "Invalid key data. Didn't match any known JWK parameter combinations:" ++ show kd where rsaPub nb eb = let m = os2ip $ bytes nb ex = os2ip $ bytes eb in RSA.PublicKey (rsaSize m 1) m ex rsaSize m i = if (2 ^ (i * 8)) > m then i else rsaSize m (i+1) os2mip = maybe 0 (os2ip . bytes) ecPoint xb yb = ECC.Point (os2ip (bytes xb)) (os2ip (bytes yb))