{-# LANGUAGE OverloadedStrings #-}
module Control.RateLimiter
    ( WindowSize(..), RateLimitMode(..), RateLimitConfig(..)
    , RateLimiter, newRateLimiter
    , isRateLimited, withRetryRateLimiter
    )
where

import Control.Monad
import Control.Monad.Extra
import Control.Monad.Trans
import Data.IORef
import Data.Time
import Data.Time.TimeSpan
import qualified Data.Sequence as Seq
import qualified Data.Vector as V

data WindowSize
    = WsMinute
    | WsHour
    | WsSecond
    | WsDay
    deriving (Int -> WindowSize -> ShowS
[WindowSize] -> ShowS
WindowSize -> String
(Int -> WindowSize -> ShowS)
-> (WindowSize -> String)
-> ([WindowSize] -> ShowS)
-> Show WindowSize
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WindowSize] -> ShowS
$cshowList :: [WindowSize] -> ShowS
show :: WindowSize -> String
$cshow :: WindowSize -> String
showsPrec :: Int -> WindowSize -> ShowS
$cshowsPrec :: Int -> WindowSize -> ShowS
Show, WindowSize -> WindowSize -> Bool
(WindowSize -> WindowSize -> Bool)
-> (WindowSize -> WindowSize -> Bool) -> Eq WindowSize
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WindowSize -> WindowSize -> Bool
$c/= :: WindowSize -> WindowSize -> Bool
== :: WindowSize -> WindowSize -> Bool
$c== :: WindowSize -> WindowSize -> Bool
Eq, Int -> WindowSize
WindowSize -> Int
WindowSize -> [WindowSize]
WindowSize -> WindowSize
WindowSize -> WindowSize -> [WindowSize]
WindowSize -> WindowSize -> WindowSize -> [WindowSize]
(WindowSize -> WindowSize)
-> (WindowSize -> WindowSize)
-> (Int -> WindowSize)
-> (WindowSize -> Int)
-> (WindowSize -> [WindowSize])
-> (WindowSize -> WindowSize -> [WindowSize])
-> (WindowSize -> WindowSize -> [WindowSize])
-> (WindowSize -> WindowSize -> WindowSize -> [WindowSize])
-> Enum WindowSize
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: WindowSize -> WindowSize -> WindowSize -> [WindowSize]
$cenumFromThenTo :: WindowSize -> WindowSize -> WindowSize -> [WindowSize]
enumFromTo :: WindowSize -> WindowSize -> [WindowSize]
$cenumFromTo :: WindowSize -> WindowSize -> [WindowSize]
enumFromThen :: WindowSize -> WindowSize -> [WindowSize]
$cenumFromThen :: WindowSize -> WindowSize -> [WindowSize]
enumFrom :: WindowSize -> [WindowSize]
$cenumFrom :: WindowSize -> [WindowSize]
fromEnum :: WindowSize -> Int
$cfromEnum :: WindowSize -> Int
toEnum :: Int -> WindowSize
$ctoEnum :: Int -> WindowSize
pred :: WindowSize -> WindowSize
$cpred :: WindowSize -> WindowSize
succ :: WindowSize -> WindowSize
$csucc :: WindowSize -> WindowSize
Enum, WindowSize
WindowSize -> WindowSize -> Bounded WindowSize
forall a. a -> a -> Bounded a
maxBound :: WindowSize
$cmaxBound :: WindowSize
minBound :: WindowSize
$cminBound :: WindowSize
Bounded)

data RateLimitMode
    = RollingWindow !DiffTime
    | FixedWindow !WindowSize
    deriving (Int -> RateLimitMode -> ShowS
[RateLimitMode] -> ShowS
RateLimitMode -> String
(Int -> RateLimitMode -> ShowS)
-> (RateLimitMode -> String)
-> ([RateLimitMode] -> ShowS)
-> Show RateLimitMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RateLimitMode] -> ShowS
$cshowList :: [RateLimitMode] -> ShowS
show :: RateLimitMode -> String
$cshow :: RateLimitMode -> String
showsPrec :: Int -> RateLimitMode -> ShowS
$cshowsPrec :: Int -> RateLimitMode -> ShowS
Show, RateLimitMode -> RateLimitMode -> Bool
(RateLimitMode -> RateLimitMode -> Bool)
-> (RateLimitMode -> RateLimitMode -> Bool) -> Eq RateLimitMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RateLimitMode -> RateLimitMode -> Bool
$c/= :: RateLimitMode -> RateLimitMode -> Bool
== :: RateLimitMode -> RateLimitMode -> Bool
$c== :: RateLimitMode -> RateLimitMode -> Bool
Eq)

