{-# LANGUAGE DeriveAnyClass, DerivingVia #-}
{-# LANGUAGE FlexibleContexts, ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings, RankNTypes #-}

{-|
Module      : Database.PostgreSQL.Resilient
Description : PostgreSQL single-connection pool with automatic reconnection support, built on top of postgresql-simple.
License     : Apache-2.0
Maintainer  : volpegabriel@gmail.com
Stability   : experimental

The `withResilientConnection` function gives us a `ResilientConnection` from which we can always get a health connection, while automatic reconnection with retries and exponential back-offs are being handled in the background.

@
import           Database.PostgreSQL.Resilient
import qualified Database.PostgreSQL.Simple    as P

withResilientConnection defaultResilientSettings logHandler connectInfo $ \pool ->
  (conn :: P.Connection) <- getConnection pool
  res <- P.query_ conn "SELECT * FROM foo"
  putStrLn $ show res

logHandler :: String -> IO ()
logHandler = putStrLn

connectInfo :: P.ConnectInfo
connectInfo = P.ConnectInfo
  { P.connectHost     = "localhost"
  , P.connectPort     = 5432
  , P.connectUser     = "postgres"
  , P.connectPassword = ""
  , P.connectDatabase = "store"
  }

defaultResilientSettings :: ResilientSettings
defaultResilientSettings = ResilientSettings
  { healthCheckEvery     = 3
  , exponentialBackoffThreshold = 10
  }
@
-}
module Database.PostgreSQL.Resilient
  ( ResilientConnection(..)
  , ResilientSettings(..)
  , Seconds
  , withResilientConnection
  , defaultResilientSettings
  )
where

import           Control.Concurrent             ( forkIO
                                                , killThread
                                                , threadDelay
                                                )
import           Control.Concurrent.MVar
import           Control.Monad                  ( forever )
import           Control.Monad.Catch
import           Data.IORef
import           Data.Functor                   ( void )
import           Data.Maybe                     ( fromJust )
import qualified Database.PostgreSQL.Simple    as P
import           GHC.IO.Exception
import           Prelude                 hiding ( init )

data DBConnectionError = DBConnectionError deriving (Show DBConnectionError
Typeable DBConnectionError
(Typeable DBConnectionError, Show DBConnectionError) =>
(DBConnectionError -> SomeException)
-> (SomeException -> Maybe DBConnectionError)
-> (DBConnectionError -> String)
-> Exception DBConnectionError
SomeException -> Maybe DBConnectionError
DBConnectionError -> String
DBConnectionError -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
displayException :: DBConnectionError -> String
$cdisplayException :: DBConnectionError -> String
fromException :: SomeException -> Maybe DBConnectionError
$cfromException :: SomeException -> Maybe DBConnectionError
toException :: DBConnectionError -> SomeException
$ctoException :: DBConnectionError -> SomeException
$cp2Exception :: Show DBConnectionError
$cp1Exception :: Typeable DBConnectionError
Exception, Int -> DBConnectionError -> ShowS
[DBConnectionError] -> ShowS
DBConnectionError -> String
(Int -> DBConnectionError -> ShowS)
-> (DBConnectionError -> String)
-> ([DBConnectionError] -> ShowS)
-> Show DBConnectionError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DBConnectionError] -> ShowS
$cshowList :: [DBConnectionError] -> ShowS
show :: DBConnectionError -> String
$cshow :: DBConnectionError -> String
showsPrec :: Int -> DBConnectionError -> ShowS
$cshowsPrec :: Int -> DBConnectionError -> ShowS
Show)

{- | Single connection pool with built-in reconnection -}
data ResilientConnection m = ResilientConnection
  { ResilientConnection m -> m Connection
getConnection :: m P.Connection -- ^ Get the latest healthy connection.
  }

type LogHandler = String -> IO ()

{- | Represents amount of seconds -}
newtype Seconds = Seconds Int
  deriving (Seconds -> Seconds -> Bool
(Seconds -> Seconds -> Bool)
-> (Seconds -> Seconds -> Bool) -> Eq Seconds
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Seconds -> Seconds -> Bool
$c/= :: Seconds -> Seconds -> Bool
== :: Seconds -> Seconds -> Bool
$c== :: Seconds -> Seconds -> Bool
Eq, Integer -> Seconds
Seconds -> Seconds
Seconds -> Seconds -> Seconds
(Seconds -> Seconds -> Seconds)
-> (Seconds -> Seconds -> Seconds)
-> (Seconds -> Seconds -> Seconds)
-> (Seconds -> Seconds)
-> (Seconds -> Seconds)
-> (Seconds -> Seconds)
-> (Integer -> Seconds)
-> Num Seconds
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> Seconds
$cfromInteger :: Integer -> Seconds
signum :: Seconds -> Seconds
$csignum :: Seconds -> Seconds
abs :: Seconds -> Seconds
$cabs :: Seconds -> Seconds
negate :: Seconds -> Seconds
$cnegate :: Seconds -> Seconds
* :: Seconds -> Seconds -> Seconds
$c* :: Seconds -> Seconds -> Seconds
- :: Seconds -> Seconds -> Seconds
$c- :: Seconds -> Seconds -> Seconds
+ :: Seconds -> Seconds -> Seconds
$c+ :: Seconds -> Seconds -> Seconds
Num, Eq Seconds
Eq Seconds =>
(Seconds -> Seconds -> Ordering)
-> (Seconds -> Seconds -> Bool)
-> (Seconds -> Seconds -> Bool)
-> (Seconds -> Seconds -> Bool)
-> (Seconds -> Seconds -> Bool)
-> (Seconds -> Seconds -> Seconds)
-> (Seconds -> Seconds -> Seconds)
-> Ord Seconds
Seconds -> Seconds -> Bool
Seconds -> Seconds -> Ordering
Seconds -> Seconds -> Seconds
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Seconds -> Seconds -> Seconds
$cmin :: Seconds -> Seconds -> Seconds
max :: Seconds -> Seconds -> Seconds
$cmax :: Seconds -> Seconds -> Seconds
>= :: Seconds -> Seconds -> Bool
$c>= :: Seconds -> Seconds -> Bool
> :: Seconds -> Seconds -> Bool
$c> :: Seconds -> Seconds -> Bool
<= :: Seconds -> Seconds -> Bool
$c<= :: Seconds -> Seconds -> Bool
< :: Seconds -> Seconds -> Bool
$c< :: Seconds -> Seconds -> Bool
compare :: Seconds -> Seconds -> Ordering
$ccompare :: Seconds -> Seconds -> Ordering
$cp1Ord :: Eq Seconds
Ord, Int -> Seconds -> ShowS
[Seconds] -> ShowS
Seconds -> String
(Int -> Seconds -> ShowS)
-> (Seconds -> String) -> ([Seconds] -> ShowS) -> Show Seconds
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Seconds] -> ShowS
$cshowList :: [Seconds] -> ShowS
show :: Seconds -> String
$cshow :: Seconds -> String
showsPrec :: Int -> Seconds -> ShowS
$cshowsPrec :: Int -> Seconds -> ShowS
Show) via Int

