module GitHub.App.Token.Refresh
  ( HasExpiresAt (..)
  , Refresh
  , refreshing
  , getRefresh
  , cancelRefresh
  ) where

import GitHub.App.Token.Prelude

import Data.Time (diffUTCTime, getCurrentTime)
import Data.Void (Void)
import GitHub.App.Token.Generate (AccessToken (..))
import UnliftIO (MonadUnliftIO)
import UnliftIO.Async (Async, async, cancel)
import UnliftIO.Concurrent (threadDelay)
import UnliftIO.IORef (IORef, newIORef, readIORef, writeIORef)

class HasExpiresAt a where
  expiresAt :: a -> UTCTime

instance HasExpiresAt AccessToken where
  expiresAt :: AccessToken -> UTCTime
expiresAt = (.expires_at)

data Refresh a = Refresh
  { forall a. Refresh a -> IORef a
ref :: IORef a
  , forall a. Refresh a -> Async Void
thread :: Async Void
  }

-- | Run an action to (e.g.) generate a token and create a thread to refresh it
--
-- 'refreshing' will create an initial token and a thread that checks its
-- 'expires_at' on a loop. When it has expired, the action is used again to
-- replace the token.
--
-- @
-- ref <- 'refreshing' $ 'generateInstallationToken' creds installationId
-- @
--
-- Use 'getRefresh' to access the (possibly) updated token.
--
-- @
-- for_ repos $ \repo -> do
--   token <- 'getRefresh'
--   makeSomeRequest token repo
-- @
--
-- If you can't rely on program exit to clean up this background thread, you can
-- manually cancel it:
--
-- @
-- 'cancelRefresh' ref
-- @
refreshing :: (MonadUnliftIO m, HasExpiresAt a) => m a -> m (Refresh a)
refreshing :: forall (m :: * -> *) a.
(MonadUnliftIO m, HasExpiresAt a) =>
m a -> m (Refresh a)
refreshing m a
f = do
  a
x <- m a
f
  IORef a
ref <- a -> m (IORef a)
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef a
x
  Async Void
thread <- m Void -> m (Async Void)
forall (m :: * -> *) a. MonadUnliftIO m => m a -> m (Async a)
async (m Void -> m (Async Void)) -> m Void -> m (Async Void)
forall a b. (a -> b) -> a -> b
$ IORef a -> a -> m Void
loop IORef a
ref a
x
  Refresh a -> m (Refresh a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Refresh {IORef a
$sel:ref:Refresh :: IORef a
ref :: IORef a
ref, Async Void
$sel:thread:Refresh :: Async Void
thread :: Async Void
thread}
 where
  loop :: IORef a -> a -> m Void
loop IORef a
ref a
current = do
    UTCTime
now <- IO UTCTime -> m UTCTime
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    Int -> m ()
forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ a -> UTCTime -> Int
forall a. HasExpiresAt a => a -> UTCTime -> Int
refreshInMicroseconds a
current UTCTime
now

    a
updated <- m a
f
    IORef a -> a -> m ()
forall (m :: * -> *) a. MonadIO m => IORef a -> a -> m ()
writeIORef IORef a
ref a
updated
    IORef a -> a -> m Void
loop IORef a
ref a
updated

refreshInMicroseconds :: HasExpiresAt a => a -> UTCTime -> Int
refreshInMicroseconds :: forall a. HasExpiresAt a => a -> UTCTime -> Int
refreshInMicroseconds a
a = do
  forall a b. (RealFrac a, Integral b) => a -> b
round @Double @Int
    (Double -> Int) -> (UTCTime -> Double) -> UTCTime -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
0.95) -- refresh a little early
    (Double -> Double) -> (UTCTime -> Double) -> UTCTime -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
1000000) -- convert to microseconds
    (Double -> Double) -> (UTCTime -> Double) -> UTCTime -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NominalDiffTime -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac
    (NominalDiffTime -> Double)
-> (UTCTime -> NominalDiffTime) -> UTCTime -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Ord a => a -> a -> a
max NominalDiffTime
0 -- negative expiry means refresh right away
    (NominalDiffTime -> NominalDiffTime)
-> (UTCTime -> NominalDiffTime) -> UTCTime -> NominalDiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime (a -> UTCTime
forall a. HasExpiresAt a => a -> UTCTime
expiresAt a
a)

getRefresh :: MonadIO m => Refresh a -> m a
getRefresh :: forall (m :: * -> *) a. MonadIO m => Refresh a -> m a
getRefresh = IORef a -> m a
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef (IORef a -> m a) -> (Refresh a -> IORef a) -> Refresh a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.ref)

cancelRefresh :: MonadIO m => Refresh a -> m ()
cancelRefresh :: forall (m :: * -> *) a. MonadIO m => Refresh a -> m ()
cancelRefresh = Async Void -> m ()
forall (m :: * -> *) a. MonadIO m => Async a -> m ()
cancel (Async Void -> m ())
-> (Refresh a -> Async Void) -> Refresh a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.thread)