data RateLimitConfig
    = RateLimitConfig
    { RateLimitConfig -> RateLimitMode
rlc_mode :: !RateLimitMode
    , RateLimitConfig -> Int
rlc_maximum :: !Int
    }

data Entry k
  = Entry
  { Entry k -> UTCTime
eTime :: !UTCTime
  , Entry k -> k
eKey :: !k
  }

data RateLimiterImpl k
    = RateLimiterImpl
    { RateLimiterImpl k -> RateLimitConfig
rl_config :: !RateLimitConfig
    , RateLimiterImpl k -> IORef (Seq (Entry k))
rl_state :: !(IORef (Seq.Seq (Entry k)))
    }

newtype RateLimiter k
    = RateLimiter { RateLimiter k -> Vector (RateLimiterImpl k)
_unRateLimiter :: V.Vector (RateLimiterImpl k) }

-- | Create a new rate limiter with a list of configurations
newRateLimiter :: V.Vector RateLimitConfig -> IO (RateLimiter k)
newRateLimiter :: Vector RateLimitConfig -> IO (RateLimiter k)
newRateLimiter Vector RateLimitConfig
cfgs =
    (Vector (RateLimiterImpl k) -> RateLimiter k)
-> IO (Vector (RateLimiterImpl k)) -> IO (RateLimiter k)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Vector (RateLimiterImpl k) -> RateLimiter k
forall k. Vector (RateLimiterImpl k) -> RateLimiter k
RateLimiter (IO (Vector (RateLimiterImpl k)) -> IO (RateLimiter k))
-> IO (Vector (RateLimiterImpl k)) -> IO (RateLimiter k)
forall a b. (a -> b) -> a -> b
$ Vector RateLimitConfig
-> (RateLimitConfig -> IO (RateLimiterImpl k))
-> IO (Vector (RateLimiterImpl k))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Vector RateLimitConfig
cfgs ((RateLimitConfig -> IO (RateLimiterImpl k))
 -> IO (Vector (RateLimiterImpl k)))
-> (RateLimitConfig -> IO (RateLimiterImpl k))
-> IO (Vector (RateLimiterImpl k))
forall a b. (a -> b) -> a -> b
$ \RateLimitConfig
cfg ->
    do IORef (Seq (Entry k))
ref <- Seq (Entry k) -> IO (IORef (Seq (Entry k)))
forall a. a -> IO (IORef a)
newIORef Seq (Entry k)
forall a. Monoid a => a
mempty
       RateLimiterImpl k -> IO (RateLimiterImpl k)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RateLimiterImpl k -> IO (RateLimiterImpl k))
-> RateLimiterImpl k -> IO (RateLimiterImpl k)
forall a b. (a -> b) -> a -> b
$ RateLimiterImpl :: forall k.
RateLimitConfig -> IORef (Seq (Entry k)) -> RateLimiterImpl k
RateLimiterImpl { rl_config :: RateLimitConfig
rl_config = RateLimitConfig
cfg, rl_state :: IORef (Seq (Entry k))
rl_state = IORef (Seq (Entry k))
ref }

-- | Check if a given key is rate limited. Use `()` if you don't need multiple keys
isRateLimited :: (Eq k, MonadIO m) => k -> RateLimiter k -> m Bool
isRateLimited :: k -> RateLimiter k -> m Bool
isRateLimited k
key (RateLimiter Vector (RateLimiterImpl k)
rls) =
    (RateLimiterImpl k -> m Bool) -> [RateLimiterImpl k] -> m Bool
