{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE TupleSections     #-}
module Network.Wai.Middleware.Auth.OAuth2
  ( OAuth2(..)
  , oAuth2Parser
  , URIParseException(..)
  , parseAbsoluteURI
  , getAccessToken
  ) where

import           Control.Monad.Catch
import           Data.Aeson.TH                        (defaultOptions,
                                                       deriveJSON,
                                                       fieldLabelModifier)
import           Data.Functor                         ((<&>))
import           Data.Int                             (Int64)
import           Data.Proxy                           (Proxy (..))
import qualified Data.Text                            as T
import           Data.Text.Encoding                   (encodeUtf8)
import           Foreign.C.Types                      (CTime (..))
import           Network.HTTP.Client.TLS              (getGlobalManager)
import qualified Network.OAuth.OAuth2                 as OA2
import           Network.Wai                          (Request)
import           Network.Wai.Auth.Internal            (decodeToken, encodeToken,
                                                       oauth2Login,
                                                       refreshTokens)
import           Network.Wai.Auth.Tools               (toLowerUnderscore)
import qualified Network.Wai.Middleware.Auth          as MA
import           Network.Wai.Middleware.Auth.Provider
import           System.PosixCompat.Time              (epochTime)
import qualified URI.ByteString                       as U

-- | General OAuth2 authentication `Provider`.
data OAuth2 = OAuth2
  { OAuth2 -> Text
oa2ClientId            :: T.Text
  , OAuth2 -> Text
oa2ClientSecret        :: T.Text
  , OAuth2 -> Text
oa2AuthorizeEndpoint   :: T.Text
  , OAuth2 -> Text
oa2AccessTokenEndpoint :: T.Text
  , OAuth2 -> Maybe [Text]
oa2Scope               :: Maybe [T.Text]
  , OAuth2 -> ProviderInfo
oa2ProviderInfo        :: ProviderInfo
  }

-- | Used for validating proper url structure. Can be thrown by
-- `parseAbsoluteURI` and consequently by `handleLogin` for `OAuth2` `Provider`
-- instance.
--
-- @since 0.1.2.0
data URIParseException = URIParseException U.URIParseError deriving Int -> URIParseException -> ShowS
[URIParseException] -> ShowS
URIParseException -> String
(Int -> URIParseException -> ShowS)
-> (URIParseException -> String)
-> ([URIParseException] -> ShowS)
-> Show URIParseException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [URIParseException] -> ShowS
$cshowList :: [URIParseException] -> ShowS
show :: URIParseException -> String
$cshow :: URIParseException -> String
showsPrec :: Int -> URIParseException -> ShowS
$cshowsPrec :: Int -> URIParseException -> ShowS
Show

instance Exception URIParseException

-- | Parse absolute URI and throw `URIParseException` in case it is malformed
--
-- @since 0.1.2.0
parseAbsoluteURI :: MonadThrow m => T.Text -> m U.URI
parseAbsoluteURI :: Text -> m URI
parseAbsoluteURI Text
urlTxt = do
  case URIParserOptions -> ByteString -> Either URIParseError URI
U.parseURI URIParserOptions
U.strictURIParserOptions (Text -> ByteString
encodeUtf8 Text
urlTxt) of
    Left URIParseError
err  -> URIParseException -> m URI
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (URIParseException -> m URI) -> URIParseException -> m URI
forall a b. (a -> b) -> a -> b
$ URIParseError -> URIParseException
URIParseException URIParseError
err
    Right URI
url -> URI -> m URI
forall (m :: * -> *) a. Monad m => a -> m a
return URI
url

getClientId :: T.Text -> T.Text
getClientId :: Text -> Text
getClientId = Text -> Text
forall a. a -> a
id

getClientSecret :: T.Text -> T.Text
getClientSecret :: Text -> Text
getClientSecret = Text -> Text
forall a. a -> a
id

$(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2)

-- | Aeson parser for `OAuth2` provider.
--
-- @since 0.1.0
oAuth2Parser :: ProviderParser
oAuth2Parser :: ProviderParser
oAuth2Parser = Proxy OAuth2 -> ProviderParser
forall ap.
(FromJSON ap, AuthProvider ap) =>
Proxy ap -> ProviderParser
mkProviderParser (Proxy OAuth2
forall k (t :: k). Proxy t
Proxy :: Proxy OAuth2)


instance AuthProvider OAuth2 where
  getProviderName :: OAuth2 -> Text
getProviderName OAuth2
_ = Text
"oauth2"
  getProviderInfo :: OAuth2 -> ProviderInfo
getProviderInfo = OAuth2 -> ProviderInfo
oa2ProviderInfo
  handleLogin :: OAuth2
-> Request
-> [Text]
-> Render ProviderUrl
-> (ByteString -> IO Response)
-> (Status -> ByteString -> IO Response)
-> IO Response
handleLogin oa2 :: OAuth2
oa2@OAuth2 {Maybe [Text]
Text
ProviderInfo
oa2ProviderInfo :: ProviderInfo
oa2Scope :: Maybe [Text]
oa2AccessTokenEndpoint :: Text
oa2AuthorizeEndpoint :: Text
oa2ClientSecret :: Text
oa2ClientId :: Text
oa2ProviderInfo :: OAuth2 -> ProviderInfo
oa2Scope :: OAuth2 -> Maybe [Text]
oa2AccessTokenEndpoint :: OAuth2 -> Text
oa2AuthorizeEndpoint :: OAuth2 -> Text
oa2ClientSecret :: OAuth2 -> Text
oa2ClientId :: OAuth2 -> Text
..} Request
req [Text]
suffix Render ProviderUrl
renderUrl ByteString -> IO Response
onSuccess Status -> ByteString -> IO Response
onFailure = do
    URI
authEndpointURI <- Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI Text
oa2AuthorizeEndpoint
    URI
accessTokenEndpointURI <- Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI Text
oa2AccessTokenEndpoint
    URI
callbackURI <- Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI (Text -> IO URI) -> Text -> IO URI
forall a b. (a -> b) -> a -> b
$ Render ProviderUrl
renderUrl ([Text] -> ProviderUrl
ProviderUrl [Text
"complete"]) []
    let oauth2 :: OAuth2
oauth2 =
          OAuth2 :: Text -> Maybe Text -> URI -> URI -> Maybe URI -> OAuth2
OA2.OAuth2
          { oauthClientId :: Text
oauthClientId = Text -> Text
getClientId Text
oa2ClientId
          , oauthClientSecret :: Maybe Text
oauthClientSecret = Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ Text -> Text
getClientSecret Text
oa2ClientSecret
          , oauthOAuthorizeEndpoint :: URI
oauthOAuthorizeEndpoint = URI
authEndpointURI
          , oauthAccessTokenEndpoint :: URI
oauthAccessTokenEndpoint = URI
accessTokenEndpointURI
          , oauthCallback :: Maybe URI
oauthCallback = URI -> Maybe URI
forall a. a -> Maybe a
Just URI
callbackURI
          }
    Manager
man <- IO Manager
getGlobalManager
    OAuth2
-> Manager
-> Maybe [Text]
-> Text
-> Request
-> [Text]
-> (ByteString -> IO Response)
-> (Status -> ByteString -> IO Response)
-> IO Response
oauth2Login
      OAuth2
oauth2
      Manager
man
      Maybe [Text]
oa2Scope
      (OAuth2 -> Text
forall ap. AuthProvider ap => ap -> Text
getProviderName OAuth2
oa2)
      Request
req
      [Text]
suffix
      ByteString -> IO Response
onSuccess
      Status -> ByteString -> IO Response
onFailure
  refreshLoginState :: OAuth2 -> Request -> AuthUser -> IO (Maybe (Request, AuthUser))
refreshLoginState OAuth2 {Maybe [Text]
Text
ProviderInfo
oa2ProviderInfo :: ProviderInfo
oa2Scope :: Maybe [Text]
oa2AccessTokenEndpoint :: Text
oa2AuthorizeEndpoint :: Text
oa2ClientSecret :: Text
oa2ClientId :: Text
oa2ProviderInfo :: OAuth2 -> ProviderInfo
oa2Scope :: OAuth2 -> Maybe [Text]
oa2AccessTokenEndpoint :: OAuth2 -> Text
oa2AuthorizeEndpoint :: OAuth2 -> Text
oa2ClientSecret :: OAuth2 -> Text
oa2ClientId :: OAuth2 -> Text
..} Request
req AuthUser
user = do
    URI
authEndpointURI <- Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI Text
oa2AuthorizeEndpoint
    URI
accessTokenEndpointURI <- Text -> IO URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
parseAbsoluteURI Text
oa2AccessTokenEndpoint
    let loginState :: ByteString
loginState = AuthUser -> ByteString
authLoginState AuthUser
user
    case ByteString -> Either String OAuth2Token
decodeToken ByteString
loginState of
      Left String
_ -> Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Request, AuthUser)
forall a. Maybe a
Nothing
      Right OAuth2Token