{- | The resilient settings -}
data ResilientSettings = ResilientSettings
  { ResilientSettings -> Seconds
healthCheckEvery :: Seconds            -- ^ How often to check the connection status.
  , ResilientSettings -> Seconds
exponentialBackoffThreshold :: Seconds -- ^ After this threshold, stop the exponential back-off.
  } deriving Int -> ResilientSettings -> ShowS
[ResilientSettings] -> ShowS
ResilientSettings -> String
(Int -> ResilientSettings -> ShowS)
-> (ResilientSettings -> String)
-> ([ResilientSettings] -> ShowS)
-> Show ResilientSettings
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ResilientSettings] -> ShowS
$cshowList :: [ResilientSettings] -> ShowS
show :: ResilientSettings -> String
$cshow :: ResilientSettings -> String
showsPrec :: Int -> ResilientSettings -> ShowS
$cshowsPrec :: Int -> ResilientSettings -> ShowS
Show

{- | Default resilient settings -}
defaultResilientSettings :: ResilientSettings
defaultResilientSettings :: ResilientSettings
defaultResilientSettings =
  ResilientSettings :: Seconds -> Seconds -> ResilientSettings
ResilientSettings { healthCheckEvery :: Seconds
healthCheckEvery = 3, exponentialBackoffThreshold :: Seconds
exponentialBackoffThreshold = 10 }

{- | Sleep for n amount of seconds -}
sleep :: Seconds -> IO ()
sleep :: Seconds -> IO ()
sleep (Seconds n :: Int
n) = Int -> IO ()
threadDelay (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* 1000000)

