module GitHub.App.Token.Refresh
  ( HasExpiresAt (..)
  , Refresh
  , refreshing
  , getRefresh
  , cancelRefresh
  ) where

import GitHub.App.Token.Prelude

import Control.Monad (forever)
import Data.Time (getCurrentTime)
import Data.Void (Void)
import GitHub.App.Token.Generate (AccessToken (..))
import UnliftIO (MonadUnliftIO)
import UnliftIO.Async (Async, async, cancel)
import UnliftIO.Concurrent (threadDelay)
import UnliftIO.IORef (IORef, newIORef, readIORef, writeIORef)

class HasExpiresAt a where
  expiresAt :: a -> UTCTime

instance HasExpiresAt AccessToken where
  expiresAt :: AccessToken -> UTCTime
expiresAt = (.expires_at)

data Refresh a = Refresh
  { forall a. Refresh a -> IORef a
ref :: IORef a
  , forall a. Refresh a -> Async Void
thread :: Async Void
  }

-- | Run an action to (e.g.) generate a token and create a thread to refresh it
--
-- 'refreshing' will create an initial token and a thread that checks its
-- 'expires_at' on a loop. When it has expired, the action is used again to
-- replace the token.
--
-- @
-- ref <- 'refreshing' $ 'generateInstallationToken' creds installationId
-- @
--
-- Use 'getRefresh' to access the (possibly) updated token.
--
-- @
-- for_ repos $ \repo -> do
--   token <- 'getRefresh'
--   makeSomeRequest token repo
-- @
--
-- If you can't rely on program exit to clean up this background thread, you can
-- manually cancel it:
--
-- @
-- 'cancelRefresh' ref
-- @
refreshing :: (MonadUnliftIO m, HasExpiresAt a) => m a -> m (Refresh a)
refreshing :: forall (m :: * -> *) a.
(MonadUnliftIO m, HasExpiresAt a) =>
m a -> m (Refresh a)
refreshing m a
f = do
  a
x <- m a
f
  IORef a
ref <- a -> m (IORef a)
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef a
x
  Async Void
thread <- m Void -> m (Async Void)
forall (m :: * -> *) a. MonadUnliftIO m => m a -> m (Async a)
async (m Void -> m (Async Void)) -> m Void -> m (Async Void)
forall a b. (a -> b) -> a -> b
$ m () -> m Void
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m Void) -> m () -> m Void
forall a b. (a -> b) -> a -> b
$ do
    Int -> m ()
forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ forall a b. (RealFrac a, Integral b) => a -> b
round @Double (Double -> Int) -> Double -> Int
forall a b. (a -> b) -> a -> b
$ Double
0.5 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
1000000 -- 0.5s
    UTCTime
now <- IO UTCTime -> m UTCTime
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    Bool
isExpired <- (UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
<= UTCTime
now) (UTCTime -> Bool) -> (a -> UTCTime) -> a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> UTCTime
forall a. HasExpiresAt a => a -> UTCTime
expiresAt (a -> Bool) -> m a -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef a -> m a
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef a
ref
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
isExpired (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef a -> a -> m ()
forall (m :: * -> *) a. MonadIO m => IORef a -> a -> m ()
writeIORef IORef a
ref (a -> m ()) -> m a -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m a
f
  Refresh a -> m (Refresh a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Refresh {IORef a
$sel:ref:Refresh :: IORef a
ref :: IORef a
ref, Async Void
$sel:thread:Refresh :: Async Void
thread :: Async Void
thread}

getRefresh :: MonadIO m => Refresh a -> m a
getRefresh :: forall (m :: * -> *) a. MonadIO m => Refresh a -> m a
getRefresh = IORef a -> m a
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef (IORef a -> m a) -> (Refresh a -> IORef a) -> Refresh a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.ref)

cancelRefresh :: MonadIO m => Refresh a -> m ()
cancelRefresh :: forall (m :: * -> *) a. MonadIO m => Refresh a -> m ()
cancelRefresh = Async Void -> m ()
forall (m :: * -> *) a. MonadIO m => Async a -> m ()
cancel (Async Void -> m ())
-> (Refresh a -> Async Void) -> Refresh a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.thread)