{-# LANGUAGE OverloadedStrings #-}
{-|
    Module: Web.OIDC.Client.CodeFlow
    Maintainer: krdlab@gmail.com
    Stability: experimental
-}
module Web.OIDC.Client.IdTokenFlow
    (
      getAuthenticationRequestUrl
    , getValidIdTokenClaims
    , prepareAuthenticationRequestUrl
    ) where

import           Control.Monad                      (when)
import           Control.Exception                  (throwIO, catch)
import           Control.Monad.IO.Class             (MonadIO, liftIO)
import           Data.Aeson                         (FromJSON)
import qualified Data.ByteString.Char8              as B
import           Data.List                          (nub)
import           Data.Maybe                         (isNothing, fromMaybe)
import           Data.Monoid                        ((<>))
import           Data.Text                          (unpack)
import           Data.Text.Encoding                 (decodeUtf8With)
import           Data.Text.Encoding.Error           (lenientDecode)
import qualified Jose.Jwt                           as Jwt
import           Network.HTTP.Client                (getUri, setQueryString)
import           Network.URI                        (URI)

import           Prelude                            hiding (exp)

import           Web.OIDC.Client.Internal           (parseUrl)
import qualified Web.OIDC.Client.Internal           as I
import           Web.OIDC.Client.Settings           (OIDC (..))
import           Web.OIDC.Client.Tokens             (IdTokenClaims (..), validateIdToken)
import           Web.OIDC.Client.Types              (OpenIdException (..),
                                                     Parameters, Scope,
                                                     SessionStore (..), State,
                                                     openId)

-- | Make URL for Authorization Request after generating state and nonce from 'SessionStore'.
prepareAuthenticationRequestUrl
    :: (MonadIO m)
    => SessionStore m
    -> OIDC
    -> Scope            -- ^ used to specify what are privileges requested for tokens. (use `ScopeValue`)
    -> Parameters       -- ^ Optional parameters
    -> m URI
prepareAuthenticationRequestUrl :: forall (m :: * -> *).
MonadIO m =>
SessionStore m -> OIDC -> Scope -> Parameters -> m URI
prepareAuthenticationRequestUrl SessionStore m
store OIDC
oidc Scope
scope Parameters
params = do
    ByteString
state <- forall (m :: * -> *). SessionStore m -> m ByteString
sessionStoreGenerate SessionStore m
store
    ByteString
nonce' <- forall (m :: * -> *). SessionStore m -> m ByteString
sessionStoreGenerate SessionStore m
store
    forall (m :: * -> *).
SessionStore m -> ByteString -> ByteString -> m ()
sessionStoreSave SessionStore m
store ByteString
state ByteString
nonce'
    forall (m :: * -> *).
MonadIO m =>
OIDC -> Scope -> Maybe ByteString -> Parameters -> m URI
getAuthenticationRequestUrl OIDC
oidc Scope
scope (forall a. a -> Maybe a
Just ByteString
state) forall a b. (a -> b) -> a -> b
$ Parameters
params forall a. [a] -> [a] -> [a]
++ [(ByteString
"nonce", forall a. a -> Maybe a
Just ByteString
nonce')]

-- | Get and validate access token and with code and state stored in the 'SessionStore'.
--   Then deletes session info by 'sessionStoreDelete'.
getValidIdTokenClaims
    :: (MonadIO m, FromJSON a)
    => SessionStore m
    -> OIDC
    -> State
    -> m B.ByteString
    -> m (IdTokenClaims a)
getValidIdTokenClaims :: forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
SessionStore m
-> OIDC -> ByteString -> m ByteString -> m (IdTokenClaims a)
getValidIdTokenClaims SessionStore m
store OIDC
oidc ByteString
stateFromIdP m ByteString
getIdToken = do
    Maybe ByteString
msavedNonce <- forall (m :: * -> *).
SessionStore m -> ByteString -> m (Maybe ByteString)
sessionStoreGet SessionStore m
store ByteString
stateFromIdP
    ByteString
savedNonce <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => e -> IO a
throwIO OpenIdException
UnknownState) forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ByteString
msavedNonce
    Jwt
jwt <- ByteString -> Jwt
Jwt.Jwt forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m ByteString
getIdToken
    forall (m :: * -> *). SessionStore m -> m ()
sessionStoreDelete SessionStore m
store
    IdTokenClaims a
idToken <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
(MonadIO m, FromJSON a) =>
OIDC -> Jwt -> m (IdTokenClaims a)
validateIdToken OIDC
oidc Jwt
jwt
    ByteString
nonce' <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => e -> IO a
throwIO OpenIdException
MissingNonceInResponse) forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. IdTokenClaims a -> Maybe ByteString
nonce IdTokenClaims a
idToken)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
nonce' forall a. Eq a => a -> a -> Bool
/= ByteString
savedNonce) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => e -> IO a
throwIO OpenIdException
MismatchedNonces
    forall (f :: * -> *) a. Applicative f => a -> f a
pure IdTokenClaims a
idToken

-- | Make URL for Authorization Request.
{-# WARNING getAuthenticationRequestUrl "This function doesn't manage state and nonce. Use prepareAuthenticationRequestUrl only unless your IdP doesn't support state and/or nonce." #-}
getAuthenticationRequestUrl
    :: (MonadIO m)
    => OIDC
    -> Scope            -- ^ used to specify what are privileges requested for tokens. (use `ScopeValue`)
    -> Maybe State      -- ^ used for CSRF mitigation. (recommended parameter)
    -> Parameters       -- ^ Optional parameters
    -> m URI
getAuthenticationRequestUrl :: forall (m :: * -> *).
MonadIO m =>
OIDC -> Scope -> Maybe ByteString -> Parameters -> m URI
getAuthenticationRequestUrl OIDC
oidc Scope
scope Maybe ByteString
state Parameters
params = do
    Request
req <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadThrow m => Text -> m Request
parseUrl Text
endpoint forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` forall (m :: * -> *) a. MonadCatch m => HttpException -> m a
I.rethrow
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Request -> URI
getUri forall a b. (a -> b) -> a -> b
$ Parameters -> Request -> Request
setQueryString Parameters
query Request
req
  where
    endpoint :: Text
endpoint  = OIDC -> Text
oidcAuthorizationServerUrl OIDC
oidc
    query :: Parameters
query     = Parameters
requireds forall a. [a] -> [a] -> [a]
++ Parameters
state' forall a. [a] -> [a] -> [a]
++ Parameters
params
    requireds :: Parameters
requireds =
        [ (ByteString
"response_type", forall a. a -> Maybe a
Just ByteString
"id_token")
        , (ByteString
"response_mode", forall a. a -> Maybe a
Just ByteString
"form_post")
        , (ByteString
"client_id",     forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ OIDC -> ByteString
oidcClientId OIDC
oidc)
        , (ByteString
"redirect_uri",  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ OIDC -> ByteString
oidcRedirectUri OIDC
oidc)
        , (ByteString
"scope",         forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
B.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
unwords forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => [a] -> [a]
nub forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Text -> String
unpack forall a b. (a -> b) -> a -> b
$ Text
openIdforall a. a -> [a] -> [a]
:Scope
scope)
        ]
    state' :: Parameters
state' =
        case Maybe ByteString
state of
            Just ByteString
_  -> [(ByteString
"state", Maybe ByteString
state)]
            Maybe ByteString
Nothing -> []