forall (m :: * -> *) a. Monad m => (a -> m Bool) -> [a] -> m Bool
anyM (k -> RateLimiterImpl k -> m Bool
forall k (m :: * -> *).
(Eq k, MonadIO m) =>
k -> RateLimiterImpl k -> m Bool
isRateLimitedImpl k
key) ([RateLimiterImpl k] -> m Bool) -> [RateLimiterImpl k] -> m Bool
forall a b. (a -> b) -> a -> b
$ Vector (RateLimiterImpl k) -> [RateLimiterImpl k]
forall a. Vector a -> [a]
V.toList Vector (RateLimiterImpl k)
rls

isRateLimitedImpl :: Eq k => MonadIO m => k -> RateLimiterImpl k -> m Bool
isRateLimitedImpl :: k -> RateLimiterImpl k -> m Bool
isRateLimitedImpl k
key RateLimiterImpl k
rl =
    do UTCTime
now <- IO UTCTime -> m UTCTime
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
       (Int
_, Bool
isLimited) <-
           IO (Int, Bool) -> m (Int, Bool)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Int, Bool) -> m (Int, Bool))
-> IO (Int, Bool) -> m (Int, Bool)
forall a b. (a -> b) -> a -> b
$
           IORef (Seq (Entry k))
-> (Seq (Entry k) -> (Seq (Entry k), (Int, Bool)))
-> IO (Int, Bool)
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' (RateLimiterImpl k -> IORef (Seq (Entry k))
forall k. RateLimiterImpl k -> IORef (Seq (Entry k))
rl_state RateLimiterImpl k
rl) ((Seq (Entry k) -> (Seq (Entry k), (Int, Bool))) -> IO (Int, Bool))
-> (Seq (Entry k) -> (Seq (Entry k), (Int, Bool)))
-> IO (Int, Bool)
forall a b. (a -> b) -> a -> b
$ \Seq (Entry k)
xs ->
           do let (Seq (Entry k)
xs', Int
bucketSize) =
                      UTCTime
-> k -> RateLimitMode -> Seq (Entry k) -> (Seq (Entry k), Int)
forall k.
Eq k =>
UTCTime
-> k -> RateLimitMode -> Seq (Entry k) -> (Seq (Entry k), Int)
currentBucketSize UTCTime
now k
key (RateLimitConfig -> RateLimitMode
rlc_mode (RateLimitConfig -> RateLimitMode)
-> RateLimitConfig -> RateLimitMode
forall a b. (a -> b) -> a -> b
$ RateLimiterImpl k -> RateLimitConfig
forall k. RateLimiterImpl k -> RateLimitConfig
rl_config RateLimiterImpl k
rl) Seq (Entry k)
xs
              if Int
bucketSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< RateLimitConfig -> Int
rlc_maximum (RateLimiterImpl k -> RateLimitConfig
forall k. RateLimiterImpl k -> RateLimitConfig
rl_config RateLimiterImpl k
rl)
                 then (Seq (Entry k)
xs' Seq (Entry k) -> Entry k -> Seq (Entry k)
forall a. Seq a -> a -> Seq a
Seq.|> UTCTime -> k -> Entry k
forall k. UTCTime -> k -> Entry k
Entry UTCTime
now k
key, (Int
bucketSize, Bool
False))
                 else (Seq (Entry k)
xs', (Int
bucketSize, Bool
True))
       Bool -> m Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
isLimited

-- | Retry action if rate limited after 1 second
withRetryRateLimiter :: Eq k => MonadIO m => k -> RateLimiter k -> m a -> m a
withRetryRateLimiter :: k -> RateLimiter k -> m a -> m a
withRetryRateLimiter k
key RateLimiter k
rl m a
action =
    do Bool
limited <- k -> RateLimiter k -> m Bool
forall k (m :: * -> *).
(Eq k, MonadIO m) =>
k -> RateLimiter k -> m Bool
isRateLimited k
key RateLimiter k
rl
       if Bool -> Bool
not Bool
limited
          then m a
action
          else do IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ TimeSpan -> IO ()
