-- |
-- Copyright: (c) 2022 Aditya Manthramurthy
-- SPDX-License-Identifier: Apache-2.0
-- Maintainer: Aditya Manthramurthy <aditya.mmy@gmail.com>
--
-- A wai-rate-limit backend using PostgreSQL.
module Network.Wai.RateLimit.Postgres
  ( PGBackendError (..),
    postgresBackend,
  )
where

import Control.Concurrent (forkIO, threadDelay)
import Control.Exception (Exception, Handler (..), catches, throwIO, try)
import Control.Monad (forever, void)
import Data.ByteString (ByteString)
import Data.Pool (Pool, withResource)
import Data.String (fromString)
import Data.Text (Text, unpack)
import qualified Data.Text as T
import qualified Database.PostgreSQL.Simple as PG
import Network.Wai.RateLimit.Backend (Backend (..), BackendError (..))

-- | Represents reasons for why requests made to Postgres backend have failed.
data PGBackendError
  = PGBackendErrorInit PG.SqlError
  | PGBackendErrorBugFmt PG.FormatError
  | PGBackendErrorBugQry PG.QueryError
  | PGBackendErrorBugRes PG.ResultError
  | PGBackendErrorBugSql PG.SqlError
  | PGBackendErrorAtMostOneRow
  | PGBackendErrorExactlyOneRow
  | PGBackendErrorExactlyOneUpdate
  deriving stock (PGBackendError -> PGBackendError -> Bool
(PGBackendError -> PGBackendError -> Bool)
-> (PGBackendError -> PGBackendError -> Bool) -> Eq PGBackendError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PGBackendError -> PGBackendError -> Bool
$c/= :: PGBackendError -> PGBackendError -> Bool
== :: PGBackendError -> PGBackendError -> Bool
$c== :: PGBackendError -> PGBackendError -> Bool
Eq, Int -> PGBackendError -> ShowS
[PGBackendError] -> ShowS
PGBackendError -> String
(Int -> PGBackendError -> ShowS)
-> (PGBackendError -> String)
-> ([PGBackendError] -> ShowS)
-> Show PGBackendError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PGBackendError] -> ShowS
$cshowList :: [PGBackendError] -> ShowS
show :: PGBackendError -> String
$cshow :: PGBackendError -> String
showsPrec :: Int -> PGBackendError -> ShowS
$cshowsPrec :: Int -> PGBackendError -> ShowS
Show)

instance Exception PGBackendError

initPostgresBackend :: Pool PG.Connection -> Text -> IO ()
initPostgresBackend :: Pool Connection -> Text -> IO ()
initPostgresBackend Pool Connection
p Text
tableName = Pool Connection -> (Connection -> IO ()) -> IO ()
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
p ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
  Either SqlError Int64
res <- IO Int64 -> IO (Either SqlError Int64)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO Int64 -> IO (Either SqlError Int64))
-> IO Int64 -> IO (Either SqlError Int64)
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
PG.execute_ Connection
c Query
createTableQuery
  (SqlError -> IO ())
-> (Int64 -> IO ()) -> Either SqlError Int64 -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either
    (BackendError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO ())
-> (SqlError -> BackendError) -> SqlError -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (PGBackendError -> BackendError)
-> (SqlError -> PGBackendError) -> SqlError -> BackendError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlError -> PGBackendError
PGBackendErrorInit)
    (IO () -> Int64 -> IO ()
forall a b. a -> b -> a
const (IO () -> Int64 -> IO ()) -> IO () -> Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
    Either SqlError Int64
res
  where
    createTableQuery :: Query
createTableQuery =
      String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
        Text -> String
unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          Text -> [Text] -> Text
T.intercalate
            Text
" "
            [ Text
"CREATE TABLE IF NOT EXISTS",
              Text
tableName,
              Text
"(key BYTEA PRIMARY KEY,",
              Text
"usage INT8 NOT NULL,",
              Text
"expires_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP + '1 week'::INTERVAL)"
            ]

sqlHandlers :: [Handler a]
sqlHandlers :: [Handler a]
sqlHandlers =
  [ (FormatError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (BackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO a)
-> (FormatError -> BackendError) -> FormatError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (PGBackendError -> BackendError)
-> (FormatError -> PGBackendError) -> FormatError -> BackendError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FormatError -> PGBackendError
PGBackendErrorBugFmt),
    (QueryError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (BackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO a)
-> (QueryError -> BackendError) -> QueryError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (PGBackendError -> BackendError)
-> (QueryError -> PGBackendError) -> QueryError -> BackendError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QueryError -> PGBackendError
PGBackendErrorBugQry),
    (ResultError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (BackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO a)
-> (ResultError -> BackendError) -> ResultError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (PGBackendError -> BackendError)
-> (ResultError -> PGBackendError) -> ResultError -> BackendError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ResultError -> PGBackendError
PGBackendErrorBugRes),
    (SqlError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (BackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO a)
-> (SqlError -> BackendError) -> SqlError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (PGBackendError -> BackendError)
-> (SqlError -> PGBackendError) -> SqlError -> BackendError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlError -> PGBackendError
PGBackendErrorBugSql)
  ]

pgBackendGetUsage :: Pool PG.Connection -> Text -> ByteString -> IO Integer
pgBackendGetUsage :: Pool Connection -> Text -> ByteString -> IO Integer
pgBackendGetUsage Pool Connection
p Text
tableName ByteString
key = Pool Connection -> (Connection -> IO Integer) -> IO Integer
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
p ((Connection -> IO Integer) -> IO Integer)
-> (Connection -> IO Integer) -> IO Integer
forall a b. (a -> b) -> a -> b
$ \Connection
c ->
  do
    [Only Integer]
rows <-
      Connection
-> Query -> Only (Binary ByteString) -> IO [Only Integer]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
PG.query Connection
c Query
getUsageQuery (Binary ByteString -> Only (Binary ByteString)
forall a. a -> Only a
PG.Only (Binary ByteString -> Only (Binary ByteString))
-> Binary ByteString -> Only (Binary ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Binary ByteString
forall a. a -> Binary a
PG.Binary ByteString
key) IO [Only Integer] -> [Handler [Only Integer]] -> IO [Only Integer]
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler [Only Integer]]
forall a. [Handler a]
sqlHandlers
    case [Only Integer]
rows of
      [] -> Integer -> IO Integer
forall (f :: * -> *) a. Applicative f => a -> f a
pure Integer
0
      [PG.Only Integer
a] -> Integer -> IO Integer
forall (f :: * -> *) a. Applicative f => a -> f a
pure Integer
a
      [Only Integer]
_ -> BackendError -> IO Integer
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO Integer) -> BackendError -> IO Integer
forall a b. (a -> b) -> a -> b
$ PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError PGBackendError
PGBackendErrorAtMostOneRow
  where
    getUsageQuery :: Query
