{-| Module : Control.Concurrent.Bag.Basic Description : Low-level bag of tasks implementation using STM Copyright : (c) Bastian Holst, 2014 License : BSD3 Maintainer : bastianholst@gmx.de Stability : experimental Portability : POSIX This is a low-level interface to the bag of tasks, which should only be used if it is not possible to use "Control.Concurrent.Bag.Safe" for your task. This module allows more control but it also requires a deeper knowledge about the implementation. -} module Control.Concurrent.Bag.Basic ( Bag () , newBag , newBag_ , addEval , addTask , getResult , writeResult , terminateBag , noMoreTasks , module Control.Concurrent.Bag.TaskBufferSTM ) where import Data.Maybe (fromJust, isJust) import Control.Concurrent ( ThreadId , forkIO , forkIOWithUnmask , killThread , getNumCapabilities , myThreadId ) import qualified Control.Concurrent.STM as STM (atomically) import Control.Concurrent.STM ( STM , TVar , TChan , newTVar , newTChan , writeTChan , readTChan , writeTVar , tryReadTChan , isEmptyTChan , retry , readTVar ) import Control.Monad.Writer import qualified Data.Map.Strict as Map import Control.Concurrent.Bag.TaskBufferSTM -- | The bag of tasks type. data Bag r = Bag { workerStates :: Map.Map ThreadId (TVar Bool) , taskBuffers :: Either (TaskBufferSTM (IO (Maybe r))) (Map.Map ThreadId (TaskBufferSTM (IO (Maybe r)))) , resultChan :: TChan r , terminateVar :: TVar Bool , moreTasksVar :: TVar Bool } -- | Get a list of all task buffers from a 'Bag' bufferList :: Bag r -> [TaskBufferSTM (IO (Maybe r))] bufferList bag = case taskBuffers bag of Left b -> [b] Right bs -> Map.elems bs -- | Build and return a new bag of tasks with a default number of worker -- threads. newBag_ :: MonadIO m => BufferType -- ^ type of the used buffer(s) -> Maybe (SplitFunction r) -- ^ Possible split function -- -- If the function is given, we will create -- a bag with one buffer per worker -- reducing the communication between the -- workers. -- buffer per thread -> m (Bag r) newBag_ buf split = (liftIO getNumCapabilities) >>= (return . (max 1)) >>= newBag buf split -- | Build and return a new bag of tasks newBag :: MonadIO m => BufferType -- ^ type of the used buffer(s) -> Maybe (SplitFunction r) -- ^ Possible split function -- -- If the function is given, we will create -- a bag with one buffer per worker -- reducing the communication between the -- workers. -> Int -- ^ number of threads -> m (Bag r) newBag buf split n = do results <- atomically $ newTChan let newBuffer = case buf of { Queue -> newChanBufferSTM; Stack -> newStackBufferSTM } buffers <- cycle `liftM` if isJust split then replicateM n $ atomically $ newBuffer else atomically $ do buffer <- newBuffer return [buffer] let splitf = case split of Nothing -> undefined Just s -> s (states, bufferMap) <- execWriterT $ mapM (\i -> do let ownBuffer = buffers !! i foreignBuffers = if isJust split then take n $ drop i buffers else [] state <- atomically $ newTVar False tid <- liftIO $ forkIOWithUnmask $ \unmask -> unmask $ worker splitf state ownBuffer foreignBuffers results tell $ (Map.insert tid state Map.empty, Map.insert tid ownBuffer Map.empty)) [0..n-1] terminated <- atomically $ newTVar False moreTasks <- atomically $ newTVar True let eitherBuffer = case split of Nothing -> Left $ head buffers Just _ -> Right bufferMap return $ Bag states eitherBuffer results terminated moreTasks -- | The main function of the worker thread. worker :: SplitFunction r -- ^ split function -> TVar Bool -- ^ worker state, running? -> TaskBufferSTM (IO (Maybe r)) -- ^ own buffer -> [TaskBufferSTM (IO (Maybe r))] -- ^ foreign buffers -> TChan r -- ^ result channel -> IO () worker split state ownBuffer foreignBuffers results = forever $ do atomically $ writeTVar state False task <- atomically $ do mTask <- tryReadBufferSTM ownBuffer writeTVar state True case mTask of Just task -> return task Nothing -> splitBuffers split ownBuffer foreignBuffers result <- task when (isJust result) $ atomically $ writeTChan results $ fromJust result where -- | Find a full buffer to split. splitBuffers :: SplitFunction r -> TaskBufferSTM (IO (Maybe r)) -> [TaskBufferSTM (IO (Maybe r))] -> STM (IO (Maybe r)) splitBuffers _ _ [] = retry splitBuffers split own (f:fs) = do isEmpty <- isEmptyBufferSTM f if isEmpty then splitBuffers split own fs else split own f -- | Add a task to the given bag of tasks. addTask :: MonadIO m => Bag r -> IO (Maybe r) -> m () addTask bag task = do buffer <- case taskBuffers bag of Left buffer -> return buffer Right bufferMap -> do tid <- liftIO myThreadId return $ Map.findWithDefault (head $ Map.elems bufferMap) tid bufferMap atomically $ writeBufferSTM buffer task -- | Add the evaluation of a haskell expression. -- -- The given expression will be evaluated to weak head normal form. addEval :: MonadIO m => Bag r -> r -> m () addEval bag e = addTask bag (e `seq` return $ Just e) -- | Write back a result to be read by 'getResult'. writeResult :: MonadIO m => Bag r -> r -> m () writeResult bag x = atomically $ writeTChan (resultChan bag) x -- | Tell the bag that there will be no more tasks from the outside, -- however, queued tasks may still add new tasks. noMoreTasks :: MonadIO m => Bag r -> m () noMoreTasks bag = atomically $ writeTVar (moreTasksVar bag) False -- | Get the next result written by 'writeResult'. getResult :: MonadIO m => Bag r -> m (Maybe r) getResult bag = atomically $ do result <- tryReadTChan (resultChan bag) case result of Just r -> return $ Just r Nothing -> do terminated <- readTVar $ terminateVar bag unless terminated $ do moreTasks <- readTVar $ moreTasksVar bag when moreTasks retry noTasks <- and `liftM` (mapM isEmptyBufferSTM (bufferList bag)) unless noTasks retry mapM_ (\tvar -> readTVar tvar >>= (flip when) retry) (Map.elems $ workerStates bag) return Nothing -- | Terminates all threads running in the thread pool. 'terminateBag' is -- non-blocking and therefore does not guarantee that all threads are -- terminated after its executions. Additionally it is not guaranteed -- that the treads are terminated promptly. It is implemented using -- asynchronous exceptions and therefore it terminates a thread once it uses -- a /safe point/. A safe point is where memory allocation occurs. Although -- most tasks will have some point of memory allocation now and the, there -- are loops that will never reach a safe point. terminateBag :: MonadIO m => Bag r -> m () terminateBag bag = do atomically $ writeTVar (terminateVar bag) True atomically $ writeTVar (moreTasksVar bag) False Map.foldrWithKey terminateThread (return ()) $ workerStates bag where terminateThread :: MonadIO m => ThreadId -> a -> m () -> m () terminateThread tid = -- killThread is blocking and exceptions are only received on safe points, -- memory allocations. However some calculations never see a safe point, -- but we want to continue here. const (liftIO (forkIO $ killThread tid) >>) -- Helper functions -- atomically :: MonadIO m => STM a -> m a atomically = liftIO . STM.atomically