{-# LANGUAGE OverloadedStrings #-}
{-|
    Module: Web.OIDC.Client.Discovery
    Maintainer: krdlab@gmail.com
    Stability: experimental
-}
module Web.OIDC.Client.Discovery
    (
      discover

    -- * OpenID Provider Issuers
    , google

    -- * OpenID Provider Configuration Information
    , Provider(..)
    , Configuration(..)

    -- * For testing
    , generateDiscoveryUrl
    ) where

import           Control.Monad.Catch                (catch, throwM)
import           Data.Aeson                         (eitherDecode)
import           Data.ByteString                    (append, isSuffixOf)
import           Data.Monoid                        ((<>))
import           Data.Text                          (pack)
import qualified Jose.Jwk                           as Jwk
import           Network.HTTP.Client                (Manager, Request, httpLbs,
                                                     path, responseBody)

import           Web.OIDC.Client.Discovery.Issuers  (google)
import           Web.OIDC.Client.Discovery.Provider (Configuration (..),
                                                     Provider (..))
import           Web.OIDC.Client.Internal           (parseUrl, rethrow)
import           Web.OIDC.Client.Types              (IssuerLocation,
                                                     OpenIdException (..))

-- | This function obtains OpenID Provider configuration and JWK set.
discover
    :: IssuerLocation   -- ^ OpenID Provider's Issuer location
    -> Manager
    -> IO Provider
discover :: IssuerLocation -> Manager -> IO Provider
discover IssuerLocation
location Manager
manager = do
    Either String Configuration
conf <- IO (Either String Configuration)
getConfiguration IO (Either String Configuration)
-> (HttpException -> IO (Either String Configuration))
-> IO (Either String Configuration)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` HttpException -> IO (Either String Configuration)
forall (m :: * -> *) a. MonadCatch m => HttpException -> m a
rethrow
    case Either String Configuration
conf of
        Right Configuration
c   -> do
            ByteString
json <- IssuerLocation -> IO ByteString
getJwkSetJson (Configuration -> IssuerLocation
jwksUri Configuration
c) IO ByteString -> (HttpException -> IO ByteString) -> IO ByteString
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` HttpException -> IO ByteString
forall (m :: * -> *) a. MonadCatch m => HttpException -> m a
rethrow
            case ByteString -> Either String [Jwk]
jwks ByteString
json of
                Right [Jwk]
keys -> Provider -> IO Provider
forall (m :: * -> *) a. Monad m => a -> m a
return (Provider -> IO Provider) -> Provider -> IO Provider
forall a b. (a -> b) -> a -> b
$ Configuration -> [Jwk] -> Provider
Provider Configuration
c [Jwk]
keys
                Left  String
err  -> OpenIdException -> IO Provider
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (OpenIdException -> IO Provider) -> OpenIdException -> IO Provider
forall a b. (a -> b) -> a -> b
$ IssuerLocation -> OpenIdException
DiscoveryException (IssuerLocation
"Failed to decode JwkSet: " IssuerLocation -> IssuerLocation -> IssuerLocation
forall a. Semigroup a => a -> a -> a
<> String -> IssuerLocation
pack String
err)
        Left  String
err -> OpenIdException -> IO Provider
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (OpenIdException -> IO Provider) -> OpenIdException -> IO Provider
forall a b. (a -> b) -> a -> b
$ IssuerLocation -> OpenIdException
DiscoveryException (IssuerLocation
"Failed to decode configuration: " IssuerLocation -> IssuerLocation -> IssuerLocation
forall a. Semigroup a => a -> a -> a
<> String -> IssuerLocation
pack String
err)
  where
    getConfiguration :: IO (Either String Configuration)
getConfiguration = do
        Request
req <- IssuerLocation -> IO Request
generateDiscoveryUrl IssuerLocation
location
        Response ByteString
res <- Request -> Manager -> IO (Response ByteString)
httpLbs Request
req Manager
manager
        Either String Configuration -> IO (Either String Configuration)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String Configuration -> IO (Either String Configuration))
-> Either String Configuration -> IO (Either String Configuration)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String Configuration
forall a. FromJSON a => ByteString -> Either String a
eitherDecode (ByteString -> Either String Configuration)
-> ByteString -> Either String Configuration
forall a b. (a -> b) -> a -> b
$ Response ByteString -> ByteString
forall body. Response body -> body
responseBody Response ByteString
res

    getJwkSetJson :: IssuerLocation -> IO ByteString
getJwkSetJson IssuerLocation
url = do
        Request
req <- IssuerLocation -> IO Request
forall (m :: * -> *). MonadThrow m => IssuerLocation -> m Request
parseUrl IssuerLocation
url
        Response ByteString
res <- Request -> Manager -> IO (Response ByteString)
httpLbs Request
req Manager
manager
        ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Response ByteString -> ByteString
forall body. Response body -> body
responseBody Response ByteString
res

    jwks :: ByteString -> Either String [Jwk]
jwks ByteString
j = JwkSet -> [Jwk]
Jwk.keys (JwkSet -> [Jwk]) -> Either String JwkSet -> Either String [Jwk]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Either String JwkSet
forall a. FromJSON a => ByteString -> Either String a
eitherDecode ByteString
j

generateDiscoveryUrl :: IssuerLocation -> IO Request
generateDiscoveryUrl :: IssuerLocation -> IO Request
generateDiscoveryUrl IssuerLocation
location = do
    Request
req <- IssuerLocation -> IO Request
forall (m :: * -> *). MonadThrow m => IssuerLocation -> m Request
parseUrl IssuerLocation
location
    Request -> IO Request
forall (m :: * -> *) a. Monad m => a -> m a
return (Request -> IO Request) -> Request -> IO Request
forall a b. (a -> b) -> a -> b
$ ByteString -> Request -> Request
appendPath ByteString
".well-known/openid-configuration" Request
req
  where
    appendPath :: ByteString -> Request -> Request
appendPath ByteString
suffix Request
req =
        let p :: ByteString
p = Request -> ByteString
path Request
req
            p' :: ByteString
p' = if ByteString
"/" ByteString -> ByteString -> Bool
`isSuffixOf` ByteString
p then ByteString
p else ByteString
p ByteString -> ByteString -> ByteString
`append` ByteString
"/"
        in
            Request
req { path :: ByteString
path = ByteString
p' ByteString -> ByteString -> ByteString
`append` ByteString
suffix }