{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE TupleSections     #-}
module Network.Wai.Middleware.Auth.OAuth2
  ( OAuth2(..)
  , oAuth2Parser
  , URIParseException(..)
  , parseAbsoluteURI
  ) where

import           Control.Monad.Catch
import           Data.Aeson.TH                        (defaultOptions,
import qualified Data.ByteString                      as S
import qualified Data.ByteString.Lazy                 as SL
import           Data.Monoid                          ((<>))
import           Data.Proxy                           (Proxy (..))
import qualified Data.Text                            as T
import           Data.Text.Encoding                   (encodeUtf8)
import           Network.HTTP.Client.TLS              (getGlobalManager)
import           Network.HTTP.Types                   (status303, status403,
                                                       status404, status501)
import qualified Network.OAuth.OAuth2                 as OA2
import           Network.Wai                          (queryString, responseLBS)
import           Network.Wai.Auth.Tools               (toLowerUnderscore)
import           Network.Wai.Middleware.Auth.Provider
import qualified URI.ByteString                       as U

#if MIN_VERSION_hoauth2(1,0,0)
import           Data.Text.Encoding                   (decodeUtf8With)
import           Data.Text.Encoding.Error             (lenientDecode)
import           URI.ByteString                       (URI)
type URI = OA2.URI

-- | General OAuth2 authentication `Provider`.
data OAuth2 = OAuth2
  { oa2ClientId            :: T.Text
  , oa2ClientSecret        :: T.Text
  , oa2AuthorizeEndpoint   :: T.Text
  , oa2AccessTokenEndpoint :: T.Text
  , oa2Scope               :: Maybe [T.Text]
  , oa2ProviderInfo        :: ProviderInfo

-- | Used for validating proper url structure. Can be thrown by
-- `parseAbsoluteURI` and consequently by `handleLogin` for `OAuth2` `Provider`
-- instance.
-- @since
data URIParseException = URIParseException U.URIParseError deriving Show

instance Exception URIParseException

-- | Parse absolute URI and throw `URIParseException` in case it is malformed
-- @since
parseAbsoluteURI :: MonadThrow m => T.Text -> m U.URI
parseAbsoluteURI urlTxt = do
  case U.parseURI U.strictURIParserOptions (encodeUtf8 urlTxt) of
    Left err  -> throwM $ URIParseException err
    Right url -> return url

#if MIN_VERSION_hoauth2(1,0,0)

parseAbsoluteURI' :: MonadThrow m => T.Text -> m U.URI
parseAbsoluteURI' = parseAbsoluteURI

getExchangeToken :: S.ByteString -> OA2.ExchangeToken
getExchangeToken = OA2.ExchangeToken . decodeUtf8With lenientDecode

appendQueryParams :: URI -> [(S.ByteString, S.ByteString)] -> URI
appendQueryParams uri params =
  OA2.appendQueryParams params uri

getClientId :: T.Text -> T.Text
getClientId = id

getClientSecret :: T.Text -> T.Text
getClientSecret = id

getRedirectURI :: U.URIRef a -> S.ByteString
getRedirectURI = U.serializeURIRef'

getAccessToken :: OA2.OAuth2Token -> S.ByteString
getAccessToken = encodeUtf8 . OA2.atoken . OA2.accessToken


parseAbsoluteURI' :: MonadThrow m => T.Text -> m URI
parseAbsoluteURI' urlTxt = U.serializeURIRef' <$> parseAbsoluteURI urlTxt

getExchangeToken :: S.ByteString -> S.ByteString
getExchangeToken = id

appendQueryParams :: URI -> [(S.ByteString, S.ByteString)] -> URI
appendQueryParams uri params = OA2.appendQueryParam uri params

getClientId :: T.Text -> S.ByteString
getClientId = encodeUtf8

getClientSecret :: T.Text -> S.ByteString
getClientSecret = encodeUtf8

getRedirectURI :: URI -> S.ByteString
getRedirectURI = id

getAccessToken :: OA2.AccessToken -> S.ByteString
getAccessToken = OA2.accessToken


-- | Aeson parser for `OAuth2` provider.
-- @since 0.1.0
oAuth2Parser :: ProviderParser
oAuth2Parser = mkProviderParser (Proxy :: Proxy OAuth2)

instance AuthProvider OAuth2 where
  getProviderName _ = "oauth2"
  getProviderInfo = oa2ProviderInfo
  handleLogin oa2@OAuth2 {..} req suffix renderUrl onSuccess onFailure = do
    authEndpointURI <- parseAbsoluteURI' oa2AuthorizeEndpoint
    accessTokenEndpointURI <- parseAbsoluteURI' oa2AccessTokenEndpoint
    callbackURI <- parseAbsoluteURI' $ renderUrl (ProviderUrl ["complete"]) []
    let oauth2 =
          { oauthClientId = getClientId oa2ClientId
          , oauthClientSecret = getClientSecret oa2ClientSecret
          , oauthOAuthorizeEndpoint = authEndpointURI
          , oauthAccessTokenEndpoint = accessTokenEndpointURI
          , oauthCallback = Just callbackURI
    case suffix of
      [] -> do
        let scope = (encodeUtf8 . T.intercalate ",") <$> oa2Scope
        let redirectUrl =
              getRedirectURI $
                (OA2.authorizationUrl oauth2)
                (maybe [] ((: []) . ("scope", )) scope)
        return $
            [("Location", redirectUrl)]
            "Redirect to OAuth2 Authentication server"
      ["complete"] ->
        let params = queryString req
        in case lookup "code" params of
             Just (Just code) -> do
               man <- getGlobalManager
               eRes <- OA2.fetchAccessToken man oauth2 $ getExchangeToken code
               case eRes of
                 Left err    -> onFailure status501 $ SL.toStrict err
                 Right token -> onSuccess $ getAccessToken token
             _ ->
               case lookup "error" params of
                 (Just (Just "access_denied")) ->
                     "User rejected access to the application."
                 (Just (Just error_code)) ->
                   onFailure status501 $ "Received an error: " <> error_code
                 (Just Nothing) ->
                   onFailure status501 $
                   "Unknown error connecting to " <>
                   encodeUtf8 (getProviderName oa2)
                 Nothing ->
                     "Page not found. Please continue with login."
      _ -> onFailure status404 "Page not found. Please continue with login."

$(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2)