{-# LANGUAGE FlexibleInstances   #-}     
{-# LANGUAGE RecordWildCards   #-}     
{-# LANGUAGE OverloadedStrings #-}
-- | An OpenID connect provider.
--
-- OpenID Connect is a simple identity layer on top of the OAuth2 protocol.
-- Learn more about it here: <https://openid.net/connect/>
--
-- @since 0.2.3.0
module Network.Wai.Middleware.Auth.OIDC
  ( -- * Creating a provider
    OpenIDConnect
  , discover
  , discoverURI
  -- * Customizing a provider
  , oidcClientId
  , oidcClientSecret
  , oidcProviderInfo
  , oidcManager
  , oidcScopes
  , oidcAllowedSkew
  -- * Accessing session data
  , getAccessToken
  , getIdToken
  ) where

import           Control.Applicative                  ((<|>))
import qualified Crypto.JOSE                          as JOSE
import qualified Crypto.JWT                           as JWT
import           Control.Monad.Except                 (runExceptT)
import           Data.Aeson                           (FromJSON(parseJSON),
                                                       withObject, (.:), (.!=))
import qualified Data.ByteString.Char8                as S8
import           Data.Function                        ((&))
import qualified Data.Time.Clock                      as Clock
import           Data.Traversable                     (for)
import qualified Data.Text                            as T
import qualified Data.Text.Lazy                       as TL
import qualified Data.Text.Lazy.Encoding              as TLE
import qualified Data.Vault.Lazy                      as Vault
import           Foreign.C.Types                      (CTime (..))
import qualified Lens.Micro                           as Lens
import qualified Lens.Micro.Extras                    as Lens.Extras
import           Network.HTTP.Simple                  (httpJSON,
                                                       getResponseBody,
                                                       parseRequestThrow)
import           Network.Wai.Middleware.Auth.OAuth2   (parseAbsoluteURI,
                                                      getAccessToken)
import qualified Network.OAuth.OAuth2                 as OA2
import           Network.HTTP.Client                  (Manager)
import           Network.HTTP.Client.TLS              (getGlobalManager)
import           Network.Wai                          (Request, vault)
import           Network.Wai.Auth.Internal            (Metadata(..),
                                                       decodeToken, encodeToken,
                                                       oauth2Login,
                                                       refreshTokens)
import           Network.Wai.Middleware.Auth.Provider
import           System.IO.Unsafe                     (unsafePerformIO)
import           System.PosixCompat.Time              (epochTime)
import qualified Text.Hamlet
import qualified URI.ByteString                       as U

-- | An Open ID Connect provider.
--
-- To create a value use `discover` to download configuration for an existing
-- provider, then use various setter functions to customize it.
--
-- @since 0.2.3.0
data OpenIDConnect
  = OpenIDConnect
      { OpenIDConnect -> Metadata
oidcMetadata :: Metadata
      , OpenIDConnect -> JWKSet
oidcJwkSet :: JOSE.JWKSet
      -- | The client id this application is registered with at the Open ID
      -- Connect provider. The default is an empty string, you will need to
      -- overwrite this.
      --
      -- @since 0.2.3.0
      , OpenIDConnect -> Text
oidcClientId :: T.Text
      -- | The client secret of this application. The default is an empty
      -- string, you will need to overwrite this.
      --
      -- @since 0.2.3.0
      , OpenIDConnect -> Text
oidcClientSecret :: T.Text
      -- | The information for this provider. The default contains some
      -- placeholder texts. If you're using the provider screen you'll want to
      -- overwrite this.
      --
      -- @since 0.2.3.0
      , OpenIDConnect -> ProviderInfo
oidcProviderInfo :: ProviderInfo
      -- | The HTTP manager to use. Defaults to the global manager when not set.
      --
      -- @since 0.2.3.0
      , OpenIDConnect -> Maybe Manager
oidcManager :: Maybe Manager
      -- | The scopes to set. Defaults to only the "openid" scope.
      --
      -- @since 0.2.3.0
      , OpenIDConnect -> [Text]
oidcScopes :: [T.Text]
      -- | The amount of clock skew to allow when validating id tokens. Defaults
      -- to 0.
      --
      -- @since 0.2.3.0
      , OpenIDConnect -> NominalDiffTime
oidcAllowedSkew :: Clock.NominalDiffTime
      }

instance FromJSON OpenIDConnect where
  parseJSON :: Value -> Parser OpenIDConnect
parseJSON =
    String
-> (Object -> Parser OpenIDConnect)
-> Value
-> Parser OpenIDConnect
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"OpenIDConnect Object" ((Object -> Parser OpenIDConnect) -> Value -> Parser OpenIDConnect)
-> (Object -> Parser OpenIDConnect)
-> Value
-> Parser OpenIDConnect
forall a b. (a -> b) -> a -> b
$ \Object
obj -> do
      Metadata
metadata <- Object
obj Object -> Text -> Parser Metadata
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"metadata"
      JWKSet
jwkSet <- Object
obj Object -> Text -> Parser JWKSet
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"jwk_set"
      Text
clientId <- Object
obj Object -> Text -> Parser Text
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"client_id"
      Text
clientSecret <- Object
obj Object -> Text -> Parser Text
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"client_secret"
      ProviderInfo
providerInfo <- Object
obj Object -> Text -> Parser (Maybe ProviderInfo)
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"provider_info" Parser (Maybe ProviderInfo) -> ProviderInfo -> Parser ProviderInfo
forall a. Parser (Maybe a) -> a -> Parser a
.!= ProviderInfo
defProviderInfo
      [Text]
scopes <- Object
obj Object -> Text -> Parser (Maybe [Text])
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"scopes" Parser (Maybe [Text]) -> [Text] -> Parser [Text]
forall a. Parser (Maybe a) -> a -> Parser a
.!= [Text
"openid"]
      NominalDiffTime
allowedSkew <- Object
obj Object -> Text -> Parser (Maybe NominalDiffTime)
forall a. FromJSON a => Object -> Text -> Parser a
.: Text
"allowed_skew" Parser (Maybe NominalDiffTime)
-> NominalDiffTime -> Parser NominalDiffTime
forall a. Parser (Maybe a) -> a -> Parser a
.!= NominalDiffTime
0
      OpenIDConnect -> Parser OpenIDConnect
forall (f :: * -> *) a. Applicative f => a -> f a
pure OpenIDConnect :: Metadata
-> JWKSet
-> Text
-> Text
-> ProviderInfo
-> Maybe Manager
-> [Text]
-> NominalDiffTime
-> OpenIDConnect
OpenIDConnect {
        oidcMetadata :: Metadata
oidcMetadata = Metadata
metadata,
        oidcJwkSet :: JWKSet
oidcJwkSet = JWKSet
jwkSet,
        oidcClientId :: Text
oidcClientId = Text
clientId,
        oidcClientSecret :: Text
oidcClientSecret = Text
clientSecret,
        oidcProviderInfo :: ProviderInfo
oidcProviderInfo = ProviderInfo
providerInfo,
        oidcManager :: Maybe Manager
oidcManager = Maybe Manager
forall a. Maybe a
Nothing,
        oidcScopes :: [Text]
oidcScopes = [Text]
scopes,
        oidcAllowedSkew :: NominalDiffTime
oidcAllowedSkew = NominalDiffTime
allowedSkew
      }

instance AuthProvider OpenIDConnect where
  getProviderName :: OpenIDConnect -> Text
getProviderName OpenIDConnect
_ = Text
"oidc"
  getProviderInfo :: OpenIDConnect -> ProviderInfo
getProviderInfo = OpenIDConnect -> ProviderInfo
oidcProviderInfo
  handleLogin :: OpenIDConnect
-> Request
-> [Text]
-> Render ProviderUrl
-> (AuthLoginState -> IO Response)
-> (Status -> AuthLoginState -> IO Response)
-> IO Response
handleLogin oidc :: OpenIDConnect
oidc@OpenIDConnect {[Text]
Maybe Manager
Text
JWKSet
NominalDiffTime
ProviderInfo
Metadata
oidcAllowedSkew :: NominalDiffTime
oidcScopes :: [Text]
oidcManager :: Maybe Manager
oidcProviderInfo :: ProviderInfo
oidcClientSecret :: Text
oidcClientId :: Text
oidcJwkSet :: JWKSet
oidcMetadata :: Metadata
oidcJwkSet :: OpenIDConnect -> JWKSet
oidcMetadata :: OpenIDConnect -> Metadata
oidcAllowedSkew :: OpenIDConnect -> NominalDiffTime
oidcScopes :: OpenIDConnect -> [Text]
oidcManager :: OpenIDConnect -> Maybe Manager
oidcProviderInfo :: OpenIDConnect -> ProviderInfo
oidcClientSecret :: OpenIDConnect -> Text
oidcClientId :: OpenIDConnect -> Text
.. } Request
req [Text]
suffix Render ProviderUrl
renderUrl AuthLoginState -> IO Response
onSuccess Status -> AuthLoginState -> IO Response
onFailure = do
    OAuth2
oauth2 <- OpenIDConnect -> Maybe (Render ProviderUrl) -> IO OAuth2
mkOauth2 OpenIDConnect
oidc (Render ProviderUrl -> Maybe (Render ProviderUrl)
forall a. a -> Maybe a
Just Render ProviderUrl
renderUrl)
    Manager
manager <- IO Manager
-> (Manager -> IO Manager) -> Maybe Manager -> IO Manager
forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO Manager
getGlobalManager Manager -> IO Manager
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Manager
oidcManager
    OAuth2
-> Manager
-> Maybe [Text]
-> Text
-> Request
-> [Text]
-> (AuthLoginState -> IO Response)
-> (Status -> AuthLoginState -> IO Response)
-> IO Response
oauth2Login
      OAuth2
oauth2
      Manager
manager
      ([Text] -> Maybe [Text]
forall a. a -> Maybe a
Just [Text]
oidcScopes)
      (OpenIDConnect -> Text
forall ap. AuthProvider ap => ap -> Text
getProviderName OpenIDConnect
oidc)
      Request
req
      [Text]
suffix
      AuthLoginState -> IO Response
onSuccess
      Status -> AuthLoginState -> IO Response
onFailure
  refreshLoginState :: OpenIDConnect
-> Request -> AuthUser -> IO (Maybe (Request, AuthUser))
refreshLoginState OpenIDConnect
oidc Request
req AuthUser
user =
    let loginState :: AuthLoginState
loginState = AuthUser -> AuthLoginState
authLoginState AuthUser
user
    in case AuthLoginState -> Either String OAuth2Token
decodeToken AuthLoginState
loginState of
      Left String
_ -> Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Request, AuthUser)
forall a. Maybe a
Nothing
      Right OAuth2Token
tokens -> do
        Maybe ClaimsSet
vRes <- OpenIDConnect -> OAuth2Token -> IO (Maybe ClaimsSet)
validateIdToken' OpenIDConnect
oidc OAuth2Token
tokens
        case Maybe ClaimsSet
vRes of
          Maybe ClaimsSet
Nothing -> do
            OAuth2
oauth2 <- OpenIDConnect -> Maybe (Render ProviderUrl) -> IO OAuth2
mkOauth2 OpenIDConnect
oidc Maybe (Render ProviderUrl)
forall a. Maybe a
Nothing
            Manager
manager <- IO Manager
-> (Manager -> IO Manager) -> Maybe Manager -> IO Manager
forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO Manager
getGlobalManager Manager -> IO Manager
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpenIDConnect -> Maybe Manager
oidcManager OpenIDConnect
oidc)
            Maybe OAuth2Token
rRes <- OAuth2Token -> Manager -> OAuth2 -> IO (Maybe OAuth2Token)
refreshTokens OAuth2Token
tokens Manager
manager OAuth2
oauth2
            case Maybe OAuth2Token
rRes of
              Maybe OAuth2Token
Nothing -> Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Request, AuthUser)
forall a. Maybe a
Nothing
              Just OAuth2Token
