{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

module Yesod.Auth.OAuth2.Dispatch
    ( FetchToken
    , fetchAccessToken
    , fetchAccessToken2
    , FetchCreds
    , dispatchAuthRequest
    ) where

import Control.Monad.Except
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Network.HTTP.Conduit (Manager)
import Network.OAuth.OAuth2
import Network.OAuth.OAuth2.TokenRequest (Errors)
import URI.ByteString.Extension
import UnliftIO.Exception
import Yesod.Auth hiding (ServerError)
import Yesod.Auth.OAuth2.DispatchError
import Yesod.Auth.OAuth2.ErrorResponse
import Yesod.Auth.OAuth2.Random
import Yesod.Core hiding (ErrorResponse)

-- | How to fetch an @'OAuth2Token'@
--
-- This will be 'fetchAccessToken' or 'fetchAccessToken2'
--
type FetchToken
    = Manager -> OAuth2 -> ExchangeToken -> IO (OAuth2Result Errors OAuth2Token)

-- | How to take an @'OAuth2Token'@ and retrieve user credentials
type FetchCreds m = Manager -> OAuth2Token -> IO (Creds m)

-- | Dispatch the various OAuth2 handshake routes
dispatchAuthRequest
    :: Text             -- ^ Name
    -> OAuth2           -- ^ Service details
    -> FetchToken       -- ^ How to get a token
    -> FetchCreds m     -- ^ How to get credentials
    -> Text             -- ^ Method
    -> [Text]           -- ^ Path pieces
    -> AuthHandler m TypedContent
dispatchAuthRequest :: Text
-> OAuth2
-> FetchToken
-> FetchCreds m
-> Text
-> [Text]
-> AuthHandler m TypedContent
dispatchAuthRequest Text
name OAuth2
oauth2 FetchToken
_ FetchCreds m
_ Text
"GET" [Text
"forward"] =
    ExceptT DispatchError m TypedContent -> m TypedContent
forall site (m :: * -> *).
MonadAuthHandler site m =>
ExceptT DispatchError m TypedContent -> m TypedContent
handleDispatchError (ExceptT DispatchError m TypedContent -> m TypedContent)
-> ExceptT DispatchError m TypedContent -> m TypedContent
forall a b. (a -> b) -> a -> b
$ Text -> OAuth2 -> ExceptT DispatchError m TypedContent
forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> m TypedContent
dispatchForward Text
name OAuth2
oauth2
dispatchAuthRequest Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds m
getCreds Text
"GET" [Text
"callback"] =
    ExceptT DispatchError m TypedContent -> m TypedContent
forall site (m :: * -> *).
MonadAuthHandler site m =>
ExceptT DispatchError m TypedContent -> m TypedContent
handleDispatchError (ExceptT DispatchError m TypedContent -> m TypedContent)
-> ExceptT DispatchError m TypedContent -> m TypedContent
forall a b. (a -> b) -> a -> b
$ Text
-> OAuth2
-> FetchToken
-> FetchCreds m
-> ExceptT DispatchError m TypedContent
forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> FetchToken -> FetchCreds site -> m TypedContent
dispatchCallback Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds m
getCreds
dispatchAuthRequest Text
_ OAuth2
_ FetchToken
_ FetchCreds m
_ Text
_ [Text]
_ = m TypedContent
forall (m :: * -> *) a. MonadHandler m => m a
notFound

-- | Handle @GET \/forward@
--
-- 1. Set a random CSRF token in our session
-- 2. Redirect to the Provider's authorization URL
--
dispatchForward
    :: (MonadError DispatchError m, MonadAuthHandler site m)
    => Text
    -> OAuth2
    -> m TypedContent
dispatchForward :: Text -> OAuth2 -> m TypedContent
dispatchForward Text
name OAuth2
oauth2 = do
    Text
csrf <- Text -> m Text
forall (m :: * -> *). MonadHandler m => Text -> m Text
setSessionCSRF (Text -> m Text) -> Text -> m Text
forall a b. (a -> b) -> a -> b
$ Text -> Text
tokenSessionKey Text
name
    OAuth2
oauth2' <- Text -> OAuth2 -> Text -> m OAuth2
forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> Text -> m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf
    Text -> m TypedContent
forall (m :: * -> *) url a.
(MonadHandler m, RedirectUrl (HandlerSite m) url) =>
url -> m a
redirect (Text -> m TypedContent) -> Text -> m TypedContent
forall a b. (a -> b) -> a -> b
$ URI -> Text
toText (URI -> Text) -> URI -> Text
forall a b. (a -> b) -> a -> b
$ OAuth2 -> URI
authorizationUrl OAuth2
oauth2'

-- | Handle @GET \/callback@
--
-- 1. Verify the URL's CSRF token matches our session
-- 2. Use the code parameter to fetch an AccessToken for the Provider
-- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider
--
dispatchCallback
    :: (MonadError DispatchError m, MonadAuthHandler site m)
    => Text
    -> OAuth2
    -> FetchToken
    -> FetchCreds site
    -> m TypedContent
dispatchCallback :: Text -> OAuth2 -> FetchToken -> FetchCreds site -> m TypedContent
dispatchCallback Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds site
getCreds = do
    (ErrorResponse -> m Any) -> m ()
forall (m :: * -> *) a.
MonadHandler m =>
(ErrorResponse -> m a) -> m ()
onErrorResponse ((ErrorResponse -> m Any) -> m ())
-> (ErrorResponse -> m Any) -> m ()
forall a b. (a -> b) -> a -> b
$ DispatchError -> m Any
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m Any)
-> (ErrorResponse -> DispatchError) -> ErrorResponse -> m Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorResponse -> DispatchError
OAuth2HandshakeError
    Text
