{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE RecordWildCards #-}

module Control.Concurrent.TokenLimiter.Concurrent
  ( Count,
    TokenLimitConfig (..),
    MonotonicTime,
    TokenLimiter (..),
    makeTokenLimiter,
    canDebit,
    tryDebit,
    waitDebit,

    -- * Helper functions
    computeCurrentCount,
  )
where

import Control.Concurrent
import Data.Word
import GHC.Clock
import GHC.Generics (Generic)
import Numeric.Natural

-- | An amount of tokens
type Count = Word64

-- | A configuration for 'TokenLimiter'
data TokenLimitConfig = TokenLimitConfig
  { -- | How many tokens should be in the bucket when it's created
    TokenLimitConfig -> Count
tokenLimitConfigInitialTokens :: !Count,
    -- | Maximum number of tokens the bucket can hold at any one time
    TokenLimitConfig -> Count
tokenLimitConfigMaxTokens :: !Count,
    -- | How many tokens are added to the bucket per second
    TokenLimitConfig -> Count
tokenLimitConfigTokensPerSecond :: !Count
  }
  deriving (Int -> TokenLimitConfig -> ShowS
[TokenLimitConfig] -> ShowS
TokenLimitConfig -> String
(Int -> TokenLimitConfig -> ShowS)
-> (TokenLimitConfig -> String)
-> ([TokenLimitConfig] -> ShowS)
-> Show TokenLimitConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TokenLimitConfig] -> ShowS
$cshowList :: [TokenLimitConfig] -> ShowS
show :: TokenLimitConfig -> String
$cshow :: TokenLimitConfig -> String
showsPrec :: Int -> TokenLimitConfig -> ShowS
$cshowsPrec :: Int -> TokenLimitConfig -> ShowS
Show, TokenLimitConfig -> TokenLimitConfig -> Bool
(TokenLimitConfig -> TokenLimitConfig -> Bool)
-> (TokenLimitConfig -> TokenLimitConfig -> Bool)
-> Eq TokenLimitConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TokenLimitConfig -> TokenLimitConfig -> Bool
$c/= :: TokenLimitConfig -> TokenLimitConfig -> Bool
== :: TokenLimitConfig -> TokenLimitConfig -> Bool
$c== :: TokenLimitConfig -> TokenLimitConfig -> Bool
Eq, (forall x. TokenLimitConfig -> Rep TokenLimitConfig x)
-> (forall x. Rep TokenLimitConfig x -> TokenLimitConfig)
-> Generic TokenLimitConfig
forall x. Rep TokenLimitConfig x -> TokenLimitConfig
forall x. TokenLimitConfig -> Rep TokenLimitConfig x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep TokenLimitConfig x -> TokenLimitConfig
$cfrom :: forall x. TokenLimitConfig -> Rep TokenLimitConfig x
Generic)

-- | A type synonym for a number of "monotonic time" nanoseconds.
--
-- This only exists because it is also a 'Word64' and would be too easy to confuse with a 'Count'.
type MonotonicTime = Word64

-- | A token bucket-based rate limiter
--
-- This token limiter is thread-safe and guarantees that:
--
-- * <https://en.wikipedia.org/wiki/Thundering_herd_problem There will be no thundering herd problem>
-- * <https://hackage.haskell.org/package/base-4.14.1.0/docs/Control-Concurrent-MVar.html#v:modifyMVar Fairness: Waiting processes will be serviced in a first-come first-service order.>
data TokenLimiter = TokenLimiter
  { TokenLimiter -> TokenLimitConfig
tokenLimiterConfig :: !TokenLimitConfig,
    -- | The last time the limiter was used, and what the token count was at that time.
    --
    -- Not that this library assumes that you never put anything into this mvar
    -- yourself and only use the functions in this library to interact with it.
    TokenLimiter -> MVar (Count, Count)
tokenLimiterLastServiced :: !(MVar (MonotonicTime, Count))
  }
  deriving (TokenLimiter -> TokenLimiter -> Bool
(TokenLimiter -> TokenLimiter -> Bool)
-> (TokenLimiter -> TokenLimiter -> Bool) -> Eq TokenLimiter
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TokenLimiter -> TokenLimiter -> Bool
$c/= :: TokenLimiter -> TokenLimiter -> Bool
== :: TokenLimiter -> TokenLimiter -> Bool
$c== :: TokenLimiter -> TokenLimiter -> Bool
Eq, (forall x. TokenLimiter -> Rep TokenLimiter x)
-> (forall x. Rep TokenLimiter x -> TokenLimiter)
-> Generic TokenLimiter
forall x. Rep TokenLimiter x -> TokenLimiter
forall x. TokenLimiter -> Rep TokenLimiter x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep TokenLimiter x -> TokenLimiter
$cfrom :: forall x. TokenLimiter -> Rep TokenLimiter x
Generic)

-- | Make a token limiter
--
-- The initial number of tokens will be the minimum of the 'tokenLimitConfigInitialTokens' and the 'tokenLimitConfigMaxTokens',
makeTokenLimiter :: TokenLimitConfig -> IO TokenLimiter
makeTokenLimiter :: TokenLimitConfig -> IO TokenLimiter
makeTokenLimiter TokenLimitConfig
tokenLimiterConfig = do
  Count
