module Ribosome.Data.Conduit where

import Conduit (ConduitT, MonadResource, bracketP, runConduit, yield, (.|))
import Control.Concurrent (forkIO)
import Control.Concurrent.Lifted (fork, killThread)
import Control.Concurrent.STM.TBMChan (TBMChan, closeTBMChan, newTBMChan, readTBMChan, writeTBMChan)
import Control.Exception.Lifted (bracket, finally)
import Control.Monad.Trans.Control (embed)
import qualified Data.Conduit.Combinators as Conduit (mapM_)

import Ribosome.Control.Monad.Ribo (modifyTMVar)

withTBMChan ::
  MonadIO m =>
  MonadBaseControl IO m =>
  Int ->
  (TBMChan a -> m b) ->
  m b
withTBMChan :: Int -> (TBMChan a -> m b) -> m b
withTBMChan Int
bound =
  m (TBMChan a) -> (TBMChan a -> m ()) -> (TBMChan a -> m b) -> m b
forall (m :: * -> *) a b c.
MonadBaseControl IO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket m (TBMChan a)
acquire TBMChan a -> m ()
forall a. TBMChan a -> m ()
release
  where
    acquire :: m (TBMChan a)
acquire =
      STM (TBMChan a) -> m (TBMChan a)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (Int -> STM (TBMChan a)
forall a. Int -> STM (TBMChan a)
newTBMChan Int
bound)
    release :: TBMChan a -> m ()
release =
      STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> (TBMChan a -> STM ()) -> TBMChan a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TBMChan a -> STM ()
forall a. TBMChan a -> STM ()
closeTBMChan

sourceChan ::
  MonadIO m =>
  TBMChan a ->
  ConduitT () a m ()
sourceChan :: TBMChan a -> ConduitT () a m ()
sourceChan TBMChan a
chan =
  ConduitT () a m ()
loop
  where
    loop :: ConduitT () a m ()
loop =
      (a -> ConduitT () a m ()) -> Maybe a -> ConduitT () a m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ a -> ConduitT () a m ()
recurse (Maybe a -> ConduitT () a m ())
-> ConduitT () a m (Maybe a) -> ConduitT () a m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< STM (Maybe a) -> ConduitT () a m (Maybe a)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TBMChan a -> STM (Maybe a)
forall a. TBMChan a -> STM (Maybe a)
readTBMChan TBMChan a
chan)
    recurse :: a -> ConduitT () a m ()
recurse a
a =
      a -> ConduitT () a m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield a
a ConduitT () a m () -> ConduitT () a m () -> ConduitT () a m ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ConduitT () a m ()
loop

sourceTerminated ::
  MonadIO m =>
  TMVar Int ->
  TBMChan a ->
  m ()
sourceTerminated :: TMVar Int -> TBMChan a -> m ()
sourceTerminated TMVar Int
var TBMChan a
chan = do
  Int
n <- (Int -> Int) -> TMVar Int -> m Int
forall (m :: * -> *) a. MonadIO m => (a -> a) -> TMVar a -> m a
modifyTMVar (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) TMVar Int
var
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ TBMChan a -> STM ()
forall a. TBMChan a -> STM ()
closeTBMChan TBMChan a
chan)

mergeSourcesWith ::
  MonadResource m =>
  TMVar Int ->
  TBMChan a ->
  (ConduitT () a m () -> IO (StM m ())) ->
  [ConduitT () a m ()] ->
  ConduitT () a m ()
mergeSourcesWith :: TMVar Int
-> TBMChan a
-> (ConduitT () a m () -> IO (StM m ()))
-> [ConduitT () a m ()]
-> ConduitT () a m ()
mergeSourcesWith TMVar Int
activeSources TBMChan a
chan ConduitT () a m () -> IO (StM m ())
sourceRunner [ConduitT () a m ()]
sources =
  IO [ThreadId]
