module Control.Concurrent.Bag.Basic
( Bag
, newBag_
, addEval
, addTask
, getResult
, terminateBag
, noMoreTasks
, module Control.Concurrent.Bag.TaskBuffer )
where
import Data.Maybe (fromJust, isJust)
import Control.Concurrent
( ThreadId
, forkIO
, forkIOWithUnmask
, killThread
, getNumCapabilities
, myThreadId )
import qualified Control.Concurrent.STM as STM (atomically)
import Control.Concurrent.STM
( STM
, TVar
, TChan
, newTVar
, newTChan
, writeTChan
, readTChan
, writeTVar
, tryReadTChan
, isEmptyTChan
, retry
, readTVar )
import Control.Monad.Writer
import qualified Data.Map.Strict as Map
import Control.Concurrent.Bag.TaskBuffer
data Bag r = Bag {
workerStates :: Map.Map ThreadId (TVar Bool)
, taskBuffers :: Either (TaskBufferSTM (IO (Maybe r))) (Map.Map ThreadId (TaskBufferSTM (IO (Maybe r))))
, resultChan :: TChan r
, terminateVar :: TVar Bool
, moreTasksVar :: TVar Bool
}
bufferList :: Bag r -> [TaskBufferSTM (IO (Maybe r))]
bufferList bag =
case taskBuffers bag of
Left b -> [b]
Right bs -> Map.elems bs
newBag_ :: MonadIO m =>
BufferType
-> Maybe (SplitFunction r)
-> m (Bag r)
newBag_ buf split =
(liftIO getNumCapabilities) >>= (return . (max 1)) >>= newBag buf split
newBag :: MonadIO m =>
BufferType
-> Maybe (SplitFunction r)
-> Int
-> m (Bag r)
newBag buf split n = do
results <- atomically $ newTChan
let newBuffer = case buf of { Queue -> newChanBuffer; Stack -> newStackBuffer }
buffers <- cycle `liftM` if isJust split
then replicateM n $ atomically $ newBuffer
else atomically $ do
buffer <- newBuffer
return [buffer]
let splitf =
case split of
Nothing -> undefined
Just s -> s
(states, bufferMap) <-
execWriterT $ mapM (\i -> do
let ownBuffer = buffers !! i
foreignBuffers = if isJust split
then take n $ drop i buffers
else []
state <- atomically $ newTVar False
tid <- liftIO $ forkIOWithUnmask $ \unmask -> unmask $ worker splitf state ownBuffer foreignBuffers results
tell $ (Map.insert tid state Map.empty, Map.insert tid ownBuffer Map.empty)) [0..n1]
terminated <- atomically $ newTVar False
moreTasks <- atomically $ newTVar True
let eitherBuffer =
case split of
Nothing -> Left $ head buffers
Just _ -> Right bufferMap
return $ Bag states eitherBuffer results terminated moreTasks
worker :: SplitFunction r
-> TVar Bool
-> TaskBufferSTM (IO (Maybe r))
-> [TaskBufferSTM (IO (Maybe r))]
-> TChan r
-> IO ()
worker split state ownBuffer foreignBuffers results = forever $ do
atomically $ writeTVar state False
task <- atomically $ do
mTask <- tryReadBufferSTM ownBuffer
writeTVar state True
case mTask of
Just task -> return task
Nothing -> splitBuffers split ownBuffer foreignBuffers
result <- task
when (isJust result) $ atomically $ writeTChan results $ fromJust result
where
splitBuffers :: SplitFunction r
-> TaskBufferSTM (IO (Maybe r))
-> [TaskBufferSTM (IO (Maybe r))]
-> STM (IO (Maybe r))
splitBuffers _ _ [] = retry
splitBuffers split own (f:fs) = do
isEmpty <- isEmptyBufferSTM f
if isEmpty
then splitBuffers split own fs
else split own f
addTask :: MonadIO m => Bag r -> IO (Maybe r) -> m ()
addTask bag task = do
buffer <- case taskBuffers bag of
Left buffer -> return buffer
Right bufferMap -> do
tid <- liftIO myThreadId
return $ Map.findWithDefault (head $ Map.elems bufferMap) tid bufferMap
atomically $ writeBufferSTM buffer task
addEval :: MonadIO m => Bag r -> r -> m ()
addEval bag e =
addTask bag (e `seq` return $ Just e)
noMoreTasks :: MonadIO m => Bag r -> m ()
noMoreTasks bag =
atomically $ writeTVar (moreTasksVar bag) False
getResult :: MonadIO m => Bag r -> m (Maybe r)
getResult bag = atomically $ do
result <- tryReadTChan (resultChan bag)
case result of
Just r -> return $ Just r
Nothing -> do
terminated <- readTVar $ terminateVar bag
unless terminated $ do
moreTasks <- readTVar $ moreTasksVar bag
when moreTasks retry
noTasks <- and `liftM` (mapM isEmptyBufferSTM (bufferList bag))
unless noTasks retry
mapM_ (\tvar -> readTVar tvar >>= (flip when) retry)
(Map.elems $ workerStates bag)
return Nothing
terminateBag :: MonadIO m => Bag r -> m ()
terminateBag bag = do
atomically $ writeTVar (terminateVar bag) True
atomically $ writeTVar (moreTasksVar bag) False
Map.foldrWithKey terminateThread (return ()) $ workerStates bag
where
terminateThread :: MonadIO m => ThreadId -> a -> m () -> m ()
terminateThread tid =
const (liftIO (forkIO $ killThread tid) >>)
atomically :: MonadIO m => STM a -> m a
atomically = liftIO . STM.atomically