-- | Thread pool implementation. module Development.Shake.Pool(Pool, addPool, blockPool, runPool) where import Control.Concurrent import Control.Exception hiding (blocked) import Development.Shake.Locks import qualified Data.HashSet as Set --------------------------------------------------------------------- -- SUPER QUEUE -- FIXME: The super queue should use randomness for the normal priority pile data SuperQueue a = SuperQueue [a] [a] newSuperQueue :: SuperQueue a newSuperQueue = SuperQueue [] [] enqueuePriority :: a -> SuperQueue a -> SuperQueue a enqueuePriority x (SuperQueue p n) = SuperQueue (x:p) n enqueue :: a -> SuperQueue a -> SuperQueue a enqueue x (SuperQueue p n) = SuperQueue p (x:n) dequeue :: SuperQueue a -> Maybe (a, SuperQueue a) dequeue (SuperQueue (p:ps) ns) = Just (p, SuperQueue ps ns) dequeue (SuperQueue [] (n:ns)) = Just (n, SuperQueue [] ns) dequeue (SuperQueue [] []) = Nothing --------------------------------------------------------------------- -- THREAD POOL {- Must keep a list of active threads, so can raise exceptions in a timely manner Must spawn a fresh thread to do blockPool If any worker throws an exception, must signal to all the other workers -} data Pool = Pool Int (Var (Maybe S)) (Barrier (Maybe SomeException)) data S = S {threads :: Set.HashSet ThreadId ,working :: Int -- threads which are actively working ,blocked :: Int -- threads which are blocked ,todo :: SuperQueue (IO ()) } emptyS :: S emptyS = S Set.empty 0 0 newSuperQueue -- | Given a pool, and a function that breaks the S invariants, restore them -- They are only allowed to touch working or todo step :: Pool -> (S -> S) -> IO () step pool@(Pool n var done) op = do let onVar act = modifyVar_ var $ maybe (return Nothing) act onVar $ \s -> do s <- return $ op s case dequeue (todo s) of Just (now, todo2) | working s < n -> do -- spawn a new worker t <- forkIO $ do t <- myThreadId res <- try now case res of Left e -> onVar $ \s -> do mapM_ killThread $ Set.toList $ Set.delete t $ threads s signalBarrier done $ Just e return Nothing Right _ -> step pool $ \s -> s{working = working s - 1, threads = Set.delete t $ threads s} return $ Just s{working = working s + 1, todo = todo2, threads = Set.insert t $ threads s} Nothing | working s == 0 && blocked s == 0 -> do signalBarrier done Nothing return Nothing _ -> return $ Just s -- | Add a new task to the pool addPool :: Pool -> IO a -> IO () addPool pool act = step pool $ \s -> s{todo = enqueue (act >> return ()) (todo s)} -- | A blocking action is being run while on the pool, yeild your thread. -- Should only be called by an action under addPool. blockPool :: Pool -> IO a -> IO a blockPool pool act = do step pool $ \s -> s{working = working s - 1, blocked = blocked s + 1} res <- act var <- newBarrier let act = do step pool $ \s -> s{working = working s + 1, blocked = blocked s - 1} signalBarrier var () step pool $ \s -> s{todo = enqueuePriority act $ todo s} waitBarrier var return res -- | Run all the tasks in the pool on the given number of works. -- If any thread throws an exception, the exception will be reraised. runPool :: Int -> (Pool -> IO ()) -> IO () -- run all tasks in the pool runPool n act = do s <- newVar $ Just emptyS let cleanup = modifyVar_ s $ \s -> do -- if someone kills our thread, make sure we kill our child threads case s of Just s -> mapM_ killThread $ Set.toList $ threads s Nothing -> return () return Nothing flip onException cleanup $ do res <- newBarrier let pool = Pool n s res addPool pool $ act pool res <- waitBarrier res case res of Nothing -> return () Just e -> throw e