module Control.Concurrent.Bag.Concurrent
( Bag
, newBag
, newBag_
, addEval
, addTask
, getResult
, writeResult
, terminateBag
, noMoreTasks
, module Control.Concurrent.Bag.TaskBuffer )
where
import Data.Maybe (fromJust, isJust)
import Control.Concurrent
( ThreadId
, forkIO
, forkIOWithUnmask
, killThread
, getNumCapabilities
, myThreadId )
import Control.Concurrent.Chan
import Control.Concurrent.MVar
import Control.Monad.Writer
import Control.Concurrent.Bag.TaskBuffer
data Result a = NoResult | Result a | NoMoreTasks
data Bag r = Bag {
workers :: [ThreadId]
, taskBuffer :: TaskBuffer (IO (Maybe r))
, resultChan :: Chan (Result r)
, taskCountVar :: MVar Int
, moreTasksVar :: MVar Bool
, waitingOneMoreTasksVar :: MVar Int
}
newBag_ :: (MonadIO m) =>
BufferType
-> m (Bag r)
newBag_ buf =
(liftIO getNumCapabilities) >>= (return . (max 1)) >>= newBag buf
newBag :: (MonadIO m) =>
BufferType
-> Int
-> m (Bag r)
newBag buf n = do
results <- liftIO newChan
buffer <- liftIO $ case buf of { Queue -> newChanBuffer; Stack -> newStackBuffer }
workers <-
execWriterT $ mapM (\i -> do
tid <- liftIO $ forkIOWithUnmask $ \unmask -> unmask $ worker buffer results
tell $ [tid]) [0..n1]
taskCount <- liftIO $ newMVar 0
moreTasks <- liftIO $ newMVar True
waitingOnMoreTasks <- liftIO $ newMVar 0
return $ Bag workers buffer results taskCount moreTasks waitingOnMoreTasks
worker :: TaskBuffer (IO (Maybe r))
-> Chan (Result r)
-> IO ()
worker ownBuffer results = forever $ do
task <- readBuffer ownBuffer
result <- task
case result of
Nothing -> writeChan results NoResult
Just r -> writeChan results (Result r)
addTask :: (MonadIO m) => Bag r -> IO (Maybe r) -> m ()
addTask bag task = liftIO $ do
modifyMVar_ (taskCountVar bag) (return . (+1))
writeBuffer (taskBuffer bag) task
addEval :: (MonadIO m) => Bag r -> r -> m ()
addEval bag e =
addTask bag (e `seq` return $ Just e)
noMoreTasks :: MonadIO m => Bag r -> m ()
noMoreTasks bag = liftIO $ do
modifyMVar_ (moreTasksVar bag) (return . const False)
writeChan (resultChan bag) NoMoreTasks
writeResult :: MonadIO m => Bag r -> r -> m ()
writeResult bag x = liftIO $ do
modifyMVar_ (taskCountVar bag) (return . (+1))
writeChan (resultChan bag) (Result x)
getResult :: (MonadIO m) => Bag r -> m (Maybe r)
getResult bag = liftIO $ do
moreTasks <- readMVar (moreTasksVar bag)
if moreTasks
then do
taskCount <- modifyMVar (taskCountVar bag) (\c -> return (c1, c))
modifyMVar_ (waitingOneMoreTasksVar bag) (return . (+1))
result <- readChan (resultChan bag)
waiting <- modifyMVar (waitingOneMoreTasksVar bag) (\c -> return (c1,c1))
case result of
NoMoreTasks -> do
when (waiting > 0) (writeChan (resultChan bag) NoMoreTasks)
if taskCount <= 0
then modifyMVar_ (taskCountVar bag) (return . (+1)) >> getResult bag
else readOneResult
NoResult -> getResult bag
Result v -> return $ Just v
else do
taskCount <- takeMVar (taskCountVar bag)
if taskCount <= 0
then do
putMVar (taskCountVar bag) taskCount
return Nothing
else do
putMVar (taskCountVar bag) (taskCount 1)
readOneResult
where
readOneResult = do
result <- readChan (resultChan bag)
case result of
NoMoreTasks -> do
waiting <- readMVar (waitingOneMoreTasksVar bag)
when (waiting > 0) (writeChan (resultChan bag) NoMoreTasks)
readOneResult
NoResult -> getResult bag
Result v -> return $ Just v
terminateBag :: MonadIO m => Bag r -> m ()
terminateBag bag = do
noMoreTasks bag
liftIO $ mapM_ terminateThread $ workers bag
where
terminateThread :: ThreadId -> IO ()
terminateThread tid =
void (forkIO $ killThread tid)