{-# LANGUAGE OverloadedStrings #-}

-- | Create a signed JWT needed to make the access token request
-- to gain access to Google APIs for server to server applications.
--
-- For all usage details, see https://developers.google.com/identity/protocols/OAuth2ServiceAccount
--

module Network.Google.OAuth2.JWT
    (
       SignedJWT
    ,  Email
    ,  Scope
    ,  getSignedJWT

    -- * Utils
    , fromPEMString
    , fromPEMFile

    ) where

import           Codec.Crypto.RSA.Pure
import           Control.Monad              (unless)
import qualified Data.ByteString            as B
import           Data.ByteString.Base64.URL (encode)
import           Data.ByteString.Lazy       (fromStrict, toStrict)
import           Data.ByteString.Char8      (unpack)
import           Data.Maybe                 (fromMaybe, fromJust)
-- import           Data.Monoid                ((<>))
import qualified Data.Text                  as T
import           Data.Text.Encoding         (encodeUtf8)
import           Data.UnixTime              (getUnixTime, utSeconds)
import           Foreign.C.Types
import           OpenSSL.EVP.PKey           (toKeyPair)
import           OpenSSL.PEM                (PemPasswordSupply (PwNone),
                                             readPrivateKey)
import           OpenSSL.RSA

newtype SignedJWT =
  SignedJWT B.ByteString
  deriving (SignedJWT -> SignedJWT -> Bool
(SignedJWT -> SignedJWT -> Bool)
-> (SignedJWT -> SignedJWT -> Bool) -> Eq SignedJWT
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SignedJWT -> SignedJWT -> Bool
$c/= :: SignedJWT -> SignedJWT -> Bool
== :: SignedJWT -> SignedJWT -> Bool
$c== :: SignedJWT -> SignedJWT -> Bool
Eq)

instance Show SignedJWT where
  show :: SignedJWT -> String
show (SignedJWT ByteString
t) = ByteString -> String
unpack ByteString
t

type Email = T.Text

type Scope = T.Text

-- | Get the private key obtained from the
-- Google API Console from a PEM file.
fromPEMFile :: FilePath -> IO PrivateKey
fromPEMFile :: String -> IO PrivateKey
fromPEMFile String
f = String -> IO String
readFile String
f IO String -> (String -> IO PrivateKey) -> IO PrivateKey
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String -> IO PrivateKey
fromPEMString

-- | Get the private key obtained from the
-- Google API Console from a PEM 'String'.
--
-- >fromPEMString "-----BEGIN PRIVATE KEY-----\nB9e [...] bMdF\n-----END PRIVATE KEY-----\n"
-- >
fromPEMString :: String -> IO PrivateKey
fromPEMString :: String -> IO PrivateKey
fromPEMString String
s =
  Maybe RSAKeyPair -> RSAKeyPair
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe RSAKeyPair -> RSAKeyPair)
-> (SomeKeyPair -> Maybe RSAKeyPair) -> SomeKeyPair -> RSAKeyPair
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeKeyPair -> Maybe RSAKeyPair
forall a. KeyPair a => SomeKeyPair -> Maybe a
toKeyPair (SomeKeyPair -> RSAKeyPair) -> IO SomeKeyPair -> IO RSAKeyPair
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PemPasswordSupply -> IO SomeKeyPair
readPrivateKey String
s PemPasswordSupply
PwNone IO RSAKeyPair -> (RSAKeyPair -> IO PrivateKey) -> IO PrivateKey
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
    \RSAKeyPair
k -> PrivateKey -> IO PrivateKey
forall (m :: * -> *) a. Monad m => a -> m a
return
      PrivateKey :: PublicKey
-> Integer
-> Integer
-> Integer
-> Integer
-> Integer
-> Integer
-> PrivateKey
PrivateKey
        { private_pub :: PublicKey
private_pub =
            PublicKey :: Int -> Integer -> Integer -> PublicKey
PublicKey
              { public_size :: Int
public_size = RSAKeyPair -> Int
forall k. RSAKey k => k -> Int
rsaSize RSAKeyPair
k
              , public_n :: Integer
public_n    = RSAKeyPair -> Integer
forall k. RSAKey k => k -> Integer
rsaN RSAKeyPair
k
              , public_e :: Integer
public_e    = RSAKeyPair -> Integer
forall k. RSAKey k => k -> Integer
rsaE RSAKeyPair
k
              }
        , private_d :: Integer
private_d    = RSAKeyPair -> Integer
rsaD RSAKeyPair
k
        , private_p :: Integer
private_p    = RSAKeyPair -> Integer
rsaP RSAKeyPair
k
        , private_q :: Integer
private_q    = RSAKeyPair -> Integer
rsaQ RSAKeyPair
k
        , private_dP :: Integer
private_dP   = Integer
0
        , private_dQ :: Integer
private_dQ   = Integer
0
        , private_qinv :: Integer
private_qinv = Integer
0
        }

