{-# LANGUAGE CPP #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Database.PostgreSQL.Simple.Notification
-- Copyright   :  (c) 2011-2015 Leon P Smith
--                (c) 2012 Joey Adams
-- License     :  BSD3
--
-- Maintainer  :  leon@melding-monads.com
-- Stability   :  experimental
--
-- Support for receiving asynchronous notifications via PostgreSQL's
-- Listen/Notify mechanism.  See
-- <https://www.postgresql.org/docs/9.5/static/sql-notify.html> for more
-- information.
--
-- Note that on Windows,  @getNotification@ currently uses a polling loop
-- of 1 second to check for more notifications,  due to some inadequacies
-- in GHC's IO implementation and interface on that platform.   See GHC
-- issue #7353 for more information.  While this workaround is less than
-- ideal,  notifications are still better than polling the database directly.
-- Notifications do not create any extra work for the backend,  and are
-- likely cheaper on the client side as well.
--
-- <https://hackage.haskell.org/trac/ghc/ticket/7353>
--
-----------------------------------------------------------------------------

module Database.PostgreSQL.Simple.Notification
     ( Notification(..)
     , getNotification
     , getNotificationNonBlocking
     , getBackendPID
     ) where

import           Control.Monad ( join, void )
import           Control.Exception ( throwIO, catch )
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import           Database.PostgreSQL.Simple.Internal
import qualified Database.PostgreSQL.LibPQ as PQ
import           System.Posix.Types ( CPid )
import           GHC.IO.Exception ( ioe_location )

#if defined(mingw32_HOST_OS)
import           Control.Concurrent ( threadDelay )
#elif !MIN_VERSION_base(4,7,0)
import           Control.Concurrent ( threadWaitRead )
#else
import           GHC.Conc           ( atomically )
import           Control.Concurrent ( threadWaitReadSTM )
#endif

data Notification = Notification
   { Notification -> CPid
notificationPid     :: {-# UNPACK #-} !CPid
   , Notification -> ByteString
notificationChannel :: {-# UNPACK #-} !B.ByteString
   , Notification -> ByteString
notificationData    :: {-# UNPACK #-} !B.ByteString
   } deriving (Int -> Notification -> ShowS
[Notification] -> ShowS
Notification -> String
(Int -> Notification -> ShowS)
-> (Notification -> String)
-> ([Notification] -> ShowS)
-> Show Notification
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Notification] -> ShowS
$cshowList :: [Notification] -> ShowS
show :: Notification -> String
$cshow :: Notification -> String
showsPrec :: Int -> Notification -> ShowS
$cshowsPrec :: Int -> Notification -> ShowS
Show, Notification -> Notification -> Bool
(Notification -> Notification -> Bool)
-> (Notification -> Notification -> Bool) -> Eq Notification
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Notification -> Notification -> Bool
$c/= :: Notification -> Notification -> Bool
== :: Notification -> Notification -> Bool
$c== :: Notification -> Notification -> Bool
Eq)

convertNotice :: PQ.Notify -> Notification
convertNotice :: Notify -> Notification
convertNotice PQ.Notify{ByteString
CPid
notifyRelname :: Notify -> ByteString
notifyBePid :: Notify -> CPid
notifyExtra :: Notify -> ByteString
notifyExtra :: ByteString
notifyBePid :: CPid
notifyRelname :: ByteString
..}
    = Notification :: CPid -> ByteString -> ByteString -> Notification
Notification { notificationPid :: CPid
notificationPid     = CPid
notifyBePid
                   , notificationChannel :: ByteString
notificationChannel = ByteString
notifyRelname
                   , notificationData :: ByteString
notificationData    = ByteString
notifyExtra   }

-- | Returns a single notification.   If no notifications are available,
--   'getNotification' blocks until one arrives.
--
--   It is safe to call 'getNotification' on a connection that is concurrently
--   being used for other purposes,   note however that PostgreSQL does not
--   deliver notifications while a connection is inside a transaction.

getNotification :: Connection -> IO Notification
getNotification :: Connection -> IO Notification
getNotification Connection
conn = IO (IO Notification) -> IO Notification
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (IO (IO Notification) -> IO Notification)
-> IO (IO Notification) -> IO Notification
forall a b. (a -> b) -> a -> b
$ Connection
-> (Connection -> IO (IO Notification)) -> IO (IO Notification)
forall a. Connection -> (Connection -> IO a) -> IO a
withConnection Connection
conn Connection -> IO (IO Notification)
fetch
  where
    funcName :: ByteString
funcName = ByteString
"Database.PostgreSQL.Simple.Notification.getNotification"

    fetch :: Connection -> IO (IO Notification)
fetch Connection
c = do
        Maybe Notify
mmsg <- Connection -> IO (Maybe Notify)
PQ.notifies Connection
c
        case Maybe Notify
mmsg of
          Just Notify
msg -> IO Notification -> IO (IO Notification)
forall (m :: * -> *) a. Monad m => a -> m a
return (Notification -> IO Notification
forall (m :: * -> *) a. Monad m => a -> m a
return (Notification -> IO Notification)
-> Notification -> IO Notification
forall a b. (a -> b) -> a -> b
$! Notify -> Notification
convertNotice Notify
msg)
          Maybe Notify
Nothing -> do
              Maybe Fd
mfd <- Connection -> IO (Maybe Fd)
PQ.socket Connection
c
              case Maybe Fd
mfd of
                Maybe Fd
Nothing  -> IO Notification -> IO (IO Notification)
forall (m :: * -> *) a. Monad m => a -> m a
return (IOError -> IO Notification
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO Notification) -> IOError -> IO Notification
forall a b. (a -> b) -> a -> b
$! ByteString -> IOError
fdError ByteString
funcName)
#if defined(mingw32_HOST_OS)
                -- threadWaitRead doesn't work for sockets on Windows, so just
                -- poll for input every second (PQconsumeInput is non-blocking).
                --
                -- We could call select(), but FFI calls can't be interrupted
                -- with async exceptions, whereas threadDelay can.
                Just _fd -> do
                  return (threadDelay 1000000 >> loop)
#elif !MIN_VERSION_base(4,7,0)
                -- Technically there's a race condition that is usually benign.
                -- If the connection is closed or reset after we drop the
                -- lock,  and then the fd index is reallocated to a new
                -- descriptor before we call threadWaitRead,  then
                -- we could end up waiting on the wrong descriptor.
                --
                -- Now, if the descriptor becomes readable promptly,  then
                -- it's no big deal as we'll wake up and notice the change
                -- on the next iteration of the loop.   But if are very
                -- unlucky,  then we could end up waiting a long time.
                Just fd  -> do
                  return $ do
                    threadWaitRead fd `catch` (throwIO . setIOErrorLocation)
                    loop
#else
                -- This case fixes the race condition above.   By registering
                -- our interest in the descriptor before we drop the lock,
                -- there is no opportunity for the descriptor index to be
                -- reallocated on us.
                --
                -- (That is, assuming there isn't concurrently executing
                -- code that manipulates the descriptor without holding
                -- the lock... but such a major bug is likely to exhibit
                -- itself in an at least somewhat more dramatic fashion.)
                Just Fd
fd  -> do
                  (STM ()
waitRead, IO ()
_) <- Fd -> IO (STM (), IO ())
threadWaitReadSTM Fd
fd
                  IO Notification -> IO (IO Notification)
forall (m :: * -> *) a. Monad m => a -> m a
return (IO Notification -> IO (IO Notification))
-> IO Notification -> IO (IO Notification)
forall a b. (a -> b) -> a -> b
$ do
                    STM () -> IO ()
forall a. STM a -> IO a
atomically STM ()
waitRead IO () -> (IOError -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO ()) -> (IOError -> IOError) -> IOError -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> IOError
setIOErrorLocation)
                    IO Notification
loop
#endif

    loop :: IO Notification
loop = IO (IO Notification) -> IO Notification
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (IO (IO Notification) -> IO Notification)
-> IO (IO Notification) -> IO Notification
forall a b. (a -> b) -> a -> b
$ Connection
-> (Connection -> IO (IO Notification)) -> IO (IO Notification)
forall a. Connection -> (Connection -> IO a) -> IO a
withConnection Connection
conn ((Connection -> IO (IO Notification)) -> IO (IO Notification))
-> (Connection -> IO (IO Notification)) -> IO (IO Notification)
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
             IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> IO Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO Bool
PQ.consumeInput Connection
c
             Connection -> IO (IO Notification)
fetch Connection
c

    setIOErrorLocation :: IOError -> IOError
    setIOErrorLocation :: IOError -> IOError
setIOErrorLocation IOError
err = IOError
err { ioe_location :: String
ioe_location = ByteString -> String
B8.unpack ByteString
funcName }


-- | Non-blocking variant of 'getNotification'.   Returns a single notification,
-- if available.   If no notifications are available,  returns 'Nothing'.

getNotificationNonBlocking :: Connection -> IO (Maybe Notification)
getNotificationNonBlocking :: Connection -> IO (Maybe Notification)
getNotificationNonBlocking Connection
conn =
    Connection
-> (Connection -> IO (Maybe Notification))
-> IO (Maybe Notification)
forall a. Connection -> (Connection -> IO a) -> IO a
withConnection Connection
conn ((Connection -> IO (Maybe Notification))
 -> IO (Maybe Notification))
-> (Connection -> IO (Maybe Notification))
-> IO (Maybe Notification)
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
        Maybe Notify
mmsg <- Connection -> IO (Maybe Notify)
PQ.notifies Connection
c
        case Maybe Notify
mmsg of
          Just Notify
msg -> Maybe Notification -> IO (Maybe Notification)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Notification -> IO (Maybe Notification))
-> Maybe Notification -> IO (Maybe Notification)
forall a b. (a -> b) -> a -> b
$! Notification -> Maybe Notification
forall a. a -> Maybe a
Just (Notification -> Maybe Notification)
-> Notification -> Maybe Notification
forall a b. (a -> b) -> a -> b
$! Notify -> Notification
convertNotice Notify
msg
          Maybe Notify