healthCheck :: LogHandler -> P.Connection -> IO ()
healthCheck :: LogHandler -> Connection -> IO ()
healthCheck logger :: LogHandler
logger conn :: Connection
conn = do
  ([Only String]
res :: [P.Only String]) <- Connection -> Query -> IO [Only String]
forall r. FromRow r => Connection -> Query -> IO [r]
P.query_ Connection
conn "SELECT version();"
  LogHandler
logger LogHandler -> LogHandler
forall a b. (a -> b) -> a -> b
$ [Only String] -> String
forall a. Show a => a -> String
show [Only String]
res

{- | Returns a `ResilientConnection` from which you can always acquire the latest connection available.
 -
 - Reconnections with configurable retries and exponential back-offs as well as closing the connection once done using it (guaranteed by `bracket`) are too handled by this function.
 - -}
withResilientConnection
  :: forall a
   . ResilientSettings
  -> LogHandler
  -> P.ConnectInfo
  -> (ResilientConnection IO -> IO a)
  -> IO a
withResilientConnection :: ResilientSettings
-> LogHandler
-> ConnectInfo
-> (ResilientConnection IO -> IO a)
-> IO a
withResilientConnection settings :: ResilientSettings
settings logger :: LogHandler
logger info :: ConnectInfo
info f :: ResilientConnection IO -> IO a
f = do
  ((,) (IORef (Maybe Connection)
 -> MVar ThreadId -> (IORef (Maybe Connection), MVar ThreadId))
-> IO (IORef (Maybe Connection))
-> IO (MVar ThreadId -> (IORef (Maybe Connection), MVar ThreadId))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Connection -> IO (IORef (Maybe Connection))
forall a. a -> IO (IORef a)
newIORef Maybe Connection
forall a. Maybe a
Nothing IO (MVar ThreadId -> (IORef (Maybe Connection), MVar ThreadId))
-> IO (MVar ThreadId)
-> IO (IORef (Maybe Connection), MVar ThreadId)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (MVar ThreadId)
forall a. IO (MVar a)
newEmptyMVar) IO (IORef (Maybe Connection), MVar ThreadId)
-> ((IORef (Maybe Connection), MVar ThreadId) -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(connRef :: IORef (Maybe Connection)
connRef, signal :: MVar ThreadId
signal) ->
    let shutdown :: IO ()
shutdown = MVar ThreadId -> IO ThreadId
forall a. MVar a -> IO a
readMVar MVar ThreadId
signal IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ThreadId -> IO ()
killThread -- ends keep-alive process
        pool :: ResilientConnection IO
pool     = IO Connection -> ResilientConnection IO
forall (m :: * -> *). m Connection -> ResilientConnection m
ResilientConnection (Maybe Connection -> Connection
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Connection -> Connection)
-> IO (Maybe Connection) -> IO Connection
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef (Maybe Connection) -> IO (Maybe Connection)
forall a. IORef a -> IO a
readIORef IORef (Maybe Connection)
connRef)
        ka :: IO ThreadId
ka       = (Seconds -> IO ()) -> ResilientConnection IO -> IO ThreadId
forall t.
Num t =>
(t -> IO ()) -> ResilientConnection IO -> IO ThreadId
keepAlive (IORef (Maybe Connection) -> Seconds -> IO ()
reconnect IORef (Maybe Connection)
connRef) ResilientConnection IO
pool
        init :: IO ()
init     = IORef (Maybe Connection) -> IO Connection
acquire IORef (Maybe Connection)
connRef IO Connection -> IO ThreadId -> IO ThreadId
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ThreadId
ka IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVar ThreadId -> ThreadId -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ThreadId
signal
    in  IO (ResilientConnection IO)
-> (ResilientConnection IO -> IO ())
-> (ResilientConnection IO -> IO a)
-> IO a
forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket (ResilientConnection IO
pool ResilientConnection IO -> IO () -> IO (ResilientConnection IO)
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ IO ()
init) (IO () -> ResilientConnection IO -> IO ()
forall b. IO b -> ResilientConnection IO -> IO b
release IO ()
shutdown) ResilientConnection IO -> IO a
f
 where
  acquire :: IORef (Maybe Connection) -> IO Connection
acquire ref :: IORef (Maybe Connection)
ref = do
    LogHandler
logger "Connecting to PostgreSQL"
    Connection
