{-# LANGUAGE DeriveAnyClass, DerivingVia #-}
{-# LANGUAGE FlexibleContexts, ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings, RankNTypes #-}
module Database.PostgreSQL.Resilient
( ResilientConnection(..)
, ResilientSettings(..)
, Seconds
, withResilientConnection
, defaultResilientSettings
)
where
import Control.Concurrent ( forkIO
, killThread
, threadDelay
)
import Control.Concurrent.MVar
import Control.Monad ( forever )
import Control.Monad.Catch
import Data.IORef
import Data.Functor ( void )
import Data.Maybe ( fromJust )
import qualified Database.PostgreSQL.Simple as P
import GHC.IO.Exception
import Prelude hiding ( init )
data DBConnectionError = DBConnectionError deriving (Show DBConnectionError
Typeable DBConnectionError
(Typeable DBConnectionError, Show DBConnectionError) =>
(DBConnectionError -> SomeException)
-> (SomeException -> Maybe DBConnectionError)
-> (DBConnectionError -> String)
-> Exception DBConnectionError
SomeException -> Maybe DBConnectionError
DBConnectionError -> String
DBConnectionError -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
displayException :: DBConnectionError -> String
$cdisplayException :: DBConnectionError -> String
fromException :: SomeException -> Maybe DBConnectionError
$cfromException :: SomeException -> Maybe DBConnectionError
toException :: DBConnectionError -> SomeException
$ctoException :: DBConnectionError -> SomeException
$cp2Exception :: Show DBConnectionError
$cp1Exception :: Typeable DBConnectionError
Exception, Int -> DBConnectionError -> ShowS
[DBConnectionError] -> ShowS
DBConnectionError -> String
(Int -> DBConnectionError -> ShowS)
-> (DBConnectionError -> String)
-> ([DBConnectionError] -> ShowS)
-> Show DBConnectionError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DBConnectionError] -> ShowS
$cshowList :: [DBConnectionError] -> ShowS
show :: DBConnectionError -> String
$cshow :: DBConnectionError -> String
showsPrec :: Int -> DBConnectionError -> ShowS
$cshowsPrec :: Int -> DBConnectionError -> ShowS
Show)
data ResilientConnection m = ResilientConnection
{ ResilientConnection m -> m Connection
getConnection :: m P.Connection
}
type LogHandler = String -> IO ()
newtype Seconds = Seconds Int
deriving (Seconds -> Seconds -> Bool
(Seconds -> Seconds -> Bool)
-> (Seconds -> Seconds -> Bool) -> Eq Seconds
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Seconds -> Seconds -> Bool
$c/= :: Seconds -> Seconds -> Bool
== :: Seconds -> Seconds -> Bool
$c== :: Seconds -> Seconds -> Bool
Eq, Integer -> Seconds
Seconds -> Seconds
Seconds -> Seconds -> Seconds
(Seconds -> Seconds -> Seconds)
-> (Seconds -> Seconds -> Seconds)
-> (Seconds -> Seconds -> Seconds)
-> (Seconds -> Seconds)
-> (Seconds -> Seconds)
-> (Seconds -> Seconds)
-> (Integer -> Seconds)
-> Num Seconds
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> Seconds
$cfromInteger :: Integer -> Seconds
signum :: Seconds -> Seconds
$csignum :: Seconds -> Seconds
abs :: Seconds -> Seconds
$cabs :: Seconds -> Seconds
negate :: Seconds -> Seconds
$cnegate :: Seconds -> Seconds
* :: Seconds -> Seconds -> Seconds
$c* :: Seconds -> Seconds -> Seconds
- :: Seconds -> Seconds -> Seconds
$c- :: Seconds -> Seconds -> Seconds
+ :: Seconds -> Seconds -> Seconds
$c+ :: Seconds -> Seconds -> Seconds
Num, Eq Seconds
Eq Seconds =>
(Seconds -> Seconds -> Ordering)
-> (Seconds -> Seconds -> Bool)
-> (Seconds -> Seconds -> Bool)
-> (Seconds -> Seconds -> Bool)
-> (Seconds -> Seconds -> Bool)
-> (Seconds -> Seconds -> Seconds)
-> (Seconds -> Seconds -> Seconds)
-> Ord Seconds
Seconds -> Seconds -> Bool
Seconds -> Seconds -> Ordering
Seconds -> Seconds -> Seconds
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Seconds -> Seconds -> Seconds
$cmin :: Seconds -> Seconds -> Seconds
max :: Seconds -> Seconds -> Seconds
$cmax :: Seconds -> Seconds -> Seconds
>= :: Seconds -> Seconds -> Bool
$c>= :: Seconds -> Seconds -> Bool
> :: Seconds -> Seconds -> Bool
$c> :: Seconds -> Seconds -> Bool
<= :: Seconds -> Seconds -> Bool
$c<= :: Seconds -> Seconds -> Bool
< :: Seconds -> Seconds -> Bool
$c< :: Seconds -> Seconds -> Bool
compare :: Seconds -> Seconds -> Ordering
$ccompare :: Seconds -> Seconds -> Ordering
$cp1Ord :: Eq Seconds
Ord, Int -> Seconds -> ShowS
[Seconds] -> ShowS
Seconds -> String
(Int -> Seconds -> ShowS)
-> (Seconds -> String) -> ([Seconds] -> ShowS) -> Show Seconds
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Seconds] -> ShowS
$cshowList :: [Seconds] -> ShowS
show :: Seconds -> String
$cshow :: Seconds -> String
showsPrec :: Int -> Seconds -> ShowS
$cshowsPrec :: Int -> Seconds -> ShowS
Show) via Int
data ResilientSettings = ResilientSettings
{ ResilientSettings -> Seconds
healthCheckEvery :: Seconds
, ResilientSettings -> Seconds
exponentialBackoffThreshold :: Seconds
} deriving Int -> ResilientSettings -> ShowS
[ResilientSettings] -> ShowS
ResilientSettings -> String
(Int -> ResilientSettings -> ShowS)
-> (ResilientSettings -> String)
-> ([ResilientSettings] -> ShowS)
-> Show ResilientSettings
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResilientSettings] -> ShowS
$cshowList :: [ResilientSettings] -> ShowS
show :: ResilientSettings -> String
$cshow :: ResilientSettings -> String
showsPrec :: Int -> ResilientSettings -> ShowS
$cshowsPrec :: Int -> ResilientSettings -> ShowS
Show
defaultResilientSettings :: ResilientSettings
defaultResilientSettings :: ResilientSettings
defaultResilientSettings =
ResilientSettings :: Seconds -> Seconds -> ResilientSettings
ResilientSettings { healthCheckEvery :: Seconds
healthCheckEvery = 3, exponentialBackoffThreshold :: Seconds
exponentialBackoffThreshold = 10 }
sleep :: Seconds -> IO ()
sleep :: Seconds -> IO ()
sleep (Seconds n :: Int
n) = Int -> IO ()
threadDelay (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* 1000000)
healthCheck :: LogHandler -> P.Connection -> IO ()
healthCheck :: LogHandler -> Connection -> IO ()
healthCheck logger :: LogHandler
logger conn :: Connection
conn = do
([Only String]
res :: [P.Only String]) <- Connection -> Query -> IO [Only String]
forall r. FromRow r => Connection -> Query -> IO [r]
P.query_ Connection
conn "SELECT version();"
LogHandler
logger LogHandler -> LogHandler
forall a b. (a -> b) -> a -> b
$ [Only String] -> String
forall a. Show a => a -> String
show [Only String]
res
withResilientConnection
:: forall a
. ResilientSettings
-> LogHandler
-> P.ConnectInfo
-> (ResilientConnection IO -> IO a)
-> IO a
withResilientConnection :: ResilientSettings
-> LogHandler
-> ConnectInfo
-> (ResilientConnection IO -> IO a)
-> IO a
withResilientConnection settings :: ResilientSettings
settings logger :: LogHandler
logger info :: ConnectInfo
info f :: ResilientConnection IO -> IO a
f = do
((,) (IORef (Maybe Connection)
-> MVar ThreadId -> (IORef (Maybe Connection), MVar ThreadId))
-> IO (IORef (Maybe Connection))
-> IO (MVar ThreadId -> (IORef (Maybe Connection), MVar ThreadId))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Connection -> IO (IORef (Maybe Connection))
forall a. a -> IO (IORef a)
newIORef Maybe Connection
forall a. Maybe a
Nothing IO (MVar ThreadId -> (IORef (Maybe Connection), MVar ThreadId))
-> IO (MVar ThreadId)
-> IO (IORef (Maybe Connection), MVar ThreadId)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (MVar ThreadId)
forall a. IO (MVar a)
newEmptyMVar) IO (IORef (Maybe Connection), MVar ThreadId)
-> ((IORef (Maybe Connection), MVar ThreadId) -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(connRef :: IORef (Maybe Connection)
connRef, signal :: MVar ThreadId
signal) ->
let shutdown :: IO ()
shutdown = MVar ThreadId -> IO ThreadId
forall a. MVar a -> IO a
readMVar MVar ThreadId
signal IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ThreadId -> IO ()
killThread
pool :: ResilientConnection IO
pool = IO Connection -> ResilientConnection IO
forall (m :: * -> *). m Connection -> ResilientConnection m
ResilientConnection (Maybe Connection -> Connection
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Connection -> Connection)
-> IO (Maybe Connection) -> IO Connection
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef (Maybe Connection) -> IO (Maybe Connection)
forall a. IORef a -> IO a
readIORef IORef (Maybe Connection)
connRef)
ka :: IO ThreadId
ka = (Seconds -> IO ()) -> ResilientConnection IO -> IO ThreadId
forall t.
Num t =>
(t -> IO ()) -> ResilientConnection IO -> IO ThreadId
keepAlive (IORef (Maybe Connection) -> Seconds -> IO ()
reconnect IORef (Maybe Connection)
connRef) ResilientConnection IO
pool
init :: IO ()
init = IORef (Maybe Connection) -> IO Connection
acquire IORef (Maybe Connection)
connRef IO Connection -> IO ThreadId -> IO ThreadId
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ThreadId
ka IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVar ThreadId -> ThreadId -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ThreadId
signal
in IO (ResilientConnection IO)
-> (ResilientConnection IO -> IO ())
-> (ResilientConnection IO -> IO a)
-> IO a
forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket (ResilientConnection IO
pool ResilientConnection IO -> IO () -> IO (ResilientConnection IO)
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ IO ()
init) (IO () -> ResilientConnection IO -> IO ()
forall b. IO b -> ResilientConnection IO -> IO b
release IO ()
shutdown) ResilientConnection IO -> IO a
f
where
acquire :: IORef (Maybe Connection) -> IO Connection
acquire ref :: IORef (Maybe Connection)
ref = do
LogHandler
logger "Connecting to PostgreSQL"
Connection
conn <- ConnectInfo -> IO Connection
P.connect ConnectInfo
info
Connection
conn Connection -> IO () -> IO Connection
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ IORef (Maybe Connection) -> Maybe Connection -> IO ()
forall a. IORef a -> a -> IO ()
atomicWriteIORef IORef (Maybe Connection)
ref (Connection -> Maybe Connection
forall a. a -> Maybe a
Just Connection
conn)
release :: IO b -> ResilientConnection IO -> IO b
release shutdown :: IO b
shutdown pool :: ResilientConnection IO
pool = do
LogHandler
logger "Closing PostgreSQL connection"
Connection
conn <- ResilientConnection IO -> IO Connection
forall (m :: * -> *). ResilientConnection m -> m Connection
getConnection ResilientConnection IO
pool
Connection -> IO ()
P.close Connection
conn
LogHandler
logger "Shutdown PostgreSQL reconnection process"
IO b
shutdown
clean :: Connection -> IO ()
clean conn :: Connection
conn = do
LogHandler
logger "Closing no longer valid PostgreSQL connection"
Connection -> IO ()
P.close Connection
conn
reconnect :: IORef (Maybe Connection) -> Seconds -> IO ()
reconnect ref :: IORef (Maybe Connection)
ref n :: Seconds
n = IO () -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
catch (IO Connection -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Connection -> IO ()) -> IO Connection -> IO ()
forall a b. (a -> b) -> a -> b
$ IORef (Maybe Connection) -> IO Connection
acquire IORef (Maybe Connection)
ref) ((SomeException -> IO ()) -> IO ())
-> (SomeException -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(SomeException
e :: SomeException) ->
LogHandler
logger (SomeException -> String
forall a. Show a => a -> String
retries SomeException
e) IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Seconds -> IO ()
sleep Seconds
n IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IORef (Maybe Connection) -> Seconds -> IO ()
reconnect IORef (Maybe Connection)
ref Seconds
n'
where
retries :: a -> String
retries e :: a
e = a -> String
forall a. Show a => a -> String
show a
e String -> ShowS
forall a. Semigroup a => a -> a -> a
<> "\n >>> Retrying in " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Seconds -> String
forall a. Show a => a -> String
show Seconds
n String -> ShowS
forall a. Semigroup a => a -> a -> a
<> " seconds."
t :: Seconds
t = ResilientSettings -> Seconds
exponentialBackoffThreshold ResilientSettings
settings
n' :: Seconds
n' = if Seconds
n Seconds -> Seconds -> Bool
forall a. Ord a => a -> a -> Bool
>= Seconds
t then Seconds
t else Seconds
n Seconds -> Seconds -> Seconds
forall a. Num a => a -> a -> a
* 2
keepAlive :: (t -> IO ()) -> ResilientConnection IO -> IO ThreadId
keepAlive rec :: t -> IO ()
rec pool :: ResilientConnection IO
pool = IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Seconds -> IO ()
sleep (Seconds -> IO ()) -> Seconds -> IO ()
forall a b. (a -> b) -> a -> b
$ ResilientSettings -> Seconds
healthCheckEvery ResilientSettings
settings
LogHandler
logger "Checking PostgreSQL connection status"
Connection
conn <- ResilientConnection IO -> IO Connection
forall (m :: * -> *). ResilientConnection m -> m Connection
getConnection ResilientConnection IO
pool
IO () -> (IOError -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
catch
(LogHandler -> Connection -> IO ()
healthCheck LogHandler
logger Connection
conn)
(\(IOError
e :: IOError) ->
if IOError -> IOErrorType
ioe_type IOError
e IOErrorType -> IOErrorType -> Bool
forall a. Eq a => a -> a -> Bool
== IOErrorType
ResourceVanished Bool -> Bool -> Bool
|| IOError -> IOErrorType
ioe_type IOError
e IOErrorType -> IOErrorType -> Bool
forall a. Eq a => a -> a -> Bool
== IOErrorType
OtherError
then Connection -> IO ()
clean Connection
conn IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> t -> IO ()
rec 1
else LogHandler
logger (IOError -> String
forall a. Show a => a -> String
show IOError
e) IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> DBConnectionError -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM DBConnectionError
DBConnectionError
)