-- |
-- Module      : PostgresWebsockets.Broadcast
-- Description : Build a Hasql.Notifications based producer 'Multiplexer'.
--
-- Uses Broadcast module adding database as a source producer.
-- This module provides a function to produce a 'Multiplexer' from a Hasql 'Connection'.
-- The producer issues a LISTEN command upon Open commands and UNLISTEN upon Close.
module PostgresWebsockets.HasqlBroadcast
  ( newHasqlBroadcaster,
    newHasqlBroadcasterOrError,
    -- re-export
    acquire,
    relayMessages,
    relayMessagesForever,
  )
where

import Control.Retry (RetryStatus (..), capDelay, exponentialBackoff, retrying)
import Data.Aeson (Value (..), decode)
import qualified Data.Aeson.KeyMap as JSON
import qualified Data.Aeson.Key as Key

import Data.Either.Combinators (mapBoth)
import Data.Function (id)
import GHC.Show
import Hasql.Connection
import qualified Hasql.Decoders as HD
import qualified Hasql.Encoders as HE
import Hasql.Notifications
import qualified Hasql.Session as H
import qualified Hasql.Statement as H
import PostgresWebsockets.Broadcast
import Protolude hiding (putErrLn, show, toS)
import Protolude.Conv

-- | Returns a multiplexer from a connection URI, keeps trying to connect in case there is any error.
--   This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners
newHasqlBroadcaster :: IO () -> Text -> Int -> Maybe Int -> ByteString -> IO Multiplexer
newHasqlBroadcaster :: IO () -> Text -> Int -> Maybe Int -> ByteString -> IO Multiplexer
newHasqlBroadcaster IO ()
onConnectionFailure Text
ch Int
maxRetries Maybe Int
checkInterval = IO Connection -> IO Multiplexer
newHasqlBroadcasterForConnection forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> IO Connection
tryUntilConnected Int
maxRetries
  where
    newHasqlBroadcasterForConnection :: IO Connection -> IO Multiplexer
newHasqlBroadcasterForConnection = IO () -> Text -> Maybe Int -> IO Connection -> IO Multiplexer
newHasqlBroadcasterForChannel IO ()
onConnectionFailure Text
ch Maybe Int
checkInterval

-- | Returns a multiplexer from a connection URI or an error message on the left case
--   This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners
newHasqlBroadcasterOrError :: IO () -> Text -> ByteString -> IO (Either ByteString Multiplexer)
newHasqlBroadcasterOrError :: IO () -> Text -> ByteString -> IO (Either ByteString Multiplexer)
newHasqlBroadcasterOrError IO ()
onConnectionFailure Text
ch =
  ByteString -> IO (Either ConnectionError Connection)
acquire forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a c b d. (a -> c) -> (b -> d) -> Either a b -> Either c d
mapBoth (forall a b. StringConv a b => a -> b
toSL forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show) (IO Connection -> IO Multiplexer
newHasqlBroadcasterForConnection forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return))
  where
    newHasqlBroadcasterForConnection :: IO Connection -> IO Multiplexer
newHasqlBroadcasterForConnection = IO () -> Text -> Maybe Int -> IO Connection -> IO Multiplexer
newHasqlBroadcasterForChannel IO ()
onConnectionFailure Text
ch forall a. Maybe a
Nothing

tryUntilConnected :: Int -> ByteString -> IO Connection
tryUntilConnected :: Int -> ByteString -> IO Connection
tryUntilConnected Int
maxRetries =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a. HasCallStack => Text -> a
panic Text
"Failure on connection retry") forall a. a -> a
id) forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> IO (Either ConnectionError Connection)
retryConnection
  where
    retryConnection :: ByteString -> IO (Either ConnectionError Connection)
retryConnection ByteString
conStr = forall (m :: * -> *) b.
MonadIO m =>
RetryPolicyM m
-> (RetryStatus -> b -> m Bool) -> (RetryStatus -> m b) -> m b
retrying RetryPolicyM IO
retryPolicy RetryStatus -> Either ConnectionError Connection -> IO Bool
shouldRetry (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ ByteString -> IO (Either ConnectionError Connection)
acquire ByteString
conStr)
    maxDelayInMicroseconds :: Int
