--------------------------------------------------------------------------------
-- Rate Limiting Middleware for WAI                                           --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

-- | Exports functions which implement rate-limiting strategies.
module Network.Wai.RateLimit.Strategy (
    Strategy(..),
    fixedWindow,
    slidingWindow
) where

--------------------------------------------------------------------------------

import Control.Monad

import Data.Time.Units

import Network.Wai
import Network.Wai.RateLimit.Backend

--------------------------------------------------------------------------------

-- | Represents rate limiting strategies.
newtype Strategy = MkStrategy {
    -- | 'strategyOnRequest' @request@ is a computation which determines
    -- whether the request should be allowed or not, based on the rate
    -- limiting strategy.
    Strategy -> Request -> IO Bool
strategyOnRequest :: Request -> IO Bool
}

-- | `windowStrategy` implements a general window-based rate limiting strategy.
windowStrategy
    :: Backend key -- ^ The storage backend to use.
    -> Second -- ^ The number of seconds after which recorded usage expires.
    -> Integer -- ^ How much capacity each key should have.
    -> (Request -> IO key) -- ^ A function which computes a key for the
                           -- request.
    -> (Integer -> Bool) -- ^ A predicate which determines whether the expiry
                         -- timer should be reset.
    -> Request -- ^ The request to apply the stragey to, used for deriving
               -- the key.
    -> IO Bool
windowStrategy :: Backend key
-> Second
-> Integer
-> (Request -> IO key)
-> (Integer -> Bool)
-> Request
-> IO Bool
windowStrategy MkBackend{key -> IO Integer
key -> Integer -> IO Integer
key -> Integer -> IO ()
backendExpireIn :: forall key. Backend key -> key -> Integer -> IO ()
backendIncAndGetUsage :: forall key. Backend key -> key -> Integer -> IO Integer
backendGetUsage :: forall key. Backend key -> key -> IO Integer
backendExpireIn :: key -> Integer -> IO ()
backendIncAndGetUsage :: key -> Integer -> IO Integer
backendGetUsage :: key -> IO Integer
..} Second
seconds Integer
capacity Request -> IO key
getKey Integer -> Bool
cond Request
req = do
    -- get a key to identify the usage bucket for the request: this is
    -- up the application and may be comprised of e.g. the IP of the client
    -- or a unique user id, followed by e.g. a timestamp
    key
key <- Request -> IO key
getKey Request
req

    -- get usage for the key and increment it by 1
    Integer
used <- key -> Integer -> IO Integer
backendIncAndGetUsage key
key Integer
1

    -- we got back the current usage: check whether it is within the
    -- acceptable limit and, if so, add to the expiry timer
    if Integer
used Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
capacity
    then do
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Integer -> Bool
cond Integer
used) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ key -> Integer -> IO ()
backendExpireIn key
key (Second -> Integer
forall a. Integral a => a -> Integer
toInteger Second
seconds)
        Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
    else Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

-- | 'fixedWindow' @seconds limit@ is a 'Strategy' which limits the number
-- of requests made by a client to @limit@ within a window of @seconds@.
fixedWindow
    :: Backend key
    -> Second
    -> Integer
    -> (Request -> IO key)
    -> Strategy
fixedWindow :: Backend key -> Second -> Integer -> (Request -> IO key) -> Strategy
fixedWindow Backend key
backend Second
seconds Integer
capacity Request -> IO key
getKey = MkStrategy :: (Request -> IO Bool) -> Strategy
MkStrategy{
    strategyOnRequest :: Request -> IO Bool
strategyOnRequest =
        let cond :: Integer -> Bool
cond = Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
(==) Integer
1
        in Backend key
-> Second
-> Integer
-> (Request -> IO key)
-> (Integer -> Bool)
-> Request
-> IO Bool
forall key.
Backend key
-> Second
-> Integer
-> (Request -> IO key)
-> (Integer -> Bool)
-> Request
-> IO Bool
windowStrategy Backend key
backend Second
seconds Integer
capacity Request -> IO key
getKey Integer -> Bool
cond
}

-- | 'slidingWindow' @seconds limit@ is a 'Strategy' which limits the number
-- of requests made by a client to @limit@ within a sliding window of
-- @seconds@. That is, for every successful request, the window is extended by
-- @seconds@ so that a "break" of @seconds@ is required after @limit@-many
-- requests have been made in a period during which the timeout has never
-- been exceeded.
slidingWindow
    :: Backend key
    -> Second
    -> Integer
    -> (Request -> IO key)
    -> Strategy
slidingWindow :: Backend key -> Second -> Integer -> (Request -> IO key) -> Strategy
slidingWindow Backend key
backend Second
seconds Integer
capacity Request -> IO key
getKey = MkStrategy :: (Request -> IO Bool) -> Strategy
MkStrategy{
    strategyOnRequest :: Request -> IO Bool
strategyOnRequest =
        let cond :: b -> Bool
cond = Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True
        in Backend key
-> Second
-> Integer
-> (Request -> IO key)
-> (Integer -> Bool)
-> Request
-> IO Bool
forall key.
Backend key
-> Second
-> Integer
-> (Request -> IO key)
-> (Integer -> Bool)
-> Request
-> IO Bool
windowStrategy Backend key
backend Second
seconds Integer
capacity Request -> IO key
getKey Integer -> Bool
forall b. b -> Bool
cond
}

--------------------------------------------------------------------------------