{-# LANGUAGE DeriveGeneric #-}

-- |
-- Module      : PostgresWebsockets.Broadcast
-- Description : Distribute messages from one producer to several consumers.
--
-- PostgresWebsockets functions to broadcast messages to several listening clients
-- This module provides a type called Multiplexer.
-- The multiplexer contains a map of channels and a producer thread.
--
-- This module avoids any database implementation details, it is used by HasqlBroadcast where
-- the database logic is combined.
module PostgresWebsockets.Broadcast
  ( Multiplexer,
    Message (..),
    newMultiplexer,
    onMessage,
    relayMessages,
    relayMessagesForever,
    superviseMultiplexer,

    -- * Re-exports
    readTQueue,
    writeTQueue,
    readTChan,
  )
where

import Control.Concurrent.STM.TChan
import Control.Concurrent.STM.TQueue
import qualified Data.Aeson as A
import Protolude hiding (toS)
import Protolude.Conv (toS)
import qualified StmContainers.Map as M

data Message = Message
  { Message -> Text
channel :: Text,
    Message -> Text
payload :: Text
  }
  deriving (Message -> Message -> Bool
(Message -> Message -> Bool)
-> (Message -> Message -> Bool) -> Eq Message
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Message -> Message -> Bool
$c/= :: Message -> Message -> Bool
== :: Message -> Message -> Bool
$c== :: Message -> Message -> Bool
Eq, Int -> Message -> ShowS
[Message] -> ShowS
Message -> String
(Int -> Message -> ShowS)
-> (Message -> String) -> ([Message] -> ShowS) -> Show Message
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Message] -> ShowS
$cshowList :: [Message] -> ShowS
show :: Message -> String
$cshow :: Message -> String
showsPrec :: Int -> Message -> ShowS
$cshowsPrec :: Int -> Message -> ShowS
Show)

data Multiplexer = Multiplexer
  { Multiplexer -> Map Text Channel
channels :: M.Map Text Channel,
    Multiplexer -> TQueue Message
messages :: TQueue Message,
    Multiplexer -> MVar ThreadId
producerThreadId :: MVar ThreadId,
    Multiplexer -> IO ThreadId
reopenProducer :: IO ThreadId
  }

data MultiplexerSnapshot = MultiplexerSnapshot
  { MultiplexerSnapshot -> Int
channelsSize :: Int,
    MultiplexerSnapshot -> Bool
messageQueueEmpty :: Bool,
    MultiplexerSnapshot -> Text
producerId :: Text
  }
  deriving ((forall x. MultiplexerSnapshot -> Rep MultiplexerSnapshot x)
-> (forall x. Rep MultiplexerSnapshot x -> MultiplexerSnapshot)
-> Generic MultiplexerSnapshot
forall x. Rep MultiplexerSnapshot x -> MultiplexerSnapshot
forall x. MultiplexerSnapshot -> Rep MultiplexerSnapshot x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep MultiplexerSnapshot x -> MultiplexerSnapshot
$cfrom :: forall x. MultiplexerSnapshot -> Rep MultiplexerSnapshot x
Generic)

data Channel = Channel
  { Channel -> TChan Message
broadcast :: TChan Message,
    Channel -> Integer
listeners :: Integer
  }

instance A.ToJSON MultiplexerSnapshot

-- | Given a multiplexer derive a type that can be printed for debugging or logging purposes
takeSnapshot :: Multiplexer -> IO MultiplexerSnapshot
takeSnapshot :: Multiplexer -> IO MultiplexerSnapshot
takeSnapshot Multiplexer
multi =
  Int -> Bool -> Text -> MultiplexerSnapshot
MultiplexerSnapshot (Int -> Bool -> Text -> MultiplexerSnapshot)
-> IO Int -> IO (Bool -> Text -> MultiplexerSnapshot)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Int
size IO (Bool -> Text -> MultiplexerSnapshot)
-> IO Bool -> IO (Text -> MultiplexerSnapshot)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Bool
e IO (Text -> MultiplexerSnapshot)
-> IO Text -> IO MultiplexerSnapshot
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Text
thread
  where
    size :: IO Int