csrf <- Text -> m Text
forall (m :: * -> *).
(MonadError DispatchError m, MonadHandler m) =>
Text -> m Text
verifySessionCSRF (Text -> m Text) -> Text -> m Text
forall a b. (a -> b) -> a -> b
$ Text -> Text
tokenSessionKey Text
name
    Text
code <- Text -> m Text
forall (m :: * -> *).
(MonadError DispatchError m, MonadHandler m) =>
Text -> m Text
requireGetParam Text
"code"
    Manager
manager <- m Manager
forall master (m :: * -> *).
(YesodAuth master, MonadHandler m, HandlerSite m ~ master) =>
m Manager
authHttpManager
    OAuth2
oauth2' <- Text -> OAuth2 -> Text -> m OAuth2
forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> Text -> m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf
    OAuth2Token
token <- (OAuth2Error Errors -> m OAuth2Token)
-> (OAuth2Token -> m OAuth2Token)
-> Either (OAuth2Error Errors) OAuth2Token
-> m OAuth2Token
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (DispatchError -> m OAuth2Token
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m OAuth2Token)
-> (OAuth2Error Errors -> DispatchError)
-> OAuth2Error Errors
-> m OAuth2Token
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OAuth2Error Errors -> DispatchError
OAuth2ResultError) OAuth2Token -> m OAuth2Token
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        (Either (OAuth2Error Errors) OAuth2Token -> m OAuth2Token)
-> m (Either (OAuth2Error Errors) OAuth2Token) -> m OAuth2Token
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Either (OAuth2Error Errors) OAuth2Token)
-> m (Either (OAuth2Error Errors) OAuth2Token)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (FetchToken
getToken Manager
manager OAuth2
oauth2' (ExchangeToken -> IO (Either (OAuth2Error Errors) OAuth2Token))
-> ExchangeToken -> IO (Either (OAuth2Error Errors) OAuth2Token)
forall a b. (a -> b) -> a -> b
$ Text -> ExchangeToken
ExchangeToken Text
code)
    Creds site
creds <-
        IO (Creds site) -> m (Creds site)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (FetchCreds site
getCreds Manager
manager OAuth2Token
token)
        m (Creds site) -> (IOException -> m (Creds site)) -> m (Creds site)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (DispatchError -> m (Creds site)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m (Creds site))
-> (IOException -> DispatchError) -> IOException -> m (Creds site)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOException -> DispatchError
FetchCredsIOException)
        m (Creds site)
-> (YesodOAuth2Exception -> m (Creds site)) -> m (Creds site)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (DispatchError -> m (Creds site)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m (Creds site))
-> (YesodOAuth2Exception -> DispatchError)
-> YesodOAuth2Exception
-> m (Creds site)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. YesodOAuth2Exception -> DispatchError
FetchCredsYesodOAuth2Exception)
    Creds (HandlerSite m) -> m TypedContent
forall (m :: * -> *).
(MonadHandler m, YesodAuth (HandlerSite m)) =>
Creds (HandlerSite m) -> m TypedContent
setCredsRedirect Creds site
Creds (HandlerSite m)
creds

withCallbackAndState
    :: (MonadError DispatchError m, MonadAuthHandler site m)
    => Text
    -> OAuth2
    -> Text
    -> m OAuth2
withCallbackAndState :: Text -> OAuth2 -> Text -> m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf = do
    Text
uri <- ((Route Auth -> Text) -> Route Auth -> Text
forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Route Auth
PluginR Text
name [Text
"callback"]) ((Route Auth -> Text) -> Text) -> m (Route Auth -> Text) -> m Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Route Auth -> Text)
forall (m :: * -> *).
MonadHandler m =>
m (Route (SubHandlerSite m) -> Text)
getParentUrlRender
    URI
callback <- m URI -> (URI -> m URI) -> Maybe URI -> m URI
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (DispatchError -> m URI
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m URI) -> DispatchError -> m URI
forall a b. (a -> b) -> a -> b
$ Text -> DispatchError
InvalidCallbackUri Text
uri) URI -> m URI
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe URI -> m URI) -> Maybe URI -> m URI
forall a b. (a -> b) -> a -> b
$ Text -> Maybe URI
fromText Text
uri
    OAuth2 -> m OAuth2
