-- | This is a low-level interface to the bag of tasks. This implementation does -- not use stm in contrast to the other implementation. This implementation -- is provided for performance comparision. This implementation should only be -- used if it is not possible to use /Control.Concurrent.Bag.SafeConcurrent/ -- for your task. This module allows more control but it also requires a deeper -- knowledge about the implementation. module Control.Concurrent.Bag.Concurrent ( Bag , newBag_ , addEval , addTask , getResult , writeResult , terminateBag , noMoreTasks , module Control.Concurrent.Bag.TaskBuffer ) where import Data.Maybe (fromJust, isJust) import Control.Concurrent ( ThreadId , forkIO , forkIOWithUnmask , killThread , getNumCapabilities , myThreadId ) import Control.Concurrent.Chan import Control.Concurrent.MVar import Control.Monad.Writer import Control.Concurrent.Bag.TaskBuffer data Result a = NoResult | Result a | NoMoreTasks data Bag r = Bag { workers :: [ThreadId] , taskBuffer :: TaskBuffer (IO (Maybe r)) , resultChan :: Chan (Result r) , taskCountVar :: MVar Int , moreTasksVar :: MVar Bool , waitingOneMoreTasksVar :: MVar Int } -- | Create a new bag and use the number of capabilities as the thread count. newBag_ :: (MonadIO m) => BufferType -> m (Bag r) newBag_ buf = (liftIO getNumCapabilities) >>= (return . (max 1)) >>= newBag buf -- | Create a new bag of tasks newBag :: (MonadIO m) => BufferType -> Int -- ^ Number of threads -> m (Bag r) newBag buf n = do results <- liftIO newChan buffer <- liftIO $ case buf of { Queue -> newChanBuffer; Stack -> newStackBuffer } workers <- execWriterT $ mapM (\i -> do tid <- liftIO $ forkIOWithUnmask $ \unmask -> unmask $ worker buffer results tell $ [tid]) [0..n-1] taskCount <- liftIO $ newMVar 0 moreTasks <- liftIO $ newMVar True waitingOnMoreTasks <- liftIO $ newMVar 0 return $ Bag workers buffer results taskCount moreTasks waitingOnMoreTasks worker :: TaskBuffer (IO (Maybe r)) -- ^ own buffer -> Chan (Result r) -- ^ result channel -> IO () worker ownBuffer results = forever $ do task <- readBuffer ownBuffer result <- task case result of Nothing -> writeChan results NoResult Just r -> writeChan results (Result r) addTask :: (MonadIO m) => Bag r -> IO (Maybe r) -> m () addTask bag task = liftIO $ do modifyMVar_ (taskCountVar bag) (return . (+1)) writeBuffer (taskBuffer bag) task addEval :: (MonadIO m) => Bag r -> r -> m () addEval bag e = addTask bag (e `seq` return $ Just e) -- | Tell the bag that there will be no more tasks from the outside, -- however, queued tasks may still add new tasks. noMoreTasks :: MonadIO m => Bag r -> m () noMoreTasks bag = liftIO $ do modifyMVar_ (moreTasksVar bag) (return . const False) writeChan (resultChan bag) NoMoreTasks writeResult :: MonadIO m => Bag r -> r -> m () writeResult bag x = liftIO $ do modifyMVar_ (taskCountVar bag) (return . (+1)) writeChan (resultChan bag) (Result x) getResult :: (MonadIO m) => Bag r -> m (Maybe r) getResult bag = liftIO $ do moreTasks <- readMVar (moreTasksVar bag) if moreTasks then do taskCount <- modifyMVar (taskCountVar bag) (\c -> return (c-1, c)) modifyMVar_ (waitingOneMoreTasksVar bag) (return . (+1)) result <- readChan (resultChan bag) waiting <- modifyMVar (waitingOneMoreTasksVar bag) (\c -> return (c-1,c-1)) case result of NoMoreTasks -> do -- when other threads wait on this message, send it again when (waiting > 0) (writeChan (resultChan bag) NoMoreTasks) if taskCount <= 0 then modifyMVar_ (taskCountVar bag) (return . (+1)) >> getResult bag else readOneResult NoResult -> getResult bag Result v -> return $ Just v else do taskCount <- takeMVar (taskCountVar bag) if taskCount <= 0 then do putMVar (taskCountVar bag) taskCount return Nothing else do putMVar (taskCountVar bag) (taskCount - 1) readOneResult where readOneResult = do result <- readChan (resultChan bag) case result of NoMoreTasks -> do -- when other threads wait on this message, send it again -- the value in the MVar does not increase anymore, because the 'moreTasksVar' -- is set to False before this message is sent. After this, no new threads -- will start to wait on this message. waiting <- readMVar (waitingOneMoreTasksVar bag) when (waiting > 0) (writeChan (resultChan bag) NoMoreTasks) readOneResult NoResult -> getResult bag Result v -> return $ Just v -- | Terminates all threads running in the thread pool. 'terminateBag' is -- non-blocking and therefore does not guarantee that all threads are -- terminated after its executions. Additionally it is not guaranteed -- that the treads are terminated promptly. It is implemented using -- asynchronous exceptions and therefore it terminates a thread once it uses -- a /safe point/. A safe point is where memory allocation occurs. Although -- most tasks will have some point of memory allocation now and the, there -- are loops that will never reach a safe point. terminateBag :: MonadIO m => Bag r -> m () terminateBag bag = do noMoreTasks bag liftIO $ mapM_ terminateThread $ workers bag where terminateThread :: ThreadId -> IO () terminateThread tid = -- killThread is blocking and exceptions are only received on safe points, -- memory allocations. However some calculations never see a safe point, -- but we want to continue here. void (forkIO $ killThread tid)