-- |
-- Module      : PostgresWebsockets.Broadcast
-- Description : Build a Hasql.Notifications based producer 'Multiplexer'.
--
-- Uses Broadcast module adding database as a source producer.
-- This module provides a function to produce a 'Multiplexer' from a Hasql 'Connection'.
-- The producer issues a LISTEN command upon Open commands and UNLISTEN upon Close.
module PostgresWebsockets.HasqlBroadcast
  ( newHasqlBroadcaster,
    newHasqlBroadcasterOrError,
    -- re-export
    acquire,
    relayMessages,
    relayMessagesForever,
  )
where

import Control.Retry (RetryStatus (..), capDelay, exponentialBackoff, retrying)
import Data.Aeson (Value (..), decode)
import Data.Either.Combinators (mapBoth)
import Data.Function (id)
import Data.HashMap.Lazy (lookupDefault)
import GHC.Show
import Hasql.Connection
import qualified Hasql.Decoders as HD
import qualified Hasql.Encoders as HE
import Hasql.Notifications
import qualified Hasql.Session as H
import qualified Hasql.Statement as H
import PostgresWebsockets.Broadcast
import Protolude hiding (putErrLn, show, toS)
import Protolude.Conv

-- | Returns a multiplexer from a connection URI, keeps trying to connect in case there is any error.
--   This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners
newHasqlBroadcaster :: IO () -> Text -> Int -> Maybe Int -> ByteString -> IO Multiplexer
newHasqlBroadcaster :: IO () -> Text -> Int -> Maybe Int -> ByteString -> IO Multiplexer
newHasqlBroadcaster IO ()
onConnectionFailure Text
ch Int
maxRetries Maybe Int
checkInterval = IO Connection -> IO Multiplexer
newHasqlBroadcasterForConnection (IO Connection -> IO Multiplexer)
-> (ByteString -> IO Connection) -> ByteString -> IO Multiplexer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> IO Connection
tryUntilConnected Int
maxRetries
  where
    newHasqlBroadcasterForConnection :: IO Connection -> IO Multiplexer
newHasqlBroadcasterForConnection = IO () -> Text -> Maybe Int -> IO Connection -> IO Multiplexer
newHasqlBroadcasterForChannel IO ()
onConnectionFailure Text
ch Maybe Int
checkInterval

-- | Returns a multiplexer from a connection URI or an error message on the left case
--   This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners
newHasqlBroadcasterOrError :: IO () -> Text -> ByteString -> IO (Either ByteString Multiplexer)
newHasqlBroadcasterOrError :: IO () -> Text -> ByteString -> IO (Either ByteString Multiplexer)
newHasqlBroadcasterOrError IO ()
onConnectionFailure Text
ch =
  ByteString -> IO (Either ConnectionError Connection)
acquire (ByteString -> IO (Either ConnectionError Connection))
-> (Either ConnectionError Connection
    -> IO (Either ByteString Multiplexer))
-> ByteString
-> IO (Either ByteString Multiplexer)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (Either ByteString (IO Multiplexer)
-> IO (Either ByteString Multiplexer)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence (Either ByteString (IO Multiplexer)
 -> IO (Either ByteString Multiplexer))
-> (Either ConnectionError Connection
    -> Either ByteString (IO Multiplexer))
-> Either ConnectionError Connection
-> IO (Either ByteString Multiplexer)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ConnectionError -> ByteString)
-> (Connection -> IO Multiplexer)
-> Either ConnectionError Connection
-> Either ByteString (IO Multiplexer)
forall a c b d. (a -> c) -> (b -> d) -> Either a b -> Either c d
mapBoth (String -> ByteString
forall a b. StringConv a b => a -> b
toSL (String -> ByteString)
-> (ConnectionError -> String) -> ConnectionError -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnectionError -> String
forall a. Show a => a -> String
show) (IO Connection -> IO Multiplexer
newHasqlBroadcasterForConnection (IO Connection -> IO Multiplexer)
-> (Connection -> IO Connection) -> Connection -> IO Multiplexer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> IO Connection
forall (m :: * -> *) a. Monad m => a -> m a
return))
  where
    newHasqlBroadcasterForConnection :: IO Connection -> IO Multiplexer
newHasqlBroadcasterForConnection = IO () -> Text -> Maybe Int -> IO Connection -> IO Multiplexer
newHasqlBroadcasterForChannel IO ()
onConnectionFailure Text
ch Maybe Int
forall a. Maybe a
Nothing

tryUntilConnected :: Int -> ByteString -> IO Connection
tryUntilConnected :: Int -> ByteString -> IO Connection
tryUntilConnected Int
maxRetries =
  (Either ConnectionError Connection -> Connection)
-> IO (Either ConnectionError Connection) -> IO Connection
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((ConnectionError -> Connection)
-> (Connection -> Connection)
-> Either ConnectionError Connection
-> Connection
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Text -> ConnectionError -> Connection
forall a. HasCallStack => Text -> a
panic Text
"Failure on connection retry") Connection -> Connection
forall a. a -> a
id) (IO (Either ConnectionError Connection) -> IO Connection)
-> (ByteString -> IO (Either ConnectionError Connection))
-> ByteString
-> IO Connection
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> IO (Either ConnectionError Connection)
retryConnection
  where
    retryConnection :: ByteString -> IO (Either ConnectionError Connection)
