module Network.Wai.RateLimit.Postgres
( PGBackendError (..),
postgresBackend,
)
where
import Control.Concurrent (forkIO, threadDelay)
import Control.Exception (Exception, Handler (..), catches, throwIO, try)
import Control.Monad (forever, void)
import Data.ByteString (ByteString)
import Data.Pool (Pool, withResource)
import Data.String (fromString)
import Data.Text (Text, unpack)
import qualified Data.Text as T
import qualified Database.PostgreSQL.Simple as PG
import Network.Wai.RateLimit.Backend (Backend (..), BackendError (..))
data PGBackendError
= PGBackendErrorInit PG.SqlError
| PGBackendErrorBugFmt PG.FormatError
| PGBackendErrorBugQry PG.QueryError
| PGBackendErrorBugRes PG.ResultError
| PGBackendErrorBugSql PG.SqlError
| PGBackendErrorAtMostOneRow
| PGBackendErrorExactlyOneRow
| PGBackendErrorExactlyOneUpdate
deriving stock (PGBackendError -> PGBackendError -> Bool
(PGBackendError -> PGBackendError -> Bool)
-> (PGBackendError -> PGBackendError -> Bool) -> Eq PGBackendError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PGBackendError -> PGBackendError -> Bool
$c/= :: PGBackendError -> PGBackendError -> Bool
== :: PGBackendError -> PGBackendError -> Bool
$c== :: PGBackendError -> PGBackendError -> Bool
Eq, Int -> PGBackendError -> ShowS
[PGBackendError] -> ShowS
PGBackendError -> String
(Int -> PGBackendError -> ShowS)
-> (PGBackendError -> String)
-> ([PGBackendError] -> ShowS)
-> Show PGBackendError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PGBackendError] -> ShowS
$cshowList :: [PGBackendError] -> ShowS
show :: PGBackendError -> String
$cshow :: PGBackendError -> String
showsPrec :: Int -> PGBackendError -> ShowS
$cshowsPrec :: Int -> PGBackendError -> ShowS
Show)
instance Exception PGBackendError
initPostgresBackend :: Pool PG.Connection -> Text -> IO ()
initPostgresBackend :: Pool Connection -> Text -> IO ()
initPostgresBackend Pool Connection
p Text
tableName = Pool Connection -> (Connection -> IO ()) -> IO ()
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
p ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
Either SqlError Int64
res <- IO Int64 -> IO (Either SqlError Int64)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO Int64 -> IO (Either SqlError Int64))
-> IO Int64 -> IO (Either SqlError Int64)
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
PG.execute_ Connection
c Query
createTableQuery
(SqlError -> IO ())
-> (Int64 -> IO ()) -> Either SqlError Int64 -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either
(BackendError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO ())
-> (SqlError -> BackendError) -> SqlError -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (PGBackendError -> BackendError)
-> (SqlError -> PGBackendError) -> SqlError -> BackendError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlError -> PGBackendError
PGBackendErrorInit)
(IO () -> Int64 -> IO ()
forall a b. a -> b -> a
const (IO () -> Int64 -> IO ()) -> IO () -> Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
Either SqlError Int64
res
where
createTableQuery :: Query
createTableQuery =
String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
Text -> String
unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
Text -> [Text] -> Text
T.intercalate
Text
" "
[ Text
"CREATE TABLE IF NOT EXISTS",
Text
tableName,
Text
"(key BYTEA PRIMARY KEY,",
Text
"usage INT8 NOT NULL,",
Text
"expires_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP + '1 week'::INTERVAL)"
]
sqlHandlers :: [Handler a]
sqlHandlers :: [Handler a]
sqlHandlers =
[ (FormatError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (BackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO a)
-> (FormatError -> BackendError) -> FormatError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (PGBackendError -> BackendError)
-> (FormatError -> PGBackendError) -> FormatError -> BackendError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FormatError -> PGBackendError
PGBackendErrorBugFmt),
(QueryError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (BackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO a)
-> (QueryError -> BackendError) -> QueryError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (PGBackendError -> BackendError)
-> (QueryError -> PGBackendError) -> QueryError -> BackendError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QueryError -> PGBackendError
PGBackendErrorBugQry),
(ResultError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (BackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO a)
-> (ResultError -> BackendError) -> ResultError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (PGBackendError -> BackendError)
-> (ResultError -> PGBackendError) -> ResultError -> BackendError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ResultError -> PGBackendError
PGBackendErrorBugRes),
(SqlError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (BackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO a)
-> (SqlError -> BackendError) -> SqlError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError (PGBackendError -> BackendError)
-> (SqlError -> PGBackendError) -> SqlError -> BackendError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlError -> PGBackendError
PGBackendErrorBugSql)
]
pgBackendGetUsage :: Pool PG.Connection -> Text -> ByteString -> IO Integer
pgBackendGetUsage :: Pool Connection -> Text -> ByteString -> IO Integer
pgBackendGetUsage Pool Connection
p Text
tableName ByteString
key = Pool Connection -> (Connection -> IO Integer) -> IO Integer
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
p ((Connection -> IO Integer) -> IO Integer)
-> (Connection -> IO Integer) -> IO Integer
forall a b. (a -> b) -> a -> b
$ \Connection
c ->
do
[Only Integer]
rows <-
Connection
-> Query -> Only (Binary ByteString) -> IO [Only Integer]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
PG.query Connection
c Query
getUsageQuery (Binary ByteString -> Only (Binary ByteString)
forall a. a -> Only a
PG.Only (Binary ByteString -> Only (Binary ByteString))
-> Binary ByteString -> Only (Binary ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Binary ByteString
forall a. a -> Binary a
PG.Binary ByteString
key) IO [Only Integer] -> [Handler [Only Integer]] -> IO [Only Integer]
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler [Only Integer]]
forall a. [Handler a]
sqlHandlers
case [Only Integer]
rows of
[] -> Integer -> IO Integer
forall (f :: * -> *) a. Applicative f => a -> f a
pure Integer
0
[PG.Only Integer
a] -> Integer -> IO Integer
forall (f :: * -> *) a. Applicative f => a -> f a
pure Integer
a
[Only Integer]
_ -> BackendError -> IO Integer
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO Integer) -> BackendError -> IO Integer
forall a b. (a -> b) -> a -> b
$ PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError PGBackendError
PGBackendErrorAtMostOneRow
where
getUsageQuery :: Query
getUsageQuery =
String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
Text -> String
unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
Text -> [Text] -> Text
T.intercalate
Text
" "
[ Text
"SELECT usage FROM",
Text
tableName,
Text
"WHERE key = ?",
Text
"AND expires_at > CURRENT_TIMESTAMP"
]
pgBackendIncAndGetUsage :: Pool PG.Connection -> Text -> ByteString -> Integer -> IO Integer
pgBackendIncAndGetUsage :: Pool Connection -> Text -> ByteString -> Integer -> IO Integer
pgBackendIncAndGetUsage Pool Connection
p Text
tableName ByteString
key Integer
usage = Pool Connection -> (Connection -> IO Integer) -> IO Integer
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
p ((Connection -> IO Integer) -> IO Integer)
-> (Connection -> IO Integer) -> IO Integer
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
[Only Integer]
rows <- Connection
-> Query -> (Binary ByteString, Integer) -> IO [Only Integer]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
PG.query Connection
c Query
incAndGetQuery (ByteString -> Binary ByteString
forall a. a -> Binary a
PG.Binary ByteString
key, Integer
usage) IO [Only Integer] -> [Handler [Only Integer]] -> IO [Only Integer]
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler [Only Integer]]
forall a. [Handler a]
sqlHandlers
case [Only Integer]
rows of
[PG.Only Integer
a] -> Integer -> IO Integer
forall (f :: * -> *) a. Applicative f => a -> f a
pure Integer
a
[Only Integer]
_ -> BackendError -> IO Integer
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO Integer) -> BackendError -> IO Integer
forall a b. (a -> b) -> a -> b
$ PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError PGBackendError
PGBackendErrorExactlyOneRow
where
incAndGetQuery :: Query
incAndGetQuery =
String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
Text -> String
unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
Text -> [Text] -> Text
T.intercalate
Text
" "
[ Text
"INSERT INTO",
Text
tableName,
Text
"as rl",
Text
"(key, usage) VALUES (?, ?)",
Text
"ON CONFLICT (key) DO UPDATE SET",
Text
"usage = CASE WHEN rl.expires_at > CURRENT_TIMESTAMP THEN rl.usage + EXCLUDED.usage ELSE EXCLUDED.usage END,",
Text
"expires_at = CASE WHEN rl.expires_at > CURRENT_TIMESTAMP THEN rl.expires_at ELSE CURRENT_TIMESTAMP + '1 week'::INTERVAL END",
Text
"RETURNING usage"
]
pgBackendExpireIn :: Pool PG.Connection -> Text -> ByteString -> Integer -> IO ()
pgBackendExpireIn :: Pool Connection -> Text -> ByteString -> Integer -> IO ()
pgBackendExpireIn Pool Connection
p Text
tableName ByteString
key Integer
seconds = Pool Connection -> (Connection -> IO ()) -> IO ()
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
p ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
Int64
count <- Connection -> Query -> (Integer, Binary ByteString) -> IO Int64
forall q. ToRow q => Connection -> Query -> q -> IO Int64
PG.execute Connection
c Query
expireInQuery (Integer
seconds, ByteString -> Binary ByteString
forall a. a -> Binary a
PG.Binary ByteString
key) IO Int64 -> [Handler Int64] -> IO Int64
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler Int64]
forall a. [Handler a]
sqlHandlers
if Int64
count Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Int64
1
then BackendError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (BackendError -> IO ()) -> BackendError -> IO ()
forall a b. (a -> b) -> a -> b
$ PGBackendError -> BackendError
forall e. Exception e => e -> BackendError
BackendError PGBackendError
PGBackendErrorExactlyOneUpdate
else () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
where
expireInQuery :: Query
expireInQuery =
String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
Text -> String
unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
Text -> [Text] -> Text
T.intercalate
Text
" "
[ Text
"UPDATE",
Text
tableName,
Text
"SET expires_at = CURRENT_TIMESTAMP + '? second'::interval",
Text
"WHERE key = ?"
]
pgBackendCleanup :: Pool PG.Connection -> Text -> IO ()
pgBackendCleanup :: Pool Connection -> Text -> IO ()
pgBackendCleanup Pool Connection
p Text
tableName = IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$
IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Either PGBackendError Int64
res <- Pool Connection
-> (Connection -> IO (Either PGBackendError Int64))
-> IO (Either PGBackendError Int64)
forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool Connection
p ((Connection -> IO (Either PGBackendError Int64))
-> IO (Either PGBackendError Int64))
-> (Connection -> IO (Either PGBackendError Int64))
-> IO (Either PGBackendError Int64)
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
IO Int64 -> IO (Either PGBackendError Int64)
forall a. IO a -> IO (Either PGBackendError a)
tryDBErr (IO Int64 -> IO (Either PGBackendError Int64))
-> IO Int64 -> IO (Either PGBackendError Int64)
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
PG.execute_ Connection
c Query
removeExpired IO Int64 -> [Handler Int64] -> IO Int64
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler Int64]
forall a. [Handler a]
sqlHandlers
case Either PGBackendError Int64
res of
Left PGBackendError
_ -> Int -> IO ()
threadDelay Int
d10s
Right Int64
n -> Int64 -> IO ()
forall a. (Num a, Ord a) => a -> IO ()
delay Int64
n
where
d10s :: Int
d10s = Int
10_000_000
d1s :: Int
d1s = Int
1_000_000
d100ms :: Int
d100ms = Int
100_000
delay :: a -> IO ()
delay a
n
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
5000 = Int -> IO ()
threadDelay Int
d100ms
| a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
0 = Int -> IO ()
threadDelay Int
d1s
| Bool
otherwise = Int -> IO ()
threadDelay Int
d10s
tryDBErr :: IO a -> IO (Either PGBackendError a)
tryDBErr :: IO a -> IO (Either PGBackendError a)
tryDBErr IO a
a = IO a -> IO (Either PGBackendError a)
forall e a. Exception e => IO a -> IO (Either e a)
try IO a
a
removeExpired :: Query
removeExpired =
String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
Text -> String
unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
Text -> [Text] -> Text
T.intercalate
Text
" "
[ Text
"DELETE FROM",
Text
tableName,
Text
"WHERE key IN (SELECT key FROM",
Text
tableName,
Text
"WHERE expires_at < CURRENT_TIMESTAMP LIMIT 5000)"
]
postgresBackend :: Pool PG.Connection -> Text -> IO (Backend ByteString)
postgresBackend :: Pool Connection -> Text -> IO (Backend ByteString)
postgresBackend Pool Connection
p Text
tableName = do
Pool Connection -> Text -> IO ()
initPostgresBackend Pool Connection
p Text
tableName
Pool Connection -> Text -> IO ()
pgBackendCleanup Pool Connection
p Text
tableName
Backend ByteString -> IO (Backend ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Backend ByteString -> IO (Backend ByteString))
-> Backend ByteString -> IO (Backend ByteString)
forall a b. (a -> b) -> a -> b
$
MkBackend :: forall key.
(key -> IO Integer)
-> (key -> Integer -> IO Integer)
-> (key -> Integer -> IO ())
-> Backend key
MkBackend
{ backendGetUsage :: ByteString -> IO Integer
backendGetUsage = Pool Connection -> Text -> ByteString -> IO Integer
pgBackendGetUsage Pool Connection
p Text
tableName,
backendIncAndGetUsage :: ByteString -> Integer -> IO Integer
backendIncAndGetUsage = Pool Connection -> Text -> ByteString -> Integer -> IO Integer
pgBackendIncAndGetUsage Pool Connection
p Text
tableName,
backendExpireIn :: ByteString -> Integer -> IO ()
backendExpireIn = Pool Connection -> Text -> ByteString -> Integer -> IO ()
pgBackendExpireIn Pool Connection
p Text
tableName
}