-- | @LISTEN@/@NOTIFY@ with @hasql@.
module Hasql.ListenNotify
  ( -- * Listen
    Identifier (..),
    listen,
    unlisten,
    unlistenAll,
    escapeIdentifier,
    Notification (..),
    await,
    poll,
    backendPid,

    -- * Notify
    Notify (..),
    notify,
  )
where

import Control.Exception (throwIO, try)
import Control.Monad.Except (throwError)
import Control.Monad.IO.Class
import Control.Monad.Reader (ask)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Builder as ByteString (Builder)
import qualified Data.ByteString.Builder as ByteString.Builder
import qualified Data.ByteString.Lazy as ByteString.Lazy
import Data.Functor.Contravariant ((>$<))
import Data.Text (Text)
import qualified Data.Text.Encoding as Text
import qualified Database.PostgreSQL.LibPQ as LibPQ
import GHC.Conc.IO (threadWaitRead)
import GHC.Generics (Generic)
import qualified Hasql.Connection as Connection
import qualified Hasql.Decoders as Decoders
import qualified Hasql.Encoders as Encoders
import Hasql.Session (Session)
import qualified Hasql.Session as Session
import Hasql.Statement (Statement (..))
import System.Posix.Types (CPid)

-- | Listen to a channel.
--
-- https://www.postgresql.org/docs/current/sql-listen.html
listen :: Identifier -> Statement () ()
listen :: Identifier -> Statement () ()
listen (Identifier ByteString
chan) =
  forall a b.
ByteString -> Params a -> Result b -> Bool -> Statement a b
Statement (Builder -> ByteString
builderToByteString Builder
sql) Params ()
Encoders.noParams Result ()
Decoders.noResult Bool
False
  where
    sql :: ByteString.Builder
    sql :: Builder
sql =
      Builder
"LISTEN " forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
ByteString.Builder.byteString ByteString
chan

-- | Stop listening to a channel.
--
-- https://www.postgresql.org/docs/current/sql-unlisten.html
unlisten :: Identifier -> Statement () ()
unlisten :: Identifier -> Statement () ()
unlisten (Identifier ByteString
chan) =
  forall a b.
ByteString -> Params a -> Result b -> Bool -> Statement a b
Statement (Builder -> ByteString
builderToByteString Builder
sql) Params ()
Encoders.noParams Result ()
Decoders.noResult Bool
False
  where
    sql :: ByteString.Builder
    sql :: Builder
sql =
      Builder
"UNLISTEN " forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
ByteString.Builder.byteString ByteString
chan

-- | Stop listening to all channels.
--
-- https://www.postgresql.org/docs/current/sql-unlisten.html
unlistenAll :: Statement () ()
unlistenAll :: Statement () ()
unlistenAll =
  forall a b.
ByteString -> Params a -> Result b -> Bool -> Statement a b
Statement ByteString
"UNLISTEN *" Params ()
Encoders.noParams Result ()
Decoders.noResult Bool
False

-- | A Postgres identifier.
newtype Identifier
  = Identifier ByteString
  deriving newtype (Identifier -> Identifier -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Identifier -> Identifier -> Bool
$c/= :: Identifier -> Identifier -> Bool
== :: Identifier -> Identifier -> Bool
$c== :: Identifier -> Identifier -> Bool
Eq, Eq Identifier
Identifier -> Identifier -> Bool
Identifier -> Identifier -> Ordering
Identifier -> Identifier -> Identifier
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Identifier -> Identifier -> Identifier
$cmin :: Identifier -> Identifier -> Identifier
max :: Identifier -> Identifier -> Identifier
$cmax :: Identifier -> Identifier -> Identifier
>= :: Identifier -> Identifier -> Bool
$c>= :: Identifier -> Identifier -> Bool
> :: Identifier -> Identifier -> Bool
$c> :: Identifier -> Identifier -> Bool
<= :: Identifier -> Identifier -> Bool
$c<= :: Identifier -> Identifier -> Bool
< :: Identifier -> Identifier -> Bool
$c< :: Identifier -> Identifier -> Bool
compare :: Identifier -> Identifier -> Ordering
$ccompare :: Identifier -> Identifier -> Ordering
Ord, Int -> Identifier -> ShowS
[Identifier] -> ShowS
Identifier -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Identifier] -> ShowS
$cshowList :: [Identifier] -> ShowS
show :: Identifier -> String
$cshow :: Identifier -> String
showsPrec :: Int -> Identifier -> ShowS
$cshowsPrec :: Int -> Identifier -> ShowS
Show)

-- | Escape a string as a Postgres identifier.
--
--
-- https://www.postgresql.org/docs/15/libpq-exec.html
escapeIdentifier :: Text -> Session Identifier
escapeIdentifier :: Text -> Session Identifier
escapeIdentifier Text
text = do
  forall a. (Connection -> IO a) -> Session a
libpq (\Connection
conn -> forall e a. Exception e => IO a -> IO (Either e a)
try (Connection -> Text -> IO ByteString
escapeIdentifier_ Connection
conn Text
text)) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left QueryError
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError QueryError
err
    Right ByteString