size = STM Int -> IO Int
forall a. STM a -> IO a
atomically (STM Int -> IO Int) -> STM Int -> IO Int
forall a b. (a -> b) -> a -> b
$ Map Text Channel -> STM Int
forall key value. Map key value -> STM Int
M.size (Map Text Channel -> STM Int) -> Map Text Channel -> STM Int
forall a b. (a -> b) -> a -> b
$ Multiplexer -> Map Text Channel
channels Multiplexer
multi
    thread :: IO Text
thread = ThreadId -> Text
forall a b. (Show a, ConvertText String b) => a -> b
show (ThreadId -> Text) -> IO ThreadId -> IO Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVar ThreadId -> IO ThreadId
forall a. MVar a -> IO a
readMVar (Multiplexer -> MVar ThreadId
producerThreadId Multiplexer
multi)
    e :: IO Bool
e = STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ TQueue Message -> STM Bool
forall a. TQueue a -> STM Bool
isEmptyTQueue (TQueue Message -> STM Bool) -> TQueue Message -> STM Bool
forall a b. (a -> b) -> a -> b
$ Multiplexer -> TQueue Message
messages Multiplexer
multi

-- | Opens a thread that relays messages from the producer thread to the channels forever
relayMessagesForever :: Multiplexer -> IO ThreadId
relayMessagesForever :: Multiplexer -> IO ThreadId
relayMessagesForever = IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId)
-> (Multiplexer -> IO ()) -> Multiplexer -> IO ThreadId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> (Multiplexer -> IO ()) -> Multiplexer -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Multiplexer -> IO ()
relayMessages

-- | Reads the messages from the producer and relays them to the active listeners in their respective channels.
relayMessages :: Multiplexer -> IO ()
relayMessages :: Multiplexer -> IO ()
relayMessages Multiplexer
multi =
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    Message
m <- TQueue Message -> STM Message
forall a. TQueue a -> STM a
readTQueue (Multiplexer -> TQueue Message
messages Multiplexer
multi)
    Maybe Channel
mChannel <- Text -> Map Text Channel -> STM (Maybe Channel)
forall key value.
(Eq key, Hashable key) =>
key -> Map key value -> STM (Maybe value)
M.lookup (Message -> Text
channel Message
m) (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
    case Maybe Channel
mChannel of
      Maybe Channel
Nothing -> () -> STM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Just Channel
c -> TChan Message -> Message -> STM ()
forall a. TChan a -> a -> STM ()
writeTChan (Channel -> TChan Message
broadcast Channel
c) Message
m

newMultiplexer ::
  (TQueue Message -> IO a) ->
  (Either SomeException a -> IO ()) ->
  IO Multiplexer
newMultiplexer :: (TQueue Message -> IO a)
-> (Either SomeException a -> IO ()) -> IO Multiplexer
newMultiplexer TQueue Message -> IO a
openProducer Either SomeException a -> IO ()
closeProducer = do
  TQueue Message
msgs <- IO (TQueue Message)
forall a. IO (TQueue a)
newTQueueIO
  let forkNewProducer :: IO ThreadId
forkNewProducer = IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally (TQueue Message -> IO a
openProducer TQueue Message
msgs) Either SomeException a -> IO ()
closeProducer
  ThreadId
tid <- IO ThreadId
forkNewProducer
  Map Text Channel
multiplexerMap <- IO (Map Text Channel)
forall key value. IO (Map key value)
M.newIO
  MVar ThreadId
producerThreadId <- ThreadId -> IO (MVar ThreadId)
forall a. a -> IO (MVar a)
newMVar ThreadId
tid
  Multiplexer -> IO Multiplexer
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Multiplexer -> IO Multiplexer) -> Multiplexer -> IO Multiplexer
forall a b. (a -> b) -> a -> b
$ Map Text Channel
-> TQueue Message -> MVar ThreadId -> IO ThreadId -> Multiplexer
Multiplexer Map Text Channel
multiplexerMap TQueue Message
msgs MVar ThreadId
producerThreadId IO ThreadId
forkNewProducer

