{-|
Module      : PostgresWebsockets.Claims
Description : Parse and validate JWT to open postgres-websockets channels.

This module provides the JWT claims validation. Since websockets and
listening connections in the database tend to be resource intensive
(not to mention stateful) we need claims authorizing a specific channel and
mode of operation.
-}
module PostgresWebsockets.Claims
  ( ConnectionInfo,validateClaims
  ) where

import Protolude
import Control.Lens
import Crypto.JWT
import Data.List
import Data.Time.Clock (UTCTime)
import qualified Crypto.JOSE.Types as JOSE.Types
import qualified Data.HashMap.Strict as M
import qualified Data.Aeson as JSON


type Claims = M.HashMap Text JSON.Value
type ConnectionInfo = ([ByteString], ByteString, Claims)

{-| Given a secret, a token and a timestamp it validates the claims and returns
    either an error message or a triple containing channel, mode and claims hashmap.
-}
validateClaims
  :: Maybe ByteString
  -> ByteString
  -> LByteString
  -> UTCTime
  -> IO (Either Text ConnectionInfo)
validateClaims requestChannel secret jwtToken time = runExceptT $ do
  cl  <- liftIO $ jwtClaims time (parseJWK secret) jwtToken
  cl' <- case cl of
    JWTClaims  c          -> pure c
    JWTInvalid JWTExpired -> throwError "Token expired"
    JWTInvalid err -> throwError $ "Error: " <> show err
  channels  <-  let chs = claimAsJSONList "channels" cl' in pure $ case claimAsJSON "channel" cl' of
    Just c ->  case chs of
      Just cs ->  nub (c : cs)
      Nothing ->  [c]
    Nothing -> fromMaybe [] chs
  mode <-
    let md = claimAsJSON "mode" cl'
    in case md of
          Just m  -> pure m
          Nothing -> throwError "Missing mode"
  requestedAllowedChannels <- case (requestChannel, length channels) of
    (Just rc, 0) -> pure [rc]
    (Just rc, _) -> pure $ filter (== rc) channels
    (Nothing, _) -> pure channels
  validChannels <- if null requestedAllowedChannels then throwError "No allowed channels" else pure requestedAllowedChannels
  pure (validChannels, mode, cl')

 where
  claimAsJSON :: Text -> Claims -> Maybe ByteString
  claimAsJSON name cl = case M.lookup name cl of
    Just (JSON.String s) -> Just $ encodeUtf8 s
    _ -> Nothing

  claimAsJSONList :: Text -> Claims -> Maybe [ByteString]
  claimAsJSONList name cl = case M.lookup name cl of
    Just channelsJson ->
      case JSON.fromJSON channelsJson :: JSON.Result [Text] of
        JSON.Success channelsList -> Just $ encodeUtf8 <$> channelsList
        _ -> Nothing
    Nothing -> Nothing

{-|
  Possible situations encountered with client JWTs
-}
data JWTAttempt = JWTInvalid JWTError
                | JWTClaims (M.HashMap Text JSON.Value)
                deriving Eq

{-|
  Receives the JWT secret (from config) and a JWT and returns a map
  of JWT claims.
-}
jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt
jwtClaims _ _ "" = return $ JWTClaims M.empty
jwtClaims time jwk' payload = do
  let config = defaultJWTValidationSettings (const True)
  eJwt <- runExceptT $ do
    jwt <- decodeCompact payload
    verifyClaimsAt config jwk' time jwt
  return $ case eJwt of
    Left e    -> JWTInvalid e
    Right jwt -> JWTClaims . claims2map $ jwt

{-|
  Internal helper used to turn JWT ClaimSet into something
  easier to work with
-}
claims2map :: ClaimsSet -> M.HashMap Text JSON.Value
claims2map = val2map . JSON.toJSON
 where
  val2map (JSON.Object o) = o
  val2map _          = M.empty

{-|
  Internal helper to generate HMAC-SHA256. When the jwt key in the
  config file is a simple string rather than a JWK object, we'll
  apply this function to it.
-}
hs256jwk :: ByteString -> JWK
hs256jwk key =
  fromKeyMaterial km
    & jwkUse ?~ Sig
    & jwkAlg ?~ JWSAlg HS256
 where
  km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key))

parseJWK :: ByteString -> JWK
parseJWK str =
  fromMaybe (hs256jwk str) (JSON.decode (toS str) :: Maybe JWK)