{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
module PostgREST.Auth (
containsRole
, jwtClaims
, JWTAttempt(..)
, parseSecret
) where
import qualified Crypto.JOSE.Types as JOSE.Types
import qualified Data.Aeson as JSON
import qualified Data.HashMap.Strict as M
import Data.Vector as V
import Control.Lens (set)
import Data.Time.Clock (UTCTime)
import Control.Lens.Operators
import Crypto.JWT
import PostgREST.Types
import Protolude hiding (toS)
import Protolude.Conv (toS)
data JWTAttempt = JWTInvalid JWTError
| JWTMissingSecret
| JWTClaims (M.HashMap Text JSON.Value)
deriving (Eq, Show)
jwtClaims :: Maybe JWKSet -> Maybe StringOrURI -> LByteString -> UTCTime -> Maybe JSPath -> IO JWTAttempt
jwtClaims _ _ "" _ _ = return $ JWTClaims M.empty
jwtClaims secret audience payload time jspath =
case secret of
Nothing -> return JWTMissingSecret
Just s -> do
let validation = set allowedSkew 1 $ defaultJWTValidationSettings (maybe (const True) (==) audience)
eJwt <- runExceptT $ do
jwt <- decodeCompact payload
verifyClaimsAt validation s time jwt
return $ case eJwt of
Left e -> JWTInvalid e
Right jwt -> JWTClaims $ claims2map jwt jspath
claims2map :: ClaimsSet -> Maybe JSPath -> M.HashMap Text JSON.Value
claims2map claims jspath = (\case
val@(JSON.Object o) ->
let role = maybe M.empty (M.singleton "role") $
walkJSPath (Just val) =<< jspath in
M.delete "role" o `M.union` role
_ -> M.empty
) $ JSON.toJSON claims
walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value
walkJSPath x [] = x
walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (M.lookup key o) rest
walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest
walkJSPath _ _ = Nothing
containsRole :: JWTAttempt -> Bool
containsRole (JWTClaims claims) = M.member "role" claims
containsRole _ = False
parseSecret :: ByteString -> JWKSet
parseSecret str =
fromMaybe (maybe secret (\jwk' -> JWKSet [jwk']) maybeJWK)
maybeJWKSet
where
maybeJWKSet = JSON.decode (toS str) :: Maybe JWKSet
maybeJWK = JSON.decode (toS str) :: Maybe JWK
secret = JWKSet [jwkFromSecret str]
jwkFromSecret :: ByteString -> JWK
jwkFromSecret key =
fromKeyMaterial km
& jwkUse ?~ Sig
& jwkAlg ?~ JWSAlg HS256
where
km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key))