identifier -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Identifier
Identifier ByteString
identifier)

escapeIdentifier_ :: LibPQ.Connection -> Text -> IO ByteString
escapeIdentifier_ :: Connection -> Text -> IO ByteString
escapeIdentifier_ Connection
conn Text
text =
  Connection -> ByteString -> IO (Maybe ByteString)
LibPQ.escapeIdentifier Connection
conn (Text -> ByteString
Text.encodeUtf8 Text
text) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe ByteString
Nothing -> forall void. Connection -> ByteString -> [Text] -> IO void
throwQueryError Connection
conn ByteString
"PQescapeIdentifier()" [Text
text]
    Just ByteString
identifier -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
identifier

-- | An incoming notification.
data Notification = Notification
  { Notification -> Text
channel :: !Text,
    Notification -> Text
payload :: !Text,
    Notification -> CPid
pid :: !CPid
  }
  deriving stock (Notification -> Notification -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Notification -> Notification -> Bool
$c/= :: Notification -> Notification -> Bool
== :: Notification -> Notification -> Bool
$c== :: Notification -> Notification -> Bool
Eq, forall x. Rep Notification x -> Notification
forall x. Notification -> Rep Notification x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Notification x -> Notification
$cfrom :: forall x. Notification -> Rep Notification x
Generic, Int -> Notification -> ShowS
[Notification] -> ShowS
Notification -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Notification] -> ShowS
$cshowList :: [Notification] -> ShowS
show :: Notification -> String
$cshow :: Notification -> String
showsPrec :: Int -> Notification -> ShowS
$cshowsPrec :: Int -> Notification -> ShowS
Show)

-- | Get the next notification received from the server.
--
-- https://www.postgresql.org/docs/current/libpq-notify.html
await :: Session Notification
await :: Session Notification
await =
  forall a. (Connection -> IO a) -> Session a
libpq (\Connection
conn -> forall e a. Exception e => IO a -> IO (Either e a)
try (Connection -> IO Notify
await_ Connection
conn)) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left QueryError
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError QueryError
err
    Right Notify
notification -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Notify -> Notification
parseNotification Notify
notification)

await_ :: LibPQ.Connection -> IO LibPQ.Notify
await_ :: Connection -> IO Notify
await_ Connection
conn =
  IO Notify
pollForNotification
  where
    pollForNotification :: IO LibPQ.Notify
    pollForNotification :: IO Notify
pollForNotification =
      Connection -> IO (Maybe Notify)
poll_ Connection
conn forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        -- Block until a notification arrives. Snag: the connection might be closed (what). If so, attempt to reset it
        -- and poll for a notification on the new connection.
        Maybe Notify
Nothing ->
          Connection -> IO (Maybe Fd)
LibPQ.socket Connection
conn forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            -- "No connection is currently open"
            Maybe Fd
Nothing -> do
              Connection -> IO ()
pqReset Connection
conn
              IO Notify
pollForNotification
            Just Fd
socket -> do
              Fd -> IO ()
threadWaitRead Fd
socket
              -- Data has appeared on the socket, but libPQ won't buffer it for us unless we do something (PQexec, etc).
              -- PQconsumeInput is provided for when we don't have anything to do except populate the notification
              -- buffer.
              Connection -> IO ()
pqConsumeInput Connection
conn
              IO Notify
pollForNotification
        Just Notify
notification -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Notify
notification

-- | Variant of 'await' that doesn't block.
poll :: Session (Maybe Notification)
poll :: Session (Maybe Notification)
poll =
  forall a. (Connection -> IO a) -> Session a
libpq (\Connection
conn -> forall e a. Exception e => IO a -> IO (Either e a)
try (Connection -> IO (Maybe Notify)
poll_ Connection
conn)) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left QueryError
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError QueryError
err
    Right Maybe Notify
maybeNotification -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Notify -> Notification
parseNotification forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Notify
maybeNotification)

-- First call `notifies` to pop a notification off of the buffer, if there is one. If there isn't, try `consumeInput` to
-- populate the buffer, followed by another followed by another `notifies`.
poll_ :: LibPQ.Connection -> IO (Maybe LibPQ.Notify)
poll_ :: Connection -> IO (Maybe Notify)
poll_ Connection
conn =
  Connection -> IO (Maybe Notify)
LibPQ.notifies Connection
conn forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe Notify
Nothing -> do
      Connection -> IO ()
pqConsumeInput Connection
conn
      Connection -> IO (Maybe Notify)
LibPQ.notifies Connection
conn
    Maybe Notify
notification -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Notify
notification

-- | Get the PID of the backend process handling this session. This can be used to filter out notifications that
-- originate from this session.
--
-- https://www.postgresql.org/docs/current/libpq-status.html
backendPid :: Session CPid
backendPid :: Session CPid
backendPid =
  forall a. (Connection -> IO a) -> Session a
libpq Connection -> IO CPid
LibPQ.backendPID

