{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- |
-- SPDX-License-Identifier: BSD-3-Clause
--
-- Authentication logic for Swarm tournament server.
module Swarm.Web.Auth where

import Control.Monad.Catch
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Aeson
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as LBS
import Data.ByteString.UTF8 as BSU
import Data.Map qualified as M
import Data.Text qualified as T
import Data.Text.Encoding qualified as DTE
import Data.Text.Lazy qualified as TL
import Database.SQLite.Simple.ToField
import GHC.Generics (Generic)
import Network.HTTP.Client qualified as HC
import Network.HTTP.Types (hAccept, hUserAgent, parseSimpleQuery, renderSimpleQuery)
import Servant
import Text.Read (readMaybe)

data GitHubCredentials = GitHubCredentials
  { GitHubCredentials -> ByteString
clientId :: BS.ByteString
  , GitHubCredentials -> ByteString
clientSecret :: BS.ByteString
  }

instance FromJSON GitHubCredentials where
  parseJSON :: Value -> Parser GitHubCredentials
parseJSON = String
-> (Object -> Parser GitHubCredentials)
-> Value
-> Parser GitHubCredentials
forall a. String -> (Object -> Parser a) -> Value -> Parser a
withObject String
"GitHubCredentials" ((Object -> Parser GitHubCredentials)
 -> Value -> Parser GitHubCredentials)
-> (Object -> Parser GitHubCredentials)
-> Value
-> Parser GitHubCredentials
forall a b. (a -> b) -> a -> b
$ \Object
v -> do
    ByteString
clientId <- String -> ByteString
BSU.fromString (String -> ByteString) -> Parser String -> Parser ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
v Object -> Key -> Parser String
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"CLIENT_ID"
    ByteString
clientSecret <- String -> ByteString
BSU.fromString (String -> ByteString) -> Parser String -> Parser ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
v Object -> Key -> Parser String
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"CLIENT_SECRET"
    GitHubCredentials -> Parser GitHubCredentials
forall a. a -> Parser a
forall (f :: * -> *) a. Applicative f => a -> f a
pure GitHubCredentials {ByteString
clientId :: ByteString
clientSecret :: ByteString
clientId :: ByteString
clientSecret :: ByteString
..}

newtype TokenExchangeCode = TokenExchangeCode BS.ByteString

instance FromHttpApiData TokenExchangeCode where
  parseUrlPiece :: Text -> Either Text TokenExchangeCode
parseUrlPiece = TokenExchangeCode -> Either Text TokenExchangeCode
forall a. a -> Either Text a
forall (m :: * -> *) a. Monad m => a -> m a
return (TokenExchangeCode -> Either Text TokenExchangeCode)
-> (Text -> TokenExchangeCode)
-> Text
-> Either Text TokenExchangeCode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> TokenExchangeCode
TokenExchangeCode (ByteString -> TokenExchangeCode)
-> (Text -> ByteString) -> Text -> TokenExchangeCode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
DTE.encodeUtf8

newtype AccessToken = AccessToken BS.ByteString

instance ToField AccessToken where
  toField :: AccessToken -> SQLData
toField (AccessToken ByteString
x) = ByteString -> SQLData
forall a. ToField a => a -> SQLData
toField ByteString
x

newtype RefreshToken = RefreshToken BS.ByteString

instance ToField RefreshToken where
  toField :: RefreshToken -> SQLData
toField (RefreshToken ByteString
x) = ByteString -> SQLData
forall a. ToField a => a -> SQLData
toField ByteString
x

data UserApiResponse = UserApiResponse
  { UserApiResponse -> Text
login :: TL.Text
  , UserApiResponse -> Int
id :: Int
  , UserApiResponse -> Text
name :: TL.Text
  }
  deriving ((forall x. UserApiResponse -> Rep UserApiResponse x)
-> (forall x. Rep UserApiResponse x -> UserApiResponse)
-> Generic UserApiResponse
forall x. Rep UserApiResponse x -> UserApiResponse
forall x. UserApiResponse -> Rep UserApiResponse x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. UserApiResponse -> Rep UserApiResponse x
from :: forall x. UserApiResponse -> Rep UserApiResponse x
$cto :: forall x. Rep UserApiResponse x -> UserApiResponse
to :: forall x. Rep UserApiResponse x -> UserApiResponse
Generic, Maybe UserApiResponse
Value -> Parser [UserApiResponse]
Value -> Parser UserApiResponse
(Value -> Parser UserApiResponse)
-> (Value -> Parser [UserApiResponse])
-> Maybe UserApiResponse
-> FromJSON UserApiResponse
forall a.
(Value -> Parser a)
-> (Value -> Parser [a]) -> Maybe a -> FromJSON a
$cparseJSON :: Value -> Parser UserApiResponse
parseJSON :: Value -> Parser UserApiResponse
$cparseJSONList :: Value -> Parser [UserApiResponse]
parseJSONList :: Value -> Parser [UserApiResponse]
$comittedField :: Maybe UserApiResponse
omittedField :: Maybe UserApiResponse
FromJSON)

data Expirable a = Expirable
  { forall a. Expirable a -> a
token :: a
  , forall a. Expirable a -> Int
expirationSeconds :: Int
  }

fetchAuthenticatedUser ::
  (MonadIO m, MonadThrow m, MonadFail m) =>
  HC.Manager ->
  AccessToken ->
  m UserApiResponse
fetchAuthenticatedUser :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadFail m) =>
Manager -> AccessToken -> m UserApiResponse
fetchAuthenticatedUser Manager
manager (AccessToken ByteString
tok) = do
  Request
