{-| This module provides the JWT claims validation. Since websockets and listening connections in the database tend to be resource intensive (not to mention stateful) we need claims authorizing a specific channel and mode of operation. -} module PostgRESTWS.Claims ( validateClaims ) where import Control.Lens import qualified Crypto.JOSE.Types as JOSE.Types import Crypto.JWT import Data.Aeson (Value (..), decode, toJSON) import qualified Data.HashMap.Strict as M import Protolude type Claims = M.HashMap Text Value type ConnectionInfo = (ByteString, ByteString, Claims) {-| Given a secret, a token and a timestamp it validates the claims and returns either an error message or a triple containing channel, mode and claims hashmap. -} validateClaims :: Maybe ByteString -> ByteString -> LByteString -> IO (Either Text ConnectionInfo) validateClaims requestChannel secret jwtToken = runExceptT $ do cl <- liftIO $ jwtClaims (parseJWK secret) jwtToken cl' <- case cl of JWTClaims c -> pure c _ -> throwError "Error" channel <- claimAsJSON requestChannel "channel" cl' mode <- claimAsJSON Nothing "mode" cl' pure (channel, mode, cl') where claimAsJSON :: Maybe ByteString -> Text -> Claims -> ExceptT Text IO ByteString claimAsJSON defaultVal name cl = case M.lookup name cl of Just (String s) -> pure $ encodeUtf8 s Just _ -> throwError "claim is not string value" Nothing -> nonExistingClaim defaultVal name nonExistingClaim :: Maybe ByteString -> Text -> ExceptT Text IO ByteString nonExistingClaim Nothing name = throwError (name <> " not in claims") nonExistingClaim (Just defaultVal) _ = pure defaultVal {- Private functions and types copied from postgrest This code duplication will be short lived since postgrest will migrate towards jose Then this library will use jose's verifyClaims and error types. -} {-| Possible situations encountered with client JWTs -} data JWTAttempt = JWTInvalid JWTError | JWTMissingSecret | JWTClaims (M.HashMap Text Value) deriving Eq {-| Receives the JWT secret (from config) and a JWT and returns a map of JWT claims. -} jwtClaims :: JWK -> LByteString -> IO JWTAttempt jwtClaims _ "" = return $ JWTClaims M.empty jwtClaims secret payload = do let validation = defaultJWTValidationSettings (const True) eJwt <- runExceptT $ do jwt <- decodeCompact payload verifyClaims validation secret jwt return $ case eJwt of Left e -> JWTInvalid e Right jwt -> JWTClaims . claims2map $ jwt {-| Internal helper used to turn JWT ClaimSet into something easier to work with -} claims2map :: ClaimsSet -> M.HashMap Text Value claims2map = val2map . toJSON where val2map (Object o) = o val2map _ = M.empty {-| Internal helper to generate HMAC-SHA256. When the jwt key in the config file is a simple string rather than a JWK object, we'll apply this function to it. -} hs256jwk :: ByteString -> JWK hs256jwk key = fromKeyMaterial km & jwkUse .~ Just Sig & jwkAlg .~ (Just $ JWSAlg HS256) where km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key)) parseJWK :: ByteString -> JWK parseJWK str = fromMaybe (hs256jwk str) (decode (toS str) :: Maybe JWK)