forall (f :: * -> *) a. Applicative f => a -> f a
pure OAuth2
oauth2
        { oauthCallback :: Maybe URI
oauthCallback = URI -> Maybe URI
forall a. a -> Maybe a
Just URI
callback
        , oauthOAuthorizeEndpoint :: URI
oauthOAuthorizeEndpoint =
            OAuth2 -> URI
oauthOAuthorizeEndpoint OAuth2
oauth2
                URI -> [(ByteString, ByteString)] -> URI
forall a. URIRef a -> [(ByteString, ByteString)] -> URIRef a
`withQuery` [(ByteString
"state", Text -> ByteString
encodeUtf8 Text
csrf)]
        }

getParentUrlRender :: MonadHandler m => m (Route (SubHandlerSite m) -> Text)
getParentUrlRender :: m (Route (SubHandlerSite m) -> Text)
getParentUrlRender = (Route (HandlerSite m) -> Text)
-> (Route (SubHandlerSite m) -> Route (HandlerSite m))
-> Route (SubHandlerSite m)
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) ((Route (HandlerSite m) -> Text)
 -> (Route (SubHandlerSite m) -> Route (HandlerSite m))
 -> Route (SubHandlerSite m)
 -> Text)
-> m (Route (HandlerSite m) -> Text)
-> m ((Route (SubHandlerSite m) -> Route (HandlerSite m))
      -> Route (SubHandlerSite m) -> Text)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Route (HandlerSite m) -> Text)
forall (m :: * -> *).
MonadHandler m =>
m (Route (HandlerSite m) -> Text)
getUrlRender m ((Route (SubHandlerSite m) -> Route (HandlerSite m))
   -> Route (SubHandlerSite m) -> Text)
-> m (Route (SubHandlerSite m) -> Route (HandlerSite m))
-> m (Route (SubHandlerSite m) -> Text)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> m (Route (SubHandlerSite m) -> Route (HandlerSite m))
forall (m :: * -> *).
MonadHandler m =>
m (Route (SubHandlerSite m) -> Route (HandlerSite m))
getRouteToParent

-- | Set a random, ~30-character value in the session
--
-- Some (but not all) providers decode a @+@ in the state token as a space when
-- sending it back to us. We don't expect this and fail. And if we did code for
-- it, we'd then fail on the providers that /don't/ do that.
--
-- Therefore, we just exclude @+@ in our tokens, which means this function may
-- return slightly less than 30 characters.
--
setSessionCSRF :: MonadHandler m => Text -> m Text
setSessionCSRF :: Text -> m Text
setSessionCSRF Text
sessionKey = do
    Text
csrfToken <- IO Text -> m Text
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Text
randomToken
    Text
csrfToken Text -> m () -> m Text
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Text -> Text -> m ()
forall (m :: * -> *). MonadHandler m => Text -> Text -> m ()
setSession Text
sessionKey Text
csrfToken
    where randomToken :: IO Text
randomToken = (Char -> Bool) -> Text -> Text
T.filter (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'+') (Text -> Text) -> IO Text -> IO Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO Text
forall (m :: * -> *). MonadRandom m => Int -> m Text
randomText Int
64

-- | Verify the callback provided the same CSRF token as in our session
verifySessionCSRF
    :: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
verifySessionCSRF :: Text -> m Text
verifySessionCSRF Text
sessionKey = do
    Text
token <- Text -> m Text
forall (m :: * -> *).
(MonadError DispatchError m, MonadHandler m) =>
Text -> m Text
requireGetParam Text
"state"
    Maybe Text
sessionToken <- Text -> m (Maybe Text)
forall (m :: * -> *). MonadHandler m => Text -> m (Maybe Text)
lookupSession Text
sessionKey
    Text -> m ()
forall (m :: * -> *). MonadHandler m => Text -> m ()
deleteSession Text
sessionKey
    Text
token Text -> m () -> m Text
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
        (Maybe Text
sessionToken Maybe Text -> Maybe Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text -> Maybe Text
forall a. a -> Maybe a
Just Text
token)
        (DispatchError -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m ()) -> DispatchError -> m ()
forall a b. (a -> b) -> a -> b
$ Maybe Text -> Text -> DispatchError
InvalidStateToken Maybe Text
sessionToken Text
token)

requireGetParam
    :: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
requireGetParam :: Text -> m Text
requireGetParam Text
key =
    m Text -> (Text -> m Text) -> Maybe Text -> m Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (DispatchError -> m Text
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (DispatchError -> m Text) -> DispatchError -> m Text
forall a b. (a -> b) -> a -> b
$ Text -> DispatchError
MissingParameter Text
key) Text -> m Text
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Text -> m Text) -> m (Maybe Text) -> m Text
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Text -> m (Maybe Text)
forall (m :: * -> *). MonadHandler m => Text -> m (Maybe Text)
lookupGetParam Text
key

tokenSessionKey :: Text -> Text
tokenSessionKey :: Text -> Text
tokenSessionKey Text
name = Text
"_yesod_oauth2_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name