conn <- ConnectInfo -> IO Connection
P.connect ConnectInfo
info
    Connection
conn Connection -> IO () -> IO Connection
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ IORef (Maybe Connection) -> Maybe Connection -> IO ()
forall a. IORef a -> a -> IO ()
atomicWriteIORef IORef (Maybe Connection)
ref (Connection -> Maybe Connection
forall a. a -> Maybe a
Just Connection
conn)

  release :: IO b -> ResilientConnection IO -> IO b
release shutdown :: IO b
shutdown pool :: ResilientConnection IO
pool = do
    LogHandler
logger "Closing PostgreSQL connection"
    Connection
conn <- ResilientConnection IO -> IO Connection
forall (m :: * -> *). ResilientConnection m -> m Connection
getConnection ResilientConnection IO
pool
    Connection -> IO ()
P.close Connection
conn
    LogHandler
logger "Shutdown PostgreSQL reconnection process"
    IO b
shutdown

  clean :: Connection -> IO ()
clean conn :: Connection
conn = do
    LogHandler
logger "Closing no longer valid PostgreSQL connection"
    Connection -> IO ()
P.close Connection
conn

  reconnect :: IORef (Maybe Connection) -> Seconds -> IO ()
reconnect ref :: IORef (Maybe Connection)
ref n :: Seconds
n = IO () -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
catch (IO Connection -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Connection -> IO ()) -> IO Connection -> IO ()
forall a b. (a -> b) -> a -> b
$ IORef (Maybe Connection) -> IO Connection
acquire IORef (Maybe Connection)
ref) ((SomeException -> IO ()) -> IO ())
-> (SomeException -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(SomeException
e :: SomeException) ->
    LogHandler
logger (SomeException -> String
forall a. Show a => a -> String
retries SomeException
e) IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Seconds -> IO ()
sleep Seconds
n IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IORef (Maybe Connection) -> Seconds -> IO ()
reconnect IORef (Maybe Connection)
ref Seconds
n'
   where
    retries :: a -> String
retries e :: a
e = a -> String
forall a. Show a => a -> String
show a
e String -> ShowS
forall a. Semigroup a => a -> a -> a
<> "\n >>> Retrying in " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Seconds -> String
forall a. Show a => a -> String
show Seconds
n String -> ShowS
forall a. Semigroup a => a -> a -> a
<> " seconds."
    t :: Seconds
t  = ResilientSettings -> Seconds
exponentialBackoffThreshold ResilientSettings
settings
    n' :: Seconds
n' = if Seconds
n Seconds -> Seconds -> Bool
forall a. Ord a => a -> a -> Bool
>= Seconds
t then Seconds
t else Seconds
n Seconds -> Seconds -> Seconds
forall a. Num a => a -> a -> a
* 2

  keepAlive :: (t -> IO ()) -> ResilientConnection IO -> IO ThreadId
keepAlive rec :: t -> IO ()
rec pool :: ResilientConnection IO
pool = IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    Seconds -> IO ()
sleep (Seconds -> IO ()) -> Seconds -> IO ()
forall a b. (a -> b) -> a -> b
$ ResilientSettings -> Seconds
healthCheckEvery ResilientSettings
settings
    LogHandler
logger "Checking PostgreSQL connection status"
    Connection
conn <- ResilientConnection IO -> IO Connection
forall (m :: * -> *). ResilientConnection m -> m Connection
getConnection ResilientConnection IO
pool
    IO () -> (IOError -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
catch
      (LogHandler -> Connection -> IO ()
healthCheck LogHandler
logger Connection
conn)
      (\(IOError
e :: IOError) ->
        -- OtherError is thrown on every internal libpq error such as connection error
        if IOError -> IOErrorType
ioe_type IOError
e IOErrorType -> IOErrorType -> Bool
forall a. Eq a => a -> a -> Bool
== IOErrorType
ResourceVanished Bool -> Bool -> Bool
|| IOError -> IOErrorType
ioe_type IOError
e IOErrorType -> IOErrorType -> Bool
forall a. Eq a => a -> a -> Bool
== IOErrorType
OtherError
          then Connection -> IO ()
clean Connection
conn IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> t -> IO ()
rec 1
          else LogHandler
logger (IOError -> String
forall a. Show a => a -> String
show IOError
e) IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> DBConnectionError -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM DBConnectionError
DBConnectionError
      )