getUsageQuery =
      String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
        Text -> String
unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          Text -> [Text] -> Text
T.intercalate
            Text
" "
            [ Text
"SELECT usage FROM",
              Text
tableName,
              Text
"WHERE key = ?",
              Text
"AND expires_at > CURRENT_TIMESTAMP"
            ]

pgBackendIncAndGetUsage :: Pool PG.Connection -> Text -> ByteString -> Integer -> IO Integer
pgBackendIncAndGetUsage :: Pool Connection -> Text -> ByteString -> Integer -> IO Integer
pgBackendIncAndGetUsage Pool Connection
p Text
tableName ByteString
key Integer
usage = Pool Connection -> (Connection -> IO Integer) -> IO Integer
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
p ((Connection -> IO Integer) -> IO Integer)
-> (Connection -> IO Integer) -> IO Integer
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
  [Only Integer]
rows <- Connection
-> Query -> (Binary ByteString, Integer) -> IO [Only Integer]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
PG.query Connection
c Query
incAndGetQuery (ByteString -> Binary ByteString
forall a. a -> Binary a
PG.Binary ByteString
key, Integer
usage) IO [Only Integer] -> [Handler [Only Integer]] -> IO [Only Integer]
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler [Only Integer]]
forall a. [Handler a]
sqlHandlers
  case [Only Integer]
rows of
    [PG.Only Integer
a] -> Integer -> IO Integer
forall (f :: * -> *) a. Applicative f => a -> f a
pure Integer
a
    [Only Integer]
_ -> BackendError -> IO Integer
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO Integer) -> BackendError -> IO Integer
forall a b. (a -> b) -> a -> b
$ PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError PGBackendError
PGBackendErrorExactlyOneRow
  where
    incAndGetQuery :: Query
incAndGetQuery =
      String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
        Text -> String
unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          Text -> [Text] -> Text
T.intercalate
            Text
" "
            [ Text
"INSERT INTO",
              Text
tableName,
              Text
"as rl",
              Text
"(key, usage) VALUES (?, ?)",
              Text
"ON CONFLICT (key) DO UPDATE SET",
              Text
"usage = CASE WHEN rl.expires_at > CURRENT_TIMESTAMP THEN rl.usage + EXCLUDED.usage ELSE EXCLUDED.usage END,",
              Text
"expires_at = CASE WHEN rl.expires_at > CURRENT_TIMESTAMP THEN rl.expires_at ELSE CURRENT_TIMESTAMP + '1 week'::INTERVAL END",
              Text
"RETURNING usage"
            ]

pgBackendExpireIn :: Pool PG.Connection -> Text -> ByteString -> Integer -> IO ()
pgBackendExpireIn :: Pool Connection -> Text -> ByteString -> Integer -> IO ()
pgBackendExpireIn Pool Connection
p Text
tableName ByteString
key Integer
seconds = Pool Connection -> (Connection -> IO ()) -> IO ()
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
p ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
  Int64
count <- Connection -> Query -> (Integer, Binary ByteString) -> IO Int64
forall q. ToRow q => Connection -> Query -> q -> IO Int64
PG.execute Connection
c Query
expireInQuery (Integer
seconds, ByteString -> Binary ByteString
forall a. a -> Binary a
PG.Binary ByteString
key) IO Int64 -> [Handler Int64] -> IO Int64
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler Int64]
forall a. [Handler a]
sqlHandlers
  if Int64
count Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Int64
1
    then BackendError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO ()) -> BackendError -> IO ()
forall a b. (a -> b) -> a -> b
$ PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError PGBackendError
PGBackendErrorExactlyOneUpdate
    else () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  where
    expireInQuery :: Query
expireInQuery =
      String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
        Text -> String
unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          Text -> [Text] -> Text
T.intercalate
            Text
" "
            [ Text
"UPDATE",
              Text
tableName,
              Text
"SET expires_at = CURRENT_TIMESTAMP + '? second'::interval",
              Text
"WHERE key = ?"
            ]

pgBackendCleanup :: Pool PG.Connection -> Text -> IO ()
pgBackendCleanup :: Pool Connection -> Text -> IO ()
pgBackendCleanup Pool Connection
p Text
tableName = IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$
  IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
    IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Either PGBackendError Int64
res <- Pool Connection
-> (Connection -> IO (Either PGBackendError Int64))
-> IO (Either PGBackendError Int64)
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
p ((Connection -> IO (Either PGBackendError Int64))
 -> IO (Either PGBackendError Int64))
-> (Connection -> IO (Either PGBackendError Int64))
-> IO (Either PGBackendError Int64)
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
        IO Int64 -> IO (Either PGBackendError Int64)
forall a. IO a -> IO (Either PGBackendError a)
tryDBErr (IO Int64 -> IO (Either PGBackendError Int64))
-> IO Int64 -> IO (Either PGBackendError Int64)
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
PG.execute_ Connection
c Query
removeExpired IO Int64 -> [Handler Int64] -> IO Int64
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler Int64]
forall a. [Handler a]
sqlHandlers
      case Either PGBackendError Int64
res of
        Left PGBackendError
_ -> Int -> IO ()
threadDelay Int
d10s
        Right Int64
n -> Int64 -> IO ()
forall a. (Num a, Ord a) => a -> IO ()
delay Int64
n
  where
    d10s :: Int
d10s = Int
10_000_000
    d1s :: Int
d1s = Int
1_000_000
    d100ms :: Int
d100ms = Int
100_000

    -- Try to ensure we cleanup as fast as garbage is created.
    delay :: a -> IO ()
delay a
n
      | a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
5000 = Int -> IO ()
threadDelay Int
d100ms
      | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
0 = Int -> IO ()
threadDelay Int
d1s
      | Bool
otherwise = Int -> IO ()
threadDelay Int
d10s

    tryDBErr :: IO a -> IO (Either PGBackendError a)
    tryDBErr :: IO a -> IO (Either PGBackendError a)
tryDBErr IO a
a = IO a -> IO (Either PGBackendError a)
forall e a. Exception e => IO a -> IO (Either e a)
try IO a
a

    removeExpired :: Query
removeExpired =
      String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
        Text -> String
unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          Text -> [Text] -> Text
T.intercalate
            Text
" "
            [ Text
"DELETE FROM",
              Text
tableName,
              Text
"WHERE key IN (SELECT key FROM",
              Text
tableName,
              Text
"WHERE expires_at < CURRENT_TIMESTAMP LIMIT 5000)"
            ]

-- | Initialize a postgres backend for rate-limiting. Takes a connection pool
-- and table name to use for storage. The table will be created if it does not
-- exist. A thread is also launched to periodically clean up expired rows from
-- the table.
postgresBackend :: Pool PG.Connection -> Text -> IO (Backend ByteString)
postgresBackend :: Pool Connection -> Text -> IO (Backend ByteString)
postgresBackend Pool Connection
p Text
tableName = do
  Pool Connection -> Text -> IO ()
initPostgresBackend Pool Connection
p Text
tableName
  Pool Connection -> Text -> IO ()
pgBackendCleanup Pool Connection
p Text
tableName
  Backend ByteString -> IO (Backend ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Backend ByteString -> IO (Backend ByteString))
-> Backend ByteString -> IO (Backend ByteString)
forall a b. (a -> b) -> a -> b
$
    MkBackend :: forall key.
(key -> IO Integer)
-> (key -> Integer -> IO Integer)
-> (key -> Integer -> IO ())
-> Backend key
MkBackend
      { backendGetUsage :: ByteString -> IO Integer
backendGetUsage = Pool Connection -> Text -> ByteString -> IO Integer
pgBackendGetUsage Pool Connection
p Text
tableName,
        backendIncAndGetUsage :: ByteString -> Integer -> IO Integer
backendIncAndGetUsage = Pool Connection -> Text -> ByteString -> Integer -> IO Integer
pgBackendIncAndGetUsage Pool Connection
p Text
tableName,
        backendExpireIn :: ByteString -> Integer -> IO ()
backendExpireIn = Pool Connection -> Text -> ByteString -> Integer -> IO ()
pgBackendExpireIn Pool Connection
p Text
tableName
      }