{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
-------------------------------------------------------------------------------
-- |
-- Module      :  BroadcastChan
-- Copyright   :  (C) 2014-2021 Merijn Verstraaten
-- License     :  BSD-style (see the file LICENSE)
-- Maintainer  :  Merijn Verstraaten <merijn@inconsistent.nl>
-- Stability   :  experimental
-- Portability :  haha
--
-- A closable, fair, single-wakeup channel that avoids the 0 reader space leak
-- that @"Control.Concurrent.Chan"@ from base suffers from.
--
-- The @Chan@ type from @"Control.Concurrent.Chan"@ consists of both a read
-- and write end combined into a single value. This means there is always at
-- least 1 read end for a @Chan@, which keeps any values written to it alive.
-- This is a problem for applications/libraries that want to have a channel
-- that can have zero listeners.
--
-- Suppose we have an library that produces events and we want to let users
-- register to receive events. If we use a channel and write all events to it,
-- we would like to drop and garbage collect any events that take place when
-- there are 0 listeners. The always present read end of @Chan@ from base
-- makes this impossible. We end up with a @Chan@ that forever accumulates
-- more and more events that will never get removed, resulting in a memory
-- leak.
--
-- @"BroadcastChan"@ splits channels into separate read and write ends. Any
-- message written to a a channel with no existing read end is immediately
-- dropped so it can be garbage collected. Once a read end is created, all
-- messages written to the channel will be accessible to that read end.
--
-- Once all read ends for a channel have disappeared and been garbage
-- collected, the channel will return to dropping messages as soon as they are
-- written.
--
-- __Why should I use "BroadcastChan" over "Control.Concurrent.Chan"?__
--
-- * @"BroadcastChan"@ is closable,
--
-- * @"BroadcastChan"@ has no 0 reader space leak,
--
-- * @"BroadcastChan"@ has comparable or better performance.
--
-- __Why should I use "BroadcastChan" over various (closable) STM channels?__
--
-- * @"BroadcastChan"@ is single-wakeup,
--
-- * @"BroadcastChan"@ is fair,
--
-- * @"BroadcastChan"@ performs better under contention.
-------------------------------------------------------------------------------
module BroadcastChan (
    -- * Datatypes
      BroadcastChan
    , Direction(..)
    , In
    , Out
    -- * Construction
    , newBroadcastChan
    , newBChanListener
    -- * Basic Operations
    , readBChan
    , writeBChan
    , closeBChan
    , isClosedBChan
    , getBChanContents
    -- * Parallel processing
    , Action(..)
    , Handler(..)
    , parMapM_
    , parFoldMap
    , parFoldMapM
    -- * Foldl combinators
    -- | Combinators for use with Tekmo's @foldl@ package.
    , foldBChan
    , foldBChanM
    ) where

import Control.Exception
    (SomeException(..), mask, throwIO, try, uninterruptibleMask_)
import Control.Monad (liftM)
import Control.Monad.IO.Unlift
    (MonadUnliftIO, UnliftIO(..), askUnliftIO, withRunInIO)
import Data.Foldable as F (Foldable(..), foldlM, forM_)

import BroadcastChan.Extra
import BroadcastChan.Internal

bracketOnError :: MonadUnliftIO m => IO a -> (a -> IO b) -> m c -> m c
bracketOnError :: IO a -> (a -> IO b) -> m c -> m c
bracketOnError IO a
before a -> IO b
after m c
thing = ((forall a. m a -> IO a) -> IO c) -> m c
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. m a -> IO a) -> IO c) -> m c)
-> ((forall a. m a -> IO a) -> IO c) -> m c
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> IO a
run -> ((forall a. IO a -> IO a) -> IO c) -> IO c
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO c) -> IO c)
-> ((forall a. IO a -> IO a) -> IO c) -> IO c
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
  a
x <- IO a
before
  Either SomeException c
