--------------------------------------------------------------------------------
-- Rate Limiting Middleware for WAI                                           --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

module Network.Wai.RateLimit.Redis (
    RedisBackendError(..),
    redisBackend
) where

--------------------------------------------------------------------------------

import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C8

import Database.Redis as Redis

import Network.Wai.RateLimit.Backend

--------------------------------------------------------------------------------

-- | Represents reasons why requests made to the Redis backend have failed.
data RedisBackendError
    = RedisBackendReply Reply 
    | RedisBackendTxAborted
    | RedisBackendTxError String 
    deriving (RedisBackendError -> RedisBackendError -> Bool
(RedisBackendError -> RedisBackendError -> Bool)
-> (RedisBackendError -> RedisBackendError -> Bool)
-> Eq RedisBackendError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RedisBackendError -> RedisBackendError -> Bool
$c/= :: RedisBackendError -> RedisBackendError -> Bool
== :: RedisBackendError -> RedisBackendError -> Bool
$c== :: RedisBackendError -> RedisBackendError -> Bool
Eq, Int -> RedisBackendError -> ShowS
[RedisBackendError] -> ShowS
RedisBackendError -> String
(Int -> RedisBackendError -> ShowS)
-> (RedisBackendError -> String)
-> ([RedisBackendError] -> ShowS)
-> Show RedisBackendError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RedisBackendError] -> ShowS
$cshowList :: [RedisBackendError] -> ShowS
show :: RedisBackendError -> String
$cshow :: RedisBackendError -> String
showsPrec :: Int -> RedisBackendError -> ShowS
$cshowsPrec :: Int -> RedisBackendError -> ShowS
Show) 

-- | 'redisBackend' @connection@ constructs a rate limiting 'Backend' for the
-- given redis @connection@.
redisBackend :: Connection -> Backend BS.ByteString RedisBackendError
redisBackend :: Connection -> Backend ByteString RedisBackendError
redisBackend Connection
conn = MkBackend :: forall key err.
(key -> IO (Either err Integer))
-> (key -> Integer -> IO (Either err Integer))
-> (key -> Integer -> IO (Either err ()))
-> Backend key err
MkBackend{
    backendGetUsage :: ByteString -> IO (Either RedisBackendError Integer)
backendGetUsage = \ByteString
key -> Connection
-> Redis (Either Reply (Maybe ByteString))
-> IO (Either Reply (Maybe ByteString))
forall a. Connection -> Redis a -> IO a
runRedis Connection
conn (ByteString -> Redis (Either Reply (Maybe ByteString))
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> m (f (Maybe ByteString))
get ByteString
key) IO (Either Reply (Maybe ByteString))
-> (Either Reply (Maybe ByteString)
    -> IO (Either RedisBackendError Integer))
-> IO (Either RedisBackendError Integer)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Left Reply
err -> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either RedisBackendError Integer
 -> IO (Either RedisBackendError Integer))
-> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall a b. (a -> b) -> a -> b
$ RedisBackendError -> Either RedisBackendError Integer
forall a b. a -> Either a b
Left (RedisBackendError -> Either RedisBackendError Integer)
-> RedisBackendError -> Either RedisBackendError Integer
forall a b. (a -> b) -> a -> b
$ Reply -> RedisBackendError
RedisBackendReply Reply
err
        Right Maybe ByteString
mVal -> case Maybe ByteString
mVal Maybe ByteString
-> (ByteString -> Maybe (Integer, ByteString))
-> Maybe (Integer, ByteString)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ByteString
val -> ByteString -> Maybe (Integer, ByteString)
C8.readInteger ByteString
val of 
            -- the key does not exist or is not a valid integer:
            -- no previous usage
            Maybe (Integer, ByteString)
Nothing -> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either RedisBackendError Integer
 -> IO (Either RedisBackendError Integer))
-> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall a b. (a -> b) -> a -> b
$ Integer -> Either RedisBackendError Integer
forall a b. b -> Either a b
Right Integer
0 
            -- the key exists: check that nothing follows the integer and then
            -- return the current usage
            Just (Integer
n, ByteString
xs) | ByteString -> Bool
BS.null ByteString
xs -> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either RedisBackendError Integer
 -> IO (Either RedisBackendError Integer))
-> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall a b. (a -> b) -> a -> b
$ Integer -> Either RedisBackendError Integer
forall a b. b -> Either a b
Right Integer
n
                         | Bool
otherwise -> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either RedisBackendError Integer
 -> IO (Either RedisBackendError Integer))
-> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall a b. (a -> b) -> a -> b
$ Integer -> Either RedisBackendError Integer
forall a b. b -> Either a b
Right Integer
0,
    backendIncAndGetUsage :: ByteString -> Integer -> IO (Either RedisBackendError Integer)
backendIncAndGetUsage = \ByteString
key Integer
val -> do
        -- increment the value of the key by the specified amount
        Either Reply Integer
result <- Connection
-> Redis (Either Reply Integer) -> IO (Either Reply Integer)
forall a. Connection -> Redis a -> IO a
runRedis Connection
conn (Redis (Either Reply Integer) -> IO (Either Reply Integer))
-> Redis (Either Reply Integer) -> IO (Either Reply Integer)
forall a b. (a -> b) -> a -> b
$ ByteString -> Integer -> Redis (Either Reply Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> Integer -> m (f Integer)
incrby ByteString
key Integer
val
        
        case Either Reply Integer
result of 
            Left Reply
err -> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either RedisBackendError Integer
 -> IO (Either RedisBackendError Integer))
-> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall a b. (a -> b) -> a -> b
$ RedisBackendError -> Either RedisBackendError Integer
forall a b. a -> Either a b
Left (RedisBackendError -> Either RedisBackendError Integer)
-> RedisBackendError -> Either RedisBackendError Integer
forall a b. (a -> b) -> a -> b
$ Reply -> RedisBackendError
RedisBackendReply Reply
err
            Right Integer
n -> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either RedisBackendError Integer
 -> IO (Either RedisBackendError Integer))
-> Either RedisBackendError Integer
-> IO (Either RedisBackendError Integer)
forall a b. (a -> b) -> a -> b
$ Integer -> Either RedisBackendError Integer
forall a b. b -> Either a b
Right Integer
n,
    backendExpireIn :: ByteString -> Integer -> IO (Either RedisBackendError ())
backendExpireIn = \ByteString
key Integer
seconds ->
        -- update the key to expire in the specified number of seconds 
        Connection -> Redis (Either Reply Bool) -> IO (Either Reply Bool)
forall a. Connection -> Redis a -> IO a
runRedis Connection
conn (ByteString -> Integer -> Redis (Either Reply Bool)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> Integer -> m (f Bool)
expire ByteString
key Integer
seconds) IO (Either Reply Bool)
-> (Either Reply Bool -> IO (Either RedisBackendError ()))
-> IO (Either RedisBackendError ())
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case  
            Left Reply
err -> Either RedisBackendError () -> IO (Either RedisBackendError ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either RedisBackendError () -> IO (Either RedisBackendError ()))
-> Either RedisBackendError () -> IO (Either RedisBackendError ())
forall a b. (a -> b) -> a -> b
$ RedisBackendError -> Either RedisBackendError ()
forall a b. a -> Either a b
Left (RedisBackendError -> Either RedisBackendError ())
-> RedisBackendError -> Either RedisBackendError ()
forall a b. (a -> b) -> a -> b
$ Reply -> RedisBackendError
RedisBackendReply Reply
err 
            Right Bool
_ -> Either RedisBackendError () -> IO (Either RedisBackendError ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either RedisBackendError () -> IO (Either RedisBackendError ()))
-> Either RedisBackendError () -> IO (Either RedisBackendError ())
forall a b. (a -> b) -> a -> b
$ () -> Either RedisBackendError ()
forall a b. b -> Either a b
Right ()
}

--------------------------------------------------------------------------------