-> ([ThreadId] -> IO ())
-> ([ThreadId] -> ConduitT () a m ())
-> ConduitT () a m ()
forall (m :: * -> *) a i o r.
MonadResource m =>
IO a -> (a -> IO ()) -> (a -> ConduitT i o m r) -> ConduitT i o m r
bracketP IO [ThreadId]
acquire [ThreadId] -> IO ()
release (ConduitT () a m () -> [ThreadId] -> ConduitT () a m ()
forall a b. a -> b -> a
const ConduitT () a m ()
combinedSource)
  where
    acquire :: IO [ThreadId]
acquire =
      (ConduitT () a m () -> IO ThreadId)
-> [ConduitT () a m ()] -> IO [ThreadId]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId)
-> (ConduitT () a m () -> IO ())
-> ConduitT () a m ()
-> IO ThreadId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConduitT () a m () -> IO ()
start) [ConduitT () a m ()]
sources
    start :: ConduitT () a m () -> IO ()
start ConduitT () a m ()
source = do
      IO (StM m ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (StM m ()) -> IO ()) -> IO (StM m ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ ConduitT () a m () -> IO (StM m ())
sourceRunner ConduitT () a m ()
source
      TMVar Int -> TBMChan a -> IO ()
forall (m :: * -> *) a. MonadIO m => TMVar Int -> TBMChan a -> m ()
sourceTerminated TMVar Int
activeSources TBMChan a
chan
    release :: [ThreadId] -> IO ()
release [ThreadId]
ids =
      (ThreadId -> IO ()) -> [ThreadId] -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ThreadId -> IO ()
forall (m :: * -> *). MonadBase IO m => ThreadId -> m ()
killThread [ThreadId]
ids IO () -> IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*>
      STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TBMChan a -> STM ()
forall a. TBMChan a -> STM ()
closeTBMChan TBMChan a
chan)
    combinedSource :: ConduitT () a m ()
combinedSource =
      TBMChan a -> ConduitT () a m ()
forall (m :: * -> *) a.
MonadIO m =>
TBMChan a -> ConduitT () a m ()
sourceChan TBMChan a
chan

mergeSources ::
  MonadResource m =>
  MonadBaseControl IO m =>
  Int ->
  [ConduitT () a m ()] ->
  ConduitT () a m ()
mergeSources :: Int -> [ConduitT () a m ()] -> ConduitT () a m ()
mergeSources Int
bound [ConduitT () a m ()]
sources = do
  TMVar Int
activeSources <- STM (TMVar Int) -> ConduitT () a m (TMVar Int)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM (TMVar Int) -> ConduitT () a m (TMVar Int))
-> STM (TMVar Int) -> ConduitT () a m (TMVar Int)
forall a b. (a -> b) -> a -> b
$ Int -> STM (TMVar Int)
forall a. a -> STM (TMVar a)
newTMVar ([ConduitT () a m ()] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ConduitT () a m ()]
sources)
  TBMChan a
chan <- STM (TBMChan a) -> ConduitT () a m (TBMChan a)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (Int -> STM (TBMChan a)
forall a. Int -> STM (TBMChan a)
newTBMChan Int
bound)
  ConduitT () a m () -> IO (StM m ())
embeddedRunner <- m (ConduitT () a m () -> IO (StM m ()))
-> ConduitT () a m (ConduitT () a m () -> IO (StM m ()))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (ConduitT () a m () -> IO (StM m ()))
 -> ConduitT () a m (ConduitT () a m () -> IO (StM m ())))
-> m (ConduitT () a m () -> IO (StM m ()))
-> ConduitT () a m (ConduitT () a m () -> IO (StM m ()))
forall a b. (a -> b) -> a -> b
$ (ConduitT () a m () -> m ())
-> m (ConduitT () a m () -> IO (StM m ()))
forall (b :: * -> *) (m :: * -> *) a c.
MonadBaseControl b m =>
(a -> m c) -> m (a -> b (StM m c))
embed (TBMChan a -> ConduitT () a m () -> m ()
forall (m :: * -> *) b.
MonadIO m =>
TBMChan b -> ConduitM () b m () -> m ()
embedSourceRunner TBMChan a
chan)
  TMVar Int
