module Network.Wai.Middleware.Throttler
( ThrottleCache(..)
, newMemoryThrottleCache
, throttle
, throttlePath
) where
import Control.Concurrent.MVar (MVar, newMVar, takeMVar, putMVar, readMVar)
import Data.Map (Map)
import qualified Data.Map as Map (empty, insert, lookup)
import Data.Maybe (fromMaybe)
import Data.ByteString as B (ByteString, take, length)
import Data.ByteString.Char8 as B8 (pack)
import Data.Time.Clock (getCurrentTime, UTCTime, NominalDiffTime)
import Data.Time.Clock.POSIX (posixSecondsToUTCTime, utcTimeToPOSIXSeconds)
import Network.Wai (Request, Response, Middleware, rawPathInfo, responseLBS)
import Network.HTTP.Types.Status (status403)
class ThrottleCache cache where
cacheCount :: cache
-> ByteString
-> IO Int
data MemoryThrottleCache = MemoryThrottleCache Int NominalDiffTime (MVar (Map ByteString (UTCTime, Int)))
newMemoryThrottleCache :: Int
-> NominalDiffTime
-> IO MemoryThrottleCache
newMemoryThrottleCache limit period = fmap (MemoryThrottleCache limit period) $ newMVar Map.empty
instance ThrottleCache MemoryThrottleCache where
cacheCount (MemoryThrottleCache limit period v) key = do
map <- takeMVar v
(t', c') <- case Map.lookup key map of
Nothing -> do
now <- getCurrentTime
return $ (alignTime now period, 1)
Just (t, c) -> do
case c + 1 > limit of
True -> do
now <- getCurrentTime
let alignedNow = alignTime now period
case alignedNow == t of
True -> return $ (t, c + 1)
False -> return $ (alignedNow, 1)
False -> return $ (t, c + 1)
putMVar v $ Map.insert key (t', c') map
return c'
where alignTime :: UTCTime -> NominalDiffTime -> UTCTime
alignTime time period = posixSecondsToUTCTime . fromIntegral $ a (a `mod` b)
where a = floor . utcTimeToPOSIXSeconds $ time
b = floor period
throttlePath :: ThrottleCache cache
=> ByteString
-> cache
-> Int
-> (Request -> Maybe ByteString)
-> Middleware
throttlePath path cache limit getKey app req = do
case pathMatches path req of
False -> app req
True -> throttle cache limit getKey app req
where pathMatches :: ByteString -> Request -> Bool
pathMatches path request = (rawPathInfo request) == path
throttle :: ThrottleCache cache
=> cache
-> Int
-> (Request -> Maybe ByteString)
-> Middleware
throttle cache limit getKey app req = do
case getKey req of
Nothing -> app req
Just key -> do
count <- cacheCount cache key
if count > limit then return throttledResponse else app req
where throttledResponse = responseLBS status403 [] ""