----------------------------------------------------------------------------- -- | -- Module : Network.Wai.Middleware.Throttle -- Description : WAI Request Throttling Middleware -- Copyright : (c) 2015 Christopher Reichert -- License : BSD3 -- Maintainer : Christopher Reichert <creichert07@gmail.com> -- Stability : experimental -- Portability : POSIX -- -- Uses a <https://en.wikipedia.org/wiki/Token_bucket Token Bucket> -- algorithm (from the token-bucket package) to throttle WAI Requests. -- -- -- == Example -- -- @ -- main = do -- st <- initThrottler -- let payload = "{ \"api\", \"return data\" }" -- app = throttle defaultThrottleSettings st -- $ \_ f -> f (responseLBS status200 [] payload) -- Warp.run 3000 app -- @ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} module Network.Wai.Middleware.Throttle ( -- | Wai Request Throttling Middleware throttle -- | Wai Throttle middleware state. -- -- Essentially, a TVar with a HashMap for indexing -- remote IP address , WaiThrottle , initThrottler -- | Throttle settings and configuration , ThrottleSettings(..) , defaultThrottleSettings ) where import Control.Applicative import Control.Concurrent.STM import Control.Concurrent.TokenBucket import Control.Monad (liftM) import qualified Data.ByteString.Char8 as BS import qualified Data.HashMap.Strict as H import qualified Data.Text as T import GHC.Word (Word64) import qualified Network.HTTP.Types.Status as Http import Network.Socket import Network.Wai newtype WaiThrottle = WT (TVar ThrottleState) -- | A 'HashMap' mapping the remote IP address to a 'TokenBucket' data ThrottleState = ThrottleState !(H.HashMap T.Text TokenBucket) -- | Settings which control various behaviors in the middleware. data ThrottleSettings = ThrottleSettings { -- | Determines whether the 'Request' is throttled isThrottled :: !(Request -> IO Bool) -- | Function to run when the request is throttled. -- -- The first argument is a 'Word64' containing the amount -- of microseconds until the next retry should be attempted , onThrottled :: !(Word64 -> Response) -- Zone name -- -- TODO use list of zones (rules) -- , throttleZone :: !T.Text -- | Rate , throttleRate :: !Integer -- requests / second -- Maximum size of the address cache in MB (similar to nginx) -- -- You can store approximately 160,000 addresses in 10MB with -- \$binary_remote_addr. -- , throttleCacheSize :: !Integer -- | Burst rate , throttleBurst :: !Integer } initThrottler :: IO WaiThrottle initThrottler = liftM WT $ newTVarIO $ ThrottleState H.empty -- | Default settings to throttle requests. defaultThrottleSettings :: ThrottleSettings defaultThrottleSettings = ThrottleSettings { isThrottled = return . const True -- , throttleZone = "" -- empty zone , throttleRate = 1 -- req / sec -- , throttleCacheSize = 10 -- 10M address cache , throttleBurst = 1 -- 5 concurrent requests , onThrottled = onThrottled' } where bshow = BS.pack . show -- remaining = bshow (if 5000 - c < 0 -- then 0 -- else 5000 - c) onThrottled' rt = responseLBS Http.status429 [ ("Content-Type", "application/json") -- , ("X-RateLimit-Limit", "5000") -- , ("X-RateLimit-Remaining", remaining) , ("X-RateLimit-Reset", bshow (fromIntegral rt / 1000000.0 :: Double)) ] -- match YesodAuth error message renderer "{\"message\":\"Too many requests.\"}" -- | WAI Request Throttling Middleware. -- -- Uses a 'Request's 'remoteHost' function to resolve the -- remote IP address. throttle :: ThrottleSettings -> WaiThrottle -> Application -> Application throttle ThrottleSettings{..} (WT tmap) app req respond = do -- determine whether the request needs throttling reqIsThrottled <- isThrottled req -- seconds remaining (if the request failed), 0 otherwise. remaining <- if reqIsThrottled then throttleReq else return 0 if remaining /= 0 then respond $ onThrottled remaining else app req respond where throttleReq = do let SockAddrInet _ host = remoteHost req remoteAddr <- T.pack <$> inet_ntoa host throttleState <- atomically $ readTVar tmap (tst, success) <- throttleReq' remoteAddr throttleState -- write the throttle state back atomically $ writeTVar tmap (ThrottleState tst) return success throttleReq' remoteAddr (ThrottleState m) = do let toInvRate r = round (1e6 / r) invRate = toInvRate (fromInteger throttleRate :: Double) burst = fromInteger throttleBurst bucket <- maybe newTokenBucket return $ H.lookup remoteAddr m remaining <- tokenBucketTryAlloc1 bucket burst invRate return (H.insert remoteAddr bucket m, remaining)