sleepTS (Double -> TimeSpan
seconds Double
1)
                  k -> RateLimiter k -> m a -> m a
forall k (m :: * -> *) a.
(Eq k, MonadIO m) =>
k -> RateLimiter k -> m a -> m a
withRetryRateLimiter k
key RateLimiter k
rl m a
action

currentBucketSize ::
    Eq k => UTCTime -> k -> RateLimitMode -> Seq.Seq (Entry k) -> (Seq.Seq (Entry k), Int)
currentBucketSize :: UTCTime
-> k -> RateLimitMode -> Seq (Entry k) -> (Seq (Entry k), Int)
currentBucketSize UTCTime
now k
key RateLimitMode
rlm Seq (Entry k)
times =
    let timeOfDay :: TimeOfDay
timeOfDay =
          DiffTime -> TimeOfDay
timeToTimeOfDay (UTCTime -> DiffTime
utctDayTime UTCTime
now)
        takeUntil :: UTCTime
takeUntil =
            case RateLimitMode
rlm of
              RollingWindow DiffTime
dt ->
                  NominalDiffTime -> UTCTime -> UTCTime
addUTCTime (Rational -> NominalDiffTime
forall a. Fractional a => Rational -> a
fromRational (Rational -> NominalDiffTime)
-> (DiffTime -> Rational) -> DiffTime -> NominalDiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> Rational
forall a. Real a => a -> Rational
toRational (DiffTime -> NominalDiffTime) -> DiffTime -> NominalDiffTime
forall a b. (a -> b) -> a -> b
$ (-DiffTime
1) DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
* DiffTime
dt) UTCTime
now
              FixedWindow WindowSize
ws ->
                  case WindowSize
ws of
                    WindowSize
WsMinute ->
                        UTCTime
now {utctDayTime :: DiffTime
utctDayTime = TimeOfDay -> DiffTime
timeOfDayToTime (TimeOfDay -> DiffTime) -> TimeOfDay -> DiffTime
forall a b. (a -> b) -> a -> b
$ TimeOfDay
timeOfDay { todSec :: Pico
todSec = Pico
0}}
                    WindowSize
WsHour ->
                        UTCTime
now {utctDayTime :: DiffTime
utctDayTime = TimeOfDay -> DiffTime
timeOfDayToTime (TimeOfDay -> DiffTime) -> TimeOfDay -> DiffTime
forall a b. (a -> b) -> a -> b
$ TimeOfDay
timeOfDay { todMin :: Int
todMin = Int
0, todSec :: Pico
todSec = Pico
0}}
                    WindowSize
WsSecond ->
                        UTCTime
now { utctDayTime :: DiffTime
utctDayTime = Integer -> DiffTime
forall a. Num a => Integer -> a
fromInteger (Integer -> DiffTime) -> Integer -> DiffTime
forall a b. (a -> b) -> a -> b
$ DiffTime -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
truncate (UTCTime -> DiffTime
utctDayTime UTCTime
now) }
                    WindowSize
WsDay ->
                        UTCTime
now { utctDayTime :: DiffTime
utctDayTime = DiffTime
0 }
        newSeq :: Seq (Entry k)
newSeq = (Entry k -> Bool) -> Seq (Entry k) -> Seq (Entry k)
forall a. (a -> Bool) -> Seq a -> Seq a
Seq.takeWhileR (\Entry k
el -> Entry k -> UTCTime
forall k. Entry k -> UTCTime
eTime Entry k
el UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
>= UTCTime
takeUntil) Seq (Entry k)
times
    in ( Seq (Entry k)
newSeq
       , Seq (Entry k) -> Int
forall a. Seq a -> Int
Seq.length (Seq (Entry k) -> Int) -> Seq (Entry k) -> Int
forall a b. (a -> b) -> a -> b
$ (Entry k -> Bool) -> Seq (Entry k) -> Seq (Entry k)
forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter (\Entry k
el -> Entry k -> k
forall k. Entry k -> k
eKey Entry k
el k -> k -> Bool
forall a. Eq a => a -> a -> Bool
== k
key) Seq (Entry k)
newSeq
       )