tokens -> do
        CTime Int64
now <- IO CTime
epochTime
        if AuthUser -> Int64 -> OAuth2Token -> Bool
tokenExpired AuthUser
user Int64
now OAuth2Token
tokens then do
          let oauth2 :: OAuth2
oauth2 =
                OAuth2 :: Text -> Maybe Text -> URI -> URI -> Maybe URI -> OAuth2
OA2.OAuth2
                { oauthClientId :: Text
oauthClientId = Text -> Text
getClientId Text
oa2ClientId
                , oauthClientSecret :: Maybe Text
oauthClientSecret = Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Text
getClientSecret Text
oa2ClientSecret)
                , oauthOAuthorizeEndpoint :: URI
oauthOAuthorizeEndpoint = URI
authEndpointURI
                , oauthAccessTokenEndpoint :: URI
oauthAccessTokenEndpoint = URI
accessTokenEndpointURI
                -- Setting callback endpoint to `Nothing` below is a lie.
                -- We do have a callback endpoint but in this context
                -- don't have access to the function that can render it.
                -- We get away with this because the callback endpoint is
                -- not needed for obtaining a refresh token, the only
                -- way we use the config here constructed.
                , oauthCallback :: Maybe URI
oauthCallback = Maybe URI
forall a. Maybe a
Nothing
                }
          Manager
