{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}

{-|
Module      :  GitHub.REST.Auth
Maintainer  :  Brandon Chinn <brandonchinn178@gmail.com>
Stability   :  experimental
Portability :  portable

Definitions for handling authentication with the GitHub REST API.
-}
module GitHub.REST.Auth (
  Token (..),
  fromToken,

  -- * Helpers for using JWT tokens with the GitHub API
  getJWTToken,
) where

import qualified Crypto.PubKey.RSA as Crypto
import Data.Aeson ((.=))
import qualified Data.Aeson as Aeson
import qualified Data.Aeson.Text as Aeson
import Data.ByteString (ByteString)
import qualified Data.Text.Encoding as Text
import qualified Data.Text.Lazy as TextL
import Data.Time (getCurrentTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import qualified Jose.Jwa as Jose
import qualified Jose.Jws as Jose
import qualified Jose.Jwt as Jose
import UnliftIO.Exception (Exception, throwIO)

-- | The token to use to authenticate with GitHub.
data Token
  = -- | https://developer.github.com/v3/#authentication
    AccessToken ByteString
  | -- | https://developer.github.com/apps/building-github-apps/authenticating-with-github-apps/#authenticating-as-a-github-app
    BearerToken ByteString
  deriving (Int -> Token -> ShowS
[Token] -> ShowS
Token -> String
(Int -> Token -> ShowS)
-> (Token -> String) -> ([Token] -> ShowS) -> Show Token
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Token -> ShowS
showsPrec :: Int -> Token -> ShowS
$cshow :: Token -> String
show :: Token -> String
$cshowList :: [Token] -> ShowS
showList :: [Token] -> ShowS
Show)

fromToken :: Token -> ByteString
fromToken :: Token -> ByteString
fromToken = \case
  AccessToken ByteString
t -> ByteString
"token " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
t
  BearerToken ByteString
t -> ByteString
"bearer " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
t

-- | The ID of your GitHub application
type AppId = Int

-- | Create a JWT token that expires in 10 minutes.
getJWTToken :: Crypto.PrivateKey -> AppId -> IO Token
getJWTToken :: PrivateKey -> Int -> IO Token
getJWTToken PrivateKey
privKey Int
appId = do
  -- use floor to ensure expiration doesn't go past 10 minutes
  -- https://github.com/orgs/community/discussions/24635#discussioncomment-3244803
  Integer
now <- POSIXTime -> Integer
forall b. Integral b => POSIXTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (POSIXTime -> Integer)
-> (UTCTime -> POSIXTime) -> UTCTime -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UTCTime -> POSIXTime
utcTimeToPOSIXSeconds (UTCTime -> Integer) -> IO UTCTime -> IO Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UTCTime
getCurrentTime

  ByteString -> Token
BearerToken (ByteString -> Token) -> (Jwt -> ByteString) -> Jwt -> Token
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Jwt -> ByteString
Jose.unJwt (Jwt -> Token) -> IO Jwt -> IO Token
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> IO Jwt
forall {m :: * -> *}.
(MonadRandom m, MonadIO m) =>
ByteString -> m Jwt
signToken (Integer -> ByteString
mkClaims Integer
now)
  where
    mkClaims :: Integer -> ByteString
mkClaims Integer
now =
      Text -> ByteString
Text.encodeUtf8 (Text -> ByteString) -> (Value -> Text) -> Value -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
TextL.toStrict (Text -> Text) -> (Value -> Text) -> Value -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value -> Text
forall a. ToJSON a => a -> Text
Aeson.encodeToLazyText (Value -> ByteString) -> Value -> ByteString
forall a b. (a -> b) -> a -> b
$
        [Pair] -> Value
Aeson.object
          [ Key
"iat" Key -> Integer -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= (Integer
now :: Integer)
          , Key
"exp" Key -> Integer -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= (Integer
now Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
10 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
60)
          , Key
"iss" Key -> String -> Pair
forall v. ToJSON v => Key -> v -> Pair
forall e kv v. (KeyValue e kv, ToJSON v) => Key -> v -> kv
.= Int -> String
forall a. Show a => a -> String
show Int
appId
          ]
    signToken :: ByteString -> m Jwt
signToken ByteString
claims =
      JwsAlg -> PrivateKey -> ByteString -> m (Either JwtError Jwt)
forall (m :: * -> *).
MonadRandom m =>
JwsAlg -> PrivateKey -> ByteString -> m (Either JwtError Jwt)
Jose.rsaEncode JwsAlg
Jose.RS256 PrivateKey
privKey ByteString
claims m (Either JwtError Jwt) -> (Either JwtError Jwt -> m Jwt) -> m Jwt
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Right Jwt
jwt -> Jwt -> m Jwt
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Jwt
jwt
        Left JwtError
e -> JwtError -> m Jwt
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (JwtError -> m Jwt) -> JwtError -> m Jwt
forall a b. (a -> b) -> a -> b
$ JwtError -> JwtError
JwtError JwtError
e

-- https://github.com/tekul/jose-jwt/issues/30
data JwtError = JwtError Jose.JwtError
  deriving (Int -> JwtError -> ShowS
[JwtError] -> ShowS
JwtError -> String
(Int -> JwtError -> ShowS)
-> (JwtError -> String) -> ([JwtError] -> ShowS) -> Show JwtError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> JwtError -> ShowS
showsPrec :: Int -> JwtError -> ShowS
$cshow :: JwtError -> String
show :: JwtError -> String
$cshowList :: [JwtError] -> ShowS
showList :: [JwtError] -> ShowS
Show, Show JwtError
Typeable JwtError
(Typeable JwtError, Show JwtError) =>
(JwtError -> SomeException)
-> (SomeException -> Maybe JwtError)
-> (JwtError -> String)
-> Exception JwtError
SomeException -> Maybe JwtError
JwtError -> String
JwtError -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
$ctoException :: JwtError -> SomeException
toException :: JwtError -> SomeException
$cfromException :: SomeException -> Maybe JwtError
fromException :: SomeException -> Maybe JwtError
$cdisplayException :: JwtError -> String
displayException :: JwtError -> String
Exception)