--------------------------------------------------------------------------------
-- Rate Limiting Middleware for WAI                                           --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

module Network.Wai.RateLimit.Redis (
    RedisBackendError(..),
    redisBackend
) where

--------------------------------------------------------------------------------

import Control.Exception

import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C8

import Database.Redis as Redis

import Network.Wai.RateLimit.Backend

--------------------------------------------------------------------------------

-- | Represents reasons why requests made to the Redis backend have failed.
data RedisBackendError
    = RedisBackendReply Reply
    | RedisBackendTxAborted
    | RedisBackendTxError String
    deriving (RedisBackendError -> RedisBackendError -> Bool
(RedisBackendError -> RedisBackendError -> Bool)
-> (RedisBackendError -> RedisBackendError -> Bool)
-> Eq RedisBackendError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RedisBackendError -> RedisBackendError -> Bool
$c/= :: RedisBackendError -> RedisBackendError -> Bool
== :: RedisBackendError -> RedisBackendError -> Bool
$c== :: RedisBackendError -> RedisBackendError -> Bool
Eq, Int -> RedisBackendError -> ShowS
[RedisBackendError] -> ShowS
RedisBackendError -> String
(Int -> RedisBackendError -> ShowS)
-> (RedisBackendError -> String)
-> ([RedisBackendError] -> ShowS)
-> Show RedisBackendError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RedisBackendError] -> ShowS
$cshowList :: [RedisBackendError] -> ShowS
show :: RedisBackendError -> String
$cshow :: RedisBackendError -> String
showsPrec :: Int -> RedisBackendError -> ShowS
$cshowsPrec :: Int -> RedisBackendError -> ShowS
Show)

instance Exception RedisBackendError

-- | 'redisBackend' @connection@ constructs a rate limiting 'Backend' for the
-- given redis @connection@.
redisBackend :: Connection -> Backend BS.ByteString
redisBackend :: Connection -> Backend ByteString
redisBackend Connection
conn = MkBackend :: forall key.
(key -> IO Integer)
-> (key -> Integer -> IO Integer)
-> (key -> Integer -> IO ())
-> Backend key
MkBackend{
    backendGetUsage :: ByteString -> IO Integer
backendGetUsage = \ByteString
key -> Connection
-> Redis (Either Reply (Maybe ByteString))
-> IO (Either Reply (Maybe ByteString))
forall a. Connection -> Redis a -> IO a
runRedis Connection
conn (ByteString -> Redis (Either Reply (Maybe ByteString))
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> m (f (Maybe ByteString))
get ByteString
key) IO (Either Reply (Maybe ByteString))
-> (Either Reply (Maybe ByteString) -> IO Integer) -> IO Integer
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Left Reply
err -> BackendError -> IO Integer
forall a e. Exception e => e -> a
throw (BackendError -> IO Integer) -> BackendError -> IO Integer
forall a b. (a -> b) -> a -> b
$ RedisBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (RedisBackendError -> BackendError)
-> RedisBackendError -> BackendError
forall a b. (a -> b) -> a -> b
$ Reply -> RedisBackendError
RedisBackendReply Reply
err
        Right Maybe ByteString
mVal -> case Maybe ByteString
mVal Maybe ByteString
-> (ByteString -> Maybe (Integer, ByteString))
-> Maybe (Integer, ByteString)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ByteString
val -> ByteString -> Maybe (Integer, ByteString)
C8.readInteger ByteString
val of
            -- the key does not exist or is not a valid integer:
            -- no previous usage
            Maybe (Integer, ByteString)
Nothing -> Integer -> IO Integer
forall (f :: * -> *) a. Applicative f => a -> f a
pure Integer
0
            -- the key exists: check that nothing follows the integer and then
            -- return the current usage
            Just (Integer
n, ByteString
xs) | ByteString -> Bool
BS.null ByteString
xs -> Integer -> IO Integer
forall (f :: * -> *) a. Applicative f => a -> f a
pure Integer
n
                         | Bool
otherwise -> Integer -> IO Integer
forall (f :: * -> *) a. Applicative f => a -> f a
pure Integer
0,
    backendIncAndGetUsage :: ByteString -> Integer -> IO Integer
backendIncAndGetUsage = \ByteString
key Integer
val -> do
        -- increment the value of the key by the specified amount
        Either Reply Integer
result <- Connection
-> Redis (Either Reply Integer) -> IO (Either Reply Integer)
forall a. Connection -> Redis a -> IO a
runRedis Connection
conn (Redis (Either Reply Integer) -> IO (Either Reply Integer))
-> Redis (Either Reply Integer) -> IO (Either Reply Integer)
forall a b. (a -> b) -> a -> b
$ ByteString -> Integer -> Redis (Either Reply Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> Integer -> m (f Integer)
incrby ByteString
key Integer
val

        case Either Reply Integer
result of
            Left Reply
err -> BackendError -> IO Integer
forall a e. Exception e => e -> a
throw (BackendError -> IO Integer) -> BackendError -> IO Integer
forall a b. (a -> b) -> a -> b
$ RedisBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (RedisBackendError -> BackendError)
-> RedisBackendError -> BackendError
forall a b. (a -> b) -> a -> b
$ Reply -> RedisBackendError
RedisBackendReply Reply
err
            Right Integer
n -> Integer -> IO Integer
forall (f :: * -> *) a. Applicative f => a -> f a
pure Integer
n,
    backendExpireIn :: ByteString -> Integer -> IO ()
backendExpireIn = \ByteString
key Integer
seconds ->
        -- update the key to expire in the specified number of seconds
        Connection -> Redis (Either Reply Bool) -> IO (Either Reply Bool)
forall a. Connection -> Redis a -> IO a
runRedis Connection
conn (ByteString -> Integer -> Redis (Either Reply Bool)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> Integer -> m (f Bool)
expire ByteString
key Integer
seconds) IO (Either Reply Bool) -> (Either Reply Bool -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Left Reply
err -> BackendError -> IO ()
forall a e. Exception e => e -> a
throw (BackendError -> IO ()) -> BackendError -> IO ()
forall a b. (a -> b) -> a -> b
$ RedisBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (RedisBackendError -> BackendError)
-> RedisBackendError -> BackendError
forall a b. (a -> b) -> a -> b
$ Reply -> RedisBackendError
RedisBackendReply Reply
err
            Right Bool
_ -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
}

--------------------------------------------------------------------------------