-- | Fast rate-limiting via token bucket algorithm. Uses lock-free -- compare-and-swap operations on the fast path when debiting tokens. {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UnboxedTuples #-} module Control.Concurrent.TokenLimiter ( Count , LimitConfig(..) , RateLimiter , newRateLimiter , tryDebit , penalize , waitDebit , defaultLimitConfig ) where import Control.Concurrent import Data.IORef import Foreign.Storable import GHC.Generics import GHC.Int import GHC.IO import GHC.Prim import System.Clock type Count = Int data LimitConfig = LimitConfig { maxBucketTokens :: {-# UNPACK #-} !Count -- ^ maximum number of tokens the bucket can hold at any one time. , initialBucketTokens :: {-# UNPACK #-} !Count -- ^ how many tokens should be in the bucket when it's created. , bucketRefillTokensPerSecond :: {-# UNPACK #-} !Count -- ^ how many tokens should replenish the bucket per second. , clockAction :: IO TimeSpec -- ^ clock action, 'defaultLimitConfig' uses the monotonic system clock. -- Mostly provided for mocking in the testsuite. , delayAction :: TimeSpec -> IO () -- ^ action to delay for the given time interval. 'defaultLimitConfig' -- forwards to 'threadDelay'. Provided for mocking. } deriving (Generic) data RateLimiter = RateLimiter { _bucketTokens :: !(MutableByteArray# RealWorld) , _bucketLastServiced :: {-# UNPACK #-} !(MVar TimeSpec) } defaultLimitConfig :: LimitConfig defaultLimitConfig = LimitConfig 5 1 1 nowIO sleepIO where nowIO = getTime Monotonic sleepIO x = threadDelay $! fromInteger (toNanoSecs x `div` 1000) newRateLimiter :: LimitConfig -> IO RateLimiter newRateLimiter lc = do !now <- nowIO !mv <- newMVar now mk mv where initial = initialBucketTokens lc nowIO = clockAction lc !(I# initial#) = initial !(I# nbytes#) = sizeOf $! initial mk mv = IO $ \s# -> case newByteArray# nbytes# s# of (# s1#, arr# #) -> case writeIntArray# arr# 0# initial# s1# of s2# -> (# s2#, RateLimiter arr# mv #) rateToNsPer :: Integral a => a -> a rateToNsPer tps = 1000000000 `div` tps readBucket :: MutableByteArray# RealWorld -> IO Int readBucket bucket# = IO $ \s# -> case readIntArray# bucket# 0# s# of (# s1#, w# #) -> (# s1#, I# w# #) -- | Unconditionally debit this amount of tokens from the rate limiter, driving -- it negative if necessary. Returns the new bucket balance. -- -- /Since: 0.2/ penalize :: RateLimiter -> Count -> IO Count penalize rl delta = addLoop where bucket# = _bucketTokens rl rdBucket = readBucket bucket# addLoop = do !b@(I# bb#) <- rdBucket let !ibb'@(I# bb'#) = b - delta IO $ \s# -> case casIntArray# bucket# 0# bb# bb'# s# of (# s1#, prev# #) -> if (I# prev#) == b then (# s1#, ibb' #) else let (IO f) = addLoop in f s1# -- | Attempt to pull the given number of tokens from the bucket. Returns 'True' -- if the tokens were successfully debited. tryDebit :: LimitConfig -> RateLimiter -> Count -> IO Bool tryDebit cfg rl cnt = do let nowIO = clockAction cfg snd <$> tryDebit' nowIO cfg rl cnt tryDebit' :: IO TimeSpec -> LimitConfig -> RateLimiter -> Count -> IO (Int, Bool) tryDebit' nowIO cfg rl ndebits = tryGrab where bucket# = _bucketTokens rl mv = _bucketLastServiced rl maxTokens = maxBucketTokens cfg refillRate = bucketRefillTokensPerSecond cfg rdBucket = readBucket bucket# tryGrab = do !nt <- rdBucket if nt >= ndebits then tryCas nt (nt - ndebits) else fetchMore nt tryCas !nt@(I# nt#) !newval@(I# newVal#) = IO $ \s# -> case casIntArray# bucket# 0# nt# newVal# s# of (# s1#, prevV# #) -> let prevV = I# prevV# rest = if prevV == nt then return (newval, True) else tryGrab (IO restF) = rest in restF s1# addLoop !numNewTokens = go where go = do !b@(I# bb#) <- rdBucket let !ibb'@(I# bb'#) = min (fromIntegral maxTokens) (b + numNewTokens) IO $ \s# -> case casIntArray# bucket# 0# bb# bb'# s# of (# s1#, prev# #) -> if (I# prev#) == b then (# s1#, ibb' #) else let (IO f) = go in f s1# fetchMore !nt = do newBalance <- modifyMVar mv $ \lastUpdated -> do now <- nowIO let !numNanos = toNanoSecs $ now - lastUpdated let !nanosPerToken = toInteger $ rateToNsPer refillRate let !numNewTokens0 = numNanos `div` nanosPerToken let numNewTokens = fromIntegral numNewTokens0 let !lastUpdated' = lastUpdated + fromNanoSecs (toInteger numNewTokens * toInteger nanosPerToken) if numNewTokens > 0 then do nb <- addLoop numNewTokens return (lastUpdated', nb) else return (lastUpdated, nt) if newBalance >= ndebits then tryGrab else return (newBalance, False) waitForTokens :: TimeSpec -> LimitConfig -> RateLimiter -> Count -> Count -> IO () waitForTokens now cfg (RateLimiter _ mv) balance ntokens = do lastUpdated <- readMVar mv let numNeeded = fromIntegral ntokens - balance let delta = toNanoSecs $ now - lastUpdated let nanos = nanosPerToken * toInteger numNeeded let sleepNanos = max 1 (fromInteger (nanos - delta + 500)) let !sleepSpec = fromNanoSecs sleepNanos sleepFor sleepSpec where nanosPerToken = toInteger $ rateToNsPer refillRate refillRate = bucketRefillTokensPerSecond cfg sleepFor = delayAction cfg -- | Attempt to pull /k/ tokens from the bucket, sleeping in a loop until they -- become available. Will not partially fulfill token requests (i.e. it loops -- until the entire allotment is available in one swoop), and makes no attempt -- at fairness or queueing (i.e. you will probably get \"thundering herd\" on -- wakeup if a number of threads are contending for fresh tokens). waitDebit :: LimitConfig -> RateLimiter -> Count -> IO () waitDebit lc rl ndebits = go where cacheClock ref = do m <- readIORef ref case m of Nothing -> do !now <- clockAction lc writeIORef ref (Just now) return now (Just t) -> return t go = do ref <- newIORef Nothing let clock = cacheClock ref (balance, b) <- tryDebit' clock lc rl ndebits if b then return $! () else do now <- clock waitForTokens now lc rl balance ndebits >> go