module PostgRESTWS.Claims
( validateClaims
) where
import Control.Lens
import qualified Crypto.JOSE.Types as JOSE.Types
import Crypto.JWT
import Data.Aeson (Value (..), decode, toJSON)
import qualified Data.HashMap.Strict as M
import Protolude
type Claims = M.HashMap Text Value
type ConnectionInfo = (ByteString, ByteString, Claims)
validateClaims :: Maybe ByteString -> ByteString -> LByteString -> IO (Either Text ConnectionInfo)
validateClaims requestChannel secret jwtToken =
runExceptT $ do
cl <- liftIO $ jwtClaims (parseJWK secret) jwtToken
cl' <- case cl of
JWTClaims c -> pure c
_ -> throwError "Error"
channel <- claimAsJSON requestChannel "channel" cl'
mode <- claimAsJSON Nothing "mode" cl'
pure (channel, mode, cl')
where
claimAsJSON :: Maybe ByteString -> Text -> Claims -> ExceptT Text IO ByteString
claimAsJSON defaultVal name cl = case M.lookup name cl of
Just (String s) -> pure $ encodeUtf8 s
Just _ -> throwError "claim is not string value"
Nothing -> nonExistingClaim defaultVal name
nonExistingClaim :: Maybe ByteString -> Text -> ExceptT Text IO ByteString
nonExistingClaim Nothing name = throwError (name <> " not in claims")
nonExistingClaim (Just defaultVal) _ = pure defaultVal
data JWTAttempt = JWTInvalid JWTError
| JWTMissingSecret
| JWTClaims (M.HashMap Text Value)
deriving Eq
jwtClaims :: JWK -> LByteString -> IO JWTAttempt
jwtClaims _ "" = return $ JWTClaims M.empty
jwtClaims secret payload = do
let validation = defaultJWTValidationSettings (const True)
eJwt <- runExceptT $ do
jwt <- decodeCompact payload
verifyClaims validation secret jwt
return $ case eJwt of
Left e -> JWTInvalid e
Right jwt -> JWTClaims . claims2map $ jwt
claims2map :: ClaimsSet -> M.HashMap Text Value
claims2map = val2map . toJSON
where
val2map (Object o) = o
val2map _ = M.empty
hs256jwk :: ByteString -> JWK
hs256jwk key =
fromKeyMaterial km
& jwkUse .~ Just Sig
& jwkAlg .~ (Just $ JWSAlg HS256)
where
km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key))
parseJWK :: ByteString -> JWK
parseJWK str =
fromMaybe (hs256jwk str) (decode (toS str) :: Maybe JWK)