-> TBMChan a
-> (ConduitT () a m () -> IO (StM m ()))
-> [ConduitT () a m ()]
-> ConduitT () a m ()
forall (m :: * -> *) a.
MonadResource m =>
TMVar Int
-> TBMChan a
-> (ConduitT () a m () -> IO (StM m ()))
-> [ConduitT () a m ()]
-> ConduitT () a m ()
mergeSourcesWith TMVar Int
activeSources TBMChan a
chan ConduitT () a m () -> IO (StM m ())
embeddedRunner [ConduitT () a m ()]
sources
  where
    embedSourceRunner :: TBMChan b -> ConduitM () b m () -> m ()
embedSourceRunner TBMChan b
chan ConduitM () b m ()
source =
      ConduitT () Void m () -> m ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitM () b m ()
source ConduitM () b m () -> ConduitM b Void m () -> ConduitT () Void m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| (b -> m ()) -> ConduitM b Void m ()
forall (m :: * -> *) a o.
Monad m =>
(a -> m ()) -> ConduitT a o m ()
Conduit.mapM_ (STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> (b -> STM ()) -> b -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TBMChan b -> b -> STM ()
forall a. TBMChan a -> a -> STM ()
writeTBMChan TBMChan b
chan))

withSourcesInChanAs ::
  MonadIO m =>
  MonadBaseControl IO m =>
  (ConduitT () a m () -> m b) ->
  [ConduitT () a m ()] ->
  TBMChan a ->
  m b
withSourcesInChanAs :: (ConduitT () a m () -> m b)
-> [ConduitT () a m ()] -> TBMChan a -> m b
withSourcesInChanAs ConduitT () a m () -> m b
executor [ConduitT () a m ()]
sources TBMChan a
chan = do
  TMVar Int
activeSources <- STM (TMVar Int) -> m (TMVar Int)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM (TMVar Int) -> m (TMVar Int))
-> STM (TMVar Int) -> m (TMVar Int)
forall a b. (a -> b) -> a -> b
$ Int -> STM (TMVar Int)
forall a. a -> STM (TMVar a)
newTMVar ([ConduitT () a m ()] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ConduitT () a m ()]
sources)
  [ThreadId]
threadIds <- (ConduitT () a m () -> m ThreadId)
-> [ConduitT () a m ()] -> m [ThreadId]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (m () -> m ThreadId
forall (m :: * -> *). MonadBaseControl IO m => m () -> m ThreadId
fork (m () -> m ThreadId)
-> (ConduitT () a m () -> m ()) -> ConduitT () a m () -> m ThreadId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TMVar Int -> ConduitT () a m () -> m ()
start TMVar Int
activeSources) [ConduitT () a m ()]
sources
  m b -> m () -> m b
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m a
finally m b
listen ([ThreadId] -> m ()
release [ThreadId]
threadIds)
  where
    release :: [ThreadId] -> m ()
release =
      (ThreadId -> m ()) -> [ThreadId] -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ThreadId -> m ()
forall (m :: * -> *). MonadBase IO m => ThreadId -> m ()
killThread
    listen :: m b
listen =
      ConduitT () a m () -> m b
executor (ConduitT () a m () -> m b) -> ConduitT () a m () -> m b
forall a b. (a -> b) -> a -> b
$ TBMChan a -> ConduitT () a m ()
forall (m :: * -> *) a.
MonadIO m =>
TBMChan a -> ConduitT () a m ()
sourceChan TBMChan a
chan
    start :: TMVar Int -> ConduitT () a m () -> m ()
start TMVar Int
activeSources ConduitT () a m ()
source = do
      ConduitT () Void m () -> m ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () a m ()
source ConduitT () a m () -> ConduitM a Void m () -> ConduitT () Void m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| (a -> m ()) -> ConduitM a Void m ()
forall (m :: * -> *) a o.
Monad m =>
(a -> m ()) -> ConduitT a o m ()
Conduit.mapM_ (STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> (a -> STM ()) -> a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TBMChan a -> a -> STM ()
forall a. TBMChan a -> a -> STM ()
writeTBMChan TBMChan a
chan))
      TMVar Int -> TBMChan a -> m ()
