-- | This is a low-level interface to the bag of tasks, which should only be -- used if it is not possible to use /Control.Concurrent.Bag.Safe/ for your -- task. This module allows more control but it also requires a deeper -- knowledge about the implementation. module Control.Concurrent.Bag.Basic ( Bag , newBag_ , addEval , addTask , getResult , terminateBag , noMoreTasks , TaskBufferSTM (..) , SplitFunction , takeFirst ) where import Data.Maybe (fromJust, isJust) import Control.Concurrent ( ThreadId , forkIO , forkIOWithUnmask , killThread , getNumCapabilities , myThreadId ) import qualified Control.Concurrent.STM as STM (atomically) import Control.Concurrent.STM ( STM , TVar , TChan , newTVar , newTChan , writeTChan , readTChan , writeTVar , tryReadTChan , isEmptyTChan , retry , readTVar ) import Control.Monad.Writer import qualified Data.Map.Strict as Map import Control.Concurrent.Bag.TaskBuffer ( TaskBufferSTM (..) , SplitFunction , takeFirst ) data Bag b r = Bag { workerStates :: Map.Map ThreadId (TVar Bool) , taskBuffers :: Either (b (IO (Maybe r))) (Map.Map ThreadId (b (IO (Maybe r)))) , resultChan :: TChan r , terminateVar :: TVar Bool , moreTasksVar :: TVar Bool } bufferList :: Bag b r -> [b (IO (Maybe r))] bufferList bag = case taskBuffers bag of Left b -> [b] Right bs -> Map.elems bs newBag_ :: (MonadIO m, TaskBufferSTM b) => Maybe (SplitFunction b r) -> m (Bag b r) newBag_ split = (liftIO getNumCapabilities) >>= (return . (max 1)) >>= newBag split -- | Create a new bag of tasks newBag :: (MonadIO m, TaskBufferSTM b) => Maybe (SplitFunction b r) -- ^ split function -- if this function is given use one -- buffer per thread -> Int -- ^ Number of threads -> m (Bag b r) newBag split n = do results <- atomically $ newTChan buffers <- cycle `liftM` if isJust split then replicateM n $ atomically $ newBufferSTM else atomically $ do buffer <- newBufferSTM return [buffer] let splitf = case split of Nothing -> undefined Just s -> s (states, bufferMap) <- execWriterT $ mapM (\i -> do let ownBuffer = buffers !! i foreignBuffers = if isJust split then take n $ drop i buffers else [] state <- atomically $ newTVar False tid <- liftIO $ forkIOWithUnmask $ \unmask -> unmask $ worker splitf state ownBuffer foreignBuffers results tell $ (Map.insert tid state Map.empty, Map.insert tid ownBuffer Map.empty)) [0..n-1] terminated <- atomically $ newTVar False moreTasks <- atomically $ newTVar True let eitherBuffer = case split of Nothing -> Left $ head buffers Just _ -> Right bufferMap return $ Bag states eitherBuffer results terminated moreTasks worker :: TaskBufferSTM b => SplitFunction b r -- ^ split function -> TVar Bool -- ^ worker state, running? -> b (IO (Maybe r)) -- ^ own buffer -> [b (IO (Maybe r))] -- ^ foreign buffers -> TChan r -- ^ result channel -> IO () worker split state ownBuffer foreignBuffers results = forever $ do atomically $ writeTVar state False task <- atomically $ do mTask <- tryReadBufferSTM ownBuffer case mTask of Just task -> return task Nothing -> splitBuffers split ownBuffer foreignBuffers atomically $ writeTVar state True result <- task when (isJust result) $ atomically $ writeTChan results $ fromJust result where splitBuffers :: TaskBufferSTM b => SplitFunction b r -> b (IO (Maybe r)) -> [b (IO (Maybe r))] -> STM (IO (Maybe r)) splitBuffers _ _ [] = retry splitBuffers split own (f:fs) = do isEmpty <- isEmptyBufferSTM f if isEmpty then splitBuffers split own fs else split own f addTask :: (MonadIO m, TaskBufferSTM b) => Bag b r -> IO (Maybe r) -> m () addTask bag task = do buffer <- case taskBuffers bag of Left buffer -> return buffer Right bufferMap -> do tid <- liftIO myThreadId return $ Map.findWithDefault (head $ Map.elems bufferMap) tid bufferMap atomically $ writeBufferSTM buffer task addEval :: (MonadIO m, TaskBufferSTM b) => Bag b r -> r -> m () addEval bag e = addTask bag (e `seq` return $ Just e) -- | Tell the bag that there will be no more tasks from the outside, -- however, queued tasks may still add new tasks. noMoreTasks :: MonadIO m => Bag b r -> m () noMoreTasks bag = atomically $ writeTVar (moreTasksVar bag) False getResult :: (MonadIO m, TaskBufferSTM b) => Bag b r -> m (Maybe r) getResult bag = atomically $ do result <- tryReadTChan (resultChan bag) case result of Just r -> return $ Just r Nothing -> do terminated <- readTVar $ terminateVar bag unless terminated $ do moreTasks <- readTVar $ moreTasksVar bag when moreTasks retry noTasks <- and `liftM` (mapM isEmptyBufferSTM (bufferList bag)) unless noTasks retry mapM_ (\tvar -> readTVar tvar >>= (flip when) retry) (Map.elems $ workerStates bag) return Nothing -- | Terminates all threads running in the thread pool. 'terminateBag' is -- non-blocking and therefore does not guarantee that all threads are -- terminated after its executions. Additionally it is not guaranteed -- that the treads are terminated promptly. It is implemented using -- asynchronous exceptions and therefore it terminates a thread once it uses -- a /safe point/. A safe point is where memory allocation occurs. Although -- most tasks will have some point of memory allocation now and the, there -- are loops that will never reach a safe point. terminateBag :: MonadIO m => Bag b r -> m () terminateBag bag = do atomically $ writeTVar (terminateVar bag) True atomically $ writeTVar (moreTasksVar bag) False Map.foldrWithKey terminateThread (return ()) $ workerStates bag where terminateThread :: MonadIO m => ThreadId -> a -> m () -> m () terminateThread tid = -- killThread is blocking and exceptions are only received on safe points, -- memory allocations. However some calculations never see a safe point, -- but we want to continue here. const (liftIO (forkIO $ killThread tid) >>) -- Helper functions -- atomically :: MonadIO m => STM a -> m a atomically = liftIO . STM.atomically