-- NOTE: The first part of this module is based on Chapter 24 from -- "Real World Haskell Programming" cf: -- (http://book.realworldhaskell.org/read/concurrent-and-multicore-programming.html) module Control.Concurrent.ManagedThreads ( ThreadManager , newManager , forkManaged , getStatus , waitFor , waitAll , nParSequenceIO , nParSequenceIO_ , nParMapIO , nParMapIO_ , nParCmd_ , parCmd_ ) where import Data.Ord (comparing) import Data.List (sortBy) import qualified Data.Map as M import Control.Monad import Control.Concurrent import Control.Exception (IOException, try) import GHC.Conc (numCapabilities) ------------------------------------------------------------------------------ -- Thread Manager ------------------------------------------------------------------------------ data ThreadStatus = Running | Finished -- terminated normally | Threw IOException -- killed by uncaught exception deriving (Show) newtype ThreadManager = Mgr (MVar (M.Map ThreadId (MVar ThreadStatus))) deriving (Eq) -- | Create a new thread manager. newManager :: IO ThreadManager newManager = Mgr `fmap` newMVar M.empty -- | Create a new managed thread. forkManaged :: ThreadManager -> IO () -> IO ThreadId forkManaged (Mgr mgr) body = modifyMVar mgr $ \m -> do state <- newEmptyMVar tid <- forkIO $ do result <- try body putMVar state (either Threw (const Finished) result) return (M.insert tid state m, tid) -- | Immediately return the status of a managed thread. getStatus :: ThreadManager -> ThreadId -> IO (Maybe ThreadStatus) getStatus (Mgr mgr) tid = modifyMVar mgr $ \m -> case M.lookup tid m of Nothing -> return (m, Nothing) Just st -> tryTakeMVar st >>= \mst -> case mst of Nothing -> return (m, Just Running) Just sth -> return (M.delete tid m, Just sth) -- | Block until a specific managed thread terminates. waitFor :: ThreadManager -> ThreadId -> IO (Maybe ThreadStatus) waitFor (Mgr mgr) tid = join . modifyMVar mgr $ \m -> return $ case M.updateLookupWithKey (\_ _ -> Nothing) tid m of (Nothing, _) -> (m, return Nothing) (Just st, m') -> (m', Just `fmap` takeMVar st) -- | Block until all managed threads terminate. waitAll :: ThreadManager -> IO () waitAll (Mgr mgr) = modifyMVar mgr elems >>= mapM_ takeMVar where elems m = return (M.empty, M.elems m) ------------------------------------------------------------------------------ -- Parallel Sequence ------------------------------------------------------------------------------ -- | Extract the head of the list in the MVar if possible. takeHead :: MVar [a] -> IO (Maybe a) takeHead v = modifyMVar v (return . extract) where extract [] = ([], Nothing) extract (x:xs) = (xs, Just x) -- | Do a parallel sequencing of a list of IO commands using n worker threads -- and gather their results in a list again. nParSequenceIO :: Int -> [IO a] -> IO [a] nParSequenceIO n ios = do inMv <- newMVar $ zip [(1::Int)..] ios resMv <- newMVar [] mgr <- newManager mapM_ (forkManaged mgr) (replicate (max 1 n) $ worker inMv resMv) waitAll mgr (map snd . sortBy (comparing fst)) `liftM` takeMVar resMv where worker inMv resMv = do nextJob <- takeHead inMv case nextJob of Nothing -> return () Just (i, io) -> do out <- io modifyMVar_ resMv (\res -> return $ (i,out) : res) worker inMv resMv -- | Do a parallel sequencing of a list of IO commands using n worker threads. nParSequenceIO_ :: Int -> [IO a] -> IO () nParSequenceIO_ n ios = do inMv <- newMVar ios mgr <- newManager mapM_ (forkManaged mgr) (replicate n $ worker inMv) waitAll mgr where worker inMv = do nextJob <- takeHead inMv case nextJob of Nothing -> return () Just io -> io >> worker inMv -- | Do a parallel map of an IO cmd over a list using n worker threads. nParMapIO :: Int -> (a -> IO b) -> [a] -> IO [b] nParMapIO n f = nParSequenceIO n . map f -- | Do a parallel map of an IO cmd over a list using n worker threads while -- ignoring the results. nParMapIO_ :: Int -> (a -> IO b) -> [a] -> IO () nParMapIO_ n f = nParSequenceIO_ n . map f -- | Parallel execution of a command using n worker threads. The channel -- argument can be used to report exactly one (!) progress value to the display -- function. -- NOTE: If the executed command is blocking, then you need to use the threaded -- runtime nParCmd_ :: Int -> (Int -> Int -> b -> IO ()) -> [Chan b -> IO a] -> IO () nParCmd_ nThreads display cmds = do chan <- newChan mgr <- newManager _ <- forkManaged mgr (displayThread chan 1) _ <- forkManaged mgr (nParMapIO_ nThreads (\cmd -> cmd chan) cmds) waitAll mgr where nCmds = length cmds displayThread ch i | nCmds < i = do return () | otherwise = do msg <- readChan ch display nCmds i msg displayThread ch (succ i) -- | Like @nParCmd_@ but uses the number of processing cores+1 as a default for -- the number of worker treads. You can change their number by adding to the -- command line of a program linked with the threaded library: -- +RTS -N -RTS parCmd_ :: (Int -> Int -> b -> IO ()) -> [Chan b -> IO a] -> IO () parCmd_ display cmds = nParCmd_ (numCapabilities+1) display cmds