maxDelayInMicroseconds = Int
32000000
    firstDelayInMicroseconds :: Int
firstDelayInMicroseconds = Int
1000000
    retryPolicy :: RetryPolicyM IO
retryPolicy = forall (m :: * -> *).
Monad m =>
Int -> RetryPolicyM m -> RetryPolicyM m
capDelay Int
maxDelayInMicroseconds forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Monad m => Int -> RetryPolicyM m
exponentialBackoff Int
firstDelayInMicroseconds
    shouldRetry :: RetryStatus -> Either ConnectionError Connection -> IO Bool
    shouldRetry :: RetryStatus -> Either ConnectionError Connection -> IO Bool
shouldRetry RetryStatus {Int
Maybe Int
rsIterNumber :: RetryStatus -> Int
rsCumulativeDelay :: RetryStatus -> Int
rsPreviousDelay :: RetryStatus -> Maybe Int
rsPreviousDelay :: Maybe Int
rsCumulativeDelay :: Int
rsIterNumber :: Int
..} Either ConnectionError Connection
con =
      case Either ConnectionError Connection
con of
        Left ConnectionError
err -> do
          Text -> IO ()
putErrLn forall a b. (a -> b) -> a -> b
$ Text
"Error connecting notification listener to database: " forall a. Semigroup a => a -> a -> a
<> (forall a b. StringConv a b => a -> b
toS forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> String
show) ConnectionError
err
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Int
rsIterNumber forall a. Ord a => a -> a -> Bool
< Int
maxRetries forall a. Num a => a -> a -> a
- Int
1
        Either ConnectionError Connection
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

