{-# LANGUAGE OverloadedStrings #-}
{-|
    Module: Web.OIDC.Client.CodeFlow
    Maintainer: krdlab@gmail.com
    Stability: experimental
-}
module Web.OIDC.Client.IdTokenFlow
    (
      getAuthenticationRequestUrl
    , getValidIdTokenClaims
    , prepareAuthenticationRequestUrl
    ) where

import           Control.Monad                      (when)
import           Control.Exception                  (throwIO, catch)
import           Control.Monad.IO.Class             (MonadIO, liftIO)
import           Data.Aeson                         (FromJSON)
import qualified Data.ByteString.Char8              as B
import           Data.List                          (nub)
import           Data.Maybe                         (isNothing, fromMaybe)
import           Data.Monoid                        ((<>))
import           Data.Text                          (unpack)
import           Data.Text.Encoding                 (decodeUtf8With)
import           Data.Text.Encoding.Error           (lenientDecode)
import qualified Jose.Jwt                           as Jwt
import           Network.HTTP.Client                (getUri, setQueryString)
import           Network.URI                        (URI)

import           Prelude                            hiding (exp)

import           Web.OIDC.Client.Internal           (parseUrl)
import qualified Web.OIDC.Client.Internal           as I
import           Web.OIDC.Client.Settings           (OIDC (..))
import           Web.OIDC.Client.Tokens             (IdTokenClaims (..), validateIdToken)
import           Web.OIDC.Client.Types              (OpenIdException (..),
                                                     Parameters, Scope,
                                                     SessionStore (..), State,
                                                     openId)

-- | Make URL for Authorization Request after generating state and nonce from 'SessionStore'.
prepareAuthenticationRequestUrl
    :: (MonadIO m)
    => SessionStore m
    -> OIDC
    -> Scope            -- ^ used to specify what are privileges requested for tokens. (use `ScopeValue`)
    -> Parameters       -- ^ Optional parameters
    -> m URI
prepareAuthenticationRequestUrl :: SessionStore m -> OIDC -> Scope -> Parameters -> m URI
prepareAuthenticationRequestUrl SessionStore m
store OIDC
oidc Scope
scope Parameters
params = do
    ByteString
state <- SessionStore m -> m ByteString
forall (m :: * -> *). SessionStore m -> m ByteString
sessionStoreGenerate SessionStore m
store
    ByteString
nonce' <- SessionStore m -> m ByteString
forall (m :: * -> *). SessionStore m -> m ByteString
sessionStoreGenerate SessionStore m
store
    SessionStore m -> ByteString -> ByteString -> m ()
forall (m :: * -> *).
SessionStore m -> ByteString -> ByteString -> m ()
sessionStoreSave SessionStore m
store ByteString
state ByteString
nonce'
    OIDC -> Scope -> Maybe ByteString -> Parameters -> m URI
forall (m :: * -> *).
MonadIO m =>
OIDC -> Scope -> Maybe ByteString -> Parameters -> m URI
getAuthenticationRequestUrl OIDC
oidc Scope
scope (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
state) (Parameters -> m URI) -> Parameters -> m URI
forall a b. (a -> b) -> a -> b
$ Parameters
params Parameters -> Parameters -> Parameters
forall a. [a] -> [a] -> [a]
++ [(ByteString
"nonce", ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
nonce')]

-- | Get and validate access token and with code and state stored in the 'SessionStore'.
--   Then deletes session info by 'sessionStoreDelete'.
getValidIdTokenClaims
    :: (MonadIO m, FromJSON a)
    => SessionStore m
    -> OIDC
    -> State
    -> m B.ByteString
    -> m (IdTokenClaims a)
getValidIdTokenClaims :: SessionStore m
-> OIDC -> ByteString -> m ByteString -> m (IdTokenClaims a)
getValidIdTokenClaims SessionStore m
store OIDC
oidc ByteString
stateFromIdP m ByteString
getIdToken = do
    (Maybe ByteString
state, Maybe ByteString
savedNonce) <- SessionStore m -> m (Maybe ByteString, Maybe ByteString)
forall (m :: * -> *).
SessionStore m -> m (Maybe ByteString, Maybe ByteString)
sessionStoreGet SessionStore m
store
    if Maybe ByteString
state Maybe ByteString -> Maybe ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
stateFromIdP
      then do
          Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe ByteString -> Bool
forall a. Maybe a -> Bool
isNothing Maybe ByteString
savedNonce) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ OpenIdException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (OpenIdException -> IO ()) -> OpenIdException -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> OpenIdException
ValidationException Text
"Nonce is not saved!"
          Jwt
jwt <- ByteString -> Jwt
Jwt.Jwt (ByteString -> Jwt) -> m ByteString -> m Jwt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m ByteString
getIdToken
          SessionStore m -> m ()
forall (m :: * -> *). SessionStore m -> m ()
sessionStoreDelete SessionStore m
store
          IdTokenClaims a
idToken <- IO (IdTokenClaims a) -> m (IdTokenClaims a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IdTokenClaims a) -> m (IdTokenClaims a))
-> IO (IdTokenClaims a) -> m (IdTokenClaims a)
forall a b. (a -> b) -> a -> b
$ OIDC -> Jwt -> IO (IdTokenClaims a)
forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken OIDC
oidc Jwt
jwt
          Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
fromMaybe Bool
True (Maybe Bool -> Bool) -> Maybe Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
(/=) (ByteString -> ByteString -> Bool)
-> Maybe ByteString -> Maybe (ByteString -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ByteString
savedNonce Maybe (ByteString -> Bool) -> Maybe ByteString -> Maybe Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IdTokenClaims a -> Maybe ByteString
forall a. IdTokenClaims a -> Maybe ByteString
nonce IdTokenClaims a
idToken)
                (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
                (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ OpenIdException -> IO ()
forall e a. Exception e => e -> IO a
throwIO
                (OpenIdException -> IO ()) -> OpenIdException -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> OpenIdException
ValidationException Text
"Nonce does not match request."
          IdTokenClaims a -> m (IdTokenClaims a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure IdTokenClaims a
idToken
      else IO (IdTokenClaims a) -> m (IdTokenClaims a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IdTokenClaims a) -> m (IdTokenClaims a))
-> IO (IdTokenClaims a) -> m (IdTokenClaims a)
forall a b. (a -> b) -> a -> b
$ OpenIdException -> IO (IdTokenClaims a)
forall e a. Exception e => e -> IO a
throwIO (OpenIdException -> IO (IdTokenClaims a))
-> OpenIdException -> IO (IdTokenClaims a)
forall a b. (a -> b) -> a -> b
$ Text -> OpenIdException
ValidationException (Text -> OpenIdException) -> Text -> OpenIdException
forall a b. (a -> b) -> a -> b
$ Text
"Inconsistent state: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> OnDecodeError -> ByteString -> Text
decodeUtf8With OnDecodeError
lenientDecode ByteString
stateFromIdP

-- | Make URL for Authorization Request.
{-# WARNING getAuthenticationRequestUrl "This function doesn't manage state and nonce. Use prepareAuthenticationRequestUrl only unless your IdP doesn't support state and/or nonce." #-}
getAuthenticationRequestUrl
    :: (MonadIO m)
    => OIDC
    -> Scope            -- ^ used to specify what are privileges requested for tokens. (use `ScopeValue`)
    -> Maybe State      -- ^ used for CSRF mitigation. (recommended parameter)
    -> Parameters       -- ^ Optional parameters
    -> m URI
getAuthenticationRequestUrl :: OIDC -> Scope -> Maybe ByteString -> Parameters -> m URI
getAuthenticationRequestUrl OIDC
oidc Scope
scope Maybe ByteString
state Parameters
params = do
    Request
req <- IO Request -> m Request
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Request -> m Request) -> IO Request -> m Request
forall a b. (a -> b) -> a -> b
$ Text -> IO Request
forall (m :: * -> *). MonadThrow m => Text -> m Request
parseUrl Text
endpoint IO Request -> (HttpException -> IO Request) -> IO Request
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` HttpException -> IO Request
forall (m :: * -> *) a. MonadCatch m => HttpException -> m a
I.rethrow
    URI -> m URI
forall (m :: * -> *) a. Monad m => a -> m a
return (URI -> m URI) -> URI -> m URI
forall a b. (a -> b) -> a -> b
$ Request -> URI
getUri (Request -> URI) -> Request -> URI
forall a b. (a -> b) -> a -> b
$ Parameters -> Request -> Request
setQueryString Parameters
query Request
req
  where
    endpoint :: Text
endpoint  = OIDC -> Text
oidcAuthorizationServerUrl OIDC
oidc
    query :: Parameters
query     = Parameters
requireds Parameters -> Parameters -> Parameters
forall a. [a] -> [a] -> [a]
++ Parameters
state' Parameters -> Parameters -> Parameters
forall a. [a] -> [a] -> [a]
++ Parameters
params
    requireds :: Parameters
requireds =
        [ (ByteString
"response_type", ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"id_token")
        , (ByteString
"response_mode", ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"form_post")
        , (ByteString
"client_id",     ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ OIDC -> ByteString
oidcClientId OIDC
oidc)
        , (ByteString
"redirect_uri",  ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ OIDC -> ByteString
oidcRedirectUri OIDC
oidc)
        , (ByteString
"scope",         ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> (Scope -> ByteString) -> Scope -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
B.pack (String -> ByteString) -> (Scope -> String) -> Scope -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unwords ([String] -> String) -> (Scope -> [String]) -> Scope -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> [String]
forall a. Eq a => [a] -> [a]
nub ([String] -> [String]) -> (Scope -> [String]) -> Scope -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> String) -> Scope -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Text -> String
unpack (Scope -> Maybe ByteString) -> Scope -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Text
openIdText -> Scope -> Scope
forall a. a -> [a] -> [a]
:Scope
scope)
        ]
    state' :: Parameters
state' =
        case Maybe ByteString
state of
            Just ByteString
_  -> [(ByteString
"state", Maybe ByteString
state)]
            Maybe ByteString
Nothing -> []