{-|
  This module has functions to send commands LISTEN and NOTIFY to the database server.
  It also has a function to wait for and handle notifications on a database connection.

  For more information check the [PostgreSQL documentation](https://www.postgresql.org/docs/current/libpq-notify.html).

-}
module Hasql.Notifications
  ( notifyPool
  , notify
  , listen
  , unlisten
  , waitForNotifications
  , PgIdentifier
  , toPgIdentifier
  , fromPgIdentifier
  ) where

import Hasql.Pool (Pool, UsageError, use)
import Hasql.Session (sql, run, statement)
import qualified Hasql.Session as S
import qualified Hasql.Statement as HST
import Hasql.Connection (Connection, withLibPQConnection)
import qualified Hasql.Decoders as HD
import qualified Hasql.Encoders as HE
import qualified Database.PostgreSQL.LibPQ as PQ
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Data.ByteString.Char8 (ByteString)
import Data.Functor.Contravariant (contramap)
import Control.Monad (void, forever)
import Control.Concurrent (threadWaitRead)
import Control.Exception (Exception, throw)

-- | A wrapped text that represents a properly escaped and quoted PostgreSQL identifier
newtype PgIdentifier = PgIdentifier Text deriving (Int -> PgIdentifier -> ShowS
[PgIdentifier] -> ShowS
PgIdentifier -> String
(Int -> PgIdentifier -> ShowS)
-> (PgIdentifier -> String)
-> ([PgIdentifier] -> ShowS)
-> Show PgIdentifier
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PgIdentifier] -> ShowS
$cshowList :: [PgIdentifier] -> ShowS
show :: PgIdentifier -> String
$cshow :: PgIdentifier -> String
showsPrec :: Int -> PgIdentifier -> ShowS
$cshowsPrec :: Int -> PgIdentifier -> ShowS
Show)

-- | Uncatchable exceptions thrown and never caught.
newtype FatalError = FatalError { FatalError -> String
fatalErrorMessage :: String }
  deriving (Int -> FatalError -> ShowS
[FatalError] -> ShowS
FatalError -> String
(Int -> FatalError -> ShowS)
-> (FatalError -> String)
-> ([FatalError] -> ShowS)
-> Show FatalError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FatalError] -> ShowS
$cshowList :: [FatalError] -> ShowS
show :: FatalError -> String
$cshow :: FatalError -> String
showsPrec :: Int -> FatalError -> ShowS
$cshowsPrec :: Int -> FatalError -> ShowS
Show)

instance Exception FatalError

-- | Given a PgIdentifier returns the wrapped text
fromPgIdentifier :: PgIdentifier -> Text
fromPgIdentifier :: PgIdentifier -> Text
fromPgIdentifier (PgIdentifier Text
bs) = Text
bs

-- | Given a text returns a properly quoted and escaped PgIdentifier
toPgIdentifier :: Text -> PgIdentifier
toPgIdentifier :: Text -> PgIdentifier
toPgIdentifier Text
x =
  Text -> PgIdentifier
PgIdentifier (Text -> PgIdentifier) -> Text -> PgIdentifier
forall a b. (a -> b) -> a -> b
$ Text
"\"" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
strictlyReplaceQuotes Text
x Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\""
  where
    strictlyReplaceQuotes :: Text -> Text
    strictlyReplaceQuotes :: Text -> Text
strictlyReplaceQuotes = Text -> Text -> Text -> Text
T.replace Text
"\"" (Text
"\"\"" :: Text)

-- | Given a Hasql Pool, a channel and a message sends a notify command to the database
notifyPool :: Pool -- ^ Pool from which the connection will be used to issue a NOTIFY command.
           -> Text -- ^ Channel where to send the notification
           -> Text -- ^ Payload to be sent with the notification
           -> IO (Either UsageError ())
notifyPool :: Pool -> Text -> Text -> IO (Either UsageError ())
notifyPool Pool
pool Text
channel Text
mesg =
   Pool -> Session () -> IO (Either UsageError ())