newTokens -> do
                Maybe ClaimsSet
v2Res <- OpenIDConnect -> OAuth2Token -> IO (Maybe ClaimsSet)
validateIdToken' OpenIDConnect
oidc OAuth2Token
newTokens
                case Maybe ClaimsSet
v2Res of
                  Maybe ClaimsSet
Nothing -> Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Request, AuthUser)
forall a. Maybe a
Nothing
                  Just ClaimsSet
claims -> do
                    CTime Int64
now <- IO CTime
epochTime
                    let newUser :: AuthUser
newUser =
                          AuthUser
user {
                            authLoginState :: AuthLoginState
authLoginState = OAuth2Token -> AuthLoginState
encodeToken OAuth2Token
newTokens,
                            authLoginTime :: Int64
authLoginTime = Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
now
                          }
                    Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Request, AuthUser) -> Maybe (Request, AuthUser)
forall a. a -> Maybe a
Just (ClaimsSet -> Request -> Request
storeClaims ClaimsSet
claims Request
req, AuthUser
newUser))
          Just ClaimsSet
claims -> 
            Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Request, AuthUser) -> Maybe (Request, AuthUser)
forall a. a -> Maybe a
Just (ClaimsSet -> Request -> Request
storeClaims ClaimsSet
claims Request
req, AuthUser
user))