req <- String -> m Request
forall (m :: * -> *). MonadThrow m => String -> m Request
HC.parseUrlThrow String
"https://api.github.com/user"
  Response ByteString
resp <-
    IO (Response ByteString) -> m (Response ByteString)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
      (IO (Response ByteString) -> m (Response ByteString))
-> (Request -> IO (Response ByteString))
-> Request
-> m (Response ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Request -> Manager -> IO (Response ByteString))
-> Manager -> Request -> IO (Response ByteString)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Request -> Manager -> IO (Response ByteString)
HC.httpLbs Manager
manager
      (Request -> IO (Response ByteString))
-> (Request -> Request) -> Request -> IO (Response ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Request -> Request
HC.applyBearerAuth ByteString
tok
      (Request -> m (Response ByteString))
-> Request -> m (Response ByteString)
forall a b. (a -> b) -> a -> b
$ Request
req
        { HC.requestHeaders =
            [ (hAccept, "application/vnd.github+json")
            , (hUserAgent, "Swarm Gaming Hub")
            , ("X-GitHub-Api-Version", "2022-11-28")
            ]
        }
  (String -> m UserApiResponse)
-> (UserApiResponse -> m UserApiResponse)
-> Either String UserApiResponse
-> m UserApiResponse
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> m UserApiResponse
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail UserApiResponse -> m UserApiResponse
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String UserApiResponse -> m UserApiResponse)
-> Either String UserApiResponse -> m UserApiResponse
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String UserApiResponse
forall a. FromJSON a => ByteString -> Either String a
eitherDecode (ByteString -> Either String UserApiResponse)
-> ByteString -> Either String UserApiResponse
forall a b. (a -> b) -> a -> b
$ Response ByteString -> ByteString
forall body. Response body -> body
HC.responseBody Response ByteString
resp

data ReceivedTokens = ReceivedTokens
  { ReceivedTokens -> Expirable AccessToken
accessToken :: Expirable AccessToken
  , ReceivedTokens -> Expirable RefreshToken
refreshToken :: Expirable RefreshToken
  }

packExchangeResponse ::
  M.Map ByteString ByteString ->
  Maybe ReceivedTokens
packExchangeResponse :: Map ByteString ByteString -> Maybe ReceivedTokens
packExchangeResponse Map ByteString ByteString
valMap =
  Expirable AccessToken -> Expirable RefreshToken -> ReceivedTokens
ReceivedTokens
    (Expirable AccessToken -> Expirable RefreshToken -> ReceivedTokens)
-> Maybe (Expirable AccessToken)
-> Maybe (Expirable RefreshToken -> ReceivedTokens)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (AccessToken -> Int -> Expirable AccessToken
forall a. a -> Int -> Expirable a
Expirable (AccessToken -> Int -> Expirable AccessToken)
-> Maybe AccessToken -> Maybe (Int -> Expirable AccessToken)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AccessToken
atVal Maybe (Int -> Expirable AccessToken)
-> Maybe Int -> Maybe (Expirable AccessToken)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ByteString -> Maybe Int
forall {b}. Read b => ByteString -> Maybe b
toInt ByteString
"expires_in")
    Maybe (Expirable RefreshToken -> ReceivedTokens)
