{-|
Module      : Neptune.OAuth
Description : Neptune Client
Copyright   : (c) Jiasen Wu, 2020
License     : BSD-3-Clause
-}
{-# LANGUAGE TemplateHaskell #-}
module Neptune.OAuth where

import           Control.Concurrent    (ThreadId, forkIO)
import           Control.Lens
import qualified Data.Aeson            as Aeson
import           Data.Aeson.Lens
import           Data.Time.Clock       (NominalDiffTime)
import           Data.Time.Clock.POSIX (getPOSIXTime)
import           Network.HTTP.Req      (JsonResponse, POST (..),
                                        ReqBodyUrlEnc (..), defaultHttpConfig,
                                        jsonResponse, req, responseBody, runReq,
                                        useHttpsURI, (=:))
import           RIO                   hiding (Lens', (.~), (^.))
import qualified RIO.Text              as T
import           Text.URI              (mkURI)
import qualified Web.JWT               as JWT


data OAuth2Session = OAuth2Session
    { OAuth2Session -> Text
_oas_client_id     :: Text
    , OAuth2Session -> Text
_oas_access_token  :: Text -- ^ access token for APIs
    , OAuth2Session -> Text
_oas_refresh_token :: Text -- ^ refresh token after expires
    , OAuth2Session -> NominalDiffTime
_oas_expires_in    :: NominalDiffTime -- ^ duration in which the access token is valid
    , OAuth2Session -> Text
_oas_refresh_url   :: Text -- ^ refresh url
    }

makeLenses ''OAuth2Session

-- | Setup a background thread to refresh OAuth2 token.
oauth2Setup :: Text -> Text -> IO (ThreadId, MVar OAuth2Session)
oauth2Setup :: Text -> Text -> IO (ThreadId, MVar OAuth2Session)
oauth2Setup Text
access_token Text
refresh_token = do
    let decoded :: Maybe (JWT UnverifiedJWT)
decoded = Text -> Maybe (JWT UnverifiedJWT)
JWT.decode (Text -> Maybe (JWT UnverifiedJWT))
-> Text -> Maybe (JWT UnverifiedJWT)
forall a b. (a -> b) -> a -> b
$ Text
access_token
        claims :: JWTClaimsSet
claims  = JWT UnverifiedJWT -> JWTClaimsSet
forall r. JWT r -> JWTClaimsSet
JWT.claims (Maybe (JWT UnverifiedJWT)
decoded Maybe (JWT UnverifiedJWT)
-> Getting
     (Endo (JWT UnverifiedJWT))
     (Maybe (JWT UnverifiedJWT))
     (JWT UnverifiedJWT)
-> JWT UnverifiedJWT
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Getting
  (Endo (JWT UnverifiedJWT))
  (Maybe (JWT UnverifiedJWT))
  (JWT UnverifiedJWT)
forall a b. Prism (Maybe a) (Maybe b) a b
_Just)
        issuer :: Text
issuer  = JWTClaimsSet -> Maybe StringOrURI
JWT.iss JWTClaimsSet
claims Maybe StringOrURI
-> Getting (Endo StringOrURI) (Maybe StringOrURI) StringOrURI
-> StringOrURI
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Getting (Endo StringOrURI) (Maybe StringOrURI) StringOrURI
forall a b. Prism (Maybe a) (Maybe b) a b
_Just StringOrURI -> (StringOrURI -> Text) -> Text
forall a b. a -> (a -> b) -> b
& StringOrURI -> Text
JWT.stringOrURIToText
        refresh_url :: Text
refresh_url = Text -> Text -> Text
T.append Text
issuer Text
"/protocol/openid-connect/token"
        client_name :: Text
client_name = ClaimsMap -> Map Text Value
JWT.unClaimsMap (JWTClaimsSet -> ClaimsMap
JWT.unregisteredClaims JWTClaimsSet
claims) Map Text Value -> Getting (Endo Text) (Map Text Value) Text -> Text
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Index (Map Text Value)
-> Traversal' (Map Text Value) (IxValue (Map Text Value))
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index (Map Text Value)
"azp" ((Value -> Const (Endo Text) Value)
 -> Map Text Value -> Const (Endo Text) (Map Text Value))
-> ((Text -> Const (Endo Text) Text)
    -> Value -> Const (Endo Text) Value)
-> Getting (Endo Text) (Map Text Value) Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Const (Endo Text) Text)
-> Value -> Const (Endo Text) Value
forall t. AsPrimitive t => Prism' t Text
_String
        expires_at :: NominalDiffTime
expires_at  = JWTClaimsSet -> Maybe IntDate
JWT.exp JWTClaimsSet
claims Maybe IntDate
-> Getting (Endo IntDate) (Maybe IntDate) IntDate -> IntDate
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Getting (Endo IntDate) (Maybe IntDate) IntDate
forall a b. Prism (Maybe a) (Maybe b) a b
_Just IntDate -> (IntDate -> NominalDiffTime) -> NominalDiffTime
forall a b. a -> (a -> b) -> b
& IntDate -> NominalDiffTime
JWT.secondsSinceEpoch

    NominalDiffTime
now <- IO NominalDiffTime
getPOSIXTime
    let expires_in :: NominalDiffTime
expires_in = NominalDiffTime
expires_at NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
- NominalDiffTime
now
        session :: OAuth2Session
session = Text -> Text -> Text -> NominalDiffTime -> Text -> OAuth2Session
OAuth2Session Text
client_name Text
access_token Text
refresh_token NominalDiffTime
expires_in Text
refresh_url

    MVar OAuth2Session
oauth_session_var <- OAuth2Session -> IO (MVar OAuth2Session)
forall (m :: * -> *) a. MonadIO m => a -> m (MVar a)
newMVar OAuth2Session
session
    ThreadId
refresh_thread <- IO () -> IO ThreadId
forkIO (MVar OAuth2Session -> IO ()
oauthRefresher MVar OAuth2Session
oauth_session_var)
    (ThreadId, MVar OAuth2Session) -> IO (ThreadId, MVar OAuth2Session)
forall (m :: * -> *) a. Monad m => a -> m a
return (ThreadId
refresh_thread, MVar OAuth2Session
oauth_session_var)

oauthRefresher :: MVar OAuth2Session -> IO ()
oauthRefresher :: MVar OAuth2Session -> IO ()
oauthRefresher MVar OAuth2Session
session_var = IO ()
update IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar OAuth2Session -> IO ()
oauthRefresher MVar OAuth2Session
session_var
    where
        update :: IO ()
update = do
            OAuth2Session
session <- MVar OAuth2Session -> IO OAuth2Session
forall (m :: * -> *) a. MonadIO m => MVar a -> m a
readMVar MVar OAuth2Session
session_var
            let wait_sec :: Int
wait_sec = NominalDiffTime -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor (NominalDiffTime -> Int) -> NominalDiffTime -> Int
forall a b. (a -> b) -> a -> b
$ NominalDiffTime
1000000 NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* OAuth2Session
session OAuth2Session
-> Getting NominalDiffTime OAuth2Session NominalDiffTime
-> NominalDiffTime
forall s a. s -> Getting a s a -> a
^. Getting NominalDiffTime OAuth2Session NominalDiffTime
Lens' OAuth2Session NominalDiffTime
oas_expires_in
            Int -> IO ()
forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay Int
wait_sec

            MVar OAuth2Session -> (OAuth2Session -> IO OAuth2Session) -> IO ()
forall (m :: * -> *) a.
MonadUnliftIO m =>
MVar a -> (a -> m a) -> m ()
modifyMVar_ MVar OAuth2Session
session_var ((OAuth2Session -> IO OAuth2Session) -> IO ())
-> (OAuth2Session -> IO OAuth2Session) -> IO ()
forall a b. (a -> b) -> a -> b
$ \OAuth2Session
session -> do
                let url :: Text
url = OAuth2Session
session OAuth2Session -> Getting Text OAuth2Session Text -> Text
forall s a. s -> Getting a s a -> a
^. Getting Text OAuth2Session Text
Lens' OAuth2Session Text
oas_refresh_url
                    tok :: Text
tok = OAuth2Session
session OAuth2Session -> Getting Text OAuth2Session Text -> Text
forall s a. s -> Getting a s a -> a
^. Getting Text OAuth2Session Text
Lens' OAuth2Session Text
oas_refresh_token
                    body :: ReqBodyUrlEnc
body = FormUrlEncodedParam -> ReqBodyUrlEnc
ReqBodyUrlEnc (FormUrlEncodedParam -> ReqBodyUrlEnc)
-> FormUrlEncodedParam -> ReqBodyUrlEnc
forall a b. (a -> b) -> a -> b
$ [FormUrlEncodedParam] -> FormUrlEncodedParam
forall a. Monoid a => [a] -> a
mconcat
                            [ Text
"grant_type"    Text -> Text -> FormUrlEncodedParam
forall param a.
(QueryParam param, ToHttpApiData a) =>
Text -> a -> param
=: (Text
"refresh_token" :: Text)
                            , Text
"refresh_token" Text -> Text -> FormUrlEncodedParam
forall param a.
(QueryParam param, ToHttpApiData a) =>
Text -> a -> param
=: Text
tok
                            , Text
"client_id"     Text -> Text -> FormUrlEncodedParam
forall param a.
(QueryParam param, ToHttpApiData a) =>
Text -> a -> param
=: (Text
"neptune-cli"   :: Text) ]

                {--
                Accept: application/json
                Content-Type: application/x-www-form-urlencoded;charset=UTF-8

                grant_type=refresh_token
                &refresh_token=...
                &client_id=neptune-cli
                --}
                case URI -> Maybe (Url 'Https, Option 'Https)
forall (scheme :: Scheme). URI -> Maybe (Url 'Https, Option scheme)
useHttpsURI (URI -> Maybe (Url 'Https, Option 'Https))
-> Maybe URI -> Maybe (Url 'Https, Option 'Https)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Text -> Maybe URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
mkURI Text
url of
                  Maybe (Url 'Https, Option 'Https)
Nothing -> [Char] -> IO OAuth2Session
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO OAuth2Session) -> [Char] -> IO OAuth2Session
forall a b. (a -> b) -> a -> b
$ [Char]
"Bad refresh url " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Text -> [Char]
T.unpack Text
url
                  Just (Url 'Https
url, Option 'Https
opt) -> do
                      -- the default retrying pocily is 50ms delay, upto 5 times
                      JsonResponse Value
resp <- HttpConfig -> Req (JsonResponse Value) -> IO (JsonResponse Value)
forall (m :: * -> *) a. MonadIO m => HttpConfig -> Req a -> m a
runReq HttpConfig
defaultHttpConfig (Req (JsonResponse Value) -> IO (JsonResponse Value))
-> Req (JsonResponse Value) -> IO (JsonResponse Value)
forall a b. (a -> b) -> a -> b
$ POST
-> Url 'Https
-> ReqBodyUrlEnc
-> Proxy (JsonResponse Value)
-> Option 'Https
-> Req (JsonResponse Value)
forall (m :: * -> *) method body response (scheme :: Scheme).
(MonadHttp m, HttpMethod method, HttpBody body,
 HttpResponse response,
 HttpBodyAllowed (AllowsBody method) (ProvidesBody body)) =>
method
-> Url scheme
-> body
-> Proxy response
-> Option scheme
-> m response
req POST
POST Url 'Https
url ReqBodyUrlEnc
body Proxy (JsonResponse Value)
forall a. Proxy (JsonResponse a)
jsonResponse Option 'Https
opt :: IO (JsonResponse Aeson.Value)
                      Value
resp <- Value -> IO Value
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Value -> IO Value) -> Value -> IO Value
forall a b. (a -> b) -> a -> b
$ JsonResponse Value -> HttpResponseBody (JsonResponse Value)
forall response.
HttpResponse response =>
response -> HttpResponseBody response
responseBody JsonResponse Value
resp
                      -- resp is a json object of keys:
                      --    "access_token", "expires_in", "refresh_expires_in", "refresh_token",
                      --    "token_type", "not-before-policy", "session_state", "scope"
                      let access_token :: Text