forall a. Pool -> Session a -> IO (Either UsageError a)
use Pool
pool ((Text, Text) -> Statement (Text, Text) () -> Session ()
forall params result.
params -> Statement params result -> Session result
statement (Text
channel, Text
mesg) Statement (Text, Text) ()
callStatement)
   where
     callStatement :: Statement (Text, Text) ()
callStatement = ByteString
-> Params (Text, Text)
-> Result ()
-> Bool
-> Statement (Text, Text) ()
forall a b.
ByteString -> Params a -> Result b -> Bool -> Statement a b
HST.Statement (ByteString
"SELECT pg_notify" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"($1, $2)") Params (Text, Text)
encoder Result ()
HD.noResult Bool
False
     encoder :: Params (Text, Text)
encoder = ((Text, Text) -> Text) -> Params Text -> Params (Text, Text)
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
contramap (Text, Text) -> Text
forall a b. (a, b) -> a
fst (NullableOrNot Value Text -> Params Text
forall a. NullableOrNot Value a -> Params a
HE.param (NullableOrNot Value Text -> Params Text)
-> NullableOrNot Value Text -> Params Text
forall a b. (a -> b) -> a -> b
$ Value Text -> NullableOrNot Value Text
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
HE.nonNullable Value Text
HE.text) Params (Text, Text) -> Params (Text, Text) -> Params (Text, Text)
forall a. Semigroup a => a -> a -> a
<> ((Text, Text) -> Text) -> Params Text -> Params (Text, Text)
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
contramap (Text, Text) -> Text
forall a b. (a, b) -> b
snd (NullableOrNot Value Text -> Params Text
forall a. NullableOrNot Value a -> Params a
HE.param (NullableOrNot Value Text -> Params Text)
-> NullableOrNot Value Text -> Params Text
forall a b. (a -> b) -> a -> b
$ Value Text -> NullableOrNot Value Text
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
HE.nonNullable Value Text
HE.text)

-- | Given a Hasql Connection, a channel and a message sends a notify command to the database
notify :: Connection -- ^ Connection to be used to send the NOTIFY command
       -> PgIdentifier -- ^ Channel where to send the notification
       -> Text -- ^ Payload to be sent with the notification
       -> IO (Either S.QueryError ())
notify :: Connection -> PgIdentifier -> Text -> IO (Either QueryError ())
notify Connection
con PgIdentifier
channel Text
mesg =
   Session () -> Connection -> IO (Either QueryError ())
forall a. Session a -> Connection -> IO (Either QueryError a)
run (ByteString -> Session ()
sql (ByteString -> Session ()) -> ByteString -> Session ()
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf8 (Text
"NOTIFY " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PgIdentifier -> Text
fromPgIdentifier PgIdentifier
channel Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
mesg Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"'")) Connection
con

{-| 
  Given a Hasql Connection and a channel sends a listen command to the database.
  Once the connection sends the LISTEN command the server register its interest in the channel.
  Hence it's important to keep track of which connection was used to open the listen command.

  Example of listening and waiting for a notification:

  @
  import System.Exit (die)

  import Hasql.Connection
  import Hasql.Notifications

  main :: IO ()
  main = do
    dbOrError <- acquire "postgres://localhost/db_name"
    case dbOrError of
        Right db -> do
            let channelToListen = toPgIdentifier "sample-channel"
            listen db channelToListen
            waitForNotifications (\channel _ -> print $ "Just got notification on channel " <> channel) db
        _ -> die "Could not open database connection"
  @
-}
listen :: Connection -- ^ Connection to be used to send the LISTEN command
       -> PgIdentifier -- ^ Channel this connection will be registered to listen to
       -> IO ()
listen :: Connection -> PgIdentifier -> IO ()
listen Connection
con PgIdentifier
channel =
  IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> (Connection -> IO ()) -> IO ()
forall a. Connection -> (Connection -> IO a) -> IO a
withLibPQConnection Connection
con Connection -> IO ()
execListen
  where
    execListen :: Connection -> IO ()