now <- IO Count
getMonotonicTimeNSec
  MVar (Count, Count)
tokenLimiterLastServiced <- (Count, Count) -> IO (MVar (Count, Count))
forall a. a -> IO (MVar a)
newMVar (Count
now, Count -> Count -> Count
forall a. Ord a => a -> a -> a
min (TokenLimitConfig -> Count
tokenLimitConfigInitialTokens TokenLimitConfig
tokenLimiterConfig) (TokenLimitConfig -> Count
tokenLimitConfigMaxTokens TokenLimitConfig
tokenLimiterConfig))
  TokenLimiter -> IO TokenLimiter
forall (f :: * -> *) a. Applicative f => a -> f a
pure TokenLimiter :: TokenLimitConfig -> MVar (Count, Count) -> TokenLimiter
TokenLimiter {MVar (Count, Count)
TokenLimitConfig
tokenLimiterLastServiced :: MVar (Count, Count)
tokenLimiterConfig :: TokenLimitConfig
tokenLimiterLastServiced :: MVar (Count, Count)
tokenLimiterConfig :: TokenLimitConfig
..}

-- | Ask if we could debit a number of tokens, without actually doing it.
--
-- Note that this information can become stale _very_ quickly.
-- If you want to also actually debit a number of tokens, use 'tryDebit' instead.
canDebit :: TokenLimiter -> Word64 -> IO Bool
canDebit :: TokenLimiter -> Count -> IO Bool
canDebit TokenLimiter {MVar (Count, Count)
TokenLimitConfig
tokenLimiterLastServiced :: MVar (Count, Count)
tokenLimiterConfig :: TokenLimitConfig
tokenLimiterLastServiced :: TokenLimiter -> MVar (Count, Count)
tokenLimiterConfig :: TokenLimiter -> TokenLimitConfig
..} Count
debit = MVar (Count, Count) -> ((Count, Count) -> IO Bool) -> IO Bool
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar (Count, Count)
tokenLimiterLastServiced (((Count, Count) -> IO Bool) -> IO Bool)
-> ((Count, Count) -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \(Count
lastServiced, Count
countThen) -> do
  Count
now <- IO Count
getMonotonicTimeNSec
  let currentCount :: Count
currentCount = TokenLimitConfig -> Count -> Count -> Count -> Count
computeCurrentCount TokenLimitConfig
tokenLimiterConfig Count
lastServiced Count
countThen Count
now
  let enoughAvailable :: Bool
enoughAvailable = Count
currentCount Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
debit
  Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
enoughAvailable

-- | Check if we can debit a number of tokens, and do it if possible.
--
-- The returned boolean represents whether the tokens were debited.
tryDebit :: TokenLimiter -> Word64 -> IO Bool
tryDebit :: TokenLimiter -> Count -> IO Bool
tryDebit TokenLimiter {MVar (Count, Count)
TokenLimitConfig
tokenLimiterLastServiced :: MVar (Count, Count)
tokenLimiterConfig :: TokenLimitConfig
tokenLimiterLastServiced :: TokenLimiter -> MVar (Count, Count)
tokenLimiterConfig :: TokenLimiter -> TokenLimitConfig
..} Count
debit = MVar (Count, Count)
-> ((Count, Count) -> IO ((Count, Count), Bool)) -> IO Bool
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (Count, Count)
tokenLimiterLastServiced (((Count, Count) -> IO ((Count, Count), Bool)) -> IO Bool)
-> ((Count, Count) -> IO ((Count, Count), Bool)) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \(Count
lastServiced, Count
countThen) -> do
  Count
now <- IO Count
getMonotonicTimeNSec
  let currentCount :: Count
currentCount = TokenLimitConfig -> Count -> Count -> Count -> Count
computeCurrentCount TokenLimitConfig
tokenLimiterConfig Count
lastServiced Count
countThen Count
now
  let enoughAvailable :: Bool
enoughAvailable = Count
currentCount Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
debit
  if Bool
enoughAvailable
    then do
      let newCount :: Count
newCount = Count
currentCount Count -> Count -> Count
forall a. Num a => a -> a -> a
- Count
debit
      ((Count, Count), Bool) -> IO ((Count, Count), Bool)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Count
now, Count
newCount), Bool
True)
    else ((Count, Count), Bool) -> IO ((Count, Count), Bool)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Count
lastServiced, Count
countThen), Bool
False)

-- | Wait until the given number of tokens can be debited
waitDebit :: TokenLimiter -> Word64 -> IO ()
waitDebit :: TokenLimiter -> Count -> IO ()
waitDebit TokenLimiter {MVar (Count, Count)
TokenLimitConfig
tokenLimiterLastServiced :: MVar (Count, Count)
tokenLimiterConfig :: TokenLimitConfig
tokenLimiterLastServiced :: TokenLimiter -> MVar (Count, Count)
tokenLimiterConfig :: TokenLimiter -> TokenLimitConfig
..} Count
debit = MVar (Count, Count)
-> ((Count, Count) -> IO (Count, Count)) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Count, Count)
tokenLimiterLastServiced (((Count, Count) -> IO (Count, Count)) -> IO ())
-> ((Count, Count) -> IO (Count, Count)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Count
lastServiced, Count
countThen) -> do
  Count
