module Streaming.Concurrent
(
Buffer
, unbounded
, bounded
, latest
, newest
, withBuffer
, InBasket(..)
, OutBasket(..)
, writeStreamBasket
, readStreamBasket
, mergeStreams
, writeByteStringBasket
, readByteStringBasket
, mergeByteStrings
) where
import Data.ByteString.Streaming (ByteString, reread, unconsChunk)
import Streaming (Of, Stream)
import qualified Streaming.Prelude as S
import Control.Applicative ((<|>))
import Control.Concurrent.Async.Lifted (concurrently,
forConcurrently_)
import qualified Control.Concurrent.STM as STM
import Control.Monad (when)
import Control.Monad.Base (MonadBase, liftBase)
import Control.Monad.Catch (MonadMask, bracket, bracket_)
import Control.Monad.Trans.Control (MonadBaseControl)
import qualified Data.ByteString as B
import Data.Foldable (forM_)
mergeStreams :: (MonadMask m, MonadBaseControl IO m, MonadBase IO n, Foldable t)
=> Buffer a -> t (Stream (Of a) m v)
-> (Stream (Of a) n () -> m r) -> m r
mergeStreams buff strs f = withBuffer buff
(forConcurrently_ strs . flip writeStreamBasket)
(`readStreamBasket` f)
mergeByteStrings :: (MonadMask m, MonadBaseControl IO m, MonadBase IO n, Foldable t)
=> Buffer B.ByteString -> t (ByteString m v)
-> (ByteString n () -> m r) -> m r
mergeByteStrings buff bss f = withBuffer buff
(forConcurrently_ bss . flip writeByteStringBasket)
(`readByteStringBasket` f)
writeStreamBasket :: (MonadBase IO m) => Stream (Of a) m r -> InBasket a -> m ()
writeStreamBasket stream (InBasket send) = go stream
where
go str = do eNxt <- S.next str
forM_ eNxt $ \(a, str') -> do
continue <- liftBase (STM.atomically (send a))
when continue (go str')
writeByteStringBasket :: (MonadBase IO m) => ByteString m r -> InBasket B.ByteString -> m ()
writeByteStringBasket bstring (InBasket send) = go bstring
where
go bs = do chNxt <- unconsChunk bs
forM_ chNxt $ \(chnk, bs') -> do
continue <- liftBase (STM.atomically (send chnk))
when continue (go bs')
readStreamBasket :: (MonadBase IO m) => OutBasket a
-> (Stream (Of a) m () -> r)
-> r
readStreamBasket (OutBasket receive) f = f (S.untilRight getNext)
where
getNext = maybe (Right ()) Left <$> liftBase (STM.atomically receive)
readByteStringBasket :: (MonadBase IO m) => OutBasket B.ByteString
-> (ByteString m () -> r)
-> r
readByteStringBasket (OutBasket receive) f =
f (reread (liftBase . STM.atomically) receive)
data Buffer a
= Unbounded
| Bounded Int
| Single
| Latest a
| Newest Int
| New
unbounded :: Buffer a
unbounded = Unbounded
bounded :: Int -> Buffer a
bounded 1 = Single
bounded n = Bounded n
latest :: a -> Buffer a
latest = Latest
newest :: Int -> Buffer a
newest 1 = New
newest n = Newest n
newtype OutBasket a = OutBasket { receiveMsg :: STM.STM (Maybe a) }
newtype InBasket a = InBasket { sendMsg :: a -> STM.STM Bool }
withBuffer :: (MonadMask m, MonadBaseControl IO m)
=> Buffer a -> (InBasket a -> m i)
-> (OutBasket a -> m r) -> m r
withBuffer buffer sendIn readOut =
bracket
(liftBase openBasket)
(\(_, _, _, seal) -> liftBase (STM.atomically seal)) $
\(writeB, readB, sealed, seal) ->
snd <$> concurrently (withIn writeB sealed seal)
(withOut readB sealed seal)
where
openBasket = do
(writeB, readB) <- case buffer of
Bounded n -> do
q <- STM.newTBQueueIO n
return (STM.writeTBQueue q, STM.readTBQueue q)
Unbounded -> do
q <- STM.newTQueueIO
return (STM.writeTQueue q, STM.readTQueue q)
Single -> do
m <- STM.newEmptyTMVarIO
return (STM.putTMVar m, STM.takeTMVar m)
Latest a -> do
t <- STM.newTVarIO a
return (STM.writeTVar t, STM.readTVar t)
New -> do
m <- STM.newEmptyTMVarIO
return (\x -> STM.tryTakeTMVar m *> STM.putTMVar m x, STM.takeTMVar m)
Newest n -> do
q <- STM.newTBQueueIO n
let writeB x = STM.writeTBQueue q x <|> (STM.tryReadTBQueue q *> writeB x)
return (writeB, STM.readTBQueue q)
sealed <- STM.newTVarIO False
let seal = STM.writeTVar sealed True
return (writeB, readB, sealed, seal)
withIn writeB sealed seal =
bracket_ (return ())
(liftBase (STM.atomically seal))
(sendIn (InBasket sendOrEnd))
where
sendOrEnd a = do
canWrite <- not <$> STM.readTVar sealed
when canWrite (writeB a)
return canWrite
withOut readB sealed seal =
bracket_ (return ())
(liftBase (STM.atomically seal))
(readOut (OutBasket readOrEnd))
where
readOrEnd = (Just <$> readB) <|> (do
b <- STM.readTVar sealed
STM.check b
return Nothing )