module Network.Wai.Middleware.Throttle (
throttle
, WaiThrottle
, initThrottler
, ThrottleSettings(..)
, defaultThrottleSettings
) where
import Control.Applicative ((<$>))
import Control.Concurrent.STM
import Control.Concurrent.TokenBucket
import Control.Monad (join, liftM)
import Data.Function (on)
import Data.Hashable (Hashable, hash, hashWithSalt)
import qualified Data.IntMap as IM
import Data.List (unionBy)
import GHC.Word (Word64)
import qualified Network.HTTP.Types.Status as Http
import Network.Socket
import Network.Wai
#ifndef MIN_VERSION_network
#define MIN_VERSION_network(a,b,c) 1
#endif
newtype WaiThrottle = WT (TVar ThrottleState)
newtype Address = Address SockAddr
instance Hashable Address where
hashWithSalt s (Address (SockAddrInet _ a)) = hashWithSalt s a
hashWithSalt s (Address (SockAddrInet6 _ _ a _)) = hashWithSalt s a
hashWithSalt s (Address (SockAddrUnix a)) = hashWithSalt s a
#if MIN_VERSION_network(2,6,1)
hashWithSalt s (Address (SockAddrCan a)) = hashWithSalt s a
#endif
instance Eq Address where
Address (SockAddrInet _ a) == Address (SockAddrInet _ b) = a == b
Address (SockAddrInet6 _ _ a _) == Address (SockAddrInet6 _ _ b _) = a == b
Address (SockAddrUnix a) == Address (SockAddrUnix b) = a == b
#if MIN_VERSION_network(2,6,1)
Address (SockAddrCan a) == Address (SockAddrCan b) = a == b
#endif
_ == _ = False
instance Ord Address where
Address (SockAddrInet _ a) <= Address (SockAddrInet _ b) = a <= b
Address (SockAddrInet6 _ _ a _) <= Address (SockAddrInet6 _ _ b _) = a <= b
Address (SockAddrUnix a) <= Address (SockAddrUnix b) = a <= b
#if MIN_VERSION_network(2,6,1)
Address (SockAddrCan a) <= Address (SockAddrCan b) = a <= b
#endif
Address a <= Address b = a <= b
data ThrottleState = ThrottleState !(IM.IntMap [(Address,TokenBucket)])
data ThrottleSettings = ThrottleSettings
{
isThrottled :: !(Request -> IO Bool)
, onThrottled :: !(Word64 -> Response)
, throttleRate :: !Integer
, throttlePeriod :: !Integer
, throttleBurst :: !Integer
}
initThrottler :: IO WaiThrottle
initThrottler = liftM WT $ newTVarIO $ ThrottleState IM.empty
defaultThrottleSettings :: ThrottleSettings
defaultThrottleSettings
= ThrottleSettings {
isThrottled = return . const True
, throttleRate = 1
, throttlePeriod = 10^6
, throttleBurst = 1
, onThrottled = onThrottled'
}
where
onThrottled' _ =
responseLBS
Http.status429
[ ("Content-Type", "application/json")
]
"{\"message\":\"Too many requests.\"}"
throttle :: ThrottleSettings
-> WaiThrottle
-> Application
-> Application
throttle ThrottleSettings{..} (WT tmap) app req respond = do
reqIsThrottled <- isThrottled req
remaining <- if reqIsThrottled
then throttleReq
else return 0
if remaining /= 0
then respond $ onThrottled remaining
else app req respond
where
throttleReq = do
let remoteAddr = Address . remoteHost $ req
throttleState <- atomically $ readTVar tmap
(tst, success) <- throttleReq' remoteAddr throttleState
atomically $ writeTVar tmap (ThrottleState tst)
return success
throttleReq' remoteAddr (ThrottleState m) = do
let toInvRate r = round (period / r)
period = (fromInteger throttlePeriod :: Double)
invRate = toInvRate (fromInteger throttleRate :: Double)
burst = fromInteger throttleBurst
bucket <- maybe newTokenBucket return $ addressToBucket remoteAddr m
remaining <- tokenBucketTryAlloc1 bucket burst invRate
return (insertBucket remoteAddr bucket m, remaining)
addressToBucket remoteAddr m = join (lookup remoteAddr <$> IM.lookup (hash remoteAddr) m)
insertBucket remoteAddr bucket m =
let col = unionBy ((==) `on` fst)
in IM.insertWith col (hash remoteAddr) [(remoteAddr, bucket)] m