res1 <- IO c -> IO (Either SomeException c)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO c -> IO (Either SomeException c))
-> (m c -> IO c) -> m c -> IO (Either SomeException c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO c -> IO c
forall a. IO a -> IO a
restore (IO c -> IO c) -> (m c -> IO c) -> m c -> IO c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m c -> IO c
forall a. m a -> IO a
run (m c -> IO (Either SomeException c))
-> m c -> IO (Either SomeException c)
forall a b. (a -> b) -> a -> b
$ m c
thing
  case Either SomeException c
res1 of
    Left (SomeException e
exc) -> do
      Either SomeException b
_ :: Either SomeException b <- IO b -> IO (Either SomeException b)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO b -> IO (Either SomeException b))
-> (IO b -> IO b) -> IO b -> IO (Either SomeException b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO b -> IO b
forall a. IO a -> IO a
uninterruptibleMask_ (IO b -> IO (Either SomeException b))
-> IO b -> IO (Either SomeException b)
forall a b. (a -> b) -> a -> b
$ a -> IO b
after a
x
      e -> IO c
forall e a. Exception e => e -> IO a
throwIO e
exc
    Right c
y -> c -> IO c
forall (m :: * -> *) a. Monad m => a -> m a
return c
y

-- | Map a monadic function over a 'Foldable', processing elements in parallel.
--
-- This function does __NOT__ guarantee that elements are processed in a
-- deterministic order!
parMapM_
    :: (F.Foldable f, MonadUnliftIO m)
    => Handler m a
    -- ^ Exception handler
    -> Int
    -- ^ Number of parallel threads to use
    -> (a -> m ())
    -- ^ Function to run in parallel
    -> f a
    -- ^ The 'Foldable' to process in parallel
    -> m ()
parMapM_ :: Handler m a -> Int -> (a -> m ()) -> f a -> m ()
parMapM_ Handler m a
hndl Int
threads a -> m ()
workFun f a
input = do
    UnliftIO forall a. m a -> IO a
runInIO <- m (UnliftIO m)
forall (m :: * -> *). MonadUnliftIO m => m (UnliftIO m)
askUnliftIO

    Bracket{IO [Weak ThreadId]
allocate :: forall (m :: * -> *) r. BracketOnError m r -> IO [Weak ThreadId]
allocate :: IO [Weak ThreadId]
allocate,[Weak ThreadId] -> IO ()
cleanup :: forall (m :: * -> *) r.
BracketOnError m r -> [Weak ThreadId] -> IO ()
cleanup :: [Weak ThreadId] -> IO ()
cleanup,m ()
action :: forall (m :: * -> *) r. BracketOnError m r -> m r
action :: m ()
action} <- Handler IO a
-> Int
-> (a -> IO ())
-> ((a -> m ()) -> m ())
-> m (BracketOnError m ())
forall (m :: * -> *) (n :: * -> *) a r.
(MonadIO m, MonadIO n) =>
Handler IO a
-> Int
-> (a -> IO ())
-> ((a -> m ()) -> n r)
-> n (BracketOnError n r)
runParallel_
        ((m Action -> IO Action) -> Handler m a -> Handler IO a
forall (m :: * -> *) (n :: * -> *) a.
(m Action -> n Action) -> Handler m a -> Handler n a
mapHandler m Action -> IO Action
forall a. m a -> IO a
runInIO Handler m a
hndl)
        Int
threads
        (m () -> IO ()
forall a. m a -> IO a
runInIO (m () -> IO ()) -> (a -> m ()) -> a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m ()
workFun)
        (f a -> (a -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ f a
input)

    IO [Weak ThreadId] -> ([Weak ThreadId] -> IO ()) -> m () -> m ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
IO a -> (a -> IO b) -> m c -> m c
bracketOnError IO [Weak ThreadId]
allocate [Weak ThreadId] -> IO ()
cleanup m ()
action

-- | Like 'parMapM_', but folds the individual results into single result
-- value.
--
-- This function does __NOT__ guarantee that elements are processed in a
-- deterministic order!
parFoldMap
    :: (F.Foldable f, MonadUnliftIO m)
    => Handler m a
    -- ^ Exception handler
    -> Int
    -- ^ Number of parallel threads to use
    -> (a -> m b)
    -- ^ Function to run in parallel
    -> (r -> b -> r)
    -- ^ Function to fold results with
    -> r
    -- ^ Zero element for the fold
    -> f a
    -- ^ The 'Foldable' to process
    -> m r
parFoldMap :: Handler m a
-> Int -> (a -> m b) -> (r -> b -> r) -> r -> f a -> m r
parFoldMap Handler m a
hndl Int
threads a -> m b
work r -> b -> r
f =
  Handler m a
-> Int -> (a -> m b) -> (r -> b -> m r) -> r -> f a -> m r
forall a b (f :: * -> *) (m :: * -> *) r.
(Foldable f, MonadUnliftIO m) =>
Handler m a
-> Int -> (a -> m b) -> (r -> b -> m r) -> r -> f a -> m r
parFoldMapM Handler m a
hndl Int
threads a -> m b
work (\r
x b
y -> r -> m r
forall (m :: * -> *) a. Monad m => a -> m a
return (r -> b -> r
f r
x b
y))

-- | Like 'parFoldMap', but uses a monadic fold function.
--
-- This function does __NOT__ guarantee that elements are processed in a
-- deterministic order!
parFoldMapM
    :: forall a b f m r
     . (F.Foldable f, MonadUnliftIO m)
    => Handler m a
    -- ^ Exception handler
    -> Int
    -- ^ Number of parallel threads to use
    -> (a -> m b)
    -- ^ Function to run in parallel
    -> (r -> b -> m r)
    -- ^ Monadic function to fold results with
    -> r
    -- ^ Zero element for the fold
    -> f a
    -- ^ The 'Foldable' to process
    -> m r
parFoldMapM :: Handler m a
-> Int -> (a -> m b) -> (r -> b -> m r) -> r -> f a -> m r
parFoldMapM Handler m a
hndl Int
threads a -> m b
workFun r -> b -> m r
f r
z f a
input = do
    UnliftIO forall a. m a -> IO a
runInIO <- m (UnliftIO m)
forall (m :: * -> *). MonadUnliftIO m => m (UnliftIO m)
askUnliftIO

    Bracket{IO [Weak ThreadId]
allocate :: IO [Weak ThreadId]
allocate :: forall (m :: * -> *) r. BracketOnError m r -> IO [Weak ThreadId]
allocate,[Weak ThreadId] -> IO ()
cleanup :: [Weak ThreadId] -> IO ()
cleanup :: forall (m :: * -> *) r.
BracketOnError m r -> [Weak ThreadId] -> IO ()
cleanup,m r
action :: m r
action :: forall (m :: * -> *) r. BracketOnError m r -> m r
action} <- Either (b -> m r) (r -> b -> m r)
-> Handler IO a
-> Int
-> (a -> IO b)
-> ((a -> m ()) -> (a -> m (Maybe b)) -> m r)
-> m (BracketOnError m r)
forall a b (m :: * -> *) (n :: * -> *) r.
(MonadIO m, MonadIO n) =>
Either (b -> n r) (r -> b -> n r)
-> Handler IO a
-> Int
-> (a -> IO b)
-> ((a -> m ()) -> (a -> m (Maybe b)) -> n r)
-> n (BracketOnError n r)
runParallel
        ((r -> b -> m r) -> Either (b -> m r) (r -> b -> m r)
forall a b. b -> Either a b
Right r -> b -> m r
f)
        ((m Action -> IO Action) -> Handler m a -> Handler IO a
forall (m :: * -> *) (n :: * -> *) a.
(m Action -> n Action) -> Handler m a -> Handler n a
mapHandler m Action -> IO Action
forall a. m a -> IO a
runInIO Handler m a
hndl)
        Int
threads
        (m b -> IO b
forall a. m a -> IO a
runInIO (m b -> IO b) -> (a -> m b) -> a -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
workFun)
        (a -> m ()) -> (a -> m (Maybe b)) -> m r
body

    IO [Weak ThreadId] -> ([Weak ThreadId] -> IO ()) -> m r -> m r
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
IO a -> (a -> IO b) -> m c -> m c
bracketOnError IO [Weak ThreadId]
allocate [Weak ThreadId] -> IO ()
cleanup m r
action
  where
    body :: (a -> m ()) -> (a -> m (Maybe b)) -> m r
    body :: (a -> m ()) -> (a -> m (Maybe b)) -> m r
body a -> m ()
send a -> m (Maybe b)
sendRecv = (Int, r) -> r
forall a b. (a, b) -> b
snd ((Int, r) -> r) -> m (Int, r) -> m r
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` ((Int, r) -> a -> m (Int, r)) -> (Int, r) -> f a -> m (Int, r)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (Int, r) -> a -> m (Int, r)
wrappedFoldFun (Int
0, r
z) f a
input
      where
        wrappedFoldFun :: (Int, r) -> a -> m (Int, r)
        wrappedFoldFun :: (Int, r) -> a -> m (Int, r)
wrappedFoldFun (Int
i, r
x) a
a
            | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
threads = (r -> (Int, r)) -> m r -> m (Int, r)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (Int
i,) (m r -> m (Int, r)) -> m r -> m (Int, r)
forall a b. (a -> b) -> a -> b
$ a -> m (Maybe b)
sendRecv a
a m (Maybe b) -> (Maybe b -> m r) -> m r
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= m r -> (b -> m r) -> Maybe b -> m r
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (r -> m r
forall (m :: * -> *) a. Monad m => a -> m a
return r
x) (r -> b -> m r
f r
x)
            | Bool
otherwise = (Int, r) -> () -> (Int, r)
forall a b. a -> b -> a
const (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, r
x) (() -> (Int, r)) -> m () -> m (Int, r)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` a -> m ()
send a
a