module PostgresWebsockets.Claims
( ConnectionInfo,validateClaims
) where
import Protolude
import Control.Lens
import Crypto.JWT
import Data.List
import Data.Time.Clock (UTCTime)
import qualified Crypto.JOSE.Types as JOSE.Types
import qualified Data.HashMap.Strict as M
import qualified Data.Aeson as JSON
type Claims = M.HashMap Text JSON.Value
type ConnectionInfo = ([ByteString], ByteString, Claims)
validateClaims
:: Maybe ByteString
-> ByteString
-> LByteString
-> UTCTime
-> IO (Either Text ConnectionInfo)
validateClaims requestChannel secret jwtToken time = runExceptT $ do
cl <- liftIO $ jwtClaims time (parseJWK secret) jwtToken
cl' <- case cl of
JWTClaims c -> pure c
JWTInvalid JWTExpired -> throwError "Token expired"
JWTInvalid err -> throwError $ "Error: " <> show err
channels <- let chs = claimAsJSONList "channels" cl' in pure $ case claimAsJSON "channel" cl' of
Just c -> case chs of
Just cs -> nub (c : cs)
Nothing -> [c]
Nothing -> fromMaybe [] chs
mode <-
let md = claimAsJSON "mode" cl'
in case md of
Just m -> pure m
Nothing -> throwError "Missing mode"
requestedAllowedChannels <- case (requestChannel, length channels) of
(Just rc, 0) -> pure [rc]
(Just rc, _) -> pure $ filter (== rc) channels
(Nothing, _) -> pure channels
validChannels <- if null requestedAllowedChannels then throwError "No allowed channels" else pure requestedAllowedChannels
pure (validChannels, mode, cl')
where
claimAsJSON :: Text -> Claims -> Maybe ByteString
claimAsJSON name cl = case M.lookup name cl of
Just (JSON.String s) -> Just $ encodeUtf8 s
_ -> Nothing
claimAsJSONList :: Text -> Claims -> Maybe [ByteString]
claimAsJSONList name cl = case M.lookup name cl of
Just channelsJson ->
case JSON.fromJSON channelsJson :: JSON.Result [Text] of
JSON.Success channelsList -> Just $ encodeUtf8 <$> channelsList
_ -> Nothing
Nothing -> Nothing
data JWTAttempt = JWTInvalid JWTError
| JWTClaims (M.HashMap Text JSON.Value)
deriving Eq
jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt
jwtClaims _ _ "" = return $ JWTClaims M.empty
jwtClaims time jwk' payload = do
let config = defaultJWTValidationSettings (const True)
eJwt <- runExceptT $ do
jwt <- decodeCompact payload
verifyClaimsAt config jwk' time jwt
return $ case eJwt of
Left e -> JWTInvalid e
Right jwt -> JWTClaims . claims2map $ jwt
claims2map :: ClaimsSet -> M.HashMap Text JSON.Value
claims2map = val2map . JSON.toJSON
where
val2map (JSON.Object o) = o
val2map _ = M.empty
hs256jwk :: ByteString -> JWK
hs256jwk key =
fromKeyMaterial km
& jwkUse ?~ Sig
& jwkAlg ?~ JWSAlg HS256
where
km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key))
parseJWK :: ByteString -> JWK
parseJWK str =
fromMaybe (hs256jwk str) (JSON.decode (toS str) :: Maybe JWK)