{-# OPTIONS_HADDOCK hide, not-home #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TupleSections #-} module Network.Wai.Auth.Internal ( OAuth2TokenBinary(..) , Metadata(..) , encodeToken , decodeToken , oauth2Login , refreshTokens ) where import qualified Data.Aeson as Aeson import Data.Binary (Binary(get, put), encode, decodeOrFail) import qualified Data.ByteString as S import qualified Data.ByteString.Char8 as S8 (pack) import qualified Data.ByteString.Lazy as SL import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8, decodeUtf8With) import Data.Text.Encoding.Error (lenientDecode) import GHC.Generics (Generic) import Network.HTTP.Client (Manager) import Network.HTTP.Types (Status, status303, status403, status404, status501) import qualified Network.OAuth.OAuth2 as OA2 import Network.Wai (Request, Response, queryString, responseLBS) import Network.Wai.Middleware.Auth.Provider import qualified URI.ByteString as U import URI.ByteString (URI) decodeToken :: S.ByteString -> Either String OA2.OAuth2Token decodeToken bs = case decodeOrFail $ SL.fromStrict bs of Right (_, _, token) -> Right $ unOAuth2TokenBinary token Left (_, _, err) -> Left err encodeToken :: OA2.OAuth2Token -> S.ByteString encodeToken = SL.toStrict . encode . OAuth2TokenBinary newtype OAuth2TokenBinary = OAuth2TokenBinary { unOAuth2TokenBinary :: OA2.OAuth2Token } deriving (Show) instance Binary OAuth2TokenBinary where put (OAuth2TokenBinary token) = do put $ OA2.atoken $ OA2.accessToken token put $ OA2.rtoken <$> OA2.refreshToken token put $ OA2.expiresIn token put $ OA2.tokenType token put $ OA2.idtoken <$> OA2.idToken token get = do accessToken <- OA2.AccessToken <$> get refreshToken <- fmap OA2.RefreshToken <$> get expiresIn <- get tokenType <- get idToken <- fmap OA2.IdToken <$> get pure $ OAuth2TokenBinary $ OA2.OAuth2Token accessToken refreshToken expiresIn tokenType idToken oauth2Login :: OA2.OAuth2 -> Manager -> Maybe [T.Text] -> T.Text -> Request -> [T.Text] -> (AuthLoginState -> IO Response) -> (Status -> S.ByteString -> IO Response) -> IO Response oauth2Login oauth2 man oa2Scope providerName req suffix onSuccess onFailure = case suffix of [] -> do -- https://tools.ietf.org/html/rfc6749#section-3.3 let scope = (encodeUtf8 . T.intercalate " ") <$> oa2Scope let redirectUrl = getRedirectURI $ appendQueryParams (OA2.authorizationUrl oauth2) (maybe [] ((: []) . ("scope", )) scope) return $ responseLBS status303 [("Location", redirectUrl)] "Redirect to OAuth2 Authentication server" ["complete"] -> let params = queryString req in case lookup "code" params of Just (Just code) -> do eRes <- OA2.fetchAccessToken man oauth2 $ getExchangeToken code case eRes of Left err -> onFailure status501 $ S8.pack $ show err Right token -> onSuccess $ encodeToken token _ -> case lookup "error" params of (Just (Just "access_denied")) -> onFailure status403 "User rejected access to the application." (Just (Just error_code)) -> onFailure status501 $ "Received an error: " <> error_code (Just Nothing) -> onFailure status501 $ "Unknown error connecting to " <> encodeUtf8 providerName Nothing -> onFailure status404 "Page not found. Please continue with login." _ -> onFailure status404 "Page not found. Please continue with login." refreshTokens :: OA2.OAuth2Token -> Manager -> OA2.OAuth2 -> IO (Maybe OA2.OAuth2Token) refreshTokens tokens manager oauth2 = case OA2.refreshToken tokens of Nothing -> pure Nothing Just refreshToken -> do res <- OA2.refreshAccessToken manager oauth2 refreshToken case res of Left _ -> pure Nothing Right newTokens -> pure (Just newTokens) getExchangeToken :: S.ByteString -> OA2.ExchangeToken getExchangeToken = OA2.ExchangeToken . decodeUtf8With lenientDecode appendQueryParams :: URI -> [(S.ByteString, S.ByteString)] -> URI appendQueryParams uri params = OA2.appendQueryParams params uri getRedirectURI :: U.URIRef a -> S.ByteString getRedirectURI = U.serializeURIRef' data Metadata = Metadata { issuer :: T.Text , authorizationEndpoint :: U.URI , tokenEndpoint :: U.URI , userinfoEndpoint :: Maybe T.Text , revocationEndpoint :: Maybe T.Text , jwksUri :: T.Text , responseTypesSupported :: [T.Text] , subjectTypesSupported :: [T.Text] , idTokenSigningAlgValuesSupported :: [T.Text] , scopesSupported :: Maybe [T.Text] , tokenEndpointAuthMethodsSupported :: Maybe [T.Text] , claimsSupported :: Maybe [T.Text] } deriving (Generic) instance Aeson.FromJSON Metadata where parseJSON = Aeson.genericParseJSON metadataAesonOptions instance Aeson.ToJSON Metadata where toJSON = Aeson.genericToJSON metadataAesonOptions toEncoding = Aeson.genericToEncoding metadataAesonOptions metadataAesonOptions :: Aeson.Options metadataAesonOptions = Aeson.defaultOptions {Aeson.fieldLabelModifier = Aeson.camelTo2 '_'}