{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

module Yesod.Auth.OAuth2.Dispatch
  ( FetchToken
  , fetchAccessToken
  , fetchAccessToken2
  , FetchCreds
  , dispatchAuthRequest
  ) where

import Control.Monad (unless)
import Control.Monad.Except (MonadError (..))
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Network.HTTP.Conduit (Manager)
import Network.OAuth.OAuth2.Compat
import URI.ByteString.Extension
import UnliftIO.Exception
import Yesod.Auth hiding (ServerError)
import Yesod.Auth.OAuth2.DispatchError
import Yesod.Auth.OAuth2.ErrorResponse
import Yesod.Auth.OAuth2.Random
import Yesod.Core hiding (ErrorResponse)

-- | How to fetch an @'OAuth2Token'@
--
-- This will be 'fetchAccessToken' or 'fetchAccessToken2'
type FetchToken =
  Manager -> OAuth2 -> ExchangeToken -> IO (OAuth2Result Errors OAuth2Token)

-- | How to take an @'OAuth2Token'@ and retrieve user credentials
type FetchCreds m = Manager -> OAuth2Token -> IO (Creds m)

-- | Dispatch the various OAuth2 handshake routes
dispatchAuthRequest
  :: Text
  -- ^ Name
  -> OAuth2
  -- ^ Service details
  -> FetchToken
  -- ^ How to get a token
  -> FetchCreds m
  -- ^ How to get credentials
  -> Text
  -- ^ Method
  -> [Text]
  -- ^ Path pieces
  -> AuthHandler m TypedContent
dispatchAuthRequest :: forall m.
Text
-> OAuth2
-> FetchToken
-> FetchCreds m
-> Text
-> [Text]
-> AuthHandler m TypedContent
dispatchAuthRequest Text
name OAuth2
oauth2 FetchToken
_ FetchCreds m
_ Text
"GET" [Text
"forward"] =
  forall site (m :: * -> *).
MonadAuthHandler site m =>
ExceptT DispatchError m TypedContent -> m TypedContent
handleDispatchError forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> m TypedContent
dispatchForward Text
name OAuth2
oauth2
dispatchAuthRequest Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds m
getCreds Text
"GET" [Text
"callback"] =
  forall site (m :: * -> *).
MonadAuthHandler site m =>
ExceptT DispatchError m TypedContent -> m TypedContent
handleDispatchError forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> FetchToken -> FetchCreds site -> m TypedContent
dispatchCallback Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds m
getCreds
dispatchAuthRequest Text
_ OAuth2
_ FetchToken
_ FetchCreds m
_ Text
_ [Text]
_ = forall (m :: * -> *) a. MonadHandler m => m a
notFound

-- | Handle @GET \/forward@
--
-- 1. Set a random CSRF token in our session
-- 2. Redirect to the Provider's authorization URL
dispatchForward
  :: (MonadError DispatchError m, MonadAuthHandler site m)
  => Text
  -> OAuth2
  -> m TypedContent
dispatchForward :: forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> m TypedContent
dispatchForward Text
name OAuth2
oauth2 = do
  Text
csrf <- forall (m :: * -> *). MonadHandler m => Text -> m Text
setSessionCSRF forall a b. (a -> b) -> a -> b
$ Text -> Text
tokenSessionKey Text
name
  OAuth2
oauth2' <- forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> Text -> m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf
  forall (m :: * -> *) url a.
(MonadHandler m, RedirectUrl (HandlerSite m) url) =>
url -> m a
redirect forall a b. (a -> b) -> a -> b
$ URI -> Text
toText forall a b. (a -> b) -> a -> b
$ OAuth2 -> URI
authorizationUrl OAuth2
oauth2'

-- | Handle @GET \/callback@
--
-- 1. Verify the URL's CSRF token matches our session
-- 2. Use the code parameter to fetch an AccessToken for the Provider
-- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider
dispatchCallback
  :: (MonadError DispatchError m, MonadAuthHandler site m)
  => Text
  -> OAuth2
  -> FetchToken
  -> FetchCreds site
  -> m TypedContent
dispatchCallback :: forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> FetchToken -> FetchCreds site -> m TypedContent
dispatchCallback Text
name OAuth2
oauth2 FetchToken
getToken FetchCreds site
getCreds = do
  forall (m :: * -> *) a.
MonadHandler m =>
(ErrorResponse -> m a) -> m ()
onErrorResponse forall a b. (a -> b) -> a -> b
$ forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorResponse -> DispatchError
OAuth2HandshakeError
  Text
csrf <- forall (m :: * -> *).
(MonadError DispatchError m, MonadHandler m) =>
Text -> m Text
verifySessionCSRF forall a b. (a -> b) -> a -> b
$ Text -> Text
tokenSessionKey Text
name
  Text
code <- forall (m :: * -> *).
(MonadError DispatchError m, MonadHandler m) =>
Text -> m Text
requireGetParam Text
"code"
  Manager
manager <- forall master (m :: * -> *).
(YesodAuth master, MonadHandler m, HandlerSite m ~ master) =>
m Manager
authHttpManager
  OAuth2
oauth2' <- forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> Text -> m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf
  OAuth2Token
token <-
    forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall b c a. (b -> c) -> (a -> b) -> a -> c
. Errors -> DispatchError
OAuth2ResultError) forall (f :: * -> *) a. Applicative f => a -> f a
pure
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (FetchToken
getToken Manager
manager OAuth2
oauth2' forall a b. (a -> b) -> a -> b
$ Text -> ExchangeToken
ExchangeToken Text
code)
  Creds site
