{-# OPTIONS_GHC -fno-warn-deprecations #-}
-- |
-- Module      : Streamly.Internal.Data.Stream.SVar.Eliminate
-- Copyright   : (c) 2017 Composewell Technologies
-- License     : BSD-3-Clause
-- Maintainer  : streamly@composewell.com
-- Stability   : experimental
-- Portability : GHC
--
-- Eliminate a stream by distributing it to multiple SVars concurrently.
--
module Streamly.Internal.Data.Stream.SVar.Eliminate
    (
    -- * Concurrent Function Application
      toSVarParallel

    -- * Concurrent folds
    -- $concurrentFolds
    , newFoldSVar
    , newFoldSVarF

    , fromConsumer
    , pushToFold
    , teeToSVar
    )
where

#include "inline.hs"

import Control.Concurrent (myThreadId, takeMVar)
import Control.Monad (when, void)
import Control.Monad.Catch (throwM)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Data.IORef (newIORef, readIORef, writeIORef)
import Streamly.Internal.Control.Concurrent (MonadAsync)
import Streamly.Internal.Control.ForkLifted (doFork)
import Streamly.Internal.Data.Atomics (atomicModifyIORefCAS_)
import Streamly.Internal.Data.Fold.SVar (write, writeLimited)
import Streamly.Internal.Data.Fold.Type (Fold(..))
import Streamly.Internal.Data.Stream.Serial (SerialT)
import Streamly.Internal.Data.Time.Clock (Clock(Monotonic), getTime)

import qualified Streamly.Internal.Data.Stream.StreamD.Type as D
    (Stream(..), Step(..), fold)
import qualified Streamly.Internal.Data.Stream.StreamK.Type as K
    (Stream, mkStream, foldStream, foldStreamShared, nilM)
import qualified Streamly.Internal.Data.Stream.Serial as Stream
    (fromStreamK, toStreamK)

import Streamly.Internal.Data.SVar

-------------------------------------------------------------------------------
-- Concurrent function application
-------------------------------------------------------------------------------

-- Using StreamD the worker stream producing code can fuse with the code to
-- queue output to the SVar giving some perf boost.
--
-- Note that StreamD can only be used in limited situations, specifically, we
-- cannot implement joinStreamVarPar using this.
--
-- XXX make sure that the SVar passed is a Parallel style SVar.

-- | Fold the supplied stream to the SVar asynchronously using Parallel
-- concurrency style.
-- {-# INLINE_NORMAL toSVarParallel #-}
{-# INLINE toSVarParallel #-}
toSVarParallel :: MonadAsync m
    => State t m a -> SVar t m a -> D.Stream m a -> m ()
toSVarParallel :: forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadAsync m =>
State t m a -> SVar t m a -> Stream m a -> m ()
toSVarParallel State t m a
st SVar t m a
sv Stream m a
xs =
    if forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> Bool
svarInspectMode SVar t m a
sv
    then m ()
forkWithDiag
    else do
        ThreadId
tid <-
                case forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
State t m a -> Maybe Count
getYieldLimit State t m a
st of
                    Maybe Count
Nothing -> forall (m :: * -> *).
MonadRunInIO m =>
m () -> RunInIO m -> (SomeException -> IO ()) -> m ThreadId
doFork (Maybe WorkerInfo -> m ()
work forall a. Maybe a
Nothing)
                                      (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> RunInIO m
svarMrun SVar t m a
sv)
                                      (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> SomeException -> IO ()
handleChildException SVar t m a
sv)
                    Just Count
_  -> forall (m :: * -> *).
MonadRunInIO m =>
m () -> RunInIO m -> (SomeException -> IO ()) -> m ThreadId
doFork (Maybe WorkerInfo -> m ()
workLim forall a. Maybe a
Nothing)
                                      (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> RunInIO m
svarMrun SVar t m a
sv)
                                      (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> SomeException -> IO ()
handleChildException SVar t m a
sv)
        forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadIO m =>
SVar t m a -> ThreadId -> m ()
modifyThread SVar t m a
sv ThreadId
tid

    where

    {-# NOINLINE work #-}
    work :: Maybe WorkerInfo -> m ()
work Maybe WorkerInfo
info = forall (m :: * -> *) a b.
Monad m =>
Fold m a b -> Stream m a -> m b
D.fold (forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadIO m =>
SVar t m a -> Maybe WorkerInfo -> Fold m a ()
write SVar t m a
sv Maybe WorkerInfo
info) Stream m a
xs

    {-# NOINLINE workLim #-}
    workLim :: Maybe WorkerInfo -> m ()
workLim Maybe WorkerInfo
info = forall (m :: * -> *) a b.
Monad m =>
Fold m a b -> Stream m a -> m b
D.fold (forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadIO m =>
SVar t m a -> Maybe WorkerInfo -> Fold m a ()
writeLimited SVar t m a
sv Maybe WorkerInfo
info) Stream m a
xs

    {-# NOINLINE forkWithDiag #-}
    forkWithDiag :: m ()
forkWithDiag = do
        -- We do not use workerCount in case of ParallelVar but still there is
        -- no harm in maintaining it correctly.
        forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall t. IORef t -> (t -> t) -> IO ()
atomicModifyIORefCAS_ (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> IORef Int
workerCount SVar t m a
sv) forall a b. (a -> b) -> a -> b
$ \Int
n -> Int
n forall a. Num a => a -> a -> a
+ Int
1
        forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadIO m =>
SVar t m a -> m ()
recordMaxWorkers SVar t m a
sv
        -- This allocation matters when significant number of workers are being
        -- sent. We allocate it only when needed. The overhead increases by 4x.
        Maybe WorkerInfo
winfo <-
            case forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> Maybe YieldRateInfo
yieldRateInfo SVar t m a
sv of
                Maybe YieldRateInfo
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
                Just YieldRateInfo
_ -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
                    IORef Count
cntRef <- forall a. a -> IO (IORef a)
newIORef Count
0
                    AbsTime
t <- Clock -> IO AbsTime
getTime Clock
Monotonic
                    IORef (Count, AbsTime)
lat <- forall a. a -> IO (IORef a)
newIORef (Count
0, AbsTime
t)
                    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just WorkerInfo
                        { workerYieldMax :: Count
workerYieldMax = Count
0
                        , workerYieldCount :: IORef Count
workerYieldCount = IORef Count
cntRef
                        , workerLatencyStart :: IORef (Count, AbsTime)
workerLatencyStart = IORef (Count, AbsTime)
lat
                        }
        ThreadId
tid <-
            case forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
State t m a -> Maybe Count
getYieldLimit State t m a
st of
                Maybe Count
Nothing -> forall (m :: * -> *).
MonadRunInIO m =>
m () -> RunInIO m -> (SomeException -> IO ()) -> m ThreadId
doFork (Maybe WorkerInfo -> m ()
work Maybe WorkerInfo
winfo)
                                  (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> RunInIO m
svarMrun SVar t m a
sv)
                                  (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> SomeException -> IO ()
handleChildException SVar t m a
sv)
                Just Count
_  -> forall (m :: * -> *).
MonadRunInIO m =>
m () -> RunInIO m -> (SomeException -> IO ()) -> m ThreadId
doFork (Maybe WorkerInfo -> m ()
workLim Maybe WorkerInfo
winfo)
                                  (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> RunInIO m
svarMrun SVar t m a
sv)
                                  (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> SomeException -> IO ()
handleChildException SVar t m a
sv)
        forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadIO m =>
SVar t m a -> ThreadId -> m ()
modifyThread SVar t m a
sv ThreadId
tid

-------------------------------------------------------------------------------
-- Support for running folds concurrently
-------------------------------------------------------------------------------

-- $concurrentFolds
--
-- To run folds concurrently, we need to decouple the fold execution from the
-- stream production. We use the SVar to do that, we have a single worker
-- pushing the stream elements to the SVar and on the consumer side a fold
-- driver pulls the values and folds them.
--
-- @
--
-- Fold worker <------SVar<------input stream
--     |  exceptions  |
--     --------------->
--
-- @
--
-- We need a channel for pushing exceptions from the fold worker to the stream
-- pusher. The stream may be pushed to multiple folds at the same time. For
-- that we need one SVar per fold:
--
-- @
--
-- Fold worker <------SVar<---
--                    |       |
-- Fold worker <------SVar<------input stream
--                    |       |
-- Fold worker <------SVar<---
--
-- @
--
-- Unlike in case concurrent stream evaluation, the puller does not drive the
-- scheduling and concurrent execution of the stream. The stream is simply
-- pushed by the stream producer at its own rate. The fold worker just pulls it
-- and folds it.
--
-- Note: If the stream pusher terminates due to an exception, we do not
-- actively terminate the fold. It gets cleaned up by the GC.

-------------------------------------------------------------------------------
-- Process events received by a fold consumer from a stream producer
-------------------------------------------------------------------------------

-- | Pull a stream from an SVar to fold it. Like 'fromSVar' except that it does
-- not drive the evaluation of the stream. It just pulls whatever is available
-- on the SVar. Also, when the fold stops it sends a notification to the stream
-- pusher/producer. No exceptions are expected to be propagated from the stream
-- pusher to the fold puller.
--
{-# NOINLINE fromProducer #-}
fromProducer :: forall m a . MonadAsync m => SVar K.Stream m a -> K.Stream m a
fromProducer :: forall (m :: * -> *) a.
MonadAsync m =>
SVar Stream m a -> Stream m a
fromProducer SVar Stream m a
sv = forall (m :: * -> *) a.
(forall r.
 State Stream m a
 -> (a -> StreamK m a -> m r) -> (a -> m r) -> m r -> m r)
-> StreamK m a
K.mkStream forall a b. (a -> b) -> a -> b
$ \State Stream m a
st a -> StreamK m a -> m r
yld a -> m r
sng m r
stp -> do
    [ChildEvent a]
list <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> m [ChildEvent a]
readOutputQ SVar Stream m a
sv
    -- Reversing the output is important to guarantee that we process the
    -- outputs in the same order as they were generated by the constituent
    -- streams.
    forall (m :: * -> *) a r.
State Stream m a
-> (a -> StreamK m a -> m r)
-> (a -> m r)
-> m r
-> StreamK m a
-> m r
K.foldStream State Stream m a
st a -> StreamK m a -> m r
yld a -> m r
sng m r
stp forall a b. (a -> b) -> a -> b
$ [ChildEvent a] -> StreamK m a
processEvents forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [ChildEvent a]
list

    where

    allDone :: m r -> m r
    allDone :: forall r. m r -> m r
allDone m r
stp = do
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> Bool
svarInspectMode SVar Stream m a
sv) forall a b. (a -> b) -> a -> b
$ do
            AbsTime
t <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Clock -> IO AbsTime
getTime Clock
Monotonic
            forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> a -> IO ()
writeIORef (SVarStats -> IORef (Maybe AbsTime)
svarStopTime (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> SVarStats
svarStats SVar Stream m a
sv)) (forall a. a -> Maybe a
Just AbsTime
t)
            forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> String -> IO ()
printSVar SVar Stream m a
sv String
"SVar Done"
        forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadIO m =>
SVar t m a -> m ()
sendStopToProducer SVar Stream m a
sv
        m r
stp

    {-# INLINE processEvents #-}
    processEvents :: [ChildEvent a] -> K.Stream m a
    processEvents :: [ChildEvent a] -> StreamK m a
processEvents [] = forall (m :: * -> *) a.
(forall r.
 State Stream m a
 -> (a -> StreamK m a -> m r) -> (a -> m r) -> m r -> m r)
-> StreamK m a
K.mkStream forall a b. (a -> b) -> a -> b
$ \State Stream m a
st a -> StreamK m a -> m r
yld a -> m r
sng m r
stp -> do
        forall (m :: * -> *) a r.
State Stream m a
-> (a -> StreamK m a -> m r)
-> (a -> m r)
-> m r
-> StreamK m a
-> m r
K.foldStream State Stream m a
st a -> StreamK m a -> m r
yld a -> m r
sng m r
stp forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadAsync m =>
SVar Stream m a -> Stream m a
fromProducer SVar Stream m a
sv

    processEvents (ChildEvent a
ev : [ChildEvent a]
es) = forall (m :: * -> *) a.
(forall r.
 State Stream m a
 -> (a -> StreamK m a -> m r) -> (a -> m r) -> m r -> m r)
-> StreamK m a
K.mkStream forall a b. (a -> b) -> a -> b
$ \State Stream m a
_ a -> StreamK m a -> m r
yld a -> m r
_ m r
stp -> do
        let rest :: StreamK m a
rest = [ChildEvent a] -> StreamK m a
processEvents [ChildEvent a]
es
        case ChildEvent a
ev of
            ChildYield a
a -> a -> StreamK m a -> m r
yld a
a StreamK m a
rest
            ChildStop ThreadId
tid Maybe SomeException
e -> do
                forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> ThreadId -> m ()
accountThread SVar Stream m a
sv ThreadId
tid
                case Maybe SomeException
e of
                    Maybe SomeException
Nothing -> forall r. m r -> m r
allDone m r
stp
                    Just SomeException
_ -> forall a. HasCallStack => String -> a
error String
"Bug: fromProducer: received exception"

-- | Create a Fold style SVar that runs a supplied fold function as the
-- consumer.  Any elements sent to the SVar are consumed by the supplied fold
-- function.
--
{-# INLINE newFoldSVar #-}
newFoldSVar :: MonadAsync m
    => State K.Stream m a -> (SerialT m a -> m b) -> m (SVar K.Stream m a)
newFoldSVar :: forall (m :: * -> *) a b.
MonadAsync m =>
State Stream m a -> (SerialT m a -> m b) -> m (SVar Stream m a)
newFoldSVar State Stream m a
stt SerialT m a -> m b
f = do
    -- Buffer size for the SVar is derived from the current state
    SVar Stream m a
sv <- forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadAsync m =>
SVarStopStyle -> State t m a -> m (SVar t m a)
newParallelVar SVarStopStyle
StopAny (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a (n :: * -> *) b.
State t m a -> State t n b
adaptState State Stream m a
stt)

    -- Add the producer thread-id to the SVar.
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ThreadId
myThreadId forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadIO m =>
SVar t m a -> ThreadId -> m ()
modifyThread SVar Stream m a
sv

    forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadRunInIO m =>
m () -> RunInIO m -> (SomeException -> IO ()) -> m ThreadId
doFork (forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ SerialT m a -> m b
f forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Stream m a -> SerialT m a
Stream.fromStreamK forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadAsync m =>
SVar Stream m a -> Stream m a
fromProducer SVar Stream m a
sv)
                  (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> RunInIO m
svarMrun SVar Stream m a
sv)
                  (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> SomeException -> IO ()
handleFoldException SVar Stream m a
sv)
    forall (m :: * -> *) a. Monad m => a -> m a
return SVar Stream m a
sv

data FromSVarState t m a =
      FromSVarInit
    | FromSVarRead (SVar t m a)
    | FromSVarLoop (SVar t m a) [ChildEvent a]
    | FromSVarDone (SVar t m a)

-- | Like 'fromProducer' but generates a StreamD style stream instead of
-- StreamK.
--
{-# INLINE_NORMAL fromProducerD #-}
fromProducerD :: (MonadAsync m) => SVar t m a -> D.Stream m a
fromProducerD :: forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadAsync m =>
SVar t m a -> Stream m a
fromProducerD SVar t m a
svar = forall (m :: * -> *) a s.
(State Stream m a -> s -> m (Step s a)) -> s -> Stream m a
D.Stream forall {m :: * -> *} {p} {t :: (* -> *) -> * -> *} {a}.
MonadIO m =>
p -> FromSVarState t m a -> m (Step (FromSVarState t m a) a)
step (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> FromSVarState t m a
FromSVarRead SVar t m a
svar)
    where

    {-# INLINE_LATE step #-}
    step :: p -> FromSVarState t m a -> m (Step (FromSVarState t m a) a)
step p
_ (FromSVarRead SVar t m a
sv) = do
        [ChildEvent a]
list <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> m [ChildEvent a]
readOutputQ SVar t m a
sv
        -- Reversing the output is important to guarantee that we process the
        -- outputs in the same order as they were generated by the constituent
        -- streams.
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall s a. s -> Step s a
D.Skip forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> [ChildEvent a] -> FromSVarState t m a
FromSVarLoop SVar t m a
sv (forall a. [a] -> [a]
Prelude.reverse [ChildEvent a]
list)

    step p
_ (FromSVarLoop SVar t m a
sv []) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall s a. s -> Step s a
D.Skip forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> FromSVarState t m a
FromSVarRead SVar t m a
sv
    step p
_ (FromSVarLoop SVar t m a
sv (ChildEvent a
ev : [ChildEvent a]
es)) = do
        case ChildEvent a
ev of
            ChildYield a
a -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall s a. a -> s -> Step s a
D.Yield a
a (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> [ChildEvent a] -> FromSVarState t m a
FromSVarLoop SVar t m a
sv [ChildEvent a]
es)
            ChildStop ThreadId
tid Maybe SomeException
e -> do
                forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> ThreadId -> m ()
accountThread SVar t m a
sv ThreadId
tid
                case Maybe SomeException
e of
                    Maybe SomeException
Nothing -> do
                        forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadIO m =>
SVar t m a -> m ()
sendStopToProducer SVar t m a
sv
                        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall s a. s -> Step s a
D.Skip (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> FromSVarState t m a
FromSVarDone SVar t m a
sv)
                    Just SomeException
_ -> forall a. HasCallStack => String -> a
error String
"Bug: fromProducer: received exception"

    step p
_ (FromSVarDone SVar t m a
sv) = do
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> Bool
svarInspectMode SVar t m a
sv) forall a b. (a -> b) -> a -> b
$ do
            AbsTime
t <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Clock -> IO AbsTime
getTime Clock
Monotonic
            forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> a -> IO ()
writeIORef (SVarStats -> IORef (Maybe AbsTime)
svarStopTime (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> SVarStats
svarStats SVar t m a
sv)) (forall a. a -> Maybe a
Just AbsTime
t)
            forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> String -> IO ()
printSVar SVar t m a
sv String
"SVar Done"
        forall (m :: * -> *) a. Monad m => a -> m a
return forall s a. Step s a
D.Stop

    step p
_ FromSVarState t m a
FromSVarInit = forall a. HasCallStack => a
undefined

-- | Like 'newFoldSVar' except that it uses a 'Fold' instead of a fold
-- function.
--
{-# INLINE newFoldSVarF #-}
newFoldSVarF :: MonadAsync m => State t m a -> Fold m a b -> m (SVar t m a)
newFoldSVarF :: forall (m :: * -> *) (t :: (* -> *) -> * -> *) a b.
MonadAsync m =>
State t m a -> Fold m a b -> m (SVar t m a)
newFoldSVarF State t m a
stt Fold m a b
f = do
    -- Buffer size for the SVar is derived from the current state
    SVar t m a
sv <- forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadAsync m =>
SVarStopStyle -> State t m a -> m (SVar t m a)
newParallelVar SVarStopStyle
StopAny (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a (n :: * -> *) b.
State t m a -> State t n b
adaptState State t m a
stt)
    -- Add the producer thread-id to the SVar.
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ThreadId
myThreadId forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadIO m =>
SVar t m a -> ThreadId -> m ()
modifyThread SVar t m a
sv
    forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadRunInIO m =>
m () -> RunInIO m -> (SomeException -> IO ()) -> m ThreadId
doFork (forall {t :: (* -> *) -> * -> *}. SVar t m a -> m ()
work SVar t m a
sv) (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> RunInIO m
svarMrun SVar t m a
sv) (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> SomeException -> IO ()
handleFoldException SVar t m a
sv)
    forall (m :: * -> *) a. Monad m => a -> m a
return SVar t m a
sv

    where

    {-# NOINLINE work #-}
    work :: SVar t m a -> m ()
work SVar t m a
sv = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b.
Monad m =>
Fold m a b -> Stream m a -> m b
D.fold Fold m a b
f forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (t :: (* -> *) -> * -> *) a.
MonadAsync m =>
SVar t m a -> Stream m a
fromProducerD SVar t m a
sv

-------------------------------------------------------------------------------
-- Process events received by the producer thread from the consumer side
-------------------------------------------------------------------------------

-- XXX currently only one event is sent by a fold consumer to the stream
-- producer. But we can potentially have multiple events e.g. the fold step can
-- generate exception more than once and the producer can ignore those
-- exceptions or handle them and still keep driving the fold.
--
-- | Poll for events sent by the fold consumer to the stream pusher. The fold
-- consumer can send a "Stop" event or an exception. When a "Stop" is received
-- this function returns 'True'. If an exception is recieved then it throws the
-- exception.
--
{-# NOINLINE fromConsumer #-}
fromConsumer :: MonadAsync m => SVar K.Stream m a -> m Bool
fromConsumer :: forall (m :: * -> *) a. MonadAsync m => SVar Stream m a -> m Bool
fromConsumer SVar Stream m a
sv = do
    ([ChildEvent a]
list, Int
_) <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. IORef ([ChildEvent a], Int) -> IO ([ChildEvent a], Int)
readOutputQBasic (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> IORef ([ChildEvent a], Int)
outputQueueFromConsumer SVar Stream m a
sv)
    -- Reversing the output is important to guarantee that we process the
    -- outputs in the same order as they were generated by the constituent
    -- streams.
    forall {m :: * -> *} {a}. MonadThrow m => [ChildEvent a] -> m Bool
processEvents forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [ChildEvent a]
list

    where

    {-# INLINE processEvents #-}
    processEvents :: [ChildEvent a] -> m Bool
processEvents [] = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    processEvents (ChildEvent a
ev : [ChildEvent a]
_) = do
        case ChildEvent a
ev of
            ChildStop ThreadId
_ Maybe SomeException
e -> do
                case Maybe SomeException
e of
                    Maybe SomeException
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
                    Just SomeException
ex -> forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM SomeException
ex
            ChildYield a
_ -> forall a. HasCallStack => String -> a
error String
"Bug: fromConsumer: invalid ChildYield event"

-- | Push values from a stream to a fold worker via an SVar. Before pushing a
-- value to the SVar it polls for events received from the fold consumer.  If a
-- stop event is received then it returns 'True' otherwise false.  Propagates
-- exceptions received from the fold consumer.
--
{-# INLINE pushToFold #-}
pushToFold :: MonadAsync m => SVar K.Stream m a -> a -> m Bool
pushToFold :: forall (m :: * -> *) a.
MonadAsync m =>
SVar Stream m a -> a -> m Bool
pushToFold SVar Stream m a
sv a
a = do
    -- Check for exceptions before decrement so that we do not
    -- block forever if the child already exited with an exception.
    --
    -- We avoid a race between the consumer fold sending an event and we
    -- blocking on decrementBufferLimit by waking up the producer thread in
    -- sendToProducer before any event is sent by the fold to the producer
    -- stream.
    let qref :: IORef ([ChildEvent a], Int)
qref = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> IORef ([ChildEvent a], Int)
outputQueueFromConsumer SVar Stream m a
sv
    Bool
done <- do
        ([ChildEvent a]
_, Int
n) <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> IO a
readIORef IORef ([ChildEvent a], Int)
qref
        if Int
n forall a. Ord a => a -> a -> Bool
> Int
0
        then forall (m :: * -> *) a. MonadAsync m => SVar Stream m a -> m Bool
fromConsumer SVar Stream m a
sv
        else forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    if Bool
done
    then forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    else forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
        forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> IO ()
decrementBufferLimit SVar Stream m a
sv
        forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> ChildEvent a -> IO Int
send SVar Stream m a
sv (forall a. a -> ChildEvent a
ChildYield a
a)
        forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

------------------------------------------------------------------------------
-- Clone and distribute a stream in parallel
------------------------------------------------------------------------------

-- XXX this could be written in StreamD style for better efficiency with fusion.
--
-- | Tap a stream and send the elements to the specified SVar in addition to
-- yielding them again. The SVar runs a fold consumer. Elements are tapped and
-- sent to the SVar until the fold finishes. Any exceptions from the fold
-- evaluation are propagated in the current thread.
--
-- @
--
-- ------input stream---------output stream----->
--                    /|\\   |
--         exceptions  |    |  input
--                     |   \\|/
--                     ----SVar
--                          |
--                         Fold
--
-- @
--
{-# INLINE teeToSVar #-}
teeToSVar :: MonadAsync m =>
    SVar K.Stream m a -> SerialT m a -> SerialT m a
teeToSVar :: forall (m :: * -> *) a.
MonadAsync m =>
SVar Stream m a -> SerialT m a -> SerialT m a
teeToSVar SVar Stream m a
svr SerialT m a
m = forall (m :: * -> *) a. Stream m a -> SerialT m a
Stream.fromStreamK forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
(forall r.
 State Stream m a
 -> (a -> StreamK m a -> m r) -> (a -> m r) -> m r -> m r)
-> StreamK m a
K.mkStream forall a b. (a -> b) -> a -> b
$ \State Stream m a
st a -> StreamK m a -> m r
yld a -> m r
sng m r
stp -> do
    forall (m :: * -> *) a r.
State Stream m a
-> (a -> StreamK m a -> m r)
-> (a -> m r)
-> m r
-> StreamK m a
-> m r
K.foldStreamShared State Stream m a
st a -> StreamK m a -> m r
yld a -> m r
sng m r
stp (Bool -> StreamK m a -> StreamK m a
go Bool
False forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. SerialT m a -> Stream m a
Stream.toStreamK SerialT m a
m)

    where

    go :: Bool -> StreamK m a -> StreamK m a
go Bool
False StreamK m a
m0 = forall (m :: * -> *) a.
(forall r.
 State Stream m a
 -> (a -> StreamK m a -> m r) -> (a -> m r) -> m r -> m r)
-> StreamK m a
K.mkStream forall a b. (a -> b) -> a -> b
$ \State Stream m a
st a -> StreamK m a -> m r
yld a -> m r
_ m r
stp -> do
        let drain :: m ()
drain = do
                -- In general, a Stop event would come equipped with the result
                -- of the fold. It is not used here but it would be useful in
                -- applicative and distribute.
                Bool
done <- forall (m :: * -> *) a. MonadAsync m => SVar Stream m a -> m Bool
fromConsumer SVar Stream m a
svr
                forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not Bool
done) forall a b. (a -> b) -> a -> b
$ do
                    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> String -> IO () -> IO ()
withDiagMVar SVar Stream m a
svr String
"teeToSVar: waiting to drain"
                           forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> IO a
takeMVar (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> MVar ()
outputDoorBellFromConsumer SVar Stream m a
svr)
                    m ()
drain

            stopFold :: m ()
stopFold = do
                forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
SVar t m a -> Maybe WorkerInfo -> IO ()
sendStop SVar Stream m a
svr forall a. Maybe a
Nothing
                -- drain/wait until a stop event arrives from the fold.
                m ()
drain

            stop :: m r
stop       = m ()
stopFold forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m r
stp
            single :: a -> m r
single a
a   = do
                Bool
done <- forall (m :: * -> *) a.
MonadAsync m =>
SVar Stream m a -> a -> m Bool
pushToFold SVar Stream m a
svr a
a
                a -> StreamK m a -> m r
yld a
a (Bool -> StreamK m a -> StreamK m a
go Bool
done (forall (m :: * -> *) b a. Applicative m => m b -> StreamK m a
K.nilM m ()
stopFold))
            yieldk :: a -> StreamK m a -> m r
yieldk a
a StreamK m a
r = forall (m :: * -> *) a.
MonadAsync m =>
SVar Stream m a -> a -> m Bool
pushToFold SVar Stream m a
svr a
a forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Bool
done -> a -> StreamK m a -> m r
yld a
a (Bool -> StreamK m a -> StreamK m a
go Bool
done StreamK m a
r)
         in forall (m :: * -> *) a r.
State Stream m a
-> (a -> StreamK m a -> m r)
-> (a -> m r)
-> m r
-> StreamK m a
-> m r
K.foldStreamShared State Stream m a
st a -> StreamK m a -> m r
yieldk a -> m r
single m r
stop StreamK m a
m0

    go Bool
True StreamK m a
m0 = StreamK m a
m0