module Network.Wai.Middleware.Throttle.Internal where
import Prelude hiding (lookup)
import Control.Concurrent.TokenBucket (TokenBucket, newTokenBucket, tokenBucketTryAlloc1)
import Control.Exception.Safe (onException)
#if MIN_VERSION_cache(0,1,1)
import Control.Monad.STM (STM, atomically)
import Data.Cache (Cache, delete, insert, insertSTM, lookupSTM, newCache)
#else
import Data.Cache (Cache, delete, insert, insert', lookup, newCache)
#endif
import Data.Hashable (Hashable, hashWithSalt)
import GHC.Word (Word64)
import Network.HTTP.Types.Status (status429)
import Network.Socket (SockAddr (SockAddrInet, SockAddrInet6, SockAddrUnix))
#if MIN_VERSION_network(2,6,1)
import Network.Socket (SockAddr (SockAddrCan))
#endif
import Network.Wai (Application, Request, Response, remoteHost, responseLBS)
import System.Clock (Clock (Monotonic), TimeSpec, getTime)
newtype Address = Address SockAddr
instance Hashable Address where
hashWithSalt s (Address (SockAddrInet _ a)) = hashWithSalt s a
hashWithSalt s (Address (SockAddrInet6 _ _ a _)) = hashWithSalt s a
hashWithSalt s (Address (SockAddrUnix a)) = hashWithSalt s a
#if MIN_VERSION_network(2,6,1)
hashWithSalt s (Address (SockAddrCan a)) = hashWithSalt s a
#endif
instance Eq Address where
Address (SockAddrInet _ a) == Address (SockAddrInet _ b) = a == b
Address (SockAddrInet6 _ _ a _) == Address (SockAddrInet6 _ _ b _) = a == b
Address (SockAddrUnix a) == Address (SockAddrUnix b) = a == b
#if MIN_VERSION_network(2,6,1)
Address (SockAddrCan a) == Address (SockAddrCan b) = a == b
#endif
_ == _ = False
instance Ord Address where
Address (SockAddrInet _ a) <= Address (SockAddrInet _ b) = a <= b
Address (SockAddrInet6 _ _ a _) <= Address (SockAddrInet6 _ _ b _) = a <= b
Address (SockAddrUnix a) <= Address (SockAddrUnix b) = a <= b
#if MIN_VERSION_network(2,6,1)
Address (SockAddrCan a) <= Address (SockAddrCan b) = a <= b
#endif
Address a <= Address b = a <= b
extractAddress :: Request -> Either Response Address
extractAddress = Right . Address . remoteHost
data CacheState a
= CacheStatePresent a
| CacheStateInitializing
data CacheResult a
= CacheResultExists a
| CacheResultEmpty
data Throttle a = Throttle
{ throttleSettings :: ThrottleSettings
, throttleCache :: Cache a (CacheState TokenBucket)
, throttleGetKey :: Request -> Either Response a
}
data ThrottleSettings = ThrottleSettings
{ throttleSettingsRate :: Double
, throttleSettingsPeriod :: Double
, throttleSettingsBurst :: Word64
, throttleSettingsCacheExpiration :: TimeSpec
, throttleSettingsIsThrottled :: Request -> Bool
, throttleSettingsOnThrottled :: Word64 -> Response
}
defaultThrottleSettings :: TimeSpec -> ThrottleSettings
defaultThrottleSettings expirationInterval = ThrottleSettings
{ throttleSettingsRate = 1
, throttleSettingsPeriod = 1000000
, throttleSettingsBurst = 1
, throttleSettingsCacheExpiration = expirationInterval
, throttleSettingsIsThrottled = const True
, throttleSettingsOnThrottled = const $
responseLBS status429 [("Content-Type", "application/json")] "{\"message\":\"Too many requests.\"}"
}
initThrottler :: ThrottleSettings -> IO (Throttle Address)
initThrottler = flip initCustomThrottler extractAddress
initCustomThrottler :: ThrottleSettings -> (Request -> Either Response a) -> IO (Throttle a)
initCustomThrottler throttleSettings@(ThrottleSettings {..}) throttleGetKey = do
throttleCache <- newCache $ Just throttleSettingsCacheExpiration
pure Throttle {..}
#if MIN_VERSION_cache(0,1,1)
retrieveCache :: (Eq a, Hashable a) => Throttle a -> TimeSpec -> a -> STM (CacheResult TokenBucket)
retrieveCache th time throttleKey = do
let cache = throttleCache th
lookupSTM True throttleKey cache time >>= \ case
Just (CacheStatePresent oldBucket) -> pure $ CacheResultExists oldBucket
Just CacheStateInitializing -> retrieveCache th time throttleKey
Nothing -> do
insertSTM throttleKey CacheStateInitializing cache Nothing
pure CacheResultEmpty
#else
retrieveCache :: (Eq a, Hashable a) => Throttle a -> TimeSpec -> a -> IO (CacheResult TokenBucket)
retrieveCache th time throttleKey = do
let cache = throttleCache th
lookup cache throttleKey >>= \ case
Just (CacheStatePresent oldBucket) -> pure $ CacheResultExists oldBucket
Just CacheStateInitializing -> retrieveCache th time throttleKey
Nothing -> do
insert' cache Nothing throttleKey CacheStateInitializing
pure CacheResultEmpty
#endif
processCacheResult :: (Eq a, Hashable a) => Throttle a -> a -> CacheResult TokenBucket -> IO TokenBucket
processCacheResult th throttleKey cacheResult = case cacheResult of
CacheResultExists bucket -> pure bucket
CacheResultEmpty -> do
let cache = throttleCache th
initializeBucket = do
bucket <- newTokenBucket
insert cache throttleKey (CacheStatePresent bucket)
pure bucket
cleanupBucket = delete cache throttleKey
initializeBucket `onException` cleanupBucket
retrieveOrInitializeBucket :: (Eq a, Hashable a) => Throttle a -> a -> IO TokenBucket
retrieveOrInitializeBucket th throttleKey = do
now <- getTime Monotonic
#if MIN_VERSION_cache(0,1,1)
cacheResult <- atomically $ retrieveCache th now throttleKey
#else
cacheResult <- retrieveCache th now throttleKey
#endif
processCacheResult th throttleKey cacheResult
throttleRequest :: (Eq a, Hashable a) => Throttle a -> a -> IO Word64
throttleRequest th throttleKey = do
bucket <- retrieveOrInitializeBucket th throttleKey
let settings = throttleSettings th
rate = throttleSettingsRate settings
period = throttleSettingsPeriod settings
burst = throttleSettingsBurst settings
tokenBucketTryAlloc1 bucket burst $ round (period / rate)
throttle :: (Eq a, Hashable a) => Throttle a -> Application -> Application
throttle th app req respond = do
let settings = throttleSettings th
getKey = throttleGetKey th
isThrottled = throttleSettingsIsThrottled settings
onThrottled = throttleSettingsOnThrottled settings
case isThrottled req of
False -> app req respond
True -> case getKey req of
Left response -> respond response
Right throttleKey -> do
throttleRequest th throttleKey >>= \ case
0 -> app req respond
n -> respond $ onThrottled n
instance Show (CacheState a) where
show = \ case
CacheStatePresent _ -> "Present"
CacheStateInitializing -> "Initializing"
instance Show (CacheResult a) where
show = \ case
CacheResultExists _ -> "Exists"
CacheResultEmpty -> "Empty"