execListen Connection
pqCon = IO (Maybe Result) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Maybe Result) -> IO ()) -> IO (Maybe Result) -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> ByteString -> IO (Maybe Result)
PQ.exec Connection
pqCon (ByteString -> IO (Maybe Result))
-> ByteString -> IO (Maybe Result)
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Text
"LISTEN " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PgIdentifier -> Text
fromPgIdentifier PgIdentifier
channel

-- | Given a Hasql Connection and a channel sends a unlisten command to the database
unlisten :: Connection -- ^ Connection currently registerd by a previous 'listen' call
         -> PgIdentifier -- ^ Channel this connection will be deregistered from
         -> IO ()
unlisten :: Connection -> PgIdentifier -> IO ()
unlisten Connection
con PgIdentifier
channel =
  IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> (Connection -> IO ()) -> IO ()
forall a. Connection -> (Connection -> IO a) -> IO a
withLibPQConnection Connection
con Connection -> IO ()
execListen
  where
    execListen :: Connection -> IO ()
execListen Connection
pqCon = IO (Maybe Result) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Maybe Result) -> IO ()) -> IO (Maybe Result) -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> ByteString -> IO (Maybe Result)
PQ.exec Connection
pqCon (ByteString -> IO (Maybe Result))
-> ByteString -> IO (Maybe Result)
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Text
"UNLISTEN " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PgIdentifier -> Text
fromPgIdentifier PgIdentifier
channel


{-| 
  Given a function that handles notifications and a Hasql connection it will listen 
  on the database connection and call the handler everytime a message arrives.

  The message handler passed as first argument needs two parameters channel and payload.
  See an example of handling notification on a separate thread:

  @
  import Control.Concurrent.Async (async)
  import Control.Monad (void)
  import System.Exit (die)

  import Hasql.Connection
  import Hasql.Notifications

  notificationHandler :: ByteString -> ByteString -> IO()
  notificationHandler channel payload = 
    void $ async do
      print $ "Handle payload " <> payload <> " in its own thread"

  main :: IO ()
  main = do
    dbOrError <- acquire "postgres://localhost/db_name"
    case dbOrError of
        Right db -> do
            let channelToListen = toPgIdentifier "sample-channel"
            listen db channelToListen
            waitForNotifications notificationHandler db
        _ -> die "Could not open database connection"
  @
-}

waitForNotifications :: (ByteString -> ByteString -> IO()) -- ^ Callback function to handle incoming notifications
                     -> Connection -- ^ Connection where we will listen to
                     -> IO ()
waitForNotifications :: (ByteString -> ByteString -> IO ()) -> Connection -> IO ()
waitForNotifications ByteString -> ByteString -> IO ()
sendNotification Connection
con =
  Connection -> (Connection -> IO ()) -> IO ()
forall a. Connection -> (Connection -> IO a) -> IO a
withLibPQConnection Connection
con ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ IO Any -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Any -> IO ()) -> (Connection -> IO Any) -> Connection -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO Any
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO Any) -> (Connection -> IO ()) -> Connection -> IO Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> IO ()
pqFetch
  where
    pqFetch :: Connection -> IO ()
pqFetch Connection
pqCon = do
      Maybe Notify
mNotification <- Connection -> IO (Maybe Notify)
PQ.notifies Connection
pqCon
      case Maybe Notify
mNotification of
        Maybe Notify
Nothing -> do
          Maybe Fd
mfd <- Connection -> IO (Maybe Fd)
PQ.socket Connection
pqCon
          case Maybe Fd
mfd of
            Maybe Fd
Nothing  -> String -> IO ()
forall a. String -> a
panic String
"Error checking for PostgreSQL notifications"
            Just Fd
fd -> do
              IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
threadWaitRead Fd
fd
              IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> IO Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO Bool
PQ.consumeInput Connection
pqCon
        Just Notify
notification ->
           ByteString -> ByteString -> IO ()
sendNotification (Notify -> ByteString
PQ.notifyRelname Notify
notification) (Notify -> ByteString
PQ.notifyExtra Notify
notification)
    panic :: String -> a
    panic :: String -> a
panic String
a = FatalError -> a
forall a e. Exception e => e -> a
throw (String -> FatalError
FatalError String
a)