{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Yesod.Auth.OAuth2.Dispatch
    ( FetchToken
    , fetchAccessToken
    , fetchAccessToken2
    , FetchCreds
    , dispatchAuthRequest
    )
where

import Control.Exception.Safe
import Control.Monad (unless, (<=<))
import Crypto.Random (getRandomBytes)
import Data.ByteArray.Encoding (Base(Base64), convertToBase)
import Data.ByteString (ByteString)
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Network.HTTP.Conduit (Manager)
import Network.OAuth.OAuth2
import Network.OAuth.OAuth2.TokenRequest (Errors)
import URI.ByteString.Extension
import Yesod.Auth hiding (ServerError)
import Yesod.Auth.OAuth2.ErrorResponse
import Yesod.Auth.OAuth2.Exception
import Yesod.Core hiding (ErrorResponse)

-- | How to fetch an @'OAuth2Token'@
--
-- This will be 'fetchAccessToken' or 'fetchAccessToken2'
--
type FetchToken
    = Manager -> OAuth2 -> ExchangeToken -> IO (OAuth2Result Errors OAuth2Token)

-- | How to take an @'OAuth2Token'@ and retrieve user credentials
type FetchCreds m = Manager -> OAuth2Token -> IO (Creds m)

-- | Dispatch the various OAuth2 handshake routes
dispatchAuthRequest
    :: Text             -- ^ Name
    -> OAuth2           -- ^ Service details
    -> FetchToken       -- ^ How to get a token
    -> FetchCreds m     -- ^ How to get credentials
    -> Text             -- ^ Method
    -> [Text]           -- ^ Path pieces
    -> AuthHandler m TypedContent
dispatchAuthRequest :: Text
-> OAuth2
-> FetchToken
-> FetchCreds m
-> Text
-> [Text]
-> AuthHandler m TypedContent
dispatchAuthRequest Text
name OAuth2
oauth2 FetchToken
_ FetchCreds m
_ Text
"GET" [Text
"forward"] =
    Text -> OAuth2 -> AuthHandler m TypedContent
forall m. Text -> OAuth2 -> AuthHandler m TypedContent
dispatchForward Text
name OAuth2
oauth2
dispatchAuthRequest Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds m
getCreds Text
"GET" [Text
"callback"] =
    Text
-> OAuth2
-> FetchToken
-> FetchCreds m
-> AuthHandler m TypedContent
forall m.
Text
-> OAuth2
-> FetchToken
-> FetchCreds m
-> AuthHandler m TypedContent
dispatchCallback Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds m
getCreds
dispatchAuthRequest Text
_ OAuth2
_ FetchToken
_ FetchCreds m
_ Text
_ [Text]
_ = m TypedContent
forall (m :: * -> *) a. MonadHandler m => m a
notFound

-- | Handle @GET \/forward@
--
-- 1. Set a random CSRF token in our session
-- 2. Redirect to the Provider's authorization URL
--
dispatchForward :: Text -> OAuth2 -> AuthHandler m TypedContent
dispatchForward :: Text -> OAuth2 -> AuthHandler m TypedContent
dispatchForward Text
name OAuth2
oauth2 = do
    Text
csrf <- Text -> m Text
forall (m :: * -> *). MonadHandler m => Text -> m Text
setSessionCSRF (Text -> m Text) -> Text -> m Text
forall a b. (a -> b) -> a -> b
$ Text -> Text
tokenSessionKey Text
name
    OAuth2
oauth2' <- Text -> OAuth2 -> Text -> AuthHandler m OAuth2
forall m. Text -> OAuth2 -> Text -> AuthHandler m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf
    Text -> m TypedContent
forall (m :: * -> *) url a.
(MonadHandler m, RedirectUrl (HandlerSite m) url) =>
url -> m a
redirect (Text -> m TypedContent) -> Text -> m TypedContent
forall a b. (a -> b) -> a -> b
$ URI -> Text
toText (URI -> Text) -> URI -> Text
forall a b. (a -> b) -> a -> b
$ OAuth2 -> URI
authorizationUrl OAuth2
oauth2'

-- | Handle @GET \/callback@
--
-- 1. Verify the URL's CSRF token matches our session
-- 2. Use the code parameter to fetch an AccessToken for the Provider
-- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider
--
dispatchCallback
    :: Text
    -> OAuth2
    -> FetchToken
    -> FetchCreds m
    -> AuthHandler m TypedContent
dispatchCallback :: Text
-> OAuth2
-> FetchToken
-> FetchCreds m
-> AuthHandler m TypedContent
dispatchCallback Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds m
getCreds = do
    Text
csrf <- Text -> m Text
forall (m :: * -> *). MonadHandler m => Text -> m Text
verifySessionCSRF (Text -> m Text) -> Text -> m Text
forall a b. (a -> b) -> a -> b
$ Text -> Text
tokenSessionKey Text
name
    (ErrorResponse -> m Any) -> m ()
forall (m :: * -> *) a.
MonadHandler m =>
(ErrorResponse -> m a) -> m ()
onErrorResponse ((ErrorResponse -> m Any) -> m ())
-> (ErrorResponse -> m Any) -> m ()
forall a b. (a -> b) -> a -> b
$ Text -> ErrorResponse -> AuthHandler m Any
forall m a. Text -> ErrorResponse -> AuthHandler m a
oauth2HandshakeError Text
name
    Text
code <- Text -> m Text
forall (m :: * -> *). MonadHandler m => Text -> m Text
requireGetParam Text
"code"
    Manager
manager <- m Manager
forall master (m :: * -> *).
(YesodAuth master, MonadHandler m, HandlerSite m ~ master) =>
m Manager
authHttpManager
    OAuth2
oauth2' <- Text -> OAuth2 -> Text -> AuthHandler m OAuth2
forall m. Text -> OAuth2 -> Text -> AuthHandler m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf
    OAuth2Token
token <- IO (Either (OAuth2Error Errors) OAuth2Token)
-> AuthHandler m OAuth2Token
forall e a m. Show e => IO (Either e a) -> AuthHandler m a
errLeft (IO (Either (OAuth2Error Errors) OAuth2Token)
 -> AuthHandler m OAuth2Token)
-> IO (Either (OAuth2Error Errors) OAuth2Token)
-> AuthHandler m OAuth2Token
forall a b. (a -> b) -> a -> b
$ FetchToken
getToken Manager
manager OAuth2
oauth2' (ExchangeToken -> IO (Either (OAuth2Error Errors) OAuth2Token))
-> ExchangeToken -> IO (Either (OAuth2Error Errors) OAuth2Token)
forall a b. (a -> b) -> a -> b
$ Text -> ExchangeToken
ExchangeToken Text
code
    Creds m
creds <- IO (Either SomeException (Creds m)) -> AuthHandler m (Creds m)
forall e a m. Show e => IO (Either e a) -> AuthHandler m a
errLeft (IO (Either SomeException (Creds m)) -> AuthHandler m (Creds m))
-> IO (Either SomeException (Creds m)) -> AuthHandler m (Creds m)
forall a b. (a -> b) -> a -> b
$ IO (Creds m) -> IO (Either SomeException (Creds m))
forall a. IO a -> IO (Either SomeException a)
tryFetchCreds (IO (Creds m) -> IO (Either SomeException (Creds m)))
-> IO (Creds m) -> IO (Either SomeException (Creds m))
forall a b. (a -> b) -> a -> b
$ FetchCreds m
getCreds Manager
manager OAuth2Token
token
    Creds (HandlerSite m) -> m TypedContent
forall (m :: * -> *).
(MonadHandler m, YesodAuth (HandlerSite m)) =>
Creds (HandlerSite m) -> m TypedContent
setCredsRedirect Creds m
Creds (HandlerSite m)
creds
  where
    errLeft :: Show e => IO (Either e a) -> AuthHandler m a
    errLeft :: IO (Either e a) -> AuthHandler m a
errLeft = (e -> m a) -> (a -> m a) -> Either e a -> m a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Text -> e -> AuthHandler m a
forall e m a. Show e => Text -> e -> AuthHandler m a
unexpectedError Text
name) a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either e a -> m a)
-> (IO (Either e a) -> m (Either e a)) -> IO (Either e a) -> m a
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< IO (Either e a) -> m (Either e a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

-- | Handle an OAuth2 @'ErrorResponse'@
--
-- These are things coming from the OAuth2 provider such an Invalid Grant or
-- Invalid Scope and /may/ be user-actionable. We've coded them to have an
-- @'erUserMessage'@ that we are comfortable displaying to the user as part of
-- the redirect, just in case.
--
oauth2HandshakeError :: Text -> ErrorResponse -> AuthHandler m a
oauth2HandshakeError :: Text -> ErrorResponse -> AuthHandler m a
oauth2HandshakeError Text
name ErrorResponse
err = do
    $(Int
String
LogLevel
String -> Text
String -> String -> String -> CharPos -> CharPos -> Loc
Text -> Text
Loc -> Text -> LogLevel -> Text -> m ()
(Text -> m ()) -> (Text -> Text) -> Text -> m ()
forall a. a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall (m :: * -> *) msg.
(MonadLogger m, ToLogStr msg) =>
Loc -> Text -> LogLevel -> msg -> m ()
pack :: String -> Text
monadLoggerLog :: forall (m :: * -> *) msg.
(MonadLogger m, ToLogStr msg) =>
Loc -> Text -> LogLevel -> msg -> m ()
id :: forall a. a -> a
. :: forall b c a. (b -> c) -> (a -> b) -> a -> c
logError) (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Handshake failure in " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" plugin: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ErrorResponse -> Text
forall a. Show a => a -> Text
tshow ErrorResponse
err
    Text -> AuthHandler m a
forall m a. Text -> AuthHandler m a
redirectMessage (Text -> AuthHandler m a) -> Text -> AuthHandler m a
forall a b. (a -> b) -> a -> b
$ Text
"OAuth2 handshake failure: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ErrorResponse -> Text
erUserMessage ErrorResponse
err

-- | Handle an unexpected error
--
-- This would be some unexpected exception while processing the callback.
-- Therefore, the user should see an opaque message and the details go only to
-- the server logs.
--
unexpectedError :: Show e => Text -> e -> AuthHandler m a
unexpectedError :: Text -> e -> AuthHandler m a
unexpectedError Text
name e
err = do
    $(Int
String
LogLevel
String -> Text
String -> String -> String -> CharPos -> CharPos -> Loc
Text -> Text
Loc -> Text -> LogLevel -> Text -> m ()
(Text -> m ()) -> (Text -> Text) -> Text -> m ()
forall a. a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall (m :: * -> *) msg.
(MonadLogger m, ToLogStr msg) =>
Loc -> Text -> LogLevel -> msg -> m ()
pack :: String -> Text
monadLoggerLog :: forall (m :: * -> *) msg.
(MonadLogger m, ToLogStr msg) =>
Loc -> Text -> LogLevel -> msg -> m ()
id :: forall a. a -> a
. :: forall b c a. (b -> c) -> (a -> b) -> a -> c
logError) (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Error in " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" OAuth2 plugin: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> e -> Text
forall a. Show a => a -> Text
tshow e
err
    Text -> AuthHandler m a
forall m a. Text -> AuthHandler m a
redirectMessage Text
"Unexpected error logging in with OAuth2"

redirectMessage :: Text -> AuthHandler m a
redirectMessage :: Text -> AuthHandler m a
redirectMessage Text
msg = do
    Route Auth -> Route m
toParent <- m (Route Auth -> Route m)
forall (m :: * -> *).
MonadHandler m =>
m (Route (SubHandlerSite m) -> Route (HandlerSite m))
getRouteToParent
    Html -> m ()
forall (m :: * -> *). MonadHandler m => Html -> m ()
setMessage (Html -> m ()) -> Html -> m ()
forall a b. (a -> b) -> a -> b
$ Text -> Html
forall a. ToMarkup a => a -> Html
toHtml Text
msg
    Route m -> m a
forall (m :: * -> *) url a.
(MonadHandler m, RedirectUrl (HandlerSite m) url) =>
url -> m a
redirect (Route m -> m a) -> Route m -> m a
forall a b. (a -> b) -> a -> b
$ Route Auth -> Route m
toParent Route Auth
LoginR

tryFetchCreds :: IO a -> IO (Either SomeException a)
tryFetchCreds :: IO a -> IO (Either SomeException a)
tryFetchCreds IO a
f =
    (a -> Either SomeException a
forall a b. b -> Either a b
Right (a -> Either SomeException a)
-> IO a -> IO (Either SomeException a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
f)
        IO (Either SomeException a)
-> (IOException -> IO (Either SomeException a))
-> IO (Either SomeException a)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (\(IOException
ex :: IOException) -> Either SomeException a -> IO (Either SomeException a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SomeException a -> IO (Either SomeException a))
-> Either SomeException a -> IO (Either SomeException a)
forall a b. (a -> b) -> a -> b
$ SomeException -> Either SomeException a
forall a b. a -> Either a b
Left (SomeException -> Either SomeException a)
-> SomeException -> Either SomeException a
forall a b. (a -> b) -> a -> b
$ IOException -> SomeException
forall e. Exception e => e -> SomeException
toException IOException
ex)
        IO (Either SomeException a)
-> (YesodOAuth2Exception -> IO (Either SomeException a))
-> IO (Either SomeException a)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (\(YesodOAuth2Exception
ex :: YesodOAuth2Exception) -> Either SomeException a -> IO (Either SomeException a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SomeException a -> IO (Either SomeException a))
-> Either SomeException a -> IO (Either SomeException a)
forall a b. (a -> b) -> a -> b
$ SomeException -> Either SomeException a
forall a b. a -> Either a b
Left (SomeException -> Either SomeException a)
-> SomeException -> Either SomeException a
forall a b. (a -> b) -> a -> b
$ YesodOAuth2Exception -> SomeException
forall e. Exception e => e -> SomeException
toException YesodOAuth2Exception
ex)

withCallbackAndState :: Text -> OAuth2 -> Text -> AuthHandler m OAuth2
withCallbackAndState :: Text -> OAuth2 -> Text -> AuthHandler m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf = do
    let url :: Route Auth
url = Text -> [Text] -> Route Auth
PluginR Text
name [Text
"callback"]
    Route Auth -> Text
render <- m (Route Auth -> Text)
forall (m :: * -> *).
MonadHandler m =>
m (Route (SubHandlerSite m) -> Text)
getParentUrlRender
    let callbackText :: Text
callbackText = Route Auth -> Text
render Route Auth
url

    URI
callback <-
        m URI -> (URI -> m URI) -> Maybe URI -> m URI
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
                (IO URI -> m URI
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
                (IO URI -> m URI) -> IO URI -> m URI
forall a b. (a -> b) -> a -> b
$ String -> IO URI
forall (m :: * -> *) a.
(MonadThrow m, HasCallStack) =>
String -> m a
throwString
                (String -> IO URI) -> String -> IO URI
forall a b. (a -> b) -> a -> b
$ String
"Invalid callback URI: "
                String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Text -> String
T.unpack Text
callbackText
                String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
". Not using an absolute Approot?"
                )
                URI -> m URI
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            (Maybe URI -> m URI) -> Maybe URI -> m URI
forall a b. (a -> b) -> a -> b
$ Text -> Maybe URI
fromText Text
callbackText

    OAuth2 -> m OAuth2
forall (f :: * -> *) a. Applicative f => a -> f a
pure OAuth2
oauth2
        { oauthCallback :: Maybe URI
oauthCallback = URI -> Maybe URI
forall a. a -> Maybe a
Just URI
callback
        , oauthOAuthorizeEndpoint :: URI
oauthOAuthorizeEndpoint =
            OAuth2 -> URI
oauthOAuthorizeEndpoint OAuth2
oauth2
                URI -> [(ByteString, ByteString)] -> URI
forall a. URIRef a -> [(ByteString, ByteString)] -> URIRef a
`withQuery` [(ByteString
"state", Text -> ByteString
encodeUtf8 Text
csrf)]
        }

getParentUrlRender :: MonadHandler m => m (Route (SubHandlerSite m) -> Text)
getParentUrlRender :: m (Route (SubHandlerSite m) -> Text)
getParentUrlRender = (Route (HandlerSite m) -> Text)
-> (Route (SubHandlerSite m) -> Route (HandlerSite m))
-> Route (SubHandlerSite m)
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) ((Route (HandlerSite m) -> Text)
 -> (Route (SubHandlerSite m) -> Route (HandlerSite m))
 -> Route (SubHandlerSite m)
 -> Text)
-> m (Route (HandlerSite m) -> Text)
-> m ((Route (SubHandlerSite m) -> Route (HandlerSite m))
      -> Route (SubHandlerSite m) -> Text)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Route (HandlerSite m) -> Text)
forall (m :: * -> *).
MonadHandler m =>
m (Route (HandlerSite m) -> Text)
getUrlRender m ((Route (SubHandlerSite m) -> Route (HandlerSite m))
   -> Route (SubHandlerSite m) -> Text)
-> m (Route (SubHandlerSite m) -> Route (HandlerSite m))
-> m (Route (SubHandlerSite m) -> Text)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> m (Route (SubHandlerSite m) -> Route (HandlerSite m))
forall (m :: * -> *).
MonadHandler m =>
m (Route (SubHandlerSite m) -> Route (HandlerSite m))
getRouteToParent

-- | Set a random, 30-character value in the session
setSessionCSRF :: MonadHandler m => Text -> m Text
setSessionCSRF :: Text -> m Text
setSessionCSRF Text
sessionKey = do
    Text
csrfToken <- IO Text -> m Text
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Text
randomToken
    Text
csrfToken Text -> m () -> m Text
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Text -> Text -> m ()
forall (m :: * -> *). MonadHandler m => Text -> Text -> m ()
setSession Text
sessionKey Text
csrfToken
  where
    randomToken :: IO Text
randomToken =
        ByteString -> Text
decodeUtf8 (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Base -> ByteString -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
convertToBase @ByteString Base
Base64 (ByteString -> Text) -> IO ByteString -> IO Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
64

-- | Verify the callback provided the same CSRF token as in our session
verifySessionCSRF :: MonadHandler m => Text -> m Text
verifySessionCSRF :: Text -> m Text
verifySessionCSRF Text
sessionKey = do
    Text
token <- Text -> m Text
forall (m :: * -> *). MonadHandler m => Text -> m Text
requireGetParam Text
"state"
    Maybe Text
sessionToken <- Text -> m (Maybe Text)
forall (m :: * -> *). MonadHandler m => Text -> m (Maybe Text)
lookupSession Text
sessionKey
    Text -> m ()
forall (m :: * -> *). MonadHandler m => Text -> m ()
deleteSession Text
sessionKey

    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Maybe Text
sessionToken Maybe Text -> Maybe Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text -> Maybe Text
forall a. a -> Maybe a
Just Text
token)
        (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Text -> m ()
forall (m :: * -> *) a. MonadHandler m => Text -> m a
permissionDenied Text
"Invalid OAuth2 state token"

    Text -> m Text
forall (m :: * -> *) a. Monad m => a -> m a
return Text
token

requireGetParam :: MonadHandler m => Text -> m Text
requireGetParam :: Text -> m Text
requireGetParam Text
key = do
    Maybe Text
m <- Text -> m (Maybe Text)
forall (m :: * -> *). MonadHandler m => Text -> m (Maybe Text)
lookupGetParam Text
key
    m Text -> (Text -> m Text) -> Maybe Text -> m Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m Text
errInvalidArgs Text -> m Text
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Text
m
  where
    errInvalidArgs :: m Text
errInvalidArgs = [Text] -> m Text
forall (m :: * -> *) a. MonadHandler m => [Text] -> m a
invalidArgs [Text
"The '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
key Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"' parameter is required"]

tokenSessionKey :: Text -> Text
tokenSessionKey :: Text -> Text
tokenSessionKey Text
name = Text
"_yesod_oauth2_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name

tshow :: Show a => a -> Text
tshow :: a -> Text
tshow = String -> Text
T.pack (String -> Text) -> (a -> String) -> a -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> String
forall a. Show a => a -> String
show