module Network.Wai.Middleware.Throttle (
throttle
, WaiThrottle, CustomWaiThrottle
, initThrottler, initCustomThrottler
, ThrottleSettings(..)
, defaultThrottleSettings
, RequestHashable(..)
) where
import Control.Applicative ((<$>), pure)
import Control.Concurrent.STM
import Control.Concurrent.TokenBucket
import Control.Monad (join, liftM)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Except (ExceptT, runExceptT)
import Data.ByteString.Builder (stringUtf8, toLazyByteString)
import Data.Function (on)
import Data.Hashable (Hashable, hash, hashWithSalt)
import qualified Data.IntMap as IM
import Data.List (unionBy)
import Data.Monoid ((<>))
import Data.Text (Text, unpack)
import GHC.Word (Word64)
import qualified Network.HTTP.Types.Status as Http
import Network.Socket
import Network.Wai
#ifndef MIN_VERSION_network
#define MIN_VERSION_network(a,b,c) 1
#endif
newtype CustomWaiThrottle a = WT (TVar (ThrottleState a))
type WaiThrottle = CustomWaiThrottle Address
newtype Address = Address SockAddr
instance Hashable Address where
hashWithSalt s (Address (SockAddrInet _ a)) = hashWithSalt s a
hashWithSalt s (Address (SockAddrInet6 _ _ a _)) = hashWithSalt s a
hashWithSalt s (Address (SockAddrUnix a)) = hashWithSalt s a
#if MIN_VERSION_network(2,6,1)
hashWithSalt s (Address (SockAddrCan a)) = hashWithSalt s a
#endif
instance Eq Address where
Address (SockAddrInet _ a) == Address (SockAddrInet _ b) = a == b
Address (SockAddrInet6 _ _ a _) == Address (SockAddrInet6 _ _ b _) = a == b
Address (SockAddrUnix a) == Address (SockAddrUnix b) = a == b
#if MIN_VERSION_network(2,6,1)
Address (SockAddrCan a) == Address (SockAddrCan b) = a == b
#endif
_ == _ = False
instance Ord Address where
Address (SockAddrInet _ a) <= Address (SockAddrInet _ b) = a <= b
Address (SockAddrInet6 _ _ a _) <= Address (SockAddrInet6 _ _ b _) = a <= b
Address (SockAddrUnix a) <= Address (SockAddrUnix b) = a <= b
#if MIN_VERSION_network(2,6,1)
Address (SockAddrCan a) <= Address (SockAddrCan b) = a <= b
#endif
Address a <= Address b = a <= b
data ThrottleState a = ThrottleState !(IM.IntMap [(a,TokenBucket)])
data ThrottleSettings = ThrottleSettings
{
isThrottled :: !(Request -> IO Bool)
, onThrottled :: !(Word64 -> Response)
, onRequestError :: !(Text -> Response)
, throttleRate :: !Integer
, throttlePeriod :: !Integer
, throttleBurst :: !Integer
}
initThrottler :: IO WaiThrottle
initThrottler = initCustomThrottler
initCustomThrottler :: IO (CustomWaiThrottle a)
initCustomThrottler = liftM WT $ newTVarIO $ ThrottleState IM.empty
defaultThrottleSettings :: ThrottleSettings
defaultThrottleSettings
= ThrottleSettings {
isThrottled = return . const True
, throttleRate = 1 :: Integer
, throttlePeriod = 1000000 :: Integer
, throttleBurst = 1 :: Integer
, onThrottled = onThrottled'
, onRequestError = onRequestError'
}
where
onThrottled' _ =
responseLBS
Http.status429
[ ("Content-Type", "application/json")
]
"{\"message\":\"Too many requests.\"}"
onRequestError' reason =
responseLBS
Http.status400
[ ("Content-Type", "application/json")
]
("{\"message\":\"" <> toLazyByteString (stringUtf8 $ unpack reason) <> "\"}")
class (Eq a, Ord a, Hashable a) => RequestHashable a where
requestToKey :: (Functor m, Monad m) => Request -> ExceptT Text m a
instance RequestHashable Address where
requestToKey = pure . Address . remoteHost
throttle :: RequestHashable a
=> ThrottleSettings
-> CustomWaiThrottle a
-> Application
-> Application
throttle ThrottleSettings{..} (WT tmap) app req respond = do
reqIsThrottled <- isThrottled req
remaining <- if reqIsThrottled
then runExceptT throttleReq
else return $ Right 0
case remaining of
Left err -> respond $ onRequestError err
Right 0 -> app req respond
Right n -> respond $ onThrottled n
where
throttleReq = do
k <- requestToKey req
throttleState <- liftIO . atomically $ readTVar tmap
(tst, success) <- liftIO $ throttleReq' k throttleState
liftIO . atomically $ writeTVar tmap (ThrottleState tst)
return success
throttleReq' k (ThrottleState m) = do
let toInvRate r = round (period / r)
period = (fromInteger throttlePeriod :: Double)
invRate = toInvRate (fromInteger throttleRate :: Double)
burst = fromInteger throttleBurst
bucket <- maybe newTokenBucket return $ join $ lookup k <$> IM.lookup (hash k) m
remaining <- tokenBucketTryAlloc1 bucket burst invRate
return (insertBucket k bucket m, remaining)
insertBucket k bucket m =
let col = unionBy ((==) `on` fst)
in IM.insertWith col (hash k) [(k, bucket)] m