{-# LANGUAGE OverloadedStrings #-} {- | Wai middleware for request throttling. Basic idea: on every (matching) request a counter is incremented. If it exceeds given limit, request is blocked and error response is sent to client. Request counter resets after defined period of time. The `throttle' function limits request to the underlying application. If you wish to limit only parts of your requests you need to do the routing yourself. For convenience, `throttlePath' function is provided which applies throttling only for requests with matching URL path. -} module Network.Wai.Middleware.Throttler ( ThrottleCache(..) , newMemoryThrottleCache , throttle , throttlePath ) where import Control.Concurrent.MVar (MVar, newMVar, takeMVar, putMVar, readMVar) import Data.Map (Map) import qualified Data.Map as Map (empty, insert, lookup) import Data.Maybe (fromMaybe) import Data.ByteString as B (ByteString, take, length) import Data.ByteString.Char8 as B8 (pack) import Data.Time.Clock (getCurrentTime, UTCTime, NominalDiffTime) import Data.Time.Clock.POSIX (posixSecondsToUTCTime, utcTimeToPOSIXSeconds) import Network.Wai (Request, Response, Middleware, rawPathInfo, responseLBS) import Network.HTTP.Types.Status (tooManyRequests429) -- | Cache type class. Throttle cache is used to store request counts. -- Can store multiple counts via different keys. E.g. keys can be client -- IP addresses or user logins. class ThrottleCache cache where -- | Increment count for given key and return new count value. -- Cache should automatically reset counts to zero after a defined period. cacheCount :: cache -- ^ cache -> ByteString -- ^ key -> IO Int data MemoryThrottleCache = MemoryThrottleCache Int NominalDiffTime (MVar (Map ByteString (UTCTime, Int))) -- | Create in-memory throttle cache. -- -- Normally throttle cache does not need to know what the limit -- is. But this one uses some trickery to prevent unnecessary -- calls to slow getCurrentTime function. newMemoryThrottleCache :: Int -- ^ limit -> NominalDiffTime -- ^ limit renew period -> IO MemoryThrottleCache newMemoryThrottleCache limit period = fmap (MemoryThrottleCache limit period) $ newMVar Map.empty instance ThrottleCache MemoryThrottleCache where cacheCount (MemoryThrottleCache limit period v) key = do map <- takeMVar v (t', c') <- case Map.lookup key map of Nothing -> do now <- getCurrentTime return $ (alignTime now period, 1) Just (t, c) -> do case c + 1 > limit of True -> do now <- getCurrentTime let alignedNow = alignTime now period case alignedNow == t of True -> return $ (t, c + 1) False -> return $ (alignedNow, 1) False -> return $ (t, c + 1) putMVar v $ Map.insert key (t', c') map return c' where alignTime :: UTCTime -> NominalDiffTime -> UTCTime alignTime time period = posixSecondsToUTCTime . fromIntegral $ a - (a `mod` b) where a = floor . utcTimeToPOSIXSeconds $ time b = floor period -- | Apply throttling to requests with matching URL path throttlePath :: ThrottleCache cache => ByteString -- ^ URL path to match -> cache -- ^ cache to store request counts -> Int -- ^ request limit -> (Request -> Maybe ByteString) -- ^ function to get cache key based on request. If Nothing is returned, request is not throttled -> Middleware throttlePath path cache limit getKey app req = do case pathMatches path req of False -> app req True -> throttle cache limit getKey app req where pathMatches :: ByteString -> Request -> Bool pathMatches path request = (rawPathInfo request) == path -- | Wai middleware that cuts requests if request rate is higher than defined level. -- Responds with 429 if limit exceeded throttle :: ThrottleCache cache => cache -- ^ cache to store request counts -> Int -- ^ request limit -> (Request -> Maybe ByteString) -- ^ function to get cache key based on request. If Nothing is returned, request is not throttled -> Middleware throttle cache limit getKey app req = do case getKey req of Nothing -> app req Just key -> do count <- cacheCount cache key if count > limit then return throttledResponse else app req where throttledResponse = responseLBS tooManyRequests429 [] ""