-- | Create the signed JWT ready for transmission
-- in the access token request as assertion value.
--
-- >grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer&assertion=
--
getSignedJWT
  :: Email
  -- ^ The email address of the service account.
  -> Maybe Email
  -- ^ The email address of the user for which the
  -- application is requesting delegated access.
  -> [Scope]
  -- ^ The list of the permissions that the application requests.
  -> Maybe Int
  -- ^ Expiration time (maximun and default value is an hour, 3600).
  -> PrivateKey
  -- ^ The private key gotten from the PEM string obtained from the
  -- Google API Console.
  -> IO (Either String SignedJWT)
  -- ^ Either an error message or a signed JWT.
getSignedJWT :: Email
-> Maybe Email
-> [Email]
-> Maybe Int
-> PrivateKey
-> IO (Either String SignedJWT)
getSignedJWT Email
iss Maybe Email
msub [Email]
scs Maybe Int
mxt PrivateKey
pk = do
  let xt :: Int64
xt = Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
3600 Maybe Int
mxt)
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int64
xt Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
1 Bool -> Bool -> Bool
&& Int64
xt Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int64
3600) (String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Bad expiration time")
  UnixTime
t <- IO UnixTime
getUnixTime
  let i :: ByteString
i = ByteString
header ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"." ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Email -> ByteString
toB64 (Email
"{\"iss\":\"" Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> Email
iss Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> Email
"\","
          Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> Email -> (Email -> Email) -> Maybe Email -> Email
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Email
T.empty (\Email
e -> Email
"\"sub\":\"" Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> Email
e Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> Email
"\",") Maybe Email
msub
          Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> Email
"\"scope\":\"" Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> Email -> [Email] -> Email
T.intercalate Email
" " [Email]
scs Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> Email
"\",\"aud\
          \\":\"https://www.googleapis.com/oauth2/v4/token\",\"ex\
          \p\":" Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> CTime -> Email
toT (UnixTime -> CTime
utSeconds UnixTime
t CTime -> CTime -> CTime
forall a. Num a => a -> a -> a
+ Int64 -> CTime
CTime Int64
xt) Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> Email
",\"iat\":"
          Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> CTime -> Email
toT (UnixTime -> CTime
utSeconds UnixTime
t) Email -> Email -> Email
forall a. Semigroup a => a -> a -> a
<> Email
"}")
  Either String SignedJWT -> IO (Either String SignedJWT)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String SignedJWT -> IO (Either String SignedJWT))
-> Either String SignedJWT -> IO (Either String SignedJWT)
forall a b. (a -> b) -> a -> b
$
    (RSAError -> Either String SignedJWT)
-> (ByteString -> Either String SignedJWT)
-> Either RSAError ByteString
-> Either String SignedJWT
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either
      (Either String SignedJWT -> RSAError -> Either String SignedJWT
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String SignedJWT -> RSAError -> Either String SignedJWT)
-> Either String SignedJWT -> RSAError -> Either String SignedJWT
forall a b. (a -> b) -> a -> b
$ String -> Either String SignedJWT
forall a b. a -> Either a b
Left String
"RSAError")
      (\ByteString
s -> SignedJWT -> Either String SignedJWT
forall (m :: * -> *) a. Monad m => a -> m a
return (SignedJWT -> Either String SignedJWT)
-> SignedJWT -> Either String SignedJWT
forall a b. (a -> b) -> a -> b
$ ByteString -> SignedJWT
SignedJWT (ByteString -> SignedJWT) -> ByteString -> SignedJWT
forall a b. (a -> b) -> a -> b
$ ByteString
i ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"." ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
encode (ByteString -> ByteString
toStrict ByteString
s))
      (HashInfo -> PrivateKey -> ByteString -> Either RSAError ByteString
rsassa_pkcs1_v1_5_sign HashInfo
hashSHA256 PrivateKey
pk (ByteString -> Either RSAError ByteString)
-> ByteString -> Either RSAError ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
fromStrict ByteString
i)
  where
    toT :: CTime -> Email
toT = String -> Email
T.pack (String -> Email) -> (CTime -> String) -> CTime -> Email
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CTime -> String
forall a. Show a => a -> String
show
    toB64 :: Email -> ByteString
toB64 = ByteString -> ByteString
encode (ByteString -> ByteString)
-> (Email -> ByteString) -> Email -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Email -> ByteString
encodeUtf8
    header :: ByteString
header = Email -> ByteString
toB64 Email
"{\"alg\":\"RS256\",\"typ\":\"JWT\"}"