-- | Fetch configuration for a provider from its discovery
-- endpoint. Sets the path to @/.well-known/..@.
--
-- @since 0.2.3.0
discover :: T.Text -> IO OpenIDConnect
discover :: Text -> IO OpenIDConnect
discover Text
urlText = do
  URI
base <- Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI Text
urlText
  let uri :: URI
uri = URI
base { uriPath :: AuthLoginState
U.uriPath = AuthLoginState
"/.well-known/openid-configuration" }
  URI -> IO OpenIDConnect
discoverURI URI
uri

-- | Fetch configuration for a provider from an exact URI.
--
-- @since 0.2.3.1
discoverURI :: U.URI -> IO OpenIDConnect
discoverURI :: URI -> IO OpenIDConnect
discoverURI URI
uri = do
  Metadata
metadata <- URI -> IO Metadata
fetchMetadata URI
uri
  JWKSet
jwkset <- Text -> IO JWKSet
fetchJWKSet (Metadata -> Text
jwksUri Metadata
metadata)
  OpenIDConnect -> IO OpenIDConnect
forall (f :: * -> *) a. Applicative f => a -> f a
pure OpenIDConnect :: Metadata
-> JWKSet
-> Text
-> Text
-> ProviderInfo
-> Maybe Manager
-> [Text]
-> NominalDiffTime
-> OpenIDConnect
OpenIDConnect 
    { oidcClientId :: Text
oidcClientId = Text
""
    , oidcClientSecret :: Text
oidcClientSecret = Text
""
    , oidcMetadata :: Metadata
oidcMetadata = Metadata
metadata
    , oidcJwkSet :: JWKSet
oidcJwkSet = JWKSet
jwkset
    , oidcProviderInfo :: ProviderInfo
oidcProviderInfo = ProviderInfo
defProviderInfo
    , oidcManager :: Maybe Manager
oidcManager = Maybe Manager
forall a. Maybe a
Nothing
    , oidcScopes :: [Text]
oidcScopes = [Text
"openid"]
    , oidcAllowedSkew :: NominalDiffTime
oidcAllowedSkew = NominalDiffTime
0
    }