access_token  = Value
resp Value
-> ((Text -> Const (Endo Text) Text)
    -> Value -> Const (Endo Text) Value)
-> Text
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Text -> Traversal' Value Value
forall t. AsValue t => Text -> Traversal' t Value
key Text
"access_token" ((Value -> Const (Endo Text) Value)
 -> Value -> Const (Endo Text) Value)
-> ((Text -> Const (Endo Text) Text)
    -> Value -> Const (Endo Text) Value)
-> (Text -> Const (Endo Text) Text)
-> Value
-> Const (Endo Text) Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Const (Endo Text) Text)
-> Value -> Const (Endo Text) Value
forall t. AsPrimitive t => Prism' t Text
_String
                          refresh_token :: Text
refresh_token = Value
resp Value
-> ((Text -> Const (Endo Text) Text)
    -> Value -> Const (Endo Text) Value)
-> Text
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Text -> Traversal' Value Value
forall t. AsValue t => Text -> Traversal' t Value
key Text
"refresh_token" ((Value -> Const (Endo Text) Value)
 -> Value -> Const (Endo Text) Value)
-> ((Text -> Const (Endo Text) Text)
    -> Value -> Const (Endo Text) Value)
-> (Text -> Const (Endo Text) Text)
-> Value
-> Const (Endo Text) Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Const (Endo Text) Text)
-> Value -> Const (Endo Text) Value
forall t. AsPrimitive t => Prism' t Text
_String
                          expires_in :: NominalDiffTime