-- |  Given a multiplexer, a number of milliseconds and an IO computation that returns a boolean
--      Runs the IO computation at every interval of milliseconds interval and reopens the multiplexer producer
--      if the resulting boolean is true
--      When interval is 0 this is NOOP, so the minimum interval is 1ms
--      Call this in case you want to ensure the producer thread is killed and restarted under a certain condition
superviseMultiplexer :: Multiplexer -> Int -> IO Bool -> IO ()
superviseMultiplexer :: Multiplexer -> Int -> IO Bool -> IO ()
superviseMultiplexer Multiplexer
multi Int
msInterval IO Bool
shouldRestart = do
  IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$
    IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$
      IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Int -> IO ()
threadDelay (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Int
msInterval Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000
        Bool
sr <- IO Bool
shouldRestart
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
sr (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          MultiplexerSnapshot
snapBefore <- Multiplexer -> IO MultiplexerSnapshot
takeSnapshot Multiplexer
multi
          IO (IO ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (IO ()) -> IO ()) -> IO (IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> IO ()
killThread (ThreadId -> IO ()) -> IO ThreadId -> IO (IO ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVar ThreadId -> IO ThreadId
forall a. MVar a -> IO a
readMVar (Multiplexer -> MVar ThreadId
producerThreadId Multiplexer
multi)
          ThreadId
new <- Multiplexer -> IO ThreadId
reopenProducer Multiplexer
multi
          IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar ThreadId -> ThreadId -> IO ThreadId
forall a. MVar a -> a -> IO a
swapMVar (Multiplexer -> MVar ThreadId
producerThreadId Multiplexer
multi) ThreadId
new
          MultiplexerSnapshot
snapAfter <- Multiplexer -> IO MultiplexerSnapshot
takeSnapshot Multiplexer
multi
          ByteString -> IO ()
forall a (m :: * -> *). (Print a, MonadIO m) => a -> m ()
putStrLn (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$
            ByteString
"Restarting producer. Multiplexer updated: "
              ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> MultiplexerSnapshot -> ByteString
forall a. ToJSON a => a -> ByteString
A.encode MultiplexerSnapshot
snapBefore
              ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
" -> "
              ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> MultiplexerSnapshot -> ByteString
forall a. ToJSON a => a -> ByteString
A.encode MultiplexerSnapshot
snapAfter

openChannel :: Multiplexer -> Text -> STM Channel
openChannel :: Multiplexer -> Text -> STM Channel
openChannel Multiplexer
multi Text
chan = do
  TChan Message
c <- STM (TChan Message)
forall a. STM (TChan a)
newBroadcastTChan
  let newChannel :: Channel
newChannel =
        Channel :: TChan Message -> Integer -> Channel
Channel
          { broadcast :: TChan Message
broadcast = TChan Message
c,
            listeners :: Integer
listeners = Integer
0
          }
  Channel -> Text -> Map Text Channel -> STM ()
forall key value.
(Eq key, Hashable key) =>
value -> key -> Map key value -> STM ()
M.insert Channel
newChannel Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
  Channel -> STM Channel
forall (m :: * -> *) a. Monad m => a -> m a
return Channel
newChannel

-- |  Adds a listener to a certain multiplexer's channel.
--      The listener must be a function that takes a 'TChan Message' and perform any IO action.
--      All listeners run in their own thread.
--      The first listener will open the channel, when a listener dies it will check if there acquire
--      any others and close the channel when that's the case.
onMessage :: Multiplexer -> Text -> (Message -> IO ()) -> IO ()
onMessage :: Multiplexer -> Text -> (Message -> IO ()) -> IO ()
onMessage Multiplexer
multi Text
chan Message -> IO ()
action = do
  TChan Message
listener <- STM (TChan Message) -> IO (TChan Message)
forall a. STM a -> IO a
atomically (STM (TChan Message) -> IO (TChan Message))
-> STM (TChan Message) -> IO (TChan Message)
forall a b. (a -> b) -> a -> b
$ STM Channel
openChannelWhenNotFound STM Channel
-> (Channel -> STM (TChan Message)) -> STM (TChan Message)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Channel -> STM (TChan Message)
addListener
  IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO Any -> (Either SomeException Any -> IO ()) -> IO ThreadId
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally (IO () -> IO Any
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (STM Message -> IO Message
forall a. STM a -> IO a
atomically (TChan Message -> STM Message
forall a. TChan a -> STM a
readTChan TChan Message
listener) IO Message -> (Message -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Message -> IO ()
action)) Either SomeException Any -> IO ()
forall p. p -> IO ()
disposeListener
  where
    disposeListener :: p -> IO ()
disposeListener p
_ = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Maybe Channel
mC <- Text -> Map Text Channel -> STM (Maybe Channel)
forall key value.
(Eq key, Hashable key) =>
key -> Map key value -> STM (Maybe value)
M.lookup Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
      let c :: Channel
c = Channel -> Maybe Channel -> Channel
forall a. a -> Maybe a -> a
fromMaybe (Text -> Channel
forall a. HasCallStack => Text -> a
panic (Text -> Channel) -> Text -> Channel
forall a b. (a -> b) -> a -> b
$ Text
"trying to remove listener from non existing channel: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
forall a b. StringConv a b => a -> b
toS Text
chan) Maybe Channel
mC
      Text -> Map Text Channel -> STM ()
forall key value.
(Eq key, Hashable key) =>
key -> Map key value -> STM ()
M.delete Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
      Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Channel -> Integer
listeners Channel
c Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0) (STM () -> STM ()) -> STM () -> STM ()
forall a b. (a -> b) -> a -> b
$
        Channel -> Text -> Map Text Channel -> STM ()
forall key value.
(Eq key, Hashable key) =>
value -> key -> Map key value -> STM ()
M.insert Channel :: TChan Message -> Integer -> Channel
Channel {broadcast :: TChan Message
broadcast = Channel -> TChan Message
broadcast Channel
c, listeners :: Integer
listeners = Channel -> Integer
listeners Channel
c Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1} Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
    openChannelWhenNotFound :: STM Channel
openChannelWhenNotFound =
      Text -> Map Text Channel -> STM (Maybe Channel)
forall key value.
(Eq key, Hashable key) =>
key -> Map key value -> STM (Maybe value)
M.lookup Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi) STM (Maybe Channel)
-> (Maybe Channel -> STM Channel) -> STM Channel
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe Channel
Nothing -> Multiplexer -> Text -> STM Channel
openChannel Multiplexer
multi Text
chan
        Just Channel
ch -> Channel -> STM Channel
forall (m :: * -> *) a. Monad m => a -> m a
return Channel
ch
    addListener :: Channel -> STM (TChan Message)
addListener Channel
ch = do
      Text -> Map Text Channel -> STM ()
forall key value.
(Eq key, Hashable key) =>
key -> Map key value -> STM ()
M.delete Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
      let newChannel :: Channel
newChannel = Channel :: TChan Message -> Integer -> Channel
Channel {broadcast :: TChan Message
broadcast = Channel -> TChan Message
broadcast Channel
ch, listeners :: Integer
listeners = Channel -> Integer
listeners Channel
ch Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1}
      Channel -> Text -> Map Text Channel -> STM ()
forall key value.
(Eq key, Hashable key) =>
value -> key -> Map key value -> STM ()
M.insert Channel
newChannel Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
      TChan Message -> STM (TChan Message)
forall a. TChan a -> STM (TChan a)
dupTChan (TChan Message -> STM (TChan Message))
-> TChan Message -> STM (TChan Message)
forall a b. (a -> b) -> a -> b
$ Channel -> TChan Message
broadcast Channel
newChannel