defProviderInfo :: ProviderInfo
defProviderInfo :: ProviderInfo
defProviderInfo = Text -> Text -> Text -> ProviderInfo
ProviderInfo Text
"OpenID Connect Provider" Text
"" Text
""

fetchMetadata :: U.URI -> IO Metadata
fetchMetadata :: URI -> IO Metadata
fetchMetadata URI
metadataEndpoint = do
  Request
req <- String -> IO Request
forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequestThrow (AuthLoginState -> String
S8.unpack (AuthLoginState -> String) -> AuthLoginState -> String
forall a b. (a -> b) -> a -> b
$ URI -> AuthLoginState
forall a. URIRef a -> AuthLoginState
U.serializeURIRef' URI
metadataEndpoint) 
  Response Metadata -> Metadata
forall a. Response a -> a
getResponseBody (Response Metadata -> Metadata)
-> IO (Response Metadata) -> IO Metadata
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Request -> IO (Response Metadata)
forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
Request -> m (Response a)
httpJSON Request
req

fetchJWKSet :: T.Text -> IO JOSE.JWKSet
fetchJWKSet :: Text -> IO JWKSet
fetchJWKSet Text
jwkSetEndpoint = do
  Request
req <- String -> IO Request
forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequestThrow (Text -> String
T.unpack Text
jwkSetEndpoint) 
  Response JWKSet -> JWKSet
forall a. Response a -> a
getResponseBody (Response JWKSet -> JWKSet) -> IO (Response JWKSet) -> IO JWKSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Request -> IO (Response JWKSet)
forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
Request -> m (Response a)
httpJSON Request
req