expires_in    = Int -> NominalDiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Value
resp Value -> Getting (Endo Int) Value Int -> Int
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Text -> Traversal' Value Value
forall t. AsValue t => Text -> Traversal' t Value
key Text
"expires_in" ((Value -> Const (Endo Int) Value)
 -> Value -> Const (Endo Int) Value)
-> Getting (Endo Int) Value Int -> Getting (Endo Int) Value Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Getting (Endo Int) Value Int
forall t a. (AsNumber t, Integral a) => Prism' t a
_Integral :: Int)
                      OAuth2Session -> IO OAuth2Session
forall (m :: * -> *) a. Monad m => a -> m a
return (OAuth2Session -> IO OAuth2Session)
-> OAuth2Session -> IO OAuth2Session
forall a b. (a -> b) -> a -> b
$ OAuth2Session
session
                        OAuth2Session -> (OAuth2Session -> OAuth2Session) -> OAuth2Session
forall a b. a -> (a -> b) -> b
& (Text -> Identity Text) -> OAuth2Session -> Identity OAuth2Session
Lens' OAuth2Session Text
oas_access_token  ((Text -> Identity Text)
 -> OAuth2Session -> Identity OAuth2Session)
-> Text -> OAuth2Session -> OAuth2Session
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Text
access_token
                        OAuth2Session -> (OAuth2Session -> OAuth2Session) -> OAuth2Session
forall a b. a -> (a -> b) -> b
& (Text -> Identity Text) -> OAuth2Session -> Identity OAuth2Session
Lens' OAuth2Session Text
oas_refresh_token ((Text -> Identity Text)
 -> OAuth2Session -> Identity OAuth2Session)
-> Text -> OAuth2Session -> OAuth2Session
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Text
refresh_token
                        OAuth2Session -> (OAuth2Session -> OAuth2Session) -> OAuth2Session
forall a b. a -> (a -> b) -> b
& (NominalDiffTime -> Identity NominalDiffTime)
-> OAuth2Session -> Identity OAuth2Session
Lens' OAuth2Session NominalDiffTime
oas_expires_in    ((NominalDiffTime -> Identity NominalDiffTime)
 -> OAuth2Session -> Identity OAuth2Session)
-> NominalDiffTime -> OAuth2Session -> OAuth2Session
forall s t a b. ASetter s t a b -> b -> s -> t
.~ NominalDiffTime
expires_in