-> Maybe (Expirable RefreshToken) -> Maybe ReceivedTokens
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (RefreshToken -> Int -> Expirable RefreshToken
forall a. a -> Int -> Expirable a
Expirable (RefreshToken -> Int -> Expirable RefreshToken)
-> Maybe RefreshToken -> Maybe (Int -> Expirable RefreshToken)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe RefreshToken
rtVal Maybe (Int -> Expirable RefreshToken)
-> Maybe Int -> Maybe (Expirable RefreshToken)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ByteString -> Maybe Int
forall {b}. Read b => ByteString -> Maybe b
toInt ByteString
"refresh_token_expires_in")
 where
  toInt :: ByteString -> Maybe b
toInt ByteString
k = String -> Maybe b
forall a. Read a => String -> Maybe a
readMaybe (String -> Maybe b)
-> (ByteString -> String) -> ByteString -> Maybe b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
BSU.toString (ByteString -> Maybe b) -> Maybe ByteString -> Maybe b
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ByteString -> Map ByteString ByteString -> Maybe ByteString
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup ByteString
k Map ByteString ByteString
valMap

  atVal :: Maybe AccessToken
atVal = ByteString -> AccessToken
AccessToken (ByteString -> AccessToken)
-> Maybe ByteString -> Maybe AccessToken
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Map ByteString ByteString -> Maybe ByteString
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup ByteString
"access_token" Map ByteString ByteString
valMap
  rtVal :: Maybe RefreshToken
rtVal = ByteString -> RefreshToken
RefreshToken (ByteString -> RefreshToken)
-> Maybe ByteString -> Maybe RefreshToken
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Map ByteString ByteString -> Maybe ByteString
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup ByteString
"refresh_token" Map ByteString ByteString
valMap

exchangeCode ::
  (MonadIO m, MonadThrow m, MonadFail m) =>
  HC.Manager ->
  GitHubCredentials ->
  TokenExchangeCode ->
  m ReceivedTokens
exchangeCode :: forall (m :: * -> *).
(MonadIO m, MonadThrow m, MonadFail m) =>
Manager
-> GitHubCredentials -> TokenExchangeCode -> m ReceivedTokens
exchangeCode Manager
manager GitHubCredentials
creds (TokenExchangeCode ByteString
code) = do
  let qParms :: String
qParms =
        Text -> String
T.unpack (Text -> String) -> (ByteString -> Text) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
DTE.decodeUtf8 (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$
          Bool -> SimpleQuery -> ByteString
renderSimpleQuery
            Bool
True
            [ (ByteString
"client_id", GitHubCredentials -> ByteString
clientId GitHubCredentials
creds)
            , (ByteString
"client_secret", GitHubCredentials -> ByteString
clientSecret GitHubCredentials
creds)
            , (ByteString
"code", ByteString
code)
            ]
  Request
req <- String -> m Request
forall (m :: * -> *). MonadThrow m => String -> m Request
HC.parseUrlThrow (String -> m Request) -> String -> m Request
forall a b. (a -> b) -> a -> b
$ String
"https://github.com/login/oauth/access_token" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
qParms
  Response ByteString
resp <- IO (Response ByteString) -> m (Response ByteString)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Response ByteString) -> m (Response ByteString))
-> IO (Response ByteString) -> m (Response ByteString)
forall a b. (a -> b) -> a -> b
$ (Request -> Manager -> IO (Response ByteString))
-> Manager -> Request -> IO (Response ByteString)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Request -> Manager -> IO (Response ByteString)
HC.httpLbs Manager
manager (Request -> IO (Response ByteString))
-> Request -> IO (Response ByteString)
forall a b. (a -> b) -> a -> b
$ Request
req {HC.method = "POST"}

  let parms :: SimpleQuery
parms = ByteString -> SimpleQuery
parseSimpleQuery (ByteString -> SimpleQuery) -> ByteString -> SimpleQuery
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Response ByteString -> ByteString
forall body. Response body -> body
HC.responseBody Response ByteString
resp
      valMap :: Map ByteString ByteString
valMap = SimpleQuery -> Map ByteString ByteString
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList SimpleQuery
parms

  m ReceivedTokens
-> (ReceivedTokens -> m ReceivedTokens)
-> Maybe ReceivedTokens
-> m ReceivedTokens
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
    (String -> m ReceivedTokens
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Response did not include access token")
    ReceivedTokens -> m ReceivedTokens
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
    (Maybe ReceivedTokens -> m ReceivedTokens)
-> Maybe ReceivedTokens -> m ReceivedTokens
forall a b. (a -> b) -> a -> b
$ Map ByteString ByteString -> Maybe ReceivedTokens
packExchangeResponse Map ByteString ByteString
valMap