mkOauth2 :: OpenIDConnect -> Maybe (Text.Hamlet.Render ProviderUrl) -> IO OA2.OAuth2
mkOauth2 :: OpenIDConnect -> Maybe (Render ProviderUrl) -> IO OAuth2
mkOauth2 OpenIDConnect {[Text]
Maybe Manager
Text
JWKSet
NominalDiffTime
ProviderInfo
Metadata
oidcAllowedSkew :: NominalDiffTime
oidcScopes :: [Text]
oidcManager :: Maybe Manager
oidcProviderInfo :: ProviderInfo
oidcClientSecret :: Text
oidcClientId :: Text
oidcJwkSet :: JWKSet
oidcMetadata :: Metadata
oidcJwkSet :: OpenIDConnect -> JWKSet
oidcMetadata :: OpenIDConnect -> Metadata
oidcAllowedSkew :: OpenIDConnect -> NominalDiffTime
oidcScopes :: OpenIDConnect -> [Text]
oidcManager :: OpenIDConnect -> Maybe Manager
oidcProviderInfo :: OpenIDConnect -> ProviderInfo
oidcClientSecret :: OpenIDConnect -> Text
oidcClientId :: OpenIDConnect -> Text
..} Maybe (Render ProviderUrl)
renderUrl = do
  Maybe URI
callbackURI <- Maybe (Render ProviderUrl)
-> (Render ProviderUrl -> IO URI) -> IO (Maybe URI)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for Maybe (Render ProviderUrl)
renderUrl ((Render ProviderUrl -> IO URI) -> IO (Maybe URI))
-> (Render ProviderUrl -> IO URI) -> IO (Maybe URI)
forall a b. (a -> b) -> a -> b
$ \Render ProviderUrl
render -> Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI (Text -> IO URI) -> Text -> IO URI
forall a b. (a -> b) -> a -> b
$ Render ProviderUrl
render ([Text] -> ProviderUrl
ProviderUrl [Text
"complete"]) []
  OAuth2 -> IO OAuth2
forall (f :: * -> *) a. Applicative f => a -> f a
pure OAuth2 :: Text -> Maybe Text -> URI -> URI -> Maybe URI -> OAuth2
OA2.OAuth2
        { oauthClientId :: Text
oauthClientId = Text
oidcClientId
        , oauthClientSecret :: Maybe Text
oauthClientSecret = Text -> Maybe Text
forall a. a -> Maybe a
Just Text
oidcClientSecret
        , oauthOAuthorizeEndpoint :: URI
oauthOAuthorizeEndpoint = Metadata -> URI
authorizationEndpoint Metadata
oidcMetadata
        , oauthAccessTokenEndpoint :: URI
oauthAccessTokenEndpoint = Metadata -> URI
tokenEndpoint Metadata
oidcMetadata
        , oauthCallback :: Maybe URI
oauthCallback = Maybe URI
callbackURI
        }

validateIdToken :: OpenIDConnect -> OA2.IdToken -> IO (Either JWT.JWTError JWT.ClaimsSet)
validateIdToken :: OpenIDConnect -> IdToken -> IO (Either JWTError ClaimsSet)
validateIdToken OpenIDConnect
oidc (OA2.IdToken Text
idToken) = ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet))
-> ExceptT JWTError IO ClaimsSet -> IO (Either JWTError ClaimsSet)
forall a b. (a -> b) -> a -> b
$ do
  SignedJWT
signedJwt <- ByteString -> ExceptT JWTError IO SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
JOSE.decodeCompact (Text -> ByteString
TLE.encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> Text
TL.fromStrict Text
idToken)
  JWTValidationSettings
-> JWKSet -> SignedJWT -> ExceptT JWTError IO ClaimsSet
forall (m :: * -> *) a e k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
 AsError e, AsJWTError e, MonadError e m,
 VerificationKeyStore m (JWSHeader ()) ClaimsSet k) =>
a -> k -> SignedJWT -> m ClaimsSet
JWT.verifyClaims (OpenIDConnect -> JWTValidationSettings
validationSettings OpenIDConnect
oidc) (OpenIDConnect -> JWKSet
oidcJwkSet OpenIDConnect
oidc) SignedJWT
signedJwt

validateIdToken' :: OpenIDConnect -> OA2.OAuth2Token -> IO (Maybe JWT.ClaimsSet)
validateIdToken' :: OpenIDConnect -> OAuth2Token -> IO (Maybe ClaimsSet)
validateIdToken' OpenIDConnect
oidc OAuth2Token
tokens = 
  case OAuth2Token -> Maybe IdToken
OA2.idToken OAuth2Token
tokens of
    Maybe IdToken
Nothing -> Maybe ClaimsSet -> IO (Maybe ClaimsSet)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ClaimsSet
forall a. Maybe a
Nothing
    Just IdToken
idToken ->
      (JWTError -> Maybe ClaimsSet)
