{-# LANGUAGE CPP #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UnboxedTuples #-}
module Streamly.SVar
(
MonadAsync
, SVar (..)
, SVarStyle (..)
, defaultMaxBuffer
, defaultMaxThreads
, State (..)
, defState
, rstState
, newAheadVar
, newParallelVar
, toStreamVar
, atomicModifyIORefCAS
, ChildEvent (..)
, AheadHeapEntry (..)
, sendYield
, sendStop
, enqueueLIFO
, workLoopLIFO
, workLoopFIFO
, enqueueFIFO
, enqueueAhead
, pushWorkerPar
, queueEmptyAhead
, dequeueAhead
, dequeueFromHeap
, postProcessBounded
, readOutputQBounded
, sendWorker
, delThread
)
where
import Control.Concurrent
(ThreadId, myThreadId, threadDelay, getNumCapabilities)
import Control.Concurrent.MVar
(MVar, newEmptyMVar, tryPutMVar, takeMVar)
import Control.Exception (SomeException(..), catch, mask)
import Control.Monad (when)
import Control.Monad.Catch (MonadThrow)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Control (MonadBaseControl, control)
import Data.Atomics
(casIORef, readForCAS, peekTicket, atomicModifyIORefCAS_,
writeBarrier, storeLoadBarrier)
import Data.Concurrent.Queue.MichaelScott
(LinkedQueue, pushL, tryPopR)
import Data.Functor (void)
import Data.Heap (Heap, Entry(..))
import Data.IORef
(IORef, modifyIORef, newIORef, readIORef, atomicModifyIORef)
import Data.Maybe (fromJust)
import Data.Set (Set)
import GHC.Conc (ThreadId(..))
import GHC.Exts
import GHC.IO (IO(..))
import qualified Data.Heap as H
import qualified Data.Set as S
#ifdef DIAGNOSTICS
import Control.Concurrent.MVar (tryTakeMVar)
import Control.Exception
(catches, throwIO, Handler(..), BlockedIndefinitelyOnMVar(..),
BlockedIndefinitelyOnSTM(..))
import Data.IORef (writeIORef)
import System.IO (hPutStrLn, stderr)
#endif
data ChildEvent a =
ChildYield a
| ChildStop ThreadId (Maybe SomeException)
data AheadHeapEntry (t :: (* -> *) -> * -> *) m a =
AheadEntryPure a
| AheadEntryStream (t m a)
data SVarStyle =
AsyncVar
| WAsyncVar
| ParallelVar
| AheadVar
deriving (Eq, Show)
data SVar t m a =
SVar {
svarStyle :: SVarStyle
, outputQueue :: IORef ([ChildEvent a], Int)
, maxYieldLimit :: Maybe (IORef Int)
, outputDoorBell :: MVar ()
, readOutputQ :: m [ChildEvent a]
, postProcess :: m Bool
, enqueue :: t m a -> IO ()
, isWorkDone :: IO Bool
, needDoorBell :: IORef Bool
, workLoop :: m ()
, workerThreads :: IORef (Set ThreadId)
, workerCount :: IORef Int
, accountThread :: ThreadId -> m ()
#ifdef DIAGNOSTICS
, outputHeap :: IORef (Heap (Entry Int (AheadHeapEntry t m a))
, Int
)
, aheadWorkQueue :: IORef ([t m a], Int)
, totalDispatches :: IORef Int
, maxWorkers :: IORef Int
, maxOutQSize :: IORef Int
, maxHeapSize :: IORef Int
, maxWorkQSize :: IORef Int
#endif
}
data State t m a = State
{ streamVar :: Maybe (SVar t m a)
, yieldLimit :: Maybe Int
, threadsHigh :: Int
, bufferHigh :: Int
}
defaultMaxThreads, defaultMaxBuffer :: Int
defaultMaxThreads = 1500
defaultMaxBuffer = 1500
defState :: State t m a
defState = State
{ streamVar = Nothing
, yieldLimit = Nothing
, threadsHigh = defaultMaxThreads
, bufferHigh = defaultMaxBuffer
}
-- XXX if perf gets affected we can have all the Nothing params in a single
-- structure so that we reset is fast. We can also use rewrite rules such that
-- reset occurs only in concurrent streams to reduce the impact on serial
-- streams.
-- We can optimize this so that we clear it only if it is a Just value, it
-- results in slightly better perf for zip/zipM but the performance of scan
-- worsens a lot, it does not fuse.
rstState :: State t m a -> State t m b
rstState st = st
{ streamVar = Nothing
, yieldLimit = Nothing
}
#ifdef DIAGNOSTICS
{-# NOINLINE dumpSVar #-}
dumpSVar :: SVar t m a -> IO String
dumpSVar sv = do
tid <- myThreadId
(oqList, oqLen) <- readIORef $ outputQueue sv
db <- tryTakeMVar $ outputDoorBell sv
aheadDump <-
if svarStyle sv == AheadVar
then do
(oheap, oheapSeq) <- readIORef $ outputHeap sv
(wq, wqSeq) <- readIORef $ aheadWorkQueue sv
maxHp <- readIORef $ maxHeapSize sv
return $ unlines
[ "heap length = " ++ show (H.size oheap)
, "heap seqeunce = " ++ show oheapSeq
, "work queue length = " ++ show (length wq)
, "work queue sequence = " ++ show wqSeq
, "heap max size = " ++ show maxHp
]
else return []
waiting <- readIORef $ needDoorBell sv
rthread <- readIORef $ workerThreads sv
workers <- readIORef $ workerCount sv
maxWrk <- readIORef $ maxWorkers sv
dispatches <- readIORef $ totalDispatches sv
maxOq <- readIORef $ maxOutQSize sv
return $ unlines
[ "tid = " ++ show tid
, "style = " ++ show (svarStyle sv)
, "outputQueue length computed = " ++ show (length oqList)
, "outputQueue length maintained = " ++ show oqLen
, "output outputDoorBell = " ++ show db
, "total dispatches = " ++ show dispatches
, "max workers = " ++ show maxWrk
, "max outQSize = " ++ show maxOq
]
++ aheadDump ++ unlines
[ "needDoorBell = " ++ show waiting
, "running threads = " ++ show rthread
, "running thread count = " ++ show workers
]
{-# NOINLINE mvarExcHandler #-}
mvarExcHandler :: SVar t m a -> String -> BlockedIndefinitelyOnMVar -> IO ()
mvarExcHandler sv label e@BlockedIndefinitelyOnMVar = do
svInfo <- dumpSVar sv
hPutStrLn stderr $ label ++ " " ++ "BlockedIndefinitelyOnMVar\n" ++ svInfo
throwIO e
{-# NOINLINE stmExcHandler #-}
stmExcHandler :: SVar t m a -> String -> BlockedIndefinitelyOnSTM -> IO ()
stmExcHandler sv label e@BlockedIndefinitelyOnSTM = do
svInfo <- dumpSVar sv
hPutStrLn stderr $ label ++ " " ++ "BlockedIndefinitelyOnSTM\n" ++ svInfo
throwIO e
withDBGMVar :: SVar t m a -> String -> IO () -> IO ()
withDBGMVar sv label action =
action `catches` [ Handler (mvarExcHandler sv label)
, Handler (stmExcHandler sv label)
]
#else
withDBGMVar :: SVar t m a -> String -> IO () -> IO ()
withDBGMVar _ _ action = action
#endif
{-# INLINE atomicModifyIORefCAS #-}
atomicModifyIORefCAS :: IORef a -> (a -> (a,b)) -> IO b
atomicModifyIORefCAS ref fn = do
tkt <- readForCAS ref
loop tkt retries
where
retries = 25 :: Int
loop _ 0 = atomicModifyIORef ref fn
loop old tries = do
let (new, result) = fn $ peekTicket old
(success, tkt) <- casIORef ref old new
if success
then return result
else loop tkt (tries - 1)
type MonadAsync m = (MonadIO m, MonadBaseControl IO m, MonadThrow m)
{-# INLINE rawForkIO #-}
rawForkIO :: IO () -> IO ThreadId
rawForkIO action = IO $ \ s ->
case (fork# action s) of (# s1, tid #) -> (# s1, ThreadId tid #)
{-# INLINE doFork #-}
doFork :: MonadBaseControl IO m
=> m ()
-> (SomeException -> IO ())
-> m ThreadId
doFork action exHandler =
control $ \runInIO ->
mask $ \restore -> do
tid <- rawForkIO $ catch (restore $ void $ runInIO action)
exHandler
runInIO (return tid)
send :: Int -> SVar t m a -> ChildEvent a -> IO Bool
send maxOutputQLen sv msg = do
len <- atomicModifyIORefCAS (outputQueue sv) $ \(es, n) ->
((msg : es, n + 1), n)
when (len <= 0) $ do
writeBarrier
void $ tryPutMVar (outputDoorBell sv) ()
return (len < maxOutputQLen || maxOutputQLen < 0)
{-# NOINLINE sendYield #-}
sendYield :: Int -> SVar t m a -> ChildEvent a -> IO Bool
sendYield maxOutputQLen sv msg = do
ylimit <- case maxYieldLimit sv of
Nothing -> return True
Just ref -> atomicModifyIORefCAS ref $ \x -> (x - 1, x > 1)
r <- send maxOutputQLen sv msg
return $ r && ylimit
{-# NOINLINE sendStop #-}
sendStop :: SVar t m a -> IO ()
sendStop sv = do
liftIO $ atomicModifyIORefCAS_ (workerCount sv) $ \n -> n - 1
myThreadId >>= \tid -> void $ send (-1) sv (ChildStop tid Nothing)
{-# INLINE enqueueLIFO #-}
enqueueLIFO :: SVar t m a -> IORef [t m a] -> t m a -> IO ()
enqueueLIFO sv q m = do
atomicModifyIORefCAS_ q $ \ms -> m : ms
storeLoadBarrier
w <- readIORef $ needDoorBell sv
when w $ do
atomicModifyIORefCAS_ (needDoorBell sv) (const False)
void $ tryPutMVar (outputDoorBell sv) ()
{-# INLINE workLoopLIFO #-}
workLoopLIFO :: MonadIO m
=> (State t m a -> IORef [t m a] -> t m a -> m () -> m ())
-> State t m a -> IORef [t m a] -> m ()
workLoopLIFO f st q = run
where
sv = fromJust $ streamVar st
run = do
work <- dequeue
case work of
Nothing -> liftIO $ sendStop sv
Just m -> f st q m run
dequeue = liftIO $ atomicModifyIORefCAS q $ \case
[] -> ([], Nothing)
x : xs -> (xs, Just x)
{-# INLINE enqueueFIFO #-}
enqueueFIFO :: SVar t m a -> LinkedQueue (t m a) -> t m a -> IO ()
enqueueFIFO sv q m = do
pushL q m
storeLoadBarrier
w <- readIORef $ needDoorBell sv
when w $ do
atomicModifyIORefCAS_ (needDoorBell sv) (const False)
void $ tryPutMVar (outputDoorBell sv) ()
{-# INLINE workLoopFIFO #-}
workLoopFIFO :: MonadIO m
=> (State t m a -> LinkedQueue (t m a) -> t m a -> m () -> m ())
-> State t m a -> LinkedQueue (t m a) -> m ()
workLoopFIFO f st q = run
where
sv = fromJust $ streamVar st
run = do
work <- liftIO $ tryPopR q
case work of
Nothing -> liftIO $ sendStop sv
Just m -> f st q m run
{-# INLINE enqueueAhead #-}
enqueueAhead :: SVar t m a -> IORef ([t m a], Int) -> t m a -> IO ()
enqueueAhead sv q m = do
atomicModifyIORefCAS_ q $ \ case
([], n) -> ([m], n + 1)
_ -> error "not empty"
storeLoadBarrier
w <- readIORef $ needDoorBell sv
when w $ do
atomicModifyIORefCAS_ (needDoorBell sv) (const False)
void $ tryPutMVar (outputDoorBell sv) ()
{-# INLINE queueEmptyAhead #-}
queueEmptyAhead :: MonadIO m => IORef ([t m a], Int) -> m Bool
queueEmptyAhead q = liftIO $ do
(xs, _) <- readIORef q
return $ null xs
{-# INLINE dequeueAhead #-}
dequeueAhead :: MonadIO m
=> IORef ([t m a], Int) -> m (Maybe (t m a, Int))
dequeueAhead q = liftIO $ do
atomicModifyIORefCAS q $ \case
([], n) -> (([], n), Nothing)
(x : [], n) -> (([], n), Just (x, n))
_ -> error "more than one item on queue"
{-# INLINE dequeueFromHeap #-}
dequeueFromHeap
:: IORef (Heap (Entry Int (AheadHeapEntry t m a)), Int)
-> IO (Maybe (Entry Int (AheadHeapEntry t m a)))
dequeueFromHeap hpRef = do
atomicModifyIORefCAS hpRef $ \hp@(h, snum) -> do
let r = H.uncons h
case r of
Nothing -> (hp, Nothing)
Just (ent@(Entry seqNo _ev), hp') ->
if (seqNo == snum)
then ((hp', seqNo), Just ent)
else (hp, Nothing)
{-# NOINLINE addThread #-}
addThread :: MonadIO m => SVar t m a -> ThreadId -> m ()
addThread sv tid =
liftIO $ modifyIORef (workerThreads sv) (S.insert tid)
{-# INLINE delThread #-}
delThread :: MonadIO m => SVar t m a -> ThreadId -> m ()
delThread sv tid =
liftIO $ modifyIORef (workerThreads sv) $ (\s -> S.delete tid s)
{-# INLINE modifyThread #-}
modifyThread :: MonadIO m => SVar t m a -> ThreadId -> m ()
modifyThread sv tid = do
changed <- liftIO $ atomicModifyIORefCAS (workerThreads sv) $ \old ->
if (S.member tid old)
then let new = (S.delete tid old) in (new, new)
else let new = (S.insert tid old) in (new, old)
if null changed
then liftIO $ do
writeBarrier
void $ tryPutMVar (outputDoorBell sv) ()
else return ()
{-# INLINE allThreadsDone #-}
allThreadsDone :: MonadIO m => SVar t m a -> m Bool
allThreadsDone sv = liftIO $ S.null <$> readIORef (workerThreads sv)
{-# NOINLINE handleChildException #-}
handleChildException :: SVar t m a -> SomeException -> IO ()
handleChildException sv e = do
tid <- myThreadId
void $ send (-1) sv (ChildStop tid (Just e))
#ifdef DIAGNOSTICS
recordMaxWorkers :: MonadIO m => SVar t m a -> m ()
recordMaxWorkers sv = liftIO $ do
active <- readIORef (workerCount sv)
maxWrk <- readIORef (maxWorkers sv)
when (active > maxWrk) $ writeIORef (maxWorkers sv) active
modifyIORef (totalDispatches sv) (+1)
#endif
{-# NOINLINE pushWorker #-}
pushWorker :: MonadAsync m => SVar t m a -> m ()
pushWorker sv = do
liftIO $ atomicModifyIORefCAS_ (workerCount sv) $ \n -> n + 1
#ifdef DIAGNOSTICS
recordMaxWorkers sv
#endif
doFork (workLoop sv) (handleChildException sv) >>= addThread sv
{-# NOINLINE pushWorkerPar #-}
pushWorkerPar :: MonadAsync m => SVar t m a -> m () -> m ()
pushWorkerPar sv wloop = do
#ifdef DIAGNOSTICS
liftIO $ atomicModifyIORefCAS_ (workerCount sv) $ \n -> n + 1
recordMaxWorkers sv
#endif
doFork wloop (handleChildException sv) >>= modifyThread sv
dispatchWorker :: MonadAsync m => Int -> SVar t m a -> m ()
dispatchWorker maxWorkerLimit sv = do
done <- liftIO $ isWorkDone sv
when (not done) $ do
cnt <- liftIO $ readIORef $ workerCount sv
limit <- case maxYieldLimit sv of
Nothing -> return maxWorkerLimit
Just x -> do
lim <- liftIO $ readIORef x
return $
if maxWorkerLimit > 0
then min maxWorkerLimit lim
else lim
when (cnt < limit || limit < 0) $ pushWorker sv
{-# NOINLINE sendWorkerWait #-}
sendWorkerWait :: MonadAsync m => Int -> SVar t m a -> m ()
sendWorkerWait maxWorkerLimit sv = do
ncpu <- liftIO $ getNumCapabilities
if ncpu <= 1
then
if (svarStyle sv == AheadVar)
then liftIO $ threadDelay 100
else liftIO $ threadDelay 25
else
if (svarStyle sv == AheadVar)
then liftIO $ threadDelay 100
else liftIO $ threadDelay 10
(_, n) <- liftIO $ readIORef (outputQueue sv)
when (n <= 0) $ do
liftIO $ atomicModifyIORefCAS_ (needDoorBell sv) $ const True
liftIO $ storeLoadBarrier
dispatchWorker maxWorkerLimit sv
done <- liftIO $ isWorkDone sv
if done
then do
liftIO $ withDBGMVar sv "sendWorkerWait: nothing to do"
$ takeMVar (outputDoorBell sv)
(_, len) <- liftIO $ readIORef (outputQueue sv)
when (len <= 0) $ sendWorkerWait maxWorkerLimit sv
else sendWorkerWait maxWorkerLimit sv
{-# INLINE readOutputQRaw #-}
readOutputQRaw :: SVar t m a -> IO ([ChildEvent a], Int)
readOutputQRaw sv = do
(list, len) <- atomicModifyIORefCAS (outputQueue sv) $ \x -> (([],0), x)
#ifdef DIAGNOSTICS
oqLen <- readIORef (maxOutQSize sv)
when (len > oqLen) $ writeIORef (maxOutQSize sv) len
#endif
return (list, len)
readOutputQBounded :: MonadAsync m => Int -> SVar t m a -> m [ChildEvent a]
readOutputQBounded n sv = do
(list, len) <- liftIO $ readOutputQRaw sv
if len <= 0
then blockingRead
else do
sendOneWorker
return list
where
sendOneWorker = do
cnt <- liftIO $ readIORef $ workerCount sv
when (cnt <= 0) $ do
done <- liftIO $ isWorkDone sv
when (not done) $ pushWorker sv
{-# INLINE blockingRead #-}
blockingRead = do
sendWorkerWait n sv
liftIO $ (readOutputQRaw sv >>= return . fst)
postProcessBounded :: MonadAsync m => SVar t m a -> m Bool
postProcessBounded sv = do
workersDone <- allThreadsDone sv
if workersDone
then do
r <- liftIO $ isWorkDone sv
when (not r) $ pushWorker sv
return r
else return False
getAheadSVar :: MonadAsync m
=> State t m a
-> ( State t m a
-> IORef ([t m a], Int)
-> IORef (Heap (Entry Int (AheadHeapEntry t m a)), Int)
-> m ())
-> IO (SVar t m a)
getAheadSVar st f = do
outQ <- newIORef ([], 0)
outH <- newIORef (H.empty, 0)
outQMv <- newEmptyMVar
active <- newIORef 0
wfw <- newIORef False
running <- newIORef S.empty
q <- newIORef ([], -1)
yl <- case yieldLimit st of
Nothing -> return Nothing
Just x -> Just <$> newIORef x
#ifdef DIAGNOSTICS
disp <- newIORef 0
maxWrk <- newIORef 0
maxOq <- newIORef 0
maxHs <- newIORef 0
maxWq <- newIORef 0
#endif
let sv =
SVar { outputQueue = outQ
, maxYieldLimit = yl
, outputDoorBell = outQMv
, readOutputQ = readOutputQBounded (threadsHigh st) sv
, postProcess = postProcessBounded sv
, workerThreads = running
, workLoop = f st{streamVar = Just sv} q outH
, enqueue = enqueueAhead sv q
, isWorkDone = isWorkDoneAhead q outH
, needDoorBell = wfw
, svarStyle = AheadVar
, workerCount = active
, accountThread = delThread sv
#ifdef DIAGNOSTICS
, aheadWorkQueue = q
, outputHeap = outH
, totalDispatches = disp
, maxWorkers = maxWrk
, maxOutQSize = maxOq
, maxHeapSize = maxHs
, maxWorkQSize = maxWq
#endif
}
in return sv
where
{-# INLINE isWorkDoneAhead #-}
isWorkDoneAhead q ref = do
heapDone <- do
(hp, _) <- readIORef ref
return (H.size hp <= 0)
queueDone <- checkEmpty q
return $ queueDone && heapDone
checkEmpty q = do
(xs, _) <- readIORef q
return $ null xs
getParallelSVar :: MonadIO m => IO (SVar t m a)
getParallelSVar = do
outQ <- newIORef ([], 0)
outQMv <- newEmptyMVar
active <- newIORef 0
running <- newIORef S.empty
#ifdef DIAGNOSTICS
disp <- newIORef 0
maxWrk <- newIORef 0
maxOq <- newIORef 0
maxHs <- newIORef 0
maxWq <- newIORef 0
#endif
let sv =
SVar { outputQueue = outQ
, maxYieldLimit = Nothing
, outputDoorBell = outQMv
, readOutputQ = readOutputQPar sv
, postProcess = allThreadsDone sv
, workerThreads = running
, workLoop = undefined
, enqueue = undefined
, isWorkDone = undefined
, needDoorBell = undefined
, svarStyle = ParallelVar
, workerCount = active
, accountThread = modifyThread sv
#ifdef DIAGNOSTICS
, aheadWorkQueue = undefined
, outputHeap = undefined
, totalDispatches = disp
, maxWorkers = maxWrk
, maxOutQSize = maxOq
, maxHeapSize = maxHs
, maxWorkQSize = maxWq
#endif
}
in return sv
where
readOutputQPar sv = liftIO $ do
withDBGMVar sv "readOutputQPar: doorbell" $ takeMVar (outputDoorBell sv)
readOutputQRaw sv >>= return . fst
sendWorker :: MonadAsync m => SVar t m a -> t m a -> m (SVar t m a)
sendWorker sv m = do
-- Note: We must have all the work on the queue before sending the
-- pushworker, otherwise the pushworker may exit before we even get a
-- chance to push.
liftIO $ enqueue sv m
pushWorker sv
return sv
{-# INLINABLE newAheadVar #-}
newAheadVar :: MonadAsync m
=> State t m a
-> t m a
-> ( State t m a
-> IORef ([t m a], Int)
-> IORef (Heap (Entry Int (AheadHeapEntry t m a)), Int)
-> m ())
-> m (SVar t m a)
newAheadVar st m wloop = do
sv <- liftIO $ getAheadSVar st wloop
sendWorker sv m
{-# INLINABLE newParallelVar #-}
newParallelVar :: MonadAsync m => m (SVar t m a)
newParallelVar = liftIO $ getParallelSVar
-- XXX this errors out for Parallel/Ahead SVars
-- | Write a stream to an 'SVar' in a non-blocking manner. The stream can then
-- be read back from the SVar using 'fromSVar'.
toStreamVar :: MonadAsync m => SVar t m a -> t m a -> m ()
toStreamVar sv m = do
liftIO $ (enqueue sv) m
done <- allThreadsDone sv
-- XXX This is safe only when called from the consumer thread or when no
-- consumer is present. There may be a race if we are not running in the
-- consumer thread.
when done $ pushWorker sv