now <- IO Count
getMonotonicTimeNSec
  let currentCount :: Count
currentCount = TokenLimitConfig -> Count -> Count -> Count -> Count
computeCurrentCount TokenLimitConfig
tokenLimiterConfig Count
lastServiced Count
countThen Count
now
  let enoughAvailable :: Bool
enoughAvailable = Count
currentCount Count -> Count -> Bool
forall a. Ord a => a -> a -> Bool
>= Count
debit
  if Bool
enoughAvailable
    then do
      let newCount :: Count
newCount = Count
currentCount Count -> Count -> Count
forall a. Num a => a -> a -> a
- Count
debit
      (Count, Count) -> IO (Count, Count)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Count
now, Count
newCount)
    else do
      let extraTokensNeeded :: Count
extraTokensNeeded = Count
debit Count -> Count -> Count
forall a. Num a => a -> a -> a
- Count
currentCount
      let microsecondsToWaitDouble :: Double
          microsecondsToWaitDouble :: Double
microsecondsToWaitDouble =
            Double
1_000_000
              -- fromIntegral :: Word64 -> Double
              Double -> Double -> Double
forall a. Num a => a -> a -> a
* Count -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
extraTokensNeeded
              -- fromIntegral :: Word64 -> Double
              Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Count -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (TokenLimitConfig -> Count
tokenLimitConfigTokensPerSecond TokenLimitConfig
tokenLimiterConfig)
      let microsecondsToWait :: Int
microsecondsToWait = Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
ceiling Double
microsecondsToWaitDouble
      -- threadDelay guarantees that _at least_ the given number of microseconds will have passed.
      Int -> IO ()
threadDelay Int
microsecondsToWait
      -- However, it could be MUCH longer than that, so we will recalculate the time instead of
      -- adding that number of microseconds to the old time.
      Count
nowAfterWaiting <- IO Count
getMonotonicTimeNSec
      -- We do assume here that we will now have enough tokens and do not need to recalculate whether there will be enough.
      -- (We would not know what to do if there weren't, anyway.)
      -- BUT this assumption _should_ hold because _modifyMVar_ guarantees
      -- atomicity if there are no other producers for this MVar, which there
      -- aren't.
      let currentCountAfterWaiting :: Count
currentCountAfterWaiting = TokenLimitConfig -> Count -> Count -> Count -> Count
computeCurrentCount TokenLimitConfig
tokenLimiterConfig Count
lastServiced Count
countThen Count
nowAfterWaiting
      let newCount :: Count
newCount = Count
currentCountAfterWaiting Count -> Count -> Count
forall a. Num a => a -> a -> a
- Count
debit
      (Count, Count) -> IO (Count, Count)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Count
nowAfterWaiting, Count
newCount)

-- | Compute the current number of tokens in a bucket purely.
--
-- You should not need this function.
computeCurrentCount :: TokenLimitConfig -> MonotonicTime -> Count -> MonotonicTime -> Count
computeCurrentCount :: TokenLimitConfig -> Count -> Count -> Count -> Count
computeCurrentCount TokenLimitConfig {Count
tokenLimitConfigTokensPerSecond :: Count
tokenLimitConfigMaxTokens :: Count
tokenLimitConfigInitialTokens :: Count
tokenLimitConfigTokensPerSecond :: TokenLimitConfig -> Count
tokenLimitConfigMaxTokens :: TokenLimitConfig -> Count
tokenLimitConfigInitialTokens :: TokenLimitConfig -> Count
..} Count
lastServiced Count
countThen Count
now =
  let nanoDiff :: Word64
      nanoDiff :: Count
nanoDiff = Count
now Count -> Count -> Count
forall a. Num a => a -> a -> a
- Count
lastServiced
      countToAddDouble :: Double
      countToAddDouble :: Double
countToAddDouble =
        -- fromIntegral :: Word64 -> Double
        Count -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
nanoDiff
          -- fromIntegral :: Word64 -> Double
          Double -> Double -> Double
forall a. Num a => a -> a -> a
* Count -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
tokenLimitConfigTokensPerSecond
          Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
1_000_000_000
      countToAdd :: Word64
      countToAdd :: Count
countToAdd = Double -> Count
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
countToAddDouble
      totalPrecise :: Natural
      totalPrecise :: Natural
totalPrecise = Count -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
countThen Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Count -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
countToAdd
      willOverflow :: Bool
willOverflow = Natural
totalPrecise Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
> Count -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Count
forall a. Bounded a => a
maxBound :: Word64)
      totalCount :: Count
totalCount = Count
countThen Count -> Count -> Count
forall a. Num a => a -> a -> a
+ Count
countToAdd
   in if Bool
willOverflow
        then Count
tokenLimitConfigMaxTokens
        else Count -> Count -> Count
forall a. Ord a => a -> a -> a
min Count
tokenLimitConfigMaxTokens Count
totalCount