retryConnection ByteString
conStr = RetryPolicyM IO
-> (RetryStatus -> Either ConnectionError Connection -> IO Bool)
-> (RetryStatus -> IO (Either ConnectionError Connection))
-> IO (Either ConnectionError Connection)
forall (m :: * -> *) b.
MonadIO m =>
RetryPolicyM m
-> (RetryStatus -> b -> m Bool) -> (RetryStatus -> m b) -> m b
retrying RetryPolicyM IO
retryPolicy RetryStatus -> Either ConnectionError Connection -> IO Bool
shouldRetry (IO (Either ConnectionError Connection)
-> RetryStatus -> IO (Either ConnectionError Connection)
forall a b. a -> b -> a
const (IO (Either ConnectionError Connection)
 -> RetryStatus -> IO (Either ConnectionError Connection))
-> IO (Either ConnectionError Connection)
-> RetryStatus
-> IO (Either ConnectionError Connection)
forall a b. (a -> b) -> a -> b
$ ByteString -> IO (Either ConnectionError Connection)
acquire ByteString
conStr)
    maxDelayInMicroseconds :: Int
maxDelayInMicroseconds = Int
32000000
    firstDelayInMicroseconds :: Int
firstDelayInMicroseconds = Int
1000000
    retryPolicy :: RetryPolicyM IO
retryPolicy = Int -> RetryPolicyM IO -> RetryPolicyM IO
forall (m :: * -> *).
Monad m =>
Int -> RetryPolicyM m -> RetryPolicyM m
capDelay Int
maxDelayInMicroseconds (RetryPolicyM IO -> RetryPolicyM IO)
-> RetryPolicyM IO -> RetryPolicyM IO
forall a b. (a -> b) -> a -> b
$ Int -> RetryPolicy
exponentialBackoff Int
firstDelayInMicroseconds
    shouldRetry :: RetryStatus -> Either ConnectionError Connection -> IO Bool
    shouldRetry :: RetryStatus -> Either ConnectionError Connection -> IO Bool
shouldRetry RetryStatus {Int
Maybe Int
rsIterNumber :: RetryStatus -> Int
rsCumulativeDelay :: RetryStatus -> Int
rsPreviousDelay :: RetryStatus -> Maybe Int
rsPreviousDelay :: Maybe Int
rsCumulativeDelay :: Int
rsIterNumber :: Int
..} Either ConnectionError Connection
con =
      case Either ConnectionError Connection
con of
        Left ConnectionError
err -> do
          Text -> IO ()
putErrLn (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"Error connecting notification listener to database: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> (String -> Text
forall a b. StringConv a b => a -> b
toS (String -> Text)
-> (ConnectionError -> String) -> ConnectionError -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnectionError -> String
forall a. Show a => a -> String
show) ConnectionError
err
          Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ Int
rsIterNumber Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
maxRetries Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
        Either ConnectionError Connection
_ -> Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