creds <-
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (FetchCreds site
getCreds Manager
manager OAuth2Token
token)
      forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOException -> DispatchError
FetchCredsIOException)
      forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall b c a. (b -> c) -> (a -> b) -> a -> c
. YesodOAuth2Exception -> DispatchError
FetchCredsYesodOAuth2Exception)
  forall (m :: * -> *).
(MonadHandler m, YesodAuth (HandlerSite m)) =>
Creds (HandlerSite m) -> m TypedContent
setCredsRedirect Creds site
creds

withCallbackAndState
  :: (MonadError DispatchError m, MonadAuthHandler site m)
  => Text
  -> OAuth2
  -> Text
  -> m OAuth2
withCallbackAndState :: forall (m :: * -> *) site.
(MonadError DispatchError m, MonadAuthHandler site m) =>
Text -> OAuth2 -> Text -> m OAuth2
withCallbackAndState Text
name OAuth2
oauth2 Text
csrf = do
  Text
uri <- (forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Route Auth
PluginR Text
name [Text
"callback"]) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadHandler m =>
m (Route (SubHandlerSite m) -> Text)
getParentUrlRender
  URI
callback <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text -> DispatchError
InvalidCallbackUri Text
uri) forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text -> Maybe URI
fromText Text
uri
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    OAuth2
oauth2
      { oauth2RedirectUri :: Maybe URI
oauth2RedirectUri = forall a. a -> Maybe a
Just URI
callback
      , oauth2AuthorizeEndpoint :: URI
oauth2AuthorizeEndpoint =
          OAuth2 -> URI
oauth2AuthorizeEndpoint OAuth2
oauth2 forall a. URIRef a -> [(ByteString, ByteString)] -> URIRef a
`withQuery` [(ByteString
"state", Text -> ByteString
encodeUtf8 Text
csrf)]
      }

getParentUrlRender :: MonadHandler m => m (Route (SubHandlerSite m) -> Text)
getParentUrlRender :: forall (m :: * -> *).
MonadHandler m =>
m (Route (SubHandlerSite m) -> Text)
getParentUrlRender = forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadHandler m =>
m (Route (HandlerSite m) -> Text)
getUrlRender forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *).
MonadHandler m =>
m (Route (SubHandlerSite m) -> Route (HandlerSite m))
getRouteToParent

-- | Set a random, ~64-byte value in the session
--
-- Some (but not all) providers decode a @+@ in the state token as a space when
-- sending it back to us. We don't expect this and fail. And if we did code for
-- it, we'd then fail on the providers that /don't/ do that.
--
-- Therefore, we just exclude @+@ in our tokens, which means this function may
-- return slightly fewer than 64 bytes.
setSessionCSRF :: MonadHandler m => Text -> m Text
setSessionCSRF :: forall (m :: * -> *). MonadHandler m => Text -> m Text
setSessionCSRF Text
sessionKey = do
  Text
csrfToken <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Text
randomToken
  Text
csrfToken forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (m :: * -> *). MonadHandler m => Text -> Text -> m ()
setSession Text
sessionKey Text
csrfToken
 where
  randomToken :: IO Text
randomToken = (Char -> Bool) -> Text -> Text
T.filter (forall a. Eq a => a -> a -> Bool
/= Char
'+') forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadRandom m => Int -> m Text
randomText Int
64

-- | Verify the callback provided the same CSRF token as in our session
verifySessionCSRF
  :: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
verifySessionCSRF :: forall (m :: * -> *).
(MonadError DispatchError m, MonadHandler m) =>
Text -> m Text
verifySessionCSRF Text
sessionKey = do
  Text
token <- forall (m :: * -> *).
(MonadError DispatchError m, MonadHandler m) =>
Text -> m Text
requireGetParam Text
"state"
  Maybe Text
sessionToken <- forall (m :: * -> *). MonadHandler m => Text -> m (Maybe Text)
lookupSession Text
sessionKey
  forall (m :: * -> *). MonadHandler m => Text -> m ()
deleteSession Text
sessionKey
  Text
token
    forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
      (Maybe Text
sessionToken forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just Text
token)
      (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Maybe Text -> Text -> DispatchError
InvalidStateToken Maybe Text
sessionToken Text
token)

requireGetParam
  :: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
requireGetParam :: forall (m :: * -> *).
(MonadError DispatchError m, MonadHandler m) =>
Text -> m Text
requireGetParam Text
key =
  forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text -> DispatchError
MissingParameter Text
key) forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadHandler m => Text -> m (Maybe Text)
lookupGetParam Text
key

tokenSessionKey :: Text -> Text
tokenSessionKey :: Text -> Text
tokenSessionKey Text
name = Text
"_yesod_oauth2_" forall a. Semigroup a => a -> a -> a
<> Text
name