module Control.Concurrent.Pool
(Task (..),
Pool,
newPool, newPoolIO,
isPoolWaiting,
queue,
noMoreTasks, noMoreTasksIO,
readResult,
resultsReader,
waitFor, waitForIO,
waitForTasks,
terminatePool, terminatePoolIO
) where
import Control.Monad
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Trans
import Control.Concurrent
import Control.Concurrent.STM
import qualified Data.Map as M
class (Monad m, MonadIO m) => Task m a r where
runTask :: m r -> a -> IO r
instance Task IO () r where
runTask io () = io
instance Task (ReaderT a IO) a r where
runTask m a = runReaderT m a
instance Task (StateT a IO) a r where
runTask m a = evalStateT m a
data (Task m a r) => Pool m a r = Pool {
poolThreads :: [ThreadId],
threadsStates :: TVar (M.Map ThreadId Bool),
poolQueue :: TChan (Integer, m r, a),
returnResults :: Bool,
poolResults :: TChan (Integer, r),
poolLastTask :: TVar Integer,
poolDone :: TVar Bool,
doneTasks :: TVar [Integer]
}
newPool :: Task m a r
=> Int
-> Bool
-> m (Pool m a r)
newPool n ret = liftIO $ newPoolIO n ret
newPoolIO :: Task m a r
=> Int
-> Bool
-> IO (Pool m a r)
newPoolIO n ret = do
chan <- atomically newTChan
results <- atomically newTChan
counter <- atomically $ newTVar 0
states <- atomically $ newTVar M.empty
done <- atomically $ newTVar False
el <- atomically $ newTVar []
threads <- replicateM n $ forkIO (worker states chan ret results el)
return $ Pool {
poolThreads = threads,
threadsStates = states,
poolQueue = chan,
returnResults = ret,
poolResults = results,
poolLastTask = counter,
poolDone = done,
doneTasks = el
}
setMyState :: TVar (M.Map ThreadId Bool) -> Bool -> IO ()
setMyState var st = do
th <- myThreadId
atomically $ do
states <- readTVar var
let states' = M.insert th st states
writeTVar var states'
worker :: Task m a r
=> TVar (M.Map ThreadId Bool)
-> TChan (Integer, m r, a)
-> Bool
-> TChan (Integer, r)
-> TVar [Integer]
-> IO ()
worker var chan ret res doneVar = forever $ do
setMyState var False
(n, m, x) <- atomically $ readTChan chan
setMyState var True
putStrLn $ ">>> Starting task #" ++ show n
y <- runTask m x
putStrLn $ ">>> Task #" ++ show n ++ " done."
when ret $
atomically $ do
writeTChan res (n, y)
done <- readTVar doneVar
writeTVar doneVar (n: done)
isPoolWaiting :: Task m a r
=> Pool m a r
-> IO Bool
isPoolWaiting pool = atomically $ do
m <- readTVar (threadsStates pool)
let states = M.elems m
return $ all (== False) states
queue :: Task m a r
=> Pool m a r
-> m r
-> a
-> m Integer
queue pool m x = do
let counter = poolLastTask pool
count <- liftIO $ atomically $ do
was <- readTVar counter
writeTVar counter (was + 1)
return (was + 1)
liftIO $ atomically $ writeTChan (poolQueue pool) (count, m, x)
return count
noMoreTasks :: Task m a r => Pool m a r -> m ()
noMoreTasks pool =
liftIO $ noMoreTasksIO pool
noMoreTasksIO :: Task m a r => Pool m a r -> IO ()
noMoreTasksIO pool =
atomically $ writeTVar (poolDone pool) True
readResult :: Task m a r
=> Pool m a r
-> m (Integer, r)
readResult pool =
liftIO $ atomically $ readTChan (poolResults pool)
resultsReader :: Task m a r => Pool m a r -> (Integer -> r -> IO b) -> IO ()
resultsReader pool fn = forever $ do
(n, r) <- atomically $ readTChan (poolResults pool)
fn n r
waitFor :: Task m a r => Pool m a r -> m ()
waitFor pool = do
empty <- liftIO $ do
threadDelay 1000
atomically $ do
m <- readTVar (threadsStates pool)
let states = M.elems m
allWaiting = all (== False) states
done <- readTVar (poolDone pool)
return (done && allWaiting)
when (not empty) $
waitFor pool
waitForIO :: Task m a r => Pool m a r -> IO ()
waitForIO pool = do
threadDelay 1000
empty <- atomically $ do
m <- readTVar (threadsStates pool)
let states = M.elems m
allWaiting = all (== False) states
done <- readTVar (poolDone pool)
return (done && allWaiting)
when (not empty) $
waitForIO pool
waitForTasks :: Task m a r => Pool m a r -> [Integer] -> m ()
waitForTasks pool tasks = do
done <- liftIO $ atomically $ readTVar (doneTasks pool)
liftIO $ putStrLn $ ">>> Done tasks: " ++ show done
let tasks' = filter (\t -> not (t `elem` done)) tasks
liftIO $ putStrLn $ ">>> Tasks to wait: " ++ show tasks'
if null tasks'
then return ()
else do
liftIO $ threadDelay 1000
waitForTasks pool tasks'
terminatePool :: Task m a r => Pool m a r -> m ()
terminatePool pool = liftIO $ terminatePoolIO pool
terminatePoolIO :: Task m a r => Pool m a r -> IO ()
terminatePoolIO pool =
mapM_ killThread (poolThreads pool)