{-| Module : PostgREST.Middleware Description : Sets the PostgreSQL GUCs, role, search_path and pre-request function. Validates JWT. -} {-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} module PostgREST.Middleware where import qualified Data.Aeson as JSON import qualified Data.HashMap.Strict as M import Data.Scientific (FPFormat (..), formatScientific, isInteger) import qualified Hasql.Transaction as H import Network.Wai (Application, Response) import Network.Wai.Middleware.Cors (cors) import Network.Wai.Middleware.Gzip (def, gzip) import Network.Wai.Middleware.Static (only, staticPolicy) import Crypto.JWT import PostgREST.ApiRequest (ApiRequest (..)) import PostgREST.Auth (JWTAttempt (..)) import PostgREST.Config (AppConfig (..), corsPolicy) import PostgREST.Error (SimpleError (JwtTokenInvalid, JwtTokenMissing), errorResponseFor) import PostgREST.QueryBuilder (setLocalQuery, setLocalSearchPathQuery) import Protolude hiding (head, toS) import Protolude.Conv (toS) runWithClaims :: AppConfig -> JWTAttempt -> (ApiRequest -> H.Transaction Response) -> ApiRequest -> H.Transaction Response runWithClaims conf eClaims app req = case eClaims of JWTMissingSecret -> return . errorResponseFor $ JwtTokenMissing JWTInvalid JWTExpired -> return . errorResponseFor . JwtTokenInvalid $ "JWT expired" JWTInvalid e -> return . errorResponseFor . JwtTokenInvalid . show $ e JWTClaims claims -> do H.sql $ toS . mconcat $ setSearchPathSql : setRoleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql mapM_ H.sql customReqCheck app req where methodSql = setLocalQuery mempty ("request.method", toS $ iMethod req) pathSql = setLocalQuery mempty ("request.path", toS $ iPath req) headersSql = setLocalQuery "request.header." <$> iHeaders req cookiesSql = setLocalQuery "request.cookie." <$> iCookies req claimsSql = setLocalQuery "request.jwt.claim." <$> [(c,unquoted v) | (c,v) <- M.toList claimsWithRole] appSettingsSql = setLocalQuery mempty <$> configSettings conf setRoleSql = maybeToList $ (\x -> setLocalQuery mempty ("role", unquoted x)) <$> M.lookup "role" claimsWithRole setSearchPathSql = setLocalSearchPathQuery (iSchema req : configExtraSearchPath conf) -- role claim defaults to anon if not specified in jwt claimsWithRole = M.union claims (M.singleton "role" anon) anon = JSON.String . toS $ configAnonRole conf customReqCheck = (\f -> "select " <> toS f <> "();") <$> configReqCheck conf defaultMiddle :: Application -> Application defaultMiddle = gzip def . cors corsPolicy . staticPolicy (only [("favicon.ico", "static/favicon.ico")]) unquoted :: JSON.Value -> Text unquoted (JSON.String t) = t unquoted (JSON.Number n) = toS $ formatScientific Fixed (if isInteger n then Just 0 else Nothing) n unquoted (JSON.Bool b) = show b unquoted v = toS $ JSON.encode v