{-# LANGUAGE ScopedTypeVariables #-}

-- | This module implements rate-limiting functionality for Haskell programs.
-- Rate-limiting is useful when trying to control / limit access to a
-- particular resource over time. For example, you might want to limit the
-- rate at which you make requests to a server, as an administrator may block
-- your access if you make too many requests too quickly. Similarly, one may
-- wish to rate-limit certain communication actions, in order to avoid
-- accidentally performing a denial-of-service attack on a critical resource.
--
-- The fundamental idea of this library is that given some basic information
-- about the requests you wante rate limited, it will return you a function
-- that hides all the rate-limiting detail. In short, you make a call to one
-- of the function generators in this file, and you will be returned a function
-- to use. For example:
--
-- @
--   do f <- generateRateLimitedFunction ...
--      ...
--      res1 <- f a
--      ...
--      res2 <- f b
--      ...
-- @
--
-- The calls to the generated function (f) will be rate limited based on the
-- parameters given to 'generateRateLimitedFunction'.
--
-- 'generateRateLimitedFunction' is the most general version of the rate
-- limiting functionality, but specialized versions of it are also exported
-- for convenience.
--
module Control.RateLimit (
    generateRateLimitedFunction
  , RateLimit(..)
  , ResultsCombiner
  , dontCombine
  , rateLimitInvocation
  , rateLimitExecution
  ) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad (void)
import Data.Functor (($>))
import Data.Time.Clock.POSIX (getPOSIXTime)
import Data.Time.Units

-- | The rate at which to limit an action.
data RateLimit a
  = PerInvocation a -- ^ Rate limit the action to invocation once per time
                    -- unit. With this option, the time it takes for the
                    -- action to take place is not taken into consideration
                    -- when computing the rate, only the time between
                    -- invocations of the action. This may cause the action
                    -- to execute concurrently, as an invocation may occur
                    -- while an action is still running.
  | PerExecution a  -- ^ Rate limit the action to execution once per time
                    -- unit. With this option, the time it takes for the
                    -- action to take plase is taken into account, and all
                    -- actions will necessarily occur sequentially. However,
                    -- if your action takes longer than the time unit given,
                    -- then the rate of execution will be slower than the
                    -- given unit of time.

-- | In some cases, if two requests are waiting to be run, it may be possible
-- to combine them into a single request and thus increase the overall
-- bandwidth. The rate limit system supports this, but requires a little
-- additional information to make everything work out right. You may also
-- need to do something a bit wonky with your types to make this work ...
-- sorry.
--
-- The basic idea is this: Given two requests, you can either return Nothing
-- (signalling that the two requests can be combined), or a Just with a new
-- request representing the combination of the two requests. In addition, you
-- will need to provide a function that can turn the response to this single
-- request into two responses, one for each of the original requests.
--
-- I hope this description helps you work through the type, which I'll admit
-- is a bit opaque.
type ResultsCombiner req resp = req -> req -> Maybe (req, resp -> (resp, resp))

dontCombine :: ResultsCombiner a b
dontCombine :: forall a b. ResultsCombiner a b
dontCombine a
_ a
_ = forall a. Maybe a
Nothing

-- | Rate limit the invocation of a given action. This is equivalent to calling
-- 'generateRateLimitedFunction' with a 'PerInvocation' rate limit and the
-- 'dontCombine' combining function.
rateLimitInvocation :: TimeUnit t
                    => t
                    -> (req -> IO resp)
                    -> IO (req -> IO resp)
rateLimitInvocation :: forall t req resp.
TimeUnit t =>
t -> (req -> IO resp) -> IO (req -> IO resp)
rateLimitInvocation t
pertime req -> IO resp
action =
  forall req resp t.
TimeUnit t =>
RateLimit t
-> (req -> IO resp)
-> ResultsCombiner req resp
-> IO (req -> IO resp)
generateRateLimitedFunction (forall a. a -> RateLimit a
PerInvocation t
pertime) req -> IO resp
action forall a b. ResultsCombiner a b
dontCombine

-- | Rate limit the execution of a given action. This is equivalent to calling
-- 'generateRateLimitedFunction' with a 'PerExecution' rate limit and the
-- 'dontCombine' combining function.
rateLimitExecution :: TimeUnit t
                   => t
                   -> (req -> IO resp)
                   -> IO (req -> IO resp)
rateLimitExecution :: forall t req resp.
TimeUnit t =>
t -> (req -> IO resp) -> IO (req -> IO resp)
rateLimitExecution t
pertime req -> IO resp
action =
  forall req resp t.
TimeUnit t =>
RateLimit t
-> (req -> IO resp)
-> ResultsCombiner req resp
-> IO (req -> IO resp)
generateRateLimitedFunction (forall a. a -> RateLimit a
PerExecution t
pertime) req -> IO resp
action forall a b. ResultsCombiner a b
dontCombine

-- | The most generic way to rate limit an invocation.
generateRateLimitedFunction :: forall req resp t
                             . TimeUnit t
                            => RateLimit t
                               -- ^ What is the rate limit for this action
                            -> (req -> IO resp)
                               -- ^ What is the action you want to rate limit,
                               -- given as an a MonadIO function from requests
                               -- to responses?
                            -> ResultsCombiner req resp
                               -- ^ A function that can combine requests if
                               -- rate limiting happens. If you cannot combine
                               -- two requests into one request, we suggest
                               -- using 'dontCombine'.
                            -> IO (req -> IO resp)
generateRateLimitedFunction :: forall req resp t.
TimeUnit t =>
RateLimit t
-> (req -> IO resp)
-> ResultsCombiner req resp
-> IO (req -> IO resp)
generateRateLimitedFunction RateLimit t
ratelimit req -> IO resp
action ResultsCombiner req resp
combiner = do
  TChan (req, MVar resp)
chan <- forall a. STM a -> IO a
atomically forall a. STM (TChan a)
newTChan
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ forall a.
Maybe Integer -> Integer -> TChan (req, MVar resp) -> IO a
runner forall a. Maybe a
Nothing Integer
0 TChan (req, MVar resp)
chan
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ TChan (req, MVar resp) -> req -> IO resp
resultFunction TChan (req, MVar resp)
chan

  where
  currentMicroseconds :: IO Integer
  currentMicroseconds :: IO Integer
currentMicroseconds =
    forall a. TimeUnit a => a -> Integer
toMicroseconds forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (Integral a, Num b) => a -> b
fromIntegral :: Int -> Picosecond) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
      IO POSIXTime
