{-# LANGUAGE TemplateHaskell #-}
module Neptune.OAuth where
import Control.Concurrent (ThreadId, forkIO)
import Control.Lens
import qualified Data.Aeson as Aeson
import Data.Aeson.Lens
import Data.Time.Clock (NominalDiffTime)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Network.HTTP.Req (JsonResponse, POST (..),
ReqBodyUrlEnc (..), defaultHttpConfig,
jsonResponse, req, responseBody, runReq,
useHttpsURI, (=:))
import RIO hiding (Lens', (.~), (^.))
import qualified RIO.Text as T
import Text.URI (mkURI)
import qualified Web.JWT as JWT
data OAuth2Session = OAuth2Session
{ OAuth2Session -> Text
_oas_client_id :: Text
, OAuth2Session -> Text
_oas_access_token :: Text
, OAuth2Session -> Text
_oas_refresh_token :: Text
, OAuth2Session -> NominalDiffTime
_oas_expires_in :: NominalDiffTime
, OAuth2Session -> Text
_oas_refresh_url :: Text
}
makeLenses ''OAuth2Session
oauth2Setup :: Text -> Text -> IO (ThreadId, MVar OAuth2Session)
oauth2Setup :: Text -> Text -> IO (ThreadId, MVar OAuth2Session)
oauth2Setup Text
access_token Text
refresh_token = do
let decoded :: Maybe (JWT UnverifiedJWT)
decoded = Text -> Maybe (JWT UnverifiedJWT)
JWT.decode (Text -> Maybe (JWT UnverifiedJWT))
-> Text -> Maybe (JWT UnverifiedJWT)
forall a b. (a -> b) -> a -> b
$ Text
access_token
claims :: JWTClaimsSet
claims = JWT UnverifiedJWT -> JWTClaimsSet
forall r. JWT r -> JWTClaimsSet
JWT.claims (Maybe (JWT UnverifiedJWT)
decoded Maybe (JWT UnverifiedJWT)
-> Getting
(Endo (JWT UnverifiedJWT))
(Maybe (JWT UnverifiedJWT))
(JWT UnverifiedJWT)
-> JWT UnverifiedJWT
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Getting
(Endo (JWT UnverifiedJWT))
(Maybe (JWT UnverifiedJWT))
(JWT UnverifiedJWT)
forall a b. Prism (Maybe a) (Maybe b) a b
_Just)
issuer :: Text
issuer = JWTClaimsSet -> Maybe StringOrURI
JWT.iss JWTClaimsSet
claims Maybe StringOrURI
-> Getting (Endo StringOrURI) (Maybe StringOrURI) StringOrURI
-> StringOrURI
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Getting (Endo StringOrURI) (Maybe StringOrURI) StringOrURI
forall a b. Prism (Maybe a) (Maybe b) a b
_Just StringOrURI -> (StringOrURI -> Text) -> Text
forall a b. a -> (a -> b) -> b
& StringOrURI -> Text
JWT.stringOrURIToText
refresh_url :: Text
refresh_url = Text -> Text -> Text
T.append Text
issuer Text
"/protocol/openid-connect/token"
client_name :: Text
client_name = ClaimsMap -> Map Text Value
JWT.unClaimsMap (JWTClaimsSet -> ClaimsMap
JWT.unregisteredClaims JWTClaimsSet
claims) Map Text Value -> Getting (Endo Text) (Map Text Value) Text -> Text
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Index (Map Text Value)
-> Traversal' (Map Text Value) (IxValue (Map Text Value))
forall m. Ixed m => Index m -> Traversal' m (IxValue m)
ix Index (Map Text Value)
"azp" ((Value -> Const (Endo Text) Value)
-> Map Text Value -> Const (Endo Text) (Map Text Value))
-> ((Text -> Const (Endo Text) Text)
-> Value -> Const (Endo Text) Value)
-> Getting (Endo Text) (Map Text Value) Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Const (Endo Text) Text)
-> Value -> Const (Endo Text) Value
forall t. AsPrimitive t => Prism' t Text
_String
expires_at :: NominalDiffTime
expires_at = JWTClaimsSet -> Maybe IntDate
JWT.exp JWTClaimsSet
claims Maybe IntDate
-> Getting (Endo IntDate) (Maybe IntDate) IntDate -> IntDate
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Getting (Endo IntDate) (Maybe IntDate) IntDate
forall a b. Prism (Maybe a) (Maybe b) a b
_Just IntDate -> (IntDate -> NominalDiffTime) -> NominalDiffTime
forall a b. a -> (a -> b) -> b
& IntDate -> NominalDiffTime
JWT.secondsSinceEpoch
NominalDiffTime
now <- IO NominalDiffTime
getPOSIXTime
let expires_in :: NominalDiffTime
expires_in = NominalDiffTime
expires_at NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
- NominalDiffTime
now
session :: OAuth2Session
session = Text -> Text -> Text -> NominalDiffTime -> Text -> OAuth2Session
OAuth2Session Text
client_name Text
access_token Text
refresh_token NominalDiffTime
expires_in Text
refresh_url
MVar OAuth2Session
oauth_session_var <- OAuth2Session -> IO (MVar OAuth2Session)
forall (m :: * -> *) a. MonadIO m => a -> m (MVar a)
newMVar OAuth2Session
session
ThreadId
refresh_thread <- IO () -> IO ThreadId
forkIO (MVar OAuth2Session -> IO ()
oauthRefresher MVar OAuth2Session
oauth_session_var)
(ThreadId, MVar OAuth2Session) -> IO (ThreadId, MVar OAuth2Session)
forall (m :: * -> *) a. Monad m => a -> m a
return (ThreadId
refresh_thread, MVar OAuth2Session
oauth_session_var)
oauthRefresher :: MVar OAuth2Session -> IO ()
oauthRefresher :: MVar OAuth2Session -> IO ()
oauthRefresher MVar OAuth2Session
session_var = IO ()
update IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar OAuth2Session -> IO ()
oauthRefresher MVar OAuth2Session
session_var
where
update :: IO ()
update = do
OAuth2Session
session <- MVar OAuth2Session -> IO OAuth2Session
forall (m :: * -> *) a. MonadIO m => MVar a -> m a
readMVar MVar OAuth2Session
session_var
let wait_sec :: Int
wait_sec = NominalDiffTime -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor (NominalDiffTime -> Int) -> NominalDiffTime -> Int
forall a b. (a -> b) -> a -> b
$ NominalDiffTime
1000000 NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* OAuth2Session
session OAuth2Session
-> Getting NominalDiffTime OAuth2Session NominalDiffTime
-> NominalDiffTime
forall s a. s -> Getting a s a -> a
^. Getting NominalDiffTime OAuth2Session NominalDiffTime
Lens' OAuth2Session NominalDiffTime
oas_expires_in
Int -> IO ()
forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay Int
wait_sec
MVar OAuth2Session -> (OAuth2Session -> IO OAuth2Session) -> IO ()
forall (m :: * -> *) a.
MonadUnliftIO m =>
MVar a -> (a -> m a) -> m ()
modifyMVar_ MVar OAuth2Session
session_var ((OAuth2Session -> IO OAuth2Session) -> IO ())
-> (OAuth2Session -> IO OAuth2Session) -> IO ()
forall a b. (a -> b) -> a -> b
$ \OAuth2Session
session -> do
let url :: Text
url = OAuth2Session
session OAuth2Session -> Getting Text OAuth2Session Text -> Text
forall s a. s -> Getting a s a -> a
^. Getting Text OAuth2Session Text
Lens' OAuth2Session Text
oas_refresh_url
tok :: Text
tok = OAuth2Session
session OAuth2Session -> Getting Text OAuth2Session Text -> Text
forall s a. s -> Getting a s a -> a
^. Getting Text OAuth2Session Text
Lens' OAuth2Session Text
oas_refresh_token
body :: ReqBodyUrlEnc
body = FormUrlEncodedParam -> ReqBodyUrlEnc
ReqBodyUrlEnc (FormUrlEncodedParam -> ReqBodyUrlEnc)
-> FormUrlEncodedParam -> ReqBodyUrlEnc
forall a b. (a -> b) -> a -> b
$ [FormUrlEncodedParam] -> FormUrlEncodedParam
forall a. Monoid a => [a] -> a
mconcat
[ Text
"grant_type" Text -> Text -> FormUrlEncodedParam
forall param a.
(QueryParam param, ToHttpApiData a) =>
Text -> a -> param
=: (Text
"refresh_token" :: Text)
, Text
"refresh_token" Text -> Text -> FormUrlEncodedParam
forall param a.
(QueryParam param, ToHttpApiData a) =>
Text -> a -> param
=: Text
tok
, Text
"client_id" Text -> Text -> FormUrlEncodedParam
forall param a.
(QueryParam param, ToHttpApiData a) =>
Text -> a -> param
=: (Text
"neptune-cli" :: Text) ]
case URI -> Maybe (Url 'Https, Option 'Https)
forall (scheme :: Scheme). URI -> Maybe (Url 'Https, Option scheme)
useHttpsURI (URI -> Maybe (Url 'Https, Option 'Https))
-> Maybe URI -> Maybe (Url 'Https, Option 'Https)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Text -> Maybe URI
forall (m :: * -> *). MonadThrow m => Text -> m URI
mkURI Text
url of
Maybe (Url 'Https, Option 'Https)
Nothing -> [Char] -> IO OAuth2Session
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO OAuth2Session) -> [Char] -> IO OAuth2Session
forall a b. (a -> b) -> a -> b
$ [Char]
"Bad refresh url " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Text -> [Char]
T.unpack Text
url
Just (Url 'Https
url, Option 'Https
opt) -> do
JsonResponse Value
resp <- HttpConfig -> Req (JsonResponse Value) -> IO (JsonResponse Value)
forall (m :: * -> *) a. MonadIO m => HttpConfig -> Req a -> m a
runReq HttpConfig
defaultHttpConfig (Req (JsonResponse Value) -> IO (JsonResponse Value))
-> Req (JsonResponse Value) -> IO (JsonResponse Value)
forall a b. (a -> b) -> a -> b
$ POST
-> Url 'Https
-> ReqBodyUrlEnc
-> Proxy (JsonResponse Value)
-> Option 'Https
-> Req (JsonResponse Value)
forall (m :: * -> *) method body response (scheme :: Scheme).
(MonadHttp m, HttpMethod method, HttpBody body,
HttpResponse response,
HttpBodyAllowed (AllowsBody method) (ProvidesBody body)) =>
method
-> Url scheme
-> body
-> Proxy response
-> Option scheme
-> m response
req POST
POST Url 'Https
url ReqBodyUrlEnc
body Proxy (JsonResponse Value)
forall a. Proxy (JsonResponse a)
jsonResponse Option 'Https
opt :: IO (JsonResponse Aeson.Value)
Value
resp <- Value -> IO Value
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Value -> IO Value) -> Value -> IO Value
forall a b. (a -> b) -> a -> b
$ JsonResponse Value -> HttpResponseBody (JsonResponse Value)
forall response.
HttpResponse response =>
response -> HttpResponseBody response
responseBody JsonResponse Value
resp
let access_token :: Text
access_token = Value
resp Value
-> ((Text -> Const (Endo Text) Text)
-> Value -> Const (Endo Text) Value)
-> Text
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Text -> Traversal' Value Value
forall t. AsValue t => Text -> Traversal' t Value
key Text
"access_token" ((Value -> Const (Endo Text) Value)
-> Value -> Const (Endo Text) Value)
-> ((Text -> Const (Endo Text) Text)
-> Value -> Const (Endo Text) Value)
-> (Text -> Const (Endo Text) Text)
-> Value
-> Const (Endo Text) Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Const (Endo Text) Text)
-> Value -> Const (Endo Text) Value
forall t. AsPrimitive t => Prism' t Text
_String
refresh_token :: Text
refresh_token = Value
resp Value
-> ((Text -> Const (Endo Text) Text)
-> Value -> Const (Endo Text) Value)
-> Text
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Text -> Traversal' Value Value
forall t. AsValue t => Text -> Traversal' t Value
key Text
"refresh_token" ((Value -> Const (Endo Text) Value)
-> Value -> Const (Endo Text) Value)
-> ((Text -> Const (Endo Text) Text)
-> Value -> Const (Endo Text) Value)
-> (Text -> Const (Endo Text) Text)
-> Value
-> Const (Endo Text) Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> Const (Endo Text) Text)
-> Value -> Const (Endo Text) Value
forall t. AsPrimitive t => Prism' t Text
_String
expires_in :: NominalDiffTime
expires_in = Int -> NominalDiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Value
resp Value -> Getting (Endo Int) Value Int -> Int
forall s a. HasCallStack => s -> Getting (Endo a) s a -> a
^?! Text -> Traversal' Value Value
forall t. AsValue t => Text -> Traversal' t Value
key Text
"expires_in" ((Value -> Const (Endo Int) Value)
-> Value -> Const (Endo Int) Value)
-> Getting (Endo Int) Value Int -> Getting (Endo Int) Value Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Getting (Endo Int) Value Int
forall t a. (AsNumber t, Integral a) => Prism' t a
_Integral :: Int)
OAuth2Session -> IO OAuth2Session
forall (m :: * -> *) a. Monad m => a -> m a
return (OAuth2Session -> IO OAuth2Session)
-> OAuth2Session -> IO OAuth2Session
forall a b. (a -> b) -> a -> b
$ OAuth2Session
session
OAuth2Session -> (OAuth2Session -> OAuth2Session) -> OAuth2Session
forall a b. a -> (a -> b) -> b
& (Text -> Identity Text) -> OAuth2Session -> Identity OAuth2Session
Lens' OAuth2Session Text
oas_access_token ((Text -> Identity Text)
-> OAuth2Session -> Identity OAuth2Session)
-> Text -> OAuth2Session -> OAuth2Session
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Text
access_token
OAuth2Session -> (OAuth2Session -> OAuth2Session) -> OAuth2Session
forall a b. a -> (a -> b) -> b
& (Text -> Identity Text) -> OAuth2Session -> Identity OAuth2Session
Lens' OAuth2Session Text
oas_refresh_token ((Text -> Identity Text)
-> OAuth2Session -> Identity OAuth2Session)
-> Text -> OAuth2Session -> OAuth2Session
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Text
refresh_token
OAuth2Session -> (OAuth2Session -> OAuth2Session) -> OAuth2Session
forall a b. a -> (a -> b) -> b
& (NominalDiffTime -> Identity NominalDiffTime)
-> OAuth2Session -> Identity OAuth2Session
Lens' OAuth2Session NominalDiffTime
oas_expires_in ((NominalDiffTime -> Identity NominalDiffTime)
-> OAuth2Session -> Identity OAuth2Session)
-> NominalDiffTime -> OAuth2Session -> OAuth2Session
forall s t a b. ASetter s t a b -> b -> s -> t
.~ NominalDiffTime
expires_in