Nothing -> do
              Bool
_ <- Connection -> IO Bool
PQ.consumeInput Connection
c
              Maybe Notify
mmsg' <- Connection -> IO (Maybe Notify)
PQ.notifies Connection
c
              case Maybe Notify
mmsg' of
                Just Notify
msg -> Maybe Notification -> IO (Maybe Notification)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Notification -> IO (Maybe Notification))
-> Maybe Notification -> IO (Maybe Notification)
forall a b. (a -> b) -> a -> b
$! Notification -> Maybe Notification
forall a. a -> Maybe a
Just (Notification -> Maybe Notification)
-> Notification -> Maybe Notification
forall a b. (a -> b) -> a -> b
$! Notify -> Notification
convertNotice Notify
msg
                Maybe Notify
Nothing  -> Maybe Notification -> IO (Maybe Notification)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Notification
forall a. Maybe a
Nothing

-- | Returns the process 'CPid' of the backend server process
-- handling this connection.
--
-- The backend PID is useful for debugging purposes and for comparison
-- to NOTIFY messages (which include the PID of the notifying backend
-- process). Note that the PID belongs to a process executing on the
-- database server host, not the local host!

getBackendPID :: Connection -> IO CPid
getBackendPID :: Connection -> IO CPid
getBackendPID Connection
conn = Connection -> (Connection -> IO CPid) -> IO CPid
forall a. Connection -> (Connection -> IO a) -> IO a
withConnection Connection
conn Connection -> IO CPid
PQ.backendPID