{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} module Network.Wai.Middleware.Auth.OAuth2 ( OAuth2(..) , oAuth2Parser ) where import Data.Aeson.TH (defaultOptions, deriveJSON, fieldLabelModifier) import qualified Data.ByteString.Lazy as SL import Data.Monoid ((<>)) import Data.Proxy (Proxy (..)) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import Network.HTTP.Client.TLS (getGlobalManager) import Network.HTTP.Types (status303, status403, status404, status501) import qualified Network.OAuth.OAuth2 as OA2 import Network.Wai (queryString, responseLBS) import Network.Wai.Auth.Tools (toLowerUnderscore) import Network.Wai.Middleware.Auth.Provider data OAuth2 = OAuth2 { oa2ClientId :: T.Text , oa2ClientSecret :: T.Text , oa2AuthorizeEndpoint :: T.Text , oa2AccessTokenEndpoint :: T.Text , oa2Scope :: Maybe [T.Text] , oa2ProviderInfo :: ProviderInfo } -- | Aeson parser for `OAuth2` provider. -- -- @since 0.1.0 oAuth2Parser :: ProviderParser oAuth2Parser = mkProviderParser (Proxy :: Proxy OAuth2) instance AuthProvider OAuth2 where getProviderName _ = "oauth2" getProviderInfo = oa2ProviderInfo handleLogin oa2@OAuth2 {..} req suffix renderUrl onSuccess onFailure = do let oauth2 = OA2.OAuth2 { oauthClientId = encodeUtf8 oa2ClientId , oauthClientSecret = encodeUtf8 oa2ClientSecret , oauthOAuthorizeEndpoint = encodeUtf8 oa2AuthorizeEndpoint , oauthAccessTokenEndpoint = encodeUtf8 oa2AccessTokenEndpoint , oauthCallback = Just $ encodeUtf8 $ renderUrl (ProviderUrl ["complete"]) [] } case suffix of [] -> do let scope = (encodeUtf8 . T.intercalate ",") <$> oa2Scope let redirectUrl = OA2.appendQueryParam (OA2.authorizationUrl oauth2) $ maybe [] ((: []) . ("scope", )) scope return $ responseLBS status303 [("Location", redirectUrl)] "Redirect to OAuth2 Authentication server" ["complete"] -> let params = queryString req in case lookup "code" params of Just (Just code) -> do man <- getGlobalManager eRes <- OA2.fetchAccessToken man oauth2 code case eRes of Left err -> onFailure status501 $ SL.toStrict err Right token -> onSuccess $ OA2.accessToken token _ -> case lookup "error" params of (Just (Just "access_denied")) -> onFailure status403 "User rejected access to the application." (Just (Just error_code)) -> onFailure status501 $ "Received an error: " <> error_code (Just Nothing) -> onFailure status501 $ "Unknown error connecting to " <> encodeUtf8 (getProviderName oa2) Nothing -> onFailure status404 "Page not found. Please continue with login." _ -> onFailure status404 "Page not found. Please continue with login." $(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2)