{-# LANGUAGE Strict #-}

{-
Module : Database.PostgreSQL.Replicant.Connection
Description : Create replication handling connections to PostgreSQL

A ReplicantConnection is different from a regular Connection because
it uses a special mode that can send replication commands that regular
Connection objects cannots send.
-}
module Database.PostgreSQL.Replicant.Connection
  ( -- * Types
    ReplicantConnection
    -- * Constructor
  , connect
  , getConnection
  , unsafeCreateConnection
  )
where

import Control.Concurrent
import Control.Exception
import Database.PostgreSQL.LibPQ
import Network.Socket.KeepAlive
import System.Posix.Types

import Database.PostgreSQL.Replicant.Exception
import Database.PostgreSQL.Replicant.Settings
import Database.PostgreSQL.Replicant.Util

newtype ReplicantConnection
  = ReplicantConnection { ReplicantConnection -> Connection
getConnection :: Connection }
  deriving ReplicantConnection -> ReplicantConnection -> Bool
(ReplicantConnection -> ReplicantConnection -> Bool)
-> (ReplicantConnection -> ReplicantConnection -> Bool)
-> Eq ReplicantConnection
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ReplicantConnection -> ReplicantConnection -> Bool
$c/= :: ReplicantConnection -> ReplicantConnection -> Bool
== :: ReplicantConnection -> ReplicantConnection -> Bool
$c== :: ReplicantConnection -> ReplicantConnection -> Bool
Eq

data ConnectResult
  = ConnectSuccess
  | ConnectFailure
  deriving (ConnectResult -> ConnectResult -> Bool
(ConnectResult -> ConnectResult -> Bool)
-> (ConnectResult -> ConnectResult -> Bool) -> Eq ConnectResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConnectResult -> ConnectResult -> Bool
$c/= :: ConnectResult -> ConnectResult -> Bool
== :: ConnectResult -> ConnectResult -> Bool
$c== :: ConnectResult -> ConnectResult -> Bool
Eq, Int -> ConnectResult -> ShowS
[ConnectResult] -> ShowS
ConnectResult -> String
(Int -> ConnectResult -> ShowS)
-> (ConnectResult -> String)
-> ([ConnectResult] -> ShowS)
-> Show ConnectResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectResult] -> ShowS
$cshowList :: [ConnectResult] -> ShowS
show :: ConnectResult -> String
$cshow :: ConnectResult -> String
showsPrec :: Int -> ConnectResult -> ShowS
$cshowsPrec :: Int -> ConnectResult -> ShowS
Show)

-- | Connect to the PostgreSQL server in replication mode
connect :: PgSettings -> IO ReplicantConnection
connect :: PgSettings -> IO ReplicantConnection
connect PgSettings
settings = do
  Connection
conn <- ByteString -> IO Connection
connectStart (ByteString -> IO Connection) -> ByteString -> IO Connection
forall a b. (a -> b) -> a -> b
$ PgSettings -> ByteString
pgConnectionString PgSettings
settings
  Maybe Fd
mFd <- Connection -> IO (Maybe Fd)
socket Connection
conn
  Fd
sockFd <- ReplicantException -> Maybe Fd -> IO Fd
forall e a. Exception e => e -> Maybe a -> IO a
maybeThrow
    (String -> ReplicantException
ReplicantException String
"withLogicalStream: could not get socket fd") Maybe Fd
mFd
  ConnectResult
pollResult <- Connection -> Fd -> IO ConnectResult
pollConnectStart Connection
conn Fd
sockFd
  case ConnectResult
pollResult of
    ConnectResult
ConnectFailure -> ReplicantException -> IO ReplicantConnection
forall e a. Exception e => e -> IO a
throwIO
      (ReplicantException -> IO ReplicantConnection)
-> ReplicantException -> IO ReplicantConnection
forall a b. (a -> b) -> a -> b
$ String -> ReplicantException
ReplicantException String
"withLogicalStream: Unable to connect to the database"
    ConnectResult
ConnectSuccess -> ReplicantConnection -> IO ReplicantConnection
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ReplicantConnection -> IO ReplicantConnection)
-> ReplicantConnection -> IO ReplicantConnection
forall a b. (a -> b) -> a -> b
$ Connection -> ReplicantConnection
ReplicantConnection Connection
conn

pollConnectStart :: Connection -> Fd -> IO ConnectResult
pollConnectStart :: Connection -> Fd -> IO ConnectResult
pollConnectStart Connection
conn fd :: Fd
fd@(Fd CInt
cint) = do
  PollingStatus
pollStatus <- Connection -> IO PollingStatus
connectPoll Connection
conn
  case PollingStatus
pollStatus of
    PollingStatus
PollingReading -> do
      Fd -> IO ()
threadWaitRead Fd
fd
      Connection -> Fd -> IO ConnectResult
pollConnectStart Connection
conn Fd
fd
    PollingStatus
PollingWriting -> do
      Fd -> IO ()
threadWaitWrite Fd
fd
      Connection -> Fd -> IO ConnectResult
pollConnectStart Connection
conn Fd
fd
    PollingStatus
PollingOk -> do
      Either KeepAliveError ()
_ <- CInt -> KeepAlive -> IO (Either KeepAliveError ())
setKeepAlive CInt
cint (KeepAlive -> IO (Either KeepAliveError ()))
-> KeepAlive -> IO (Either KeepAliveError ())
forall a b. (a -> b) -> a -> b
$ Bool -> Word32 -> Word32 -> KeepAlive
KeepAlive Bool
True Word32
60 Word32
2
      ConnectResult -> IO ConnectResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure ConnectResult
ConnectSuccess
    PollingStatus
PollingFailed -> ConnectResult -> IO ConnectResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure ConnectResult
ConnectFailure

-- | Unsafe function for wrapping regular libpq Connection.  This is
-- unsafe because the Connection needs to be set up to send
-- replication commands.  Improperly constructed connections can lead
-- to runtime exceptions.
unsafeCreateConnection :: Connection -> ReplicantConnection
unsafeCreateConnection :: Connection -> ReplicantConnection
unsafeCreateConnection = Connection -> ReplicantConnection
ReplicantConnection