--------------------------------------------------------------------------------
-- Rate Limiting Middleware for WAI                                           --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

module Network.Wai.RateLimit.Strategy (
    Strategy(..),
    fixedWindow,
    slidingWindow
) where

--------------------------------------------------------------------------------

import Control.Monad

import Network.Wai
import Network.Wai.RateLimit.Backend

--------------------------------------------------------------------------------

-- | Represents rate limiting strategies.
data Strategy = MkStrategy {
    -- | 'strategyOnRequest' @request@ is a computation which determines
    -- whether the request should be allowed or not, based on the rate
    -- limiting strategy.
    Strategy -> Request -> IO Bool
strategyOnRequest :: Request -> IO Bool
}


-- | 'windowStrategy'
windowStrategy :: Backend key err 
               -> Integer 
               -> Integer 
               -> (Request -> IO key) 
               -> (Integer -> Bool)
               -> Request
               -> IO Bool
windowStrategy :: Backend key err
-> Integer
-> Integer
-> (Request -> IO key)
-> (Integer -> Bool)
-> Request
-> IO Bool
windowStrategy MkBackend{key -> IO (Either err Integer)
key -> Integer -> IO (Either err Integer)
key -> Integer -> IO (Either err ())
backendExpireIn :: forall key err.
Backend key err -> key -> Integer -> IO (Either err ())
backendIncAndGetUsage :: forall key err.
Backend key err -> key -> Integer -> IO (Either err Integer)
backendGetUsage :: forall key err. Backend key err -> key -> IO (Either err Integer)
backendExpireIn :: key -> Integer -> IO (Either err ())
backendIncAndGetUsage :: key -> Integer -> IO (Either err Integer)
backendGetUsage :: key -> IO (Either err Integer)
..} Integer
seconds Integer
capacity Request -> IO key
getKey Integer -> Bool
cond Request
req = do
    -- get a key to identify the usage bucket for the request: this is
    -- up the application and may be comprised of e.g. the IP of the client
    -- or a unique user id, followed by e.g. a timestamp
    key
key <- Request -> IO key
getKey Request
req 

    -- get usage for the key and increment it by 1
    Either err Integer
result <- key -> Integer -> IO (Either err Integer)
backendIncAndGetUsage key
key Integer
1

    case Either err Integer
result of 
        -- a backend error occurred, deny the request
        Left err
err -> Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False 
        -- we got back the current usage: check whether it is within the
        -- acceptable limit and, if so, add to the expiry timer
        Right Integer
used 
            | Integer
used Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
capacity -> do
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Integer -> Bool
cond Integer
used) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO (Either err ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either err ()) -> IO ()) -> IO (Either err ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ key -> Integer -> IO (Either err ())
backendExpireIn key
key Integer
seconds
                Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True 
            | Bool
otherwise -> Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

-- | 'fixedWindow' @seconds limit@ is a 'Strategy' which limits the number
-- of requests made by a client to @limit@ within a window of @seconds@. 
fixedWindow :: Backend key err 
            -> Integer 
            -> Integer 
            -> (Request -> IO key) 
            -> Strategy
fixedWindow :: Backend key err
-> Integer -> Integer -> (Request -> IO key) -> Strategy
fixedWindow Backend key err
backend Integer
seconds Integer
capacity Request -> IO key
getKey = MkStrategy :: (Request -> IO Bool) -> Strategy
MkStrategy{
    strategyOnRequest :: Request -> IO Bool
strategyOnRequest = 
        let cond :: Integer -> Bool
cond = Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
(==) Integer
1
        in Backend key err
-> Integer
-> Integer
-> (Request -> IO key)
-> (Integer -> Bool)
-> Request
-> IO Bool
forall key err.
Backend key err
-> Integer
-> Integer
-> (Request -> IO key)
-> (Integer -> Bool)
-> Request
-> IO Bool
windowStrategy Backend key err
backend Integer
seconds Integer
capacity Request -> IO key
getKey Integer -> Bool
cond
}

-- | 'slidingWindow' @seconds limit@ is a 'Strategy' which limits the number
-- of requests made by a client to @limit@ within a sliding window of 
-- @seconds@. That is, for every successful request, the window is extended by
-- @seconds@ so that a "break" of @seconds@ is required after @limit@-many 
-- requests have been made in a period during which the timeout has never
-- been exceeded.
slidingWindow :: Backend key err 
              -> Integer 
              -> Integer 
              -> (Request -> IO key) 
              -> Strategy
slidingWindow :: Backend key err
-> Integer -> Integer -> (Request -> IO key) -> Strategy
slidingWindow Backend key err
backend Integer
seconds Integer
capacity Request -> IO key
getKey = MkStrategy :: (Request -> IO Bool) -> Strategy
MkStrategy{
    strategyOnRequest :: Request -> IO Bool
strategyOnRequest = 
        let cond :: b -> Bool
cond = Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True
        in Backend key err
-> Integer
-> Integer
-> (Request -> IO key)
-> (Integer -> Bool)
-> Request
-> IO Bool
forall key err.
Backend key err
-> Integer
-> Integer
-> (Request -> IO key)
-> (Integer -> Bool)
-> Request
-> IO Bool
windowStrategy Backend key err
backend Integer
seconds Integer
capacity Request -> IO key
getKey Integer -> Bool
forall b. b -> Bool
cond
}

--------------------------------------------------------------------------------