{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
module PostgREST.Auth (
containsRole
, jwtClaims
, JWTAttempt(..)
, parseJWK
) where
import Control.Lens.Operators
import Control.Lens (set)
import qualified Data.Aeson as JSON
import qualified Data.HashMap.Strict as M
import Data.Time.Clock (UTCTime)
import Data.Vector as V
import PostgREST.Types
import Protolude
import qualified Crypto.JOSE.Types as JOSE.Types
import Crypto.JWT
data JWTAttempt = JWTInvalid JWTError
| JWTMissingSecret
| JWTClaims (M.HashMap Text JSON.Value)
deriving (Eq, Show)
jwtClaims :: Maybe JWK -> Maybe StringOrURI -> LByteString -> UTCTime -> Maybe JSPath -> IO JWTAttempt
jwtClaims _ _ "" _ _ = return $ JWTClaims M.empty
jwtClaims secret audience payload time jspath =
case secret of
Nothing -> return JWTMissingSecret
Just s -> do
let validation = set allowedSkew 1 $ defaultJWTValidationSettings (maybe (const True) (==) audience)
eJwt <- runExceptT $ do
jwt <- decodeCompact payload
verifyClaimsAt validation s time jwt
return $ case eJwt of
Left e -> JWTInvalid e
Right jwt -> JWTClaims $ claims2map jwt jspath
claims2map :: ClaimsSet -> Maybe JSPath -> M.HashMap Text JSON.Value
claims2map claims jspath = (\case
val@(JSON.Object o) ->
let role = maybe M.empty (M.singleton "role") $
walkJSPath (Just val) =<< jspath in
M.delete "role" o `M.union` role
_ -> M.empty
) $ JSON.toJSON claims
walkJSPath :: Maybe JSON.Value -> JSPath -> Maybe JSON.Value
walkJSPath x [] = x
walkJSPath (Just (JSON.Object o)) (JSPKey key:rest) = walkJSPath (M.lookup key o) rest
walkJSPath (Just (JSON.Array ar)) (JSPIdx idx:rest) = walkJSPath (ar V.!? idx) rest
walkJSPath _ _ = Nothing
containsRole :: JWTAttempt -> Bool
containsRole (JWTClaims claims) = M.member "role" claims
containsRole _ = False
parseJWK :: ByteString -> JWK
parseJWK str =
fromMaybe (hs256jwk str) (JSON.decode (toS str) :: Maybe JWK)
hs256jwk :: ByteString -> JWK
hs256jwk key =
fromKeyMaterial km
& jwkUse .~ Just Sig
& jwkAlg .~ (Just $ JWSAlg HS256)
where
km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key))