-- | An outgoing notification.
data Notify = Notify
  { Notify -> Text
channel :: !Text,
    Notify -> Text
payload :: !Text
  }
  deriving stock (Notify -> Notify -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Notify -> Notify -> Bool
$c/= :: Notify -> Notify -> Bool
== :: Notify -> Notify -> Bool
$c== :: Notify -> Notify -> Bool
Eq, forall x. Rep Notify x -> Notify
forall x. Notify -> Rep Notify x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Notify x -> Notify
$cfrom :: forall x. Notify -> Rep Notify x
Generic, Int -> Notify -> ShowS
[Notify] -> ShowS
Notify -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Notify] -> ShowS
$cshowList :: [Notify] -> ShowS
show :: Notify -> String
$cshow :: Notify -> String
showsPrec :: Int -> Notify -> ShowS
$cshowsPrec :: Int -> Notify -> ShowS
Show)

-- | Notify a channel.
--
-- https://www.postgresql.org/docs/current/sql-notify.html
notify :: Statement Notify ()
notify :: Statement Notify ()
notify =
  forall a b.
ByteString -> Params a -> Result b -> Bool -> Statement a b
Statement ByteString
sql Params Notify
encoder Result ()
Decoders.noResult Bool
True
  where
    sql :: ByteString
    sql :: ByteString
sql =
      ByteString
"SELECT pg_notify($1, $2)"

    encoder :: Encoders.Params Notify
    encoder :: Params Notify
encoder =
      ((\Notify {Text
channel :: Text
$sel:channel:Notify :: Notify -> Text
channel} -> Text
channel) forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
>$< forall a. NullableOrNot Value a -> Params a
Encoders.param (forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
Encoders.nonNullable Value Text
Encoders.text))
        forall a. Semigroup a => a -> a -> a
<> ((\Notify {Text
payload :: Text
$sel:payload:Notify :: Notify -> Text
payload} -> Text
payload) forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
>$< forall a. NullableOrNot Value a -> Params a
Encoders.param (forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
Encoders.nonNullable Value Text
Encoders.text))

------------------------------------------------------------------------------------------------------------------------
-- Little wrappers that throw

pqConsumeInput :: LibPQ.Connection -> IO ()
pqConsumeInput :: Connection -> IO ()
pqConsumeInput Connection
conn =
  Connection -> IO Bool
LibPQ.consumeInput Connection
conn forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
False -> forall void. Connection -> ByteString -> [Text] -> IO void
throwQueryError Connection
conn ByteString
"PQconsumeInput()" []
    Bool
True -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

pqReset :: LibPQ.Connection -> IO ()
pqReset :: Connection -> IO ()
pqReset Connection
conn = do
  Connection -> IO ()
LibPQ.reset Connection
conn
  Connection -> IO ConnStatus
LibPQ.status Connection
conn forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    ConnStatus
LibPQ.ConnectionOk -> forall void. Connection -> ByteString -> [Text] -> IO void
throwQueryError Connection
conn ByteString
"PQreset()" []
    ConnStatus
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- Throws a QueryError
throwQueryError :: LibPQ.Connection -> ByteString -> [Text] -> IO void
throwQueryError :: forall void. Connection -> ByteString -> [Text] -> IO void
throwQueryError Connection
conn ByteString
context [Text]
params = do
  Maybe ByteString
message <- Connection -> IO (Maybe ByteString)
LibPQ.errorMessage Connection
conn
  forall e a. Exception e => e -> IO a
throwIO (ByteString -> [Text] -> CommandError -> QueryError
Session.QueryError ByteString
context [Text]
params (Maybe ByteString -> CommandError
Session.ClientError Maybe ByteString
message))

--

libpq :: (LibPQ.Connection -> IO a) -> Session a
libpq :: forall a. (Connection -> IO a) -> Session a
libpq Connection -> IO a
action = do
  Connection
conn <- forall r (m :: * -> *). MonadReader r m => m r
ask
  forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (forall a. Connection -> (Connection -> IO a) -> IO a
Connection.withLibPQConnection Connection
conn Connection -> IO a
action)

builderToByteString :: ByteString.Builder -> ByteString
builderToByteString :: Builder -> ByteString
builderToByteString =
  ByteString -> ByteString
ByteString.Lazy.toStrict forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
ByteString.Builder.toLazyByteString
{-# INLINE builderToByteString #-}

-- Parse a Notify from a LibPQ.Notify
parseNotification :: LibPQ.Notify -> Notification
parseNotification :: Notify -> Notification
parseNotification Notify
notification =
  Notification
    { $sel:channel:Notification :: Text
channel = ByteString -> Text
Text.decodeUtf8 (Notify -> ByteString
LibPQ.notifyRelname Notify
notification),
      $sel:payload:Notification :: Text
payload = ByteString -> Text
Text.decodeUtf8 (Notify -> ByteString
LibPQ.notifyExtra Notify
notification),
      $sel:pid:Notification :: CPid
pid = Notify -> CPid
LibPQ.notifyBePid Notify
notification
    }