{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
module Network.Wai.Middleware.Auth.OAuth2
( OAuth2(..)
, oAuth2Parser
, URIParseException(..)
, parseAbsoluteURI
, getAccessToken
) where
import Control.Monad.Catch
import Data.Aeson.TH (defaultOptions,
deriveJSON,
fieldLabelModifier)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8 (pack)
import Data.Monoid ((<>))
import Data.Proxy (Proxy (..))
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8,
decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Network.HTTP.Client.TLS (getGlobalManager)
import Network.HTTP.Types (status303, status403,
status404, status501)
import qualified Network.OAuth.OAuth2 as OA2
import Network.Wai (Request, queryString,
responseLBS)
import Network.Wai.Auth.Internal (encodeToken, decodeToken)
import Network.Wai.Auth.Tools (toLowerUnderscore)
import qualified Network.Wai.Middleware.Auth as MA
import Network.Wai.Middleware.Auth.Provider
import qualified URI.ByteString as U
import URI.ByteString (URI)
data OAuth2 = OAuth2
{ oa2ClientId :: T.Text
, oa2ClientSecret :: T.Text
, oa2AuthorizeEndpoint :: T.Text
, oa2AccessTokenEndpoint :: T.Text
, oa2Scope :: Maybe [T.Text]
, oa2ProviderInfo :: ProviderInfo
}
data URIParseException = URIParseException U.URIParseError deriving Show
instance Exception URIParseException
parseAbsoluteURI :: MonadThrow m => T.Text -> m U.URI
parseAbsoluteURI urlTxt = do
case U.parseURI U.strictURIParserOptions (encodeUtf8 urlTxt) of
Left err -> throwM $ URIParseException err
Right url -> return url
parseAbsoluteURI' :: MonadThrow m => T.Text -> m U.URI
parseAbsoluteURI' = parseAbsoluteURI
getExchangeToken :: S.ByteString -> OA2.ExchangeToken
getExchangeToken = OA2.ExchangeToken . decodeUtf8With lenientDecode
appendQueryParams :: URI -> [(S.ByteString, S.ByteString)] -> URI
appendQueryParams uri params =
OA2.appendQueryParams params uri
getClientId :: T.Text -> T.Text
getClientId = id
getClientSecret :: T.Text -> T.Text
getClientSecret = id
getRedirectURI :: U.URIRef a -> S.ByteString
getRedirectURI = U.serializeURIRef'
oAuth2Parser :: ProviderParser
oAuth2Parser = mkProviderParser (Proxy :: Proxy OAuth2)
instance AuthProvider OAuth2 where
getProviderName _ = "oauth2"
getProviderInfo = oa2ProviderInfo
handleLogin oa2@OAuth2 {..} req suffix renderUrl onSuccess onFailure = do
authEndpointURI <- parseAbsoluteURI' oa2AuthorizeEndpoint
accessTokenEndpointURI <- parseAbsoluteURI' oa2AccessTokenEndpoint
callbackURI <- parseAbsoluteURI' $ renderUrl (ProviderUrl ["complete"]) []
let oauth2 =
OA2.OAuth2
{ oauthClientId = getClientId oa2ClientId
, oauthClientSecret = getClientSecret oa2ClientSecret
, oauthOAuthorizeEndpoint = authEndpointURI
, oauthAccessTokenEndpoint = accessTokenEndpointURI
, oauthCallback = Just callbackURI
}
case suffix of
[] -> do
let scope = (encodeUtf8 . T.intercalate ",") <$> oa2Scope
let redirectUrl =
getRedirectURI $
appendQueryParams
(OA2.authorizationUrl oauth2)
(maybe [] ((: []) . ("scope", )) scope)
return $
responseLBS
status303
[("Location", redirectUrl)]
"Redirect to OAuth2 Authentication server"
["complete"] ->
let params = queryString req
in case lookup "code" params of
Just (Just code) -> do
man <- getGlobalManager
eRes <- OA2.fetchAccessToken man oauth2 $ getExchangeToken code
case eRes of
Left err -> onFailure status501 $ S8.pack $ show err
Right token -> onSuccess $ encodeToken token
_ ->
case lookup "error" params of
(Just (Just "access_denied")) ->
onFailure
status403
"User rejected access to the application."
(Just (Just error_code)) ->
onFailure status501 $ "Received an error: " <> error_code
(Just Nothing) ->
onFailure status501 $
"Unknown error connecting to " <>
encodeUtf8 (getProviderName oa2)
Nothing ->
onFailure
status404
"Page not found. Please continue with login."
_ -> onFailure status404 "Page not found. Please continue with login."
$(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2)
getAccessToken :: Request -> Maybe OA2.OAuth2Token
getAccessToken req = do
user <- MA.getAuthUser req
either (const Nothing) Just $ decodeToken (authLoginState user)