getPOSIXTime

  -- runner: Repeatedly run requests from the channel, keeping track of the
  -- time immediately before the last request, and a "sleep discount" allowance
  -- we can spend (i.e. reduce future sleep times) based on the amount of time
  -- we've "overslept" in the past.
  runner :: Maybe Integer -> Integer -> TChan (req, MVar resp) -> IO a
  runner :: forall a.
Maybe Integer -> Integer -> TChan (req, MVar resp) -> IO a
runner Maybe Integer
mLastRun Integer
lastAllowance TChan (req, MVar resp)
chan = do
    (req
req, MVar resp
respMV) <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TChan a -> STM a
readTChan TChan (req, MVar resp)
chan
    let baseHandler :: resp -> IO ()
baseHandler resp
resp = forall a. MVar a -> a -> IO ()
putMVar MVar resp
respMV resp
resp

    -- should we wait for some amount of time before running?
    Integer
beforeWait <- IO Integer
currentMicroseconds
    let targetPeriod :: Integer
targetPeriod     = forall a. TimeUnit a => a -> Integer
toMicroseconds forall a b. (a -> b) -> a -> b
$ RateLimit t -> t
getRate RateLimit t
ratelimit
        timeSinceLastRun :: Integer
timeSinceLastRun = case Maybe Integer
mLastRun of
          Just Integer
lastRun -> Integer
beforeWait forall a. Num a => a -> a -> a
- Integer
lastRun
          Maybe Integer
Nothing -> forall a. Num a => a -> a
negate Integer
targetPeriod
        targetDelay :: Integer
targetDelay      = Integer
targetPeriod forall a. Num a => a -> a -> a
- Integer
timeSinceLastRun forall a. Num a => a -> a -> a
- Integer
lastAllowance

    -- sleep if necessary; determine sleep-discount allowance for next round
    Integer
nextAllowance <- if Integer
targetDelay forall a. Ord a => a -> a -> Bool
< Integer
0
      then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
abs Integer
targetDelay -- we have more allowance left
      else do
        -- sleep for *at least* our target delay time
        Int -> IO ()
threadDelay forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
targetDelay
        Integer
afterWait <- IO Integer
currentMicroseconds
        let slept :: Integer
slept     = Integer
afterWait forall a. Num a => a -> a -> a
- Integer
beforeWait
            overslept :: Integer
overslept = Integer
slept forall a. Num a => a -> a -> a
- Integer
targetDelay
        forall (m :: * -> *) a. Monad m => a -> m a
