{-# LANGUAGE DeriveDataTypeable, QuasiQuotes, Rank2Types, ScopedTypeVariables #-} -- FlexibleInstances, TypeSynonymInstances #-} -- | Module with the internal workhorse for the library, 'parallelTasks'. You only -- need to use this module if you want to alter 'ExtendedParTaskOpts', which allows -- you to redirect the logging output or store information about task timing. module Control.Concurrent.ParallelTasks.Base (ExtendedParTaskOpts(..), ParTaskOpts(..), SimpleParTaskOpts(..), TaskOutcome(Success, TookTooLong), defaultExtendedParTaskOpts, defaultParTaskOpts, parallelTasks) where import Control.Applicative ((<$), (<$>)) import Control.Concurrent (forkIO, getNumCapabilities, killThread, myThreadId, threadDelay) import Control.Concurrent.STM (TVar, atomically, newTVarIO, readTVar, retry, writeTVar) import Control.Concurrent.STM.TMVar (newTMVar, putTMVar, takeTMVar) import Control.Exception.Base (Exception, bracket_, evaluate, handle, onException, throwTo) import Control.Monad (replicateM_, when) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Maybe (isJust) import Data.String.Here.Interpolated (i) import Data.Time.Clock (UTCTime, addUTCTime, diffUTCTime, getCurrentTime) import Data.Time.Format (formatTime) import Data.Typeable import qualified Data.Vector.Mutable as V (IOVector, new, write) import System.IO (Handle, hFlush, hPutStrLn, stderr) import System.Locale (defaultTimeLocale) data SimpleParTaskOpts = SimpleParTaskOpts { -- | Number of worker threads to use. When this is Nothing, defaults to number of capabilities (see @numCapabilities@) numberWorkers :: Maybe Int, -- | How often to print the progress of the tasks. E.g. when Just 100, print a message roughly -- after the completion of every 100 tasks. printProgress :: Maybe Int, -- | How often to print an estimate of the estimated completion time. E.g. when Just 100, -- print an estimate after the completion of every 100 tasks. printEstimate :: Maybe Int } -- | Options controlling the general running of parallel tasks. The @m@ parameter is the monad (which must be an instance -- of 'MonadIO') in which the tasks will be run, and the @a@ parameter is the return value of the tasks. data ParTaskOpts m a = ParTaskOpts { -- | The simple options. simpleOpts :: SimpleParTaskOpts, -- | Function to use to run the @m@ monad on top of IO. The returned function is run at least once per worker, so should support -- being run multiple times in parallel, and should clean up after itself. Suitable instance for IO is simply @return id@. wrapWorker :: forall r. m (m r -> IO r), -- | When Just, the number of microseconds to let each task run for, before assuming it will -- not complete, and killing it off. In the case that the task is killed off, the second -- part of the pair is the value that will be stored in the vector. timeLimit :: Maybe (Integer, a) } -- | Advanced options controlling the behaviour of parallel tasks. The @m@ parameter -- is the monad that the tasks execute in, the @a@ parameter is the output value of the -- tasks, and the @b@ parameter is the type that is stored in the results array. It is -- common that either @b = a@ or @b = Maybe a@. data ExtendedParTaskOpts m a = ExtendedParTaskOpts { -- | Core options coreOpts :: ParTaskOpts m a, -- | Function that supplies a handle to an inner block to write messages to. -- To use stdout or stderr, you can just supply @($ stdout)@. To write to a file, -- use @withFile \"blah\" WriteMode@. printTo :: forall r. (Handle -> IO r) -> IO r, -- | Function used to store the outcome of the task. Arguments are (in order): -- -- * Time that the task took to complete (in seconds) -- -- * Index at which to store the result (same as index of the task in the original tasks list) -- -- * The outcome of the task -- -- If a String is returned, it is logged afterFinish :: Double -> Int -> TaskOutcome -> IO (Maybe String) } -- | Value indicating whether a task successfully completed, or was killed off for taking too long data TaskOutcome = Success | TookTooLong data TookTooLongException = TookTooLongException deriving (Show, Typeable) instance Exception TookTooLongException -- | A version of threadDelay that accommodates delays longer than @maxBound :: Int@. threadDelay' :: Integer -> IO () threadDelay' target | target > mx = threadDelay maxBound >> threadDelay' (target - mx) | otherwise = threadDelay (fromInteger target) where mx :: Integer mx = toInteger (maxBound :: Int) -- | Default extended options. Prints messages to stderr, and writes a message when a -- task is killed defaultExtendedParTaskOpts :: MonadIO m => ParTaskOpts m a -> ExtendedParTaskOpts m a defaultExtendedParTaskOpts opts = ExtendedParTaskOpts { coreOpts = opts, printTo = ($ stderr), afterFinish = printKill } where printKill _ n TookTooLong = return $ Just [i|*** Killed task ${n} for taking too long|] printKill _ _ _ = return Nothing -- | Default parallel task options. The number of workers defaults to the number of capabilities, -- with no time limit, and printing progress every 50 tasks and an estimated time every 200 defaultParTaskOpts :: ParTaskOpts IO a defaultParTaskOpts = ParTaskOpts { simpleOpts = SimpleParTaskOpts {numberWorkers = Nothing, printProgress = Just 50, printEstimate = Just 200 }, wrapWorker = return id, timeLimit = Nothing } -- | Runs the given set of computations in parallel, and once they are all finished, returns their results. -- Note that they won't all be run in parallel from the start; rather, a set of -- workers will be spawned that work their way through the (potentially large) set of jobs. parallelTasks :: MonadIO m => ExtendedParTaskOpts m a -> [m a] -> m (V.IOVector a) parallelTasks _ [] = liftIO $ V.new 0 parallelTasks opts tasks = wrapWorker (coreOpts opts) >>= \run -> liftIO $ printTo opts $ \h -> do let numTasks = length tasks numWorkers <- maybe getNumCapabilities return (numberWorkers $ simpleOpts $ coreOpts opts) vValue <- V.new numTasks tvWork <- newTVarIO (numTasks, zip [0..] tasks) tvDone <- newTVarIO numWorkers startTime <- getCurrentTime let printStartEnd = isJust $ printProgress $ simpleOpts $ coreOpts opts when printStartEnd $ hPrintTime h [i|Total tasks: ${numTasks}, starting at: |] hMutex <- atomically $ newTMVar () let safeLog s = bracket_ (atomically $ takeTMVar hMutex) (atomically $ putTMVar hMutex ()) (hPutStrLnFlush h s) replicateM_ numWorkers $ forkIO $ worker (timeLimit $ coreOpts opts) run (\t n x o -> V.write vValue n x >> afterFinish opts t n o >>= maybe (return ()) safeLog) tvWork tvDone waitForWorkers (simpleOpts $ coreOpts opts) safeLog startTime numTasks tvWork tvDone when printStartEnd $ hPrintTime h "Finished at: " return vValue hPrintTime :: Handle -> String -> IO () hPrintTime h msg = getCurrentTime >>= hPutStrLnFlush h . (msg ++) . show -- | Primarily waits for the last TVar to hit zero, but in the mean time prints details -- of the tasks remaining and the estimated completion time waitForWorkers :: SimpleParTaskOpts -> (String -> IO ()) -> UTCTime -> Int -> TVar (Int, _x) -> TVar Int -> IO () waitForWorkers opts safeLog startTime totalTasks tvWork tvDone = go True totalTasks totalTasks where go False _ _ = return () go True lastProgress lastETA = do (workersRemaining, tasksRemaining, timeToPrintProgress, timeToPrintETA) <- atomically $ do tasksRemaining <- fst <$> readTVar tvWork workersRemaining <- readTVar tvDone let timeToPrintProgress = case printProgress opts of Nothing -> False Just n -> tasksRemaining <= lastProgress - n timeToPrintETA = case printEstimate opts of Nothing -> False Just n -> tasksRemaining <= lastETA - n when (workersRemaining > 0 && not timeToPrintETA && not timeToPrintProgress) retry return (workersRemaining, tasksRemaining, timeToPrintProgress, timeToPrintETA) curTime <- getCurrentTime safeLog $ concat [if timeToPrintProgress then [i|Tasks remaining: ${tasksRemaining} |] else "" ,if timeToPrintETA then [i|ETA: ${eta startTime totalTasks curTime tasksRemaining}|] else ""] go (workersRemaining > 0) (if timeToPrintProgress then tasksRemaining else lastProgress) (if timeToPrintETA then tasksRemaining else lastETA) eta :: UTCTime -> Int -> UTCTime -> Int -> String eta startTime totalTasks curTime tasksRemaining = [i|${timeLeft} seconds (${showTime timeFinish})|] where timeSoFar = curTime `diffUTCTime` startTime timePerTask = timeSoFar / fromIntegral (totalTasks - tasksRemaining) timeLeft = timePerTask * fromIntegral tasksRemaining timeFinish = timeLeft `addUTCTime` curTime showTime :: UTCTime -> String showTime = formatTime defaultTimeLocale "%T %F" -- | Repeatedly picks next job from queue and executes it. worker :: forall m a. MonadIO m => -- Limit in microseconds for the computations: Maybe (Integer, a) -> -- Function for running the monad: (m () -> IO ()) -> -- Function for storing the result of the computation: (Double -> Int -> a -> TaskOutcome -> IO ()) -> -- Variable to grab the next work item from: TVar (Int, [(Int, m a)]) -> -- Variable to decrement when the worker finishes: TVar Int -> IO () worker mlimit run store tvWork tvDone = case mlimit of Just _limit -> handle (\TookTooLongException -> return ()) (run $ go True) `onException` finish -- We do not call finish when we receive a TookTooLongException, -- because another worker will have replaced us Nothing -> (run $ go True) `onException` finish where finish = atomically $ readTVar tvDone >>= writeTVar tvDone . pred go :: Bool -> m () go False = liftIO finish go True = do nextWork <- liftIO $ atomically $ do (count, work) <- readTVar tvWork case work of [] -> return Nothing (w:ws) -> Just w <$ writeTVar tvWork (pred count, ws) keepGoing <- case (nextWork, mlimit) of (Nothing, _) -> return False (Just (n, work), Nothing) -> do (x, t) <- withTime work liftIO $ store t n x Success return True (Just (n, work), Just (limit, def)) -> do myId <- liftIO myThreadId start <- liftIO getCurrentTime -- Fork a watchdog to watch for if we have been executing too long: theirId <- liftIO $ forkIO $ do threadDelay' limit end <- getCurrentTime store (end `timeDiff` start) n def TookTooLong -- We fork a new worker, because we cannot guarantee we will ever kill -- the old one (e.g. if it is blocked, or not allocating memory) -- TODO there is a slim chance that the worker can start and then we are killed -- first: _ <- forkIO $ worker mlimit run store tvWork tvDone -- This call may block (effectively) forever, waiting for the other thread to die: throwTo myId TookTooLongException (x, t) <- withTime work liftIO $ killThread theirId -- Kill the watchdog liftIO $ store t n x Success return True go keepGoing -- | A version of putStrLn that flushes after writing (and works in any MonadIO monad) hPutStrLnFlush :: MonadIO m => Handle -> String -> m () hPutStrLnFlush h s = liftIO $ hPutStrLn h s >> hFlush h withTime :: MonadIO m => m a -> m (a, Double) withTime m = do start <- liftIO getCurrentTime x <- m end <- liftIO getCurrentTime t <- liftIO $ evaluate $ end `timeDiff` start return (x, t) timeDiff :: UTCTime -> UTCTime -> Double timeDiff end start = fromRational $ toRational (end `diffUTCTime` start)