module Servant.Auth.Server.Internal.JWT where
import Control.Lens
import Control.Monad.Except
import Control.Monad.Reader
import qualified Crypto.JOSE as Jose
import qualified Crypto.JWT as Jose
import Data.Aeson (FromJSON, Result (..), ToJSON, fromJSON,
toJSON)
import Data.ByteArray (constEq)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import qualified Data.HashMap.Strict as HM
import Data.Maybe (fromMaybe)
import qualified Data.Text as T
import Data.Time (UTCTime)
import Network.Wai (requestHeaders)
import Servant.Auth.Server.Internal.ConfigTypes
import Servant.Auth.Server.Internal.Types
class FromJWT a where
decodeJWT :: Jose.ClaimsSet -> Either T.Text a
default decodeJWT :: FromJSON a => Jose.ClaimsSet -> Either T.Text a
decodeJWT m = case HM.lookup "dat" (m ^. Jose.unregisteredClaims) of
Nothing -> Left "Missing 'dat' claim"
Just v -> case fromJSON v of
Error e -> Left $ T.pack e
Success a -> Right a
class ToJWT a where
encodeJWT :: a -> Jose.ClaimsSet
default encodeJWT :: ToJSON a => a -> Jose.ClaimsSet
encodeJWT a = Jose.addClaim "dat" (toJSON a) Jose.emptyClaimsSet
jwtAuthCheck :: FromJWT usr => JWTSettings -> AuthCheck usr
jwtAuthCheck config = do
req <- ask
token <- maybe mempty return $ do
authHdr <- lookup "Authorization" $ requestHeaders req
let bearer = "Bearer "
(mbearer, rest) = BS.splitAt (BS.length bearer) authHdr
guard (mbearer `constEq` bearer)
return rest
verifiedJWT <- liftIO $ runExceptT $ do
unverifiedJWT <- Jose.decodeCompact $ BSL.fromStrict token
Jose.verifyClaims (jwtSettingsToJwtValidationSettings config)
(validationKeys config)
unverifiedJWT
case verifiedJWT of
Left (_ :: Jose.JWTError) -> mzero
Right v -> case decodeJWT v of
Left _ -> mzero
Right v' -> return v'
makeJWT :: ToJWT a
=> a -> JWTSettings -> Maybe UTCTime -> IO (Either Jose.Error BSL.ByteString)
makeJWT v cfg expiry = runExceptT $ do
bestAlg <- Jose.bestJWSAlg $ signingKey cfg
let alg = fromMaybe bestAlg $ jwtAlg cfg
ejwt <- Jose.signClaims (signingKey cfg)
(Jose.newJWSHeader ((), alg))
(addExp $ encodeJWT v)
return $ Jose.encodeCompact ejwt
where
addExp claims = case expiry of
Nothing -> claims
Just e -> claims & Jose.claimExp ?~ Jose.NumericDate e