-- | Returns a multiplexer from a channel and an IO Connection, listen for different database notifications on the provided channel using the connection produced.
--
--   This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners
--
--   To listen on channels *chat*
--
--   @
--   import Protolude
--   import PostgresWebsockets.HasqlBroadcast
--   import PostgresWebsockets.Broadcast
--   import Hasql.Connection
--
--   main = do
--    conOrError <- H.acquire "postgres://localhost/test_database"
--    let con = either (panic . show) id conOrError :: Connection
--    multi <- newHasqlBroadcaster con
--
--    onMessage multi "chat" (\ch ->
--      forever $ fmap print (atomically $ readTChan ch)
--   @
newHasqlBroadcasterForChannel :: IO () -> Text -> Maybe Int -> IO Connection -> IO Multiplexer
newHasqlBroadcasterForChannel :: IO () -> Text -> Maybe Int -> IO Connection -> IO Multiplexer
newHasqlBroadcasterForChannel IO ()
onConnectionFailure Text
ch Maybe Int
checkInterval IO Connection
getCon = do
  Multiplexer
multi <- (TQueue Message -> IO ())
-> (Either SomeException () -> IO ()) -> IO Multiplexer
forall a.
(TQueue Message -> IO a)
-> (Either SomeException a -> IO ()) -> IO Multiplexer
newMultiplexer TQueue Message -> IO ()
openProducer ((Either SomeException () -> IO ()) -> IO Multiplexer)
-> (Either SomeException () -> IO ()) -> IO Multiplexer
forall a b. (a -> b) -> a -> b
$ IO () -> Either SomeException () -> IO ()
forall a b. a -> b -> a
const IO ()
onConnectionFailure
  case Maybe Int
checkInterval of
    Just Int
i -> Multiplexer -> Int -> IO Bool -> IO ()
superviseMultiplexer Multiplexer
multi Int
i IO Bool
shouldRestart
    Maybe Int
_ -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ Multiplexer -> IO ThreadId
relayMessagesForever Multiplexer
multi
  Multiplexer -> IO Multiplexer
forall (m :: * -> *) a. Monad m => a -> m a
return Multiplexer
multi
  where
    toMsg :: Text -> Text -> Message
    toMsg :: Text -> Text -> Message
toMsg Text
c Text
m = case ByteString -> Maybe Value
forall a. FromJSON a => ByteString -> Maybe a
decode (Text -> ByteString
forall a b. StringConv a b => a -> b
toS Text
m) of
      Just Value
v -> Text -> Text -> Message
Message (Text -> Value -> Text
channelDef Text
c Value
v) Text
m
      Maybe Value
Nothing -> Text -> Text -> Message
Message Text
c Text
m

    lookupStringDef :: Text -> Text -> Value -> Text
    lookupStringDef :: Text -> Text -> Value -> Text
lookupStringDef Text
key Text
d (Object Object
obj) =
      case Value -> Text -> Object -> Value
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
lookupDefault (Text -> Value
String (Text -> Value) -> Text -> Value
forall a b. (a -> b) -> a -> b
$ Text -> Text
forall a b. StringConv a b => a -> b
toS Text
d) Text
key Object
obj of
        String Text
s -> Text -> Text
forall a b. StringConv a b => a -> b
toS Text
s
        Value
_ -> Text -> Text
forall a b. StringConv a b => a -> b
toS Text
d
    lookupStringDef Text
_ Text
d Value
_ = Text -> Text
forall a b. StringConv a b => a -> b
toS Text
d
    channelDef :: Text -> Value -> Text
channelDef = Text -> Text -> Value -> Text
lookupStringDef Text
"channel"
    shouldRestart :: IO Bool
shouldRestart = do
      Connection
con <- IO Connection
getCon
      Bool -> Bool
not (Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> Text -> IO Bool
isListening Connection
con Text
ch

    openProducer :: TQueue Message -> IO ()
openProducer TQueue Message
msgQ = do
      Connection
con <- IO Connection
getCon
      Connection -> PgIdentifier -> IO ()
listen Connection
con (PgIdentifier -> IO ()) -> PgIdentifier -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> PgIdentifier
toPgIdentifier Text
ch
      (ByteString -> ByteString -> IO ()) -> Connection -> IO ()
waitForNotifications
        (\ByteString
c ByteString
m -> STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue Message -> Message -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue TQueue Message
msgQ (Message -> STM ()) -> Message -> STM ()
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Message
toMsg (ByteString -> Text
forall a b. StringConv a b => a -> b
toS ByteString
c) (ByteString -> Text
forall a b. StringConv a b => a -> b
toS ByteString
m))
        Connection
con

putErrLn :: Text -> IO ()
putErrLn :: Text -> IO ()
putErrLn = Handle -> Text -> IO ()
forall a (m :: * -> *). (Print a, MonadIO m) => Handle -> a -> m ()
hPutStrLn Handle
stderr

isListening :: Connection -> Text -> IO Bool
isListening :: Connection -> Text -> IO Bool
isListening Connection
con Text
ch = do
  Either QueryError Bool
resultOrError <- Session Bool -> Connection -> IO (Either QueryError Bool)
forall a. Session a -> Connection -> IO (Either QueryError a)
H.run Session Bool
session Connection
con
  Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ Bool -> Either QueryError Bool -> Bool
forall b a. b -> Either a b -> b
fromRight Bool
False Either QueryError Bool
resultOrError
  where
    session :: Session Bool
session = Text -> Statement Text Bool -> Session Bool
forall params result.
params -> Statement params result -> Session result
H.statement Text
chPattern Statement Text Bool
isListeningStatement
    chPattern :: Text
chPattern = Text
"listen%" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
ch Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"%"

isListeningStatement :: H.Statement Text Bool
isListeningStatement :: Statement Text Bool
isListeningStatement =
  ByteString
-> Params Text -> Result Bool -> Bool -> Statement Text Bool
forall a b.
ByteString -> Params a -> Result b -> Bool -> Statement a b
H.Statement ByteString
sql Params Text
encoder Result Bool
decoder Bool
True
  where
    sql :: ByteString
sql = ByteString
"select exists (select * from pg_stat_activity where datname = current_database() and query ilike $1);"
    encoder :: Params Text
encoder = NullableOrNot Value Text -> Params Text
forall a. NullableOrNot Value a -> Params a
HE.param (NullableOrNot Value Text -> Params Text)
-> NullableOrNot Value Text -> Params Text
forall a b. (a -> b) -> a -> b
$ Value Text -> NullableOrNot Value Text
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
HE.nonNullable Value Text
HE.text
    decoder :: Result Bool
decoder = Row Bool -> Result Bool
forall a. Row a -> Result a
HD.singleRow (NullableOrNot Value Bool -> Row Bool
forall a. NullableOrNot Value a -> Row a
HD.column (Value Bool -> NullableOrNot Value Bool
forall (decoder :: * -> *) a. decoder a -> NullableOrNot decoder a
HD.nonNullable Value Bool
HD.bool))