-- | Fast rate-limiting via token bucket algorithm. Uses lock-free
-- compare-and-swap operations on the fast path when debiting tokens.

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}

module Control.Concurrent.TokenLimiter
  ( Count
  , LimitConfig(..)
  , RateLimiter
  , newRateLimiter
  , tryDebit
  , waitDebit
  , defaultLimitConfig
  ) where

import Control.Concurrent
import Data.IORef
import Foreign.Storable
import GHC.Generics
import GHC.Int
import GHC.IO
import GHC.Prim
import System.Clock

type Count = Int

data LimitConfig = LimitConfig {
    maxBucketTokens :: {-# UNPACK #-} !Count
      -- ^ maximum number of tokens the bucket can hold at any one time.
  , initialBucketTokens :: {-# UNPACK #-} !Count
      -- ^ how many tokens should be in the bucket when it's created.
  , bucketRefillTokensPerSecond :: {-# UNPACK #-} !Count
      -- ^ how many tokens should replenish the bucket per second.
  , clockAction :: IO TimeSpec
      -- ^ clock action, 'defaultLimitConfig' uses the coarse monotonic system
      -- clock. Mostly provided for mocking in the testsuite.
  , delayAction :: TimeSpec -> IO ()
      -- ^ action to delay for the given time interval. 'defaultLimitConfig'
      -- forwards to 'threadDelay'. Provided for mocking.
  } deriving (Generic)

data RateLimiter = RateLimiter {
    _bucketTokens :: !(MutableByteArray# RealWorld)
  , _bucketLastServiced :: {-# UNPACK #-} !(MVar TimeSpec)
  }


defaultLimitConfig :: LimitConfig
defaultLimitConfig = LimitConfig 5 1 1 nowIO sleepIO
  where
    nowIO = getTime MonotonicCoarse
    sleepIO x = threadDelay $! fromInteger (toNanoSecs x `div` 1000)


newRateLimiter :: LimitConfig -> IO RateLimiter
newRateLimiter lc = do
    !now <- nowIO
    !mv <- newMVar now
    mk mv
  where
    initial = initialBucketTokens lc
    nowIO = clockAction lc
    !(I# initial#) = initial
    !(I# nbytes#) = sizeOf $! initial
    mk mv = IO $ \s# ->
            case newByteArray# nbytes# s# of
              (# s1#, arr# #) -> case writeIntArray# arr# 0# initial# s1# of
                s2# -> (# s2#, RateLimiter arr# mv #)

rateToNsPer :: Integral a => a -> a
rateToNsPer tps = 1000000000 `div` tps

readBucket :: MutableByteArray# RealWorld -> IO Int
readBucket bucket# = IO $ \s# ->
                     case readIntArray# bucket# 0# s# of
                       (# s1#, w# #) -> (# s1#, I# w# #)

-- | Attempt to pull the given number of tokens from the bucket. Returns 'True'
-- if the tokens were successfully debited.
tryDebit :: LimitConfig -> RateLimiter -> Count -> IO Bool
tryDebit cfg = tryDebit' (clockAction cfg) cfg

tryDebit' :: IO TimeSpec -> LimitConfig -> RateLimiter -> Count -> IO Bool
tryDebit' nowIO cfg rl ndebits = tryGrab
  where
    bucket# = _bucketTokens rl
    mv = _bucketLastServiced rl
    maxTokens = maxBucketTokens cfg
    refillRate = bucketRefillTokensPerSecond cfg
    rdBucket = readBucket bucket#

    tryGrab = do
        !nt <- rdBucket
        if nt >= ndebits
          then tryCas nt (nt - ndebits)
          else fetchMore

    tryCas !nt@(I# nt#) !(I# newVal#) =
        IO $ \s# -> case casIntArray# bucket# 0# nt# newVal# s# of
                      (# s1#, prevV# #) -> let prevV = I# prevV#
                                               rest = if prevV == nt
                                                        then return True
                                                        else tryGrab
                                               (IO restF) = rest
                                           in restF s1#

    addLoop !numNewTokens = go
      where
        go = do
            !b@(I# bb#) <- rdBucket
            let !(I# bb'#) = min (fromIntegral maxTokens) (b + numNewTokens)
            IO $ \s# -> case casIntArray# bucket# 0# bb# bb'# s# of
                          (# s1#, prev# #) -> if (I# prev#) == b
                                                 then (# s1#, () #)
                                                 else let (IO f) = go in f s1#

    fetchMore = modifyMVar mv $ \lastUpdated -> do
        !now <- nowIO
        let !numNanos = toNanoSecs $ now - lastUpdated
        let !nanosPerToken = toInteger $ rateToNsPer refillRate
        let !numNewTokens0 = numNanos `div` nanosPerToken
        let numNewTokens = fromIntegral numNewTokens0
        -- TODO: allow partial debit fulfillment?
        if numNewTokens < ndebits
          then return $! (lastUpdated, False)
          else do
              let !lastUpdated' = lastUpdated +
                                  fromNanoSecs (toInteger numNewTokens * toInteger nanosPerToken)
              if numNewTokens == fromIntegral ndebits
                then return (lastUpdated', True)
                else do
                  addLoop (numNewTokens - fromIntegral ndebits)
                  return $! (lastUpdated', True)

waitForTokens :: TimeSpec -> LimitConfig -> RateLimiter -> Count -> IO ()
waitForTokens now cfg (RateLimiter bucket# mv) ntokens = do
    b <- rdBucket
    lastUpdated <- readMVar mv
    let numNeeded = fromIntegral ntokens - b
    let delta = toNanoSecs $ now - lastUpdated
    let nanos = nanosPerToken * toInteger numNeeded
    let sleepNanos = max 1 (fromInteger (nanos - delta + 500))
    let !sleepSpec = fromNanoSecs sleepNanos
    sleepFor sleepSpec
  where
    rdBucket = readBucket bucket#
    nanosPerToken = toInteger $ rateToNsPer refillRate
    refillRate = bucketRefillTokensPerSecond cfg
    sleepFor = delayAction cfg

-- | Attempt to pull /k/ tokens from the bucket, sleeping in a loop until they
-- become available. Will not partially fulfill token requests (i.e. it loops
-- until the entire allotment is available in one swoop), and makes no attempt
-- at fairness or queueing (i.e. you will definitely get \"thundering herd\" on
-- wakeup if a number of threads are contending for fresh tokens).
waitDebit :: LimitConfig -> RateLimiter -> Count -> IO ()
waitDebit lc rl ndebits = go
  where
    -- ask for time at most once through the loop.
    cacheClock ref = do
        m <- readIORef ref
        case m of
          Nothing -> do !now <- clockAction lc
                        writeIORef ref (Just now)
                        return now
          (Just t) -> return t

    go = do
        ref <- newIORef Nothing
        let clock = cacheClock ref
        b <- tryDebit' clock lc rl ndebits
        if b
          then return $! ()
          else do
            now <- clock
            waitForTokens now lc rl ndebits >> go