man <- IO Manager
getGlobalManager
          Maybe OAuth2Token
rRes <- OAuth2Token -> Manager -> OAuth2 -> IO (Maybe OAuth2Token)
refreshTokens OAuth2Token
tokens Manager
man OAuth2
oauth2
          Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe OAuth2Token
rRes Maybe OAuth2Token
-> (OAuth2Token -> (Request, AuthUser))
-> Maybe (Request, AuthUser)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \OAuth2Token
newTokens -> (Request
req, AuthUser
user {
                 authLoginState :: ByteString
authLoginState = OAuth2Token -> ByteString
encodeToken OAuth2Token
newTokens,
                 authLoginTime :: Int64
authLoginTime = Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
now
               }))
        else
          Maybe (Request, AuthUser) -> IO (Maybe (Request, AuthUser))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Request, AuthUser) -> Maybe (Request, AuthUser)
forall a. a -> Maybe a
Just (Request
req, AuthUser
user))

tokenExpired :: AuthUser -> Int64 -> OA2.OAuth2Token -> Bool
tokenExpired :: AuthUser -> Int64 -> OAuth2Token -> Bool
tokenExpired AuthUser
user Int64
now OAuth2Token
tokens =
  case OAuth2Token -> Maybe Int
OA2.expiresIn OAuth2Token
tokens of
    Maybe Int
Nothing -> Bool
False
    Just Int
expiresIn -> AuthUser -> Int64
authLoginTime AuthUser
user Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
expiresIn) Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
now

-- | Get the @AccessToken@ for the current user.
--
-- If called on a @Request@ behind the middleware, should always return a
-- @Just@ value.
--
-- @since 0.2.0.0
getAccessToken :: Request -> Maybe OA2.OAuth2Token
getAccessToken :: Request -> Maybe OAuth2Token
getAccessToken Request
req = do
  AuthUser
user <- Request -> Maybe AuthUser
MA.getAuthUser Request
req
  (String -> Maybe OAuth2Token)
-> (OAuth2Token -> Maybe OAuth2Token)
-> Either String OAuth2Token
-> Maybe OAuth2Token
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe OAuth2Token -> String -> Maybe OAuth2Token
forall a b. a -> b -> a
const Maybe OAuth2Token
forall a. Maybe a
Nothing) OAuth2Token -> Maybe OAuth2Token
forall a. a -> Maybe a
Just (Either String OAuth2Token -> Maybe OAuth2Token)
-> Either String OAuth2Token -> Maybe OAuth2Token
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String OAuth2Token
decodeToken (AuthUser -> ByteString
authLoginState AuthUser
user)