module GitHub.App.Token.JWT
  ( signJWT
  , ExpirationTime (..)
  , Issuer (..)

    -- * Private RSA Key data
  , PrivateKey (..)

    -- * Errors
  , InvalidPrivateKey (..)
  , InvalidDate (..)
  , InvalidIssuer (..)
  ) where

import GitHub.App.Token.Prelude

import Data.Text.Encoding (encodeUtf8)
import Data.Time (NominalDiffTime, addUTCTime, getCurrentTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import Web.JWT qualified as JWT

newtype ExpirationTime = ExpirationTime
  { ExpirationTime -> NominalDiffTime
unwrap :: NominalDiffTime
  }

newtype Issuer = Issuer
  { Issuer -> Text
unwrap :: Text
  }
  deriving stock (Int -> Issuer -> ShowS
[Issuer] -> ShowS
Issuer -> String
(Int -> Issuer -> ShowS)
-> (Issuer -> String) -> ([Issuer] -> ShowS) -> Show Issuer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Issuer -> ShowS
showsPrec :: Int -> Issuer -> ShowS
$cshow :: Issuer -> String
show :: Issuer -> String
$cshowList :: [Issuer] -> ShowS
showList :: [Issuer] -> ShowS
Show)

newtype PrivateKey = PrivateKey
  { PrivateKey -> ByteString
unwrap :: ByteString
  }
  deriving stock (Int -> PrivateKey -> ShowS
[PrivateKey] -> ShowS
PrivateKey -> String
(Int -> PrivateKey -> ShowS)
-> (PrivateKey -> String)
-> ([PrivateKey] -> ShowS)
-> Show PrivateKey
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PrivateKey -> ShowS
showsPrec :: Int -> PrivateKey -> ShowS
$cshow :: PrivateKey -> String
show :: PrivateKey -> String
$cshowList :: [PrivateKey] -> ShowS
showList :: [PrivateKey] -> ShowS
Show)

newtype InvalidPrivateKey = InvalidPrivateKey PrivateKey
  deriving stock (Int -> InvalidPrivateKey -> ShowS
[InvalidPrivateKey] -> ShowS
InvalidPrivateKey -> String
(Int -> InvalidPrivateKey -> ShowS)
-> (InvalidPrivateKey -> String)
-> ([InvalidPrivateKey] -> ShowS)
-> Show InvalidPrivateKey
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> InvalidPrivateKey -> ShowS
showsPrec :: Int -> InvalidPrivateKey -> ShowS
$cshow :: InvalidPrivateKey -> String
show :: InvalidPrivateKey -> String
$cshowList :: [InvalidPrivateKey] -> ShowS
showList :: [InvalidPrivateKey] -> ShowS
Show)
  deriving anyclass (Show InvalidPrivateKey
Typeable InvalidPrivateKey
(Typeable InvalidPrivateKey, Show InvalidPrivateKey) =>
(InvalidPrivateKey -> SomeException)
-> (SomeException -> Maybe InvalidPrivateKey)
-> (InvalidPrivateKey -> String)
-> Exception InvalidPrivateKey
SomeException -> Maybe InvalidPrivateKey
InvalidPrivateKey -> String
InvalidPrivateKey -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
$ctoException :: InvalidPrivateKey -> SomeException
toException :: InvalidPrivateKey -> SomeException
$cfromException :: SomeException -> Maybe InvalidPrivateKey
fromException :: SomeException -> Maybe InvalidPrivateKey
$cdisplayException :: InvalidPrivateKey -> String
displayException :: InvalidPrivateKey -> String
Exception)

data InvalidDate = InvalidDate
  { InvalidDate -> String
field :: String
  , InvalidDate -> UTCTime
date :: UTCTime
  }
  deriving stock (Int -> InvalidDate -> ShowS
[InvalidDate] -> ShowS
InvalidDate -> String
(Int -> InvalidDate -> ShowS)
-> (InvalidDate -> String)
-> ([InvalidDate] -> ShowS)
-> Show InvalidDate
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> InvalidDate -> ShowS
showsPrec :: Int -> InvalidDate -> ShowS
$cshow :: InvalidDate -> String
show :: InvalidDate -> String
$cshowList :: [InvalidDate] -> ShowS
showList :: [InvalidDate] -> ShowS
Show)
  deriving anyclass (Show InvalidDate
Typeable InvalidDate
(Typeable InvalidDate, Show InvalidDate) =>
(InvalidDate -> SomeException)
-> (SomeException -> Maybe InvalidDate)
-> (InvalidDate -> String)
-> Exception InvalidDate
SomeException -> Maybe InvalidDate
InvalidDate -> String
InvalidDate -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
$ctoException :: InvalidDate -> SomeException
toException :: InvalidDate -> SomeException
$cfromException :: SomeException -> Maybe InvalidDate
fromException :: SomeException -> Maybe InvalidDate
$cdisplayException :: InvalidDate -> String
displayException :: InvalidDate -> String
Exception)

newtype InvalidIssuer = InvalidIssuer Issuer
  deriving stock (Int -> InvalidIssuer -> ShowS
[InvalidIssuer] -> ShowS
InvalidIssuer -> String
(Int -> InvalidIssuer -> ShowS)
-> (InvalidIssuer -> String)
-> ([InvalidIssuer] -> ShowS)
-> Show InvalidIssuer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> InvalidIssuer -> ShowS
showsPrec :: Int -> InvalidIssuer -> ShowS
$cshow :: InvalidIssuer -> String
show :: InvalidIssuer -> String
$cshowList :: [InvalidIssuer] -> ShowS
showList :: [InvalidIssuer] -> ShowS
Show)
  deriving anyclass (Show InvalidIssuer
Typeable InvalidIssuer
(Typeable InvalidIssuer, Show InvalidIssuer) =>
(InvalidIssuer -> SomeException)
-> (SomeException -> Maybe InvalidIssuer)
-> (InvalidIssuer -> String)
-> Exception InvalidIssuer
SomeException -> Maybe InvalidIssuer
InvalidIssuer -> String
InvalidIssuer -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
$ctoException :: InvalidIssuer -> SomeException
toException :: InvalidIssuer -> SomeException
$cfromException :: SomeException -> Maybe InvalidIssuer
fromException :: SomeException -> Maybe InvalidIssuer
$cdisplayException :: InvalidIssuer -> String
displayException :: InvalidIssuer -> String
Exception)

signJWT
  :: MonadIO m
  => ExpirationTime
  -> Issuer
  -> PrivateKey
  -> m ByteString
signJWT :: forall (m :: * -> *).
MonadIO m =>
ExpirationTime -> Issuer -> PrivateKey -> m ByteString
signJWT ExpirationTime
expirationTime Issuer
issuer PrivateKey
privateKey = IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ do
  UTCTime
now <- IO UTCTime
getCurrentTime
  let expiration :: UTCTime
expiration = NominalDiffTime -> UTCTime -> UTCTime
addUTCTime ExpirationTime
expirationTime.unwrap UTCTime
now

  EncodeSigner
signer <-
    IO EncodeSigner
-> (PrivateKey -> IO EncodeSigner)
-> Maybe PrivateKey
-> IO EncodeSigner
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
      (InvalidPrivateKey -> IO EncodeSigner
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (InvalidPrivateKey -> IO EncodeSigner)
-> InvalidPrivateKey -> IO EncodeSigner
forall a b. (a -> b) -> a -> b
$ PrivateKey -> InvalidPrivateKey
InvalidPrivateKey PrivateKey
privateKey)
      (EncodeSigner -> IO EncodeSigner
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (EncodeSigner -> IO EncodeSigner)
-> (PrivateKey -> EncodeSigner) -> PrivateKey -> IO EncodeSigner
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrivateKey -> EncodeSigner
JWT.EncodeRSAPrivateKey)
      (Maybe PrivateKey -> IO EncodeSigner)
-> Maybe PrivateKey -> IO EncodeSigner
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe PrivateKey
JWT.readRsaSecret PrivateKey
privateKey.unwrap

  NumericDate
iat <-
    IO NumericDate
-> (NumericDate -> IO NumericDate)
-> Maybe NumericDate
-> IO NumericDate
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (InvalidDate -> IO NumericDate
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (InvalidDate -> IO NumericDate) -> InvalidDate -> IO NumericDate
forall a b. (a -> b) -> a -> b
$ String -> UTCTime -> InvalidDate
InvalidDate String
"iat" UTCTime
now) NumericDate -> IO NumericDate
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      (Maybe NumericDate -> IO NumericDate)
-> Maybe NumericDate -> IO NumericDate
forall a b. (a -> b) -> a -> b
$ UTCTime -> Maybe NumericDate
numericDate UTCTime
now

  NumericDate
exp <-
    IO NumericDate
-> (NumericDate -> IO NumericDate)
-> Maybe NumericDate
-> IO NumericDate
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (InvalidDate -> IO NumericDate
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (InvalidDate -> IO NumericDate) -> InvalidDate -> IO NumericDate
forall a b. (a -> b) -> a -> b
$ String -> UTCTime -> InvalidDate
InvalidDate String
"exp" UTCTime
expiration) NumericDate -> IO NumericDate
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      (Maybe NumericDate -> IO NumericDate)
-> Maybe NumericDate -> IO NumericDate
forall a b. (a -> b) -> a -> b
$ UTCTime -> Maybe NumericDate
numericDate UTCTime
expiration

  StringOrURI
iss <-
    IO StringOrURI
-> (StringOrURI -> IO StringOrURI)
-> Maybe StringOrURI
-> IO StringOrURI
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (InvalidIssuer -> IO StringOrURI
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (InvalidIssuer -> IO StringOrURI)
-> InvalidIssuer -> IO StringOrURI
forall a b. (a -> b) -> a -> b
$ Issuer -> InvalidIssuer
InvalidIssuer Issuer
issuer) StringOrURI -> IO StringOrURI
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      (Maybe StringOrURI -> IO StringOrURI)
-> Maybe StringOrURI -> IO StringOrURI
forall a b. (a -> b) -> a -> b
$ Text -> Maybe StringOrURI
JWT.stringOrURI Issuer
issuer.unwrap

  ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
encodeUtf8
    (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ EncodeSigner -> JOSEHeader -> JWTClaimsSet -> Text
JWT.encodeSigned
      EncodeSigner
signer
      JOSEHeader
forall a. Monoid a => a
mempty {JWT.alg = Just JWT.RS256}
      JWTClaimsSet
forall a. Monoid a => a
mempty
        { JWT.iat = Just iat
        , JWT.exp = Just exp
        , JWT.iss = Just iss
        }

numericDate :: UTCTime -> Maybe JWT.NumericDate
numericDate :: UTCTime -> Maybe NumericDate
numericDate = NominalDiffTime -> Maybe NumericDate
JWT.numericDate (NominalDiffTime -> Maybe NumericDate)
-> (UTCTime -> NominalDiffTime) -> UTCTime -> Maybe NumericDate
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UTCTime -> NominalDiffTime
utcTimeToPOSIXSeconds