return Integer
overslept

    -- before running, can we combine this with any other requests on the pipe?
    (req
req', resp -> IO ()
finalHandler) <- TChan (req, MVar resp)
-> req -> (resp -> IO ()) -> IO (req, resp -> IO ())
updateRequestWithFollowers TChan (req, MVar resp)
chan req
req resp -> IO ()
baseHandler
    let run :: IO ()
run = req -> IO resp
action req
req' forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= resp -> IO ()
finalHandler

    Integer
beforeRun <- IO Integer
currentMicroseconds
    if RateLimit t -> Bool
shouldFork RateLimit t
ratelimit
      then forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO IO ()
run
      else IO ()
run

    forall a.
Maybe Integer -> Integer -> TChan (req, MVar resp) -> IO a
runner (forall a. a -> Maybe a
Just Integer
beforeRun) Integer
nextAllowance TChan (req, MVar resp)
chan

  -- updateRequestWithFollowers: We have one request. Can we combine it with
  -- some other requests into a cohesive whole?
  updateRequestWithFollowers :: TChan (req, MVar resp)
                             -> req
                             -> (resp -> IO ())
                             -> IO (req, (resp -> IO ()))
  updateRequestWithFollowers :: TChan (req, MVar resp)
-> req -> (resp -> IO ()) -> IO (req, resp -> IO ())
updateRequestWithFollowers TChan (req, MVar resp)
chan req
req resp -> IO ()
handler = do
    Bool
isEmpty <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TChan a -> STM Bool
isEmptyTChan TChan (req, MVar resp)
chan
    if Bool
isEmpty
      then forall (m :: * -> *) a. Monad m => a -> m a
return (req
req, resp -> IO ()
handler)
      else do Maybe ((req, resp -> (resp, resp)), MVar resp)
mCombinedAndMV <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
                tup :: (req, MVar resp)
tup@(req
next, MVar resp
nextRespMV) <- forall a. TChan a -> STM a
readTChan TChan (req, MVar resp)
chan
                case ResultsCombiner req resp
combiner req
req req
next of
                  Maybe (req, resp -> (resp, resp))
Nothing -> forall a. TChan a -> a -> STM ()
unGetTChan TChan (req, MVar resp)
chan (req, MVar resp)
tup forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> forall a. Maybe a
Nothing
                  Just (req, resp -> (resp, resp))
combined -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just ((req, resp -> (resp, resp))
combined, MVar resp
nextRespMV)

              case Maybe ((req, resp -> (resp, resp)), MVar resp)
mCombinedAndMV of
                Maybe ((req, resp -> (resp, resp)), MVar resp)
Nothing ->
                  forall (m :: * -> *) a. Monad m => a -> m a
return (req
req, resp -> IO ()
handler)
                Just ((req
req', resp -> (resp, resp)
splitResponse), MVar resp
nextRespMV) ->
                  TChan (req, MVar resp)
-> req -> (resp -> IO ()) -> IO (req, resp -> IO ())
updateRequestWithFollowers TChan (req, MVar resp)
chan req
req' forall a b. (a -> b) -> a -> b
$ \resp
resp -> do
                    let (resp
theirs, resp
mine) = resp -> (resp, resp)
splitResponse resp
resp
                    forall a. MVar a -> a -> IO ()
putMVar MVar resp
nextRespMV resp
mine
                    resp -> IO ()
handler resp
theirs

  -- shouldFork: should we fork or execute the action in place?
  shouldFork :: RateLimit t -> Bool
  shouldFork :: RateLimit t -> Bool
shouldFork (PerInvocation t
_) = Bool
True
  shouldFork (PerExecution t
_)  = Bool
False

  -- getRate: what is the rate of this action?
  getRate :: RateLimit t -> t
  getRate :: RateLimit t -> t
getRate (PerInvocation t
x) = t
x
  getRate (PerExecution  t
x) = t
x

  -- resultFunction: the function (partially applied on the channel) that will
  -- be returned from this monstrosity.
  resultFunction :: TChan (req, MVar resp) -> req -> IO resp
  resultFunction :: TChan (req, MVar resp) -> req -> IO resp
resultFunction TChan (req, MVar resp)
chan req
req = do
    MVar resp
respMV <- forall a. IO (MVar a)
newEmptyMVar
    forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TChan a -> a -> STM ()
writeTChan TChan (req, MVar resp)
chan (req
req, MVar resp
respMV)
    forall a. MVar a -> IO a
takeMVar MVar resp
respMV