-> (ClaimsSet -> Maybe ClaimsSet)
-> Either JWTError ClaimsSet
-> Maybe ClaimsSet
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe ClaimsSet -> JWTError -> Maybe ClaimsSet
forall a b. a -> b -> a
const Maybe ClaimsSet
forall a. Maybe a
Nothing) ClaimsSet -> Maybe ClaimsSet
forall a. a -> Maybe a
Just (Either JWTError ClaimsSet -> Maybe ClaimsSet)
-> IO (Either JWTError ClaimsSet) -> IO (Maybe ClaimsSet)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpenIDConnect -> IdToken -> IO (Either JWTError ClaimsSet)
validateIdToken OpenIDConnect
oidc IdToken
idToken

-- The validation of the ID token below is stricter then specified in the OIDC
-- spec, to make the job of validating tokens easier. If this is too limiting
-- for your user case please open an issue.
--
-- Full spec for ID token validation:
-- https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
--
-- Ways in which the validation below is stricter then the spec requires:
-- - We don't allow the `aud` claim to contain any audiences beyond ourselves.
validationSettings :: OpenIDConnect -> JWT.JWTValidationSettings
validationSettings :: OpenIDConnect -> JWTValidationSettings
validationSettings OpenIDConnect
oidc =
  -- The Client MUST validate that the aud (audience) Claim contains its
  -- client_id value registered at the Issuer identified by the iss (issuer)
  -- Claim as an audience. The aud (audience) Claim MAY contain an array with
  -- more than one element. The ID Token MUST be rejected if the ID Token does
  -- not list the Client as a valid audience, or if it contains additional
  -- audiences not trusted by the Client.
  OpenIDConnect -> StringOrURI -> Bool
validateAudience OpenIDConnect
oidc
    -- If the ID Token is encrypted, decrypt it using the keys and algorithms
    -- that the Client specified during Registration that the OP was to use to
    -- encrypt the ID Token. If encryption was negotiated with the OP at
    -- Registration time and the ID Token is not encrypted, the RP SHOULD
    -- reject it.
    (StringOrURI -> Bool)
-> ((StringOrURI -> Bool) -> JWTValidationSettings)
-> JWTValidationSettings
forall a b. a -> (a -> b) -> b
& (StringOrURI -> Bool) -> JWTValidationSettings
JWT.defaultJWTValidationSettings
    -- The current time MUST be before the time represented by the exp Claim.
    JWTValidationSettings
-> (JWTValidationSettings -> JWTValidationSettings)
-> JWTValidationSettings
forall a b. a -> (a -> b) -> b
& ASetter JWTValidationSettings JWTValidationSettings Bool Bool
-> Bool -> JWTValidationSettings -> JWTValidationSettings
forall s t a b. ASetter s t a b -> b -> s -> t
Lens.set ASetter JWTValidationSettings JWTValidationSettings Bool Bool
forall c. HasJWTValidationSettings c => Lens' c Bool
JWT.jwtValidationSettingsCheckIssuedAt Bool
True
    -- The Issuer Identifier for the OpenID Provider (which is typically
    -- obtained during Discovery) MUST exactly match the value of the iss
    -- (issuer) Claim.
    JWTValidationSettings
-> (JWTValidationSettings -> JWTValidationSettings)
-> JWTValidationSettings
forall a b. a -> (a -> b) -> b
& ASetter
  JWTValidationSettings
  JWTValidationSettings
  (StringOrURI -> Bool)
  (StringOrURI -> Bool)
-> (StringOrURI -> Bool)
-> JWTValidationSettings
-> JWTValidationSettings
forall s t a b. ASetter s t a b -> b -> s -> t
Lens.set ASetter
  JWTValidationSettings
  JWTValidationSettings
  (StringOrURI -> Bool)
  (StringOrURI -> Bool)
forall c.
HasJWTValidationSettings c =>
Lens' c (StringOrURI -> Bool)
JWT.jwtValidationSettingsIssuerPredicate (OpenIDConnect -> StringOrURI -> Bool
validateIssuer OpenIDConnect
oidc)
    JWTValidationSettings
-> (JWTValidationSettings -> JWTValidationSettings)
-> JWTValidationSettings
forall a b. a -> (a -> b) -> b
& ASetter
  JWTValidationSettings
  JWTValidationSettings
  NominalDiffTime
  NominalDiffTime
