module PostgREST.Auth (
claimsToSQL
, containsRole
, jwtClaims
, tokenJWT
) where
import Lens.Micro
import Lens.Micro.Aeson
import Data.Aeson (Value (..), parseJSON, toJSON)
import Data.Aeson.Types (parseMaybe, emptyObject, emptyArray)
import qualified Data.ByteString as BS
import qualified Data.Vector as V
import qualified Data.HashMap.Strict as M
import Data.Maybe (fromMaybe, maybeToList, fromJust)
import Data.Monoid ((<>))
import Data.String.Conversions (cs)
import Data.Text (Text)
import Data.Time.Clock (NominalDiffTime)
import PostgREST.QueryBuilder (pgFmtIdent, pgFmtLit, unquoted)
import qualified Web.JWT as JWT
claimsToSQL :: M.HashMap Text Value -> [BS.ByteString]
claimsToSQL claims = roleStmts <> varStmts
where
roleStmts = maybeToList $
(\r -> "set local role " <> r <> ";") . cs . valueToVariable <$> M.lookup "role" claims
varStmts = map setVar $ M.toList (M.delete "role" claims)
setVar (k, val) = "set local " <> cs (pgFmtIdent $ "postgrest.claims." <> k)
<> " = " <> cs (valueToVariable val) <> ";"
valueToVariable = pgFmtLit . unquoted
jwtClaims :: JWT.Secret -> Text -> NominalDiffTime -> Either Text (M.HashMap Text Value)
jwtClaims _ "" _ = Right M.empty
jwtClaims secret jwt time =
case isExpired <$> mClaims of
Just True -> Left "JWT expired"
Nothing -> Left "Invalid JWT"
Just False -> Right $ value2map $ fromJust mClaims
where
isExpired claims =
let mExp = claims ^? key "exp" . _Integer
in fromMaybe False $ (<= time) . fromInteger <$> mExp
mClaims = toJSON . JWT.claims <$> JWT.decodeAndVerifySignature secret jwt
value2map (Object o) = o
value2map _ = M.empty
tokenJWT :: JWT.Secret -> Value -> Text
tokenJWT secret (Array arr) =
let obj = if V.null arr then emptyObject else V.head arr
jcs = parseMaybe parseJSON obj :: Maybe JWT.JWTClaimsSet in
JWT.encodeSigned JWT.HS256 secret $ fromMaybe JWT.def jcs
tokenJWT secret _ = tokenJWT secret emptyArray
containsRole :: Either Text (M.HashMap Text Value) -> Bool
containsRole (Left _) = False
containsRole (Right claims) = M.member "role" claims