forall (m :: * -> *) a. MonadIO m => TMVar Int -> TBMChan a -> m ()
sourceTerminated TMVar Int
activeSources TBMChan a
chan

simpleExecutor ::
  Monad m =>
  ConduitT a Void m b ->
  ConduitT () a m () ->
  m b
simpleExecutor :: ConduitT a Void m b -> ConduitT () a m () -> m b
simpleExecutor ConduitT a Void m b
consumer ConduitT () a m ()
s =
  ConduitT () Void m b -> m b
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void m b -> m b) -> ConduitT () Void m b -> m b
forall a b. (a -> b) -> a -> b
$ ConduitT () a m ()
s ConduitT () a m () -> ConduitT a Void m b -> ConduitT () Void m b
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| ConduitT a Void m b
consumer

withSourcesInChan ::
  MonadIO m =>
  MonadBaseControl IO m =>
  ConduitT a Void m b ->
  [ConduitT () a m ()] ->
  TBMChan a ->
  m b
withSourcesInChan :: ConduitT a Void m b -> [ConduitT () a m ()] -> TBMChan a -> m b
withSourcesInChan =
  (ConduitT () a m () -> m b)
-> [ConduitT () a m ()] -> TBMChan a -> m b
forall (m :: * -> *) a b.
(MonadIO m, MonadBaseControl IO m) =>
(ConduitT () a m () -> m b)
-> [ConduitT () a m ()] -> TBMChan a -> m b
withSourcesInChanAs ((ConduitT () a m () -> m b)
 -> [ConduitT () a m ()] -> TBMChan a -> m b)
-> (ConduitT a Void m b -> ConduitT () a m () -> m b)
-> ConduitT a Void m b
-> [ConduitT () a m ()]
-> TBMChan a
-> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConduitT a Void m b -> ConduitT () a m () -> m b
forall (m :: * -> *) a b.
Monad m =>
ConduitT a Void m b -> ConduitT () a m () -> m b
simpleExecutor

withMergedSourcesAs ::
  MonadIO m =>
  MonadBaseControl IO m =>
  (ConduitT () a m () -> m b) ->
  Int ->
  [ConduitT () a m ()] ->
  m b
withMergedSourcesAs :: (ConduitT () a m () -> m b) -> Int -> [ConduitT () a m ()] -> m b
withMergedSourcesAs ConduitT () a m () -> m b
executor Int
bound [ConduitT () a m ()]
sources =
  Int -> (TBMChan a -> m b) -> m b
forall (m :: * -> *) a b.
(MonadIO m, MonadBaseControl IO m) =>
Int -> (TBMChan a -> m b) -> m b
withTBMChan Int
bound ((ConduitT () a m () -> m b)
-> [ConduitT () a m ()] -> TBMChan a -> m b
forall (m :: * -> *) a b.
(MonadIO m, MonadBaseControl IO m) =>
(ConduitT () a m () -> m b)
-> [ConduitT () a m ()] -> TBMChan a -> m b
withSourcesInChanAs ConduitT () a m () -> m b
executor [ConduitT () a m ()]
sources)

withMergedSources ::
  MonadIO m =>
  MonadBaseControl IO m =>
  ConduitT a Void m b ->
  Int ->
  [ConduitT () a m ()] ->
  m b
withMergedSources :: ConduitT a Void m b -> Int -> [ConduitT () a m ()] -> m b
withMergedSources =
  (ConduitT () a m () -> m b) -> Int -> [ConduitT () a m ()] -> m b
forall (m :: * -> *) a b.
(MonadIO m, MonadBaseControl IO m) =>
(ConduitT () a m () -> m b) -> Int -> [ConduitT () a m ()] -> m b
withMergedSourcesAs ((ConduitT () a m () -> m b) -> Int -> [ConduitT () a m ()] -> m b)
-> (ConduitT a Void m b -> ConduitT () a m () -> m b)
-> ConduitT a Void m b
-> Int
-> [ConduitT () a m ()]
-> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConduitT a Void m b -> ConduitT () a m () -> m b
forall (m :: * -> *) a b.
Monad m =>
ConduitT a Void m b -> ConduitT () a m () -> m b
simpleExecutor