-> NominalDiffTime
-> JWTValidationSettings
-> JWTValidationSettings
forall s t a b. ASetter s t a b -> b -> s -> t
Lens.set ASetter
  JWTValidationSettings
  JWTValidationSettings
  NominalDiffTime
  NominalDiffTime
forall c. HasJWTValidationSettings c => Lens' c NominalDiffTime
JWT.jwtValidationSettingsAllowedSkew (OpenIDConnect -> NominalDiffTime
oidcAllowedSkew OpenIDConnect
oidc)

validateAudience :: OpenIDConnect -> JWT.StringOrURI -> Bool
validateAudience :: OpenIDConnect -> StringOrURI -> Bool
validateAudience OpenIDConnect
oidc StringOrURI
audClaim =
  Maybe Text
audienceFromJWT Maybe Text -> Maybe Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text -> Maybe Text
forall a. a -> Maybe a
Just Text
correctClientId
  where
    correctClientId :: Text
correctClientId = OpenIDConnect -> Text
oidcClientId OpenIDConnect
oidc
    audienceFromJWT :: Maybe Text
audienceFromJWT = StringOrURI -> Maybe Text
fromStringOrURI StringOrURI
audClaim

validateIssuer :: OpenIDConnect -> JWT.StringOrURI -> Bool
validateIssuer :: OpenIDConnect -> StringOrURI -> Bool
validateIssuer OpenIDConnect
oidc StringOrURI
issClaim =
  Maybe Text
issuerFromJWT Maybe Text -> Maybe Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text -> Maybe Text
forall a. a -> Maybe a
Just Text
correctIssuer
  where
    correctIssuer :: Text
correctIssuer = Metadata -> Text
issuer (OpenIDConnect -> Metadata
oidcMetadata OpenIDConnect
oidc)
    issuerFromJWT :: Maybe Text
issuerFromJWT = StringOrURI -> Maybe Text
fromStringOrURI StringOrURI
issClaim

fromStringOrURI :: JWT.StringOrURI -> Maybe T.Text
fromStringOrURI :: StringOrURI -> Maybe Text
fromStringOrURI StringOrURI
stringOrURI =
  Getting (First Text) StringOrURI Text -> StringOrURI -> Maybe Text
forall a s. Getting (First a) s a -> s -> Maybe a
Lens.Extras.preview Getting (First Text) StringOrURI Text
Prism' StringOrURI Text
JWT.string StringOrURI
stringOrURI
   Maybe Text -> Maybe Text -> Maybe Text
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (URI -> Text) -> Maybe URI -> Maybe Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (String -> Text
T.pack (String -> Text) -> (URI -> String) -> URI -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. URI -> String
forall a. Show a => a -> String
show) (Getting (First URI) StringOrURI URI -> StringOrURI -> Maybe URI
forall a s. Getting (First a) s a -> s -> Maybe a
Lens.Extras.preview Getting (First URI) StringOrURI URI
Prism' StringOrURI URI
JWT.uri StringOrURI
stringOrURI)

storeClaims :: JWT.ClaimsSet -> Request -> Request
storeClaims :: ClaimsSet -> Request -> Request
storeClaims ClaimsSet
claims Request
req =
  Request
req { vault :: Vault
vault = Key ClaimsSet -> ClaimsSet -> Vault -> Vault
forall a. Key a -> a -> Vault -> Vault
Vault.insert Key ClaimsSet
idTokenKey ClaimsSet
claims (Request -> Vault
vault Request
req) }

-- | Get the @IdToken@ for the current user.
--
-- If called on a @Request@ behind the middleware, should always return a
-- @Just@ value.
--
-- The token returned was validated when the request was processed by the
-- middleware.
--
-- @since 0.2.3.0
getIdToken :: Request -> Maybe JWT.ClaimsSet
getIdToken :: Request -> Maybe ClaimsSet
getIdToken Request
req = Key ClaimsSet -> Vault -> Maybe ClaimsSet
forall a. Key a -> Vault -> Maybe a
Vault.lookup Key ClaimsSet
idTokenKey (Request -> Vault
vault Request
req)

idTokenKey :: Vault.Key JWT.ClaimsSet
idTokenKey :: Key ClaimsSet
idTokenKey = IO (Key ClaimsSet) -> Key ClaimsSet
forall a. IO a -> a
unsafePerformIO IO (Key ClaimsSet)
forall a. IO (Key a)
Vault.newKey
{-# NOINLINE idTokenKey #-}