-- | Returns a multiplexer from a channel and an IO Connection, listen for different database notifications on the provided channel using the connection produced.
--
--   This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners
--
--   To listen on channels *chat*
--
--   @
--   import Protolude
--   import PostgresWebsockets.HasqlBroadcast
--   import PostgresWebsockets.Broadcast
--   import Hasql.Connection
--
--   main = do
--    conOrError <- H.acquire "postgres://localhost/test_database"
--    let con = either (panic . show) id conOrError :: Connection
--    multi <- newHasqlBroadcaster con
--
--    onMessage multi "chat" (\ch ->
--      forever $ fmap print (atomically $ readTChan ch)
--   @
newHasqlBroadcasterForChannel :: IO () -> Text -> Maybe Int -> IO Connection -> IO Multiplexer
newHasqlBroadcasterForChannel :: IO () -> Text -> Maybe Int -> IO Connection -> IO Multiplexer
newHasqlBroadcasterForChannel IO ()
onConnectionFailure Text
ch Maybe Int
checkInterval IO Connection
getCon = do
  Multiplexer
multi <- forall a.
(TQueue Message -> IO a)
-> (Either SomeException a -> IO ()) -> IO Multiplexer
newMultiplexer TQueue Message -> IO ()
openProducer forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const IO ()
onConnectionFailure
  case Maybe Int
checkInterval of
    Just Int
i -> Multiplexer -> Int -> IO Bool -> IO ()
superviseMultiplexer Multiplexer
multi Int
i IO Bool
shouldRestart
    Maybe Int
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Multiplexer -> IO ThreadId
relayMessagesForever Multiplexer
multi
  forall (m :: * -> *) a. Monad m => a -> m a
return Multiplexer
multi
  where
    toMsg :: Text -> Text -> Message
    toMsg :: Text -> Text -> Message
toMsg Text
c Text
m = case forall a. FromJSON a => ByteString -> Maybe a
decode (forall a b. StringConv a b => a -> b
toS Text
m) of
      Just Value
v -> Text -> Text -> Message
Message (Text -> Value -> Text
channelDef Text
c Value
v) Text
m
      Maybe Value
Nothing -> Text -> Text -> Message
Message Text
c Text
m

    lookupStringDef :: Text -> Text -> Value -> Text
    lookupStringDef :: Text -> Text -> Value -> Text
lookupStringDef Text
key Text
d (Object Object
obj) =
      case forall {a}. a -> Text -> KeyMap a -> a
lookupDefault (Text -> Value
String forall a b. (a -> b) -> a -> b
$ forall a b. StringConv a b => a -> b
toS Text
d) Text
key Object
obj of
        String Text
s -> forall a b. StringConv a b => a -> b
toS Text
s
        Value
_ -> forall a b. StringConv a b => a -> b
toS Text
d
    lookupStringDef Text
_ Text
d Value
_ = forall a b. StringConv a b => a -> b
toS Text
d

    lookupDefault :: a -> Text -> KeyMap a -> a
lookupDefault a
d Text
key KeyMap a
obj = forall a. a -> Maybe a -> a
fromMaybe a
d forall a b. (a -> b) -> a -> b
$ forall v. Key -> KeyMap v -> Maybe v
JSON.lookup (Text -> Key
Key.fromText Text
key) KeyMap a
obj

    channelDef :: Text -> Value -> Text
channelDef = Text -> Text -> Value -> Text
lookupStringDef Text
"channel"
    shouldRestart :: IO Bool
shouldRestart = do
      Connection
con <- IO Connection
getCon
      Bool -> Bool
not forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> Text -> IO Bool
isListening Connection
con Text
ch

    openProducer :: TQueue Message -> IO ()
openProducer TQueue Message
msgQ = do
      Connection
con <- IO Connection
getCon
      Connection -> PgIdentifier -> IO ()
listen Connection
con forall a b. (a -> b) -> a -> b
$ Text -> PgIdentifier
toPgIdentifier Text
ch
      (ByteString -> ByteString -> IO ()) -> Connection -> IO ()
waitForNotifications
        (\ByteString
c ByteString
m -> forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TQueue a -> a -> STM ()
writeTQueue TQueue Message
msgQ forall a b. (a -> b) -> a -> b
$ Text -> Text -> Message
toMsg (forall a b. StringConv a b => a -> b
toS ByteString
c) (forall a b. StringConv a b => a -> b
toS ByteString
m))
        Connection
con

putErrLn :: Text -> IO ()
putErrLn :: Text -> IO ()
putErrLn = forall a (m :: * -> *). (Print a, MonadIO m) => Handle -> a -> m ()
hPutStrLn Handle
stderr

isListening :: Connection -> Text -> IO Bool
isListening :: Connection -> Text -> IO Bool
isListening Connection
con Text
ch = do
  Either QueryError Bool
resultOrError <- forall a. Session a -> Connection -> IO (Either QueryError a)
H.run Session Bool
session Connection
con
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall b a. b -> Either a b -> b
fromRight Bool
False Either QueryError Bool
resultOrError
  where
    session :: Session Bool
session = forall params result.
params -> Statement params result -> Session result
H.statement Text
chPattern Statement Text Bool
isListeningStatement
    chPattern :: Text
chPattern = Text
"listen%" forall a. Semigroup a => a -> a -> a
<> Text
ch forall a. Semigroup a => a -> a -> a
<> Text
"%"

isListeningStatement :: H.Statement Text Bool
isListeningStatement :: Statement Text Bool
isListeningStatement =
  forall a b.
ByteString -> Params a -> Result b -> Bool -> Statement a b
H.Statement ByteString
sql Params Text
encoder Result Bool
decoder Bool
True
  where
    sql :: ByteString
sql = ByteString
"select exists (select * from pg_stat_activity where datname = current_database() and query ilike $1);"
    encoder :: Params Text
encoder = forall a. NullableOrNot Value a -> Params a
HE.param forall a b. (a -> b) -> a -> b
$ forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
HE.nonNullable Value Text
HE.text
    decoder :: Result Bool
decoder = forall a. Row a -> Result a
HD.singleRow (forall a. NullableOrNot Value a -> Row a
HD.column (forall (decoder :: * -> *) a. decoder a -> NullableOrNot decoder a
HD.nonNullable Value Bool
HD.bool))