module Control.Concurrent.ParallelTasks.Base (ExtendedParTaskOpts(..), ParTaskOpts(..), SimpleParTaskOpts(..), TaskOutcome(Success, TookTooLong), defaultExtendedParTaskOpts, defaultParTaskOpts, parallelTasks) where
import Control.Applicative ((<$), (<$>))
import Control.Concurrent (forkIO, getNumCapabilities, killThread, myThreadId, threadDelay)
import Control.Concurrent.STM (TVar, atomically, newTVarIO, readTVar, retry, writeTVar)
import Control.Concurrent.STM.TMVar (newTMVar, putTMVar, takeTMVar)
import Control.Exception.Base (Exception, bracket_, evaluate, handle, onException, throwTo)
import Control.Monad (replicateM_, when)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Maybe (isJust)
import Data.String.Here.Interpolated (i)
import Data.Time.Clock (UTCTime, addUTCTime, diffUTCTime, getCurrentTime)
import Data.Time.Format (formatTime)
import Data.Typeable
import qualified Data.Vector.Mutable as V (IOVector, new, write)
import System.IO (Handle, hFlush, hPutStrLn, stderr)
import System.Locale (defaultTimeLocale)
data SimpleParTaskOpts = SimpleParTaskOpts {
numberWorkers :: Maybe Int,
printProgress :: Maybe Int,
printEstimate :: Maybe Int
}
data ParTaskOpts m a = ParTaskOpts {
simpleOpts :: SimpleParTaskOpts,
wrapWorker :: forall r. m (m r -> IO r),
timeLimit :: Maybe (Integer, a)
}
data ExtendedParTaskOpts m a = ExtendedParTaskOpts {
coreOpts :: ParTaskOpts m a,
printTo :: forall r. (Handle -> IO r) -> IO r,
afterFinish :: Double -> Int -> TaskOutcome -> IO (Maybe String)
}
data TaskOutcome = Success | TookTooLong
data TookTooLongException = TookTooLongException deriving (Show, Typeable)
instance Exception TookTooLongException
threadDelay' :: Integer -> IO ()
threadDelay' target
| target > mx = threadDelay maxBound >> threadDelay' (target mx)
| otherwise = threadDelay (fromInteger target)
where
mx :: Integer
mx = toInteger (maxBound :: Int)
defaultExtendedParTaskOpts :: MonadIO m => ParTaskOpts m a -> ExtendedParTaskOpts m a
defaultExtendedParTaskOpts opts = ExtendedParTaskOpts { coreOpts = opts, printTo = ($ stderr), afterFinish = printKill }
where
printKill _ n TookTooLong = return $ Just [i|*** Killed task ${n} for taking too long|]
printKill _ _ _ = return Nothing
defaultParTaskOpts :: ParTaskOpts IO a
defaultParTaskOpts = ParTaskOpts { simpleOpts = SimpleParTaskOpts {numberWorkers = Nothing, printProgress = Just 50, printEstimate = Just 200 }, wrapWorker = return id, timeLimit = Nothing }
parallelTasks :: MonadIO m => ExtendedParTaskOpts m a -> [m a] -> m (V.IOVector a)
parallelTasks _ [] = liftIO $ V.new 0
parallelTasks opts tasks = wrapWorker (coreOpts opts) >>= \run -> liftIO $ printTo opts $ \h -> do
let numTasks = length tasks
numWorkers <- maybe getNumCapabilities return (numberWorkers $ simpleOpts $ coreOpts opts)
vValue <- V.new numTasks
tvWork <- newTVarIO (numTasks, zip [0..] tasks)
tvDone <- newTVarIO numWorkers
startTime <- getCurrentTime
let printStartEnd = isJust $ printProgress $ simpleOpts $ coreOpts opts
when printStartEnd $
hPrintTime h [i|Total tasks: ${numTasks}, starting at: |]
hMutex <- atomically $ newTMVar ()
let safeLog s = bracket_ (atomically $ takeTMVar hMutex) (atomically $ putTMVar hMutex ()) (hPutStrLnFlush h s)
replicateM_ numWorkers $ forkIO $
worker (timeLimit $ coreOpts opts) run
(\t n x o -> V.write vValue n x >> afterFinish opts t n o >>= maybe (return ()) safeLog)
tvWork tvDone
waitForWorkers (simpleOpts $ coreOpts opts) safeLog startTime numTasks tvWork tvDone
when printStartEnd $
hPrintTime h "Finished at: "
return vValue
hPrintTime :: Handle -> String -> IO ()
hPrintTime h msg = getCurrentTime >>= hPutStrLnFlush h . (msg ++) . show
waitForWorkers :: SimpleParTaskOpts -> (String -> IO ()) -> UTCTime -> Int -> TVar (Int, _x) -> TVar Int -> IO ()
waitForWorkers opts safeLog startTime totalTasks tvWork tvDone = go True totalTasks totalTasks
where
go False _ _ = return ()
go True lastProgress lastETA = do
(workersRemaining, tasksRemaining, timeToPrintProgress, timeToPrintETA) <- atomically $ do
tasksRemaining <- fst <$> readTVar tvWork
workersRemaining <- readTVar tvDone
let timeToPrintProgress = case printProgress opts of
Nothing -> False
Just n -> tasksRemaining <= lastProgress n
timeToPrintETA = case printEstimate opts of
Nothing -> False
Just n -> tasksRemaining <= lastETA n
when (workersRemaining > 0 && not timeToPrintETA && not timeToPrintProgress) retry
return (workersRemaining, tasksRemaining, timeToPrintProgress, timeToPrintETA)
curTime <- getCurrentTime
safeLog $ concat
[if timeToPrintProgress then [i|Tasks remaining: ${tasksRemaining} |] else ""
,if timeToPrintETA then [i|ETA: ${eta startTime totalTasks curTime tasksRemaining}|] else ""]
go (workersRemaining > 0)
(if timeToPrintProgress then tasksRemaining else lastProgress)
(if timeToPrintETA then tasksRemaining else lastETA)
eta :: UTCTime -> Int -> UTCTime -> Int -> String
eta startTime totalTasks curTime tasksRemaining
= [i|${timeLeft} seconds (${showTime timeFinish})|]
where
timeSoFar = curTime `diffUTCTime` startTime
timePerTask = timeSoFar / fromIntegral (totalTasks tasksRemaining)
timeLeft = timePerTask * fromIntegral tasksRemaining
timeFinish = timeLeft `addUTCTime` curTime
showTime :: UTCTime -> String
showTime = formatTime defaultTimeLocale "%T %F"
worker :: forall m a. MonadIO m =>
Maybe (Integer, a) ->
(m () -> IO ()) ->
(Double -> Int -> a -> TaskOutcome -> IO ()) ->
TVar (Int, [(Int, m a)]) ->
TVar Int ->
IO ()
worker mlimit run store tvWork tvDone = case mlimit of
Just _limit -> handle (\TookTooLongException -> return ())
(run $ go True) `onException` finish
Nothing -> (run $ go True) `onException` finish
where
finish = atomically $ readTVar tvDone >>= writeTVar tvDone . pred
go :: Bool -> m ()
go False = liftIO finish
go True = do nextWork <- liftIO $ atomically $ do
(count, work) <- readTVar tvWork
case work of
[] -> return Nothing
(w:ws) -> Just w <$ writeTVar tvWork (pred count, ws)
keepGoing <- case (nextWork, mlimit) of
(Nothing, _) -> return False
(Just (n, work), Nothing) -> do
(x, t) <- withTime work
liftIO $ store t n x Success
return True
(Just (n, work), Just (limit, def)) -> do
myId <- liftIO myThreadId
start <- liftIO getCurrentTime
theirId <- liftIO $ forkIO $ do
threadDelay' limit
end <- getCurrentTime
store (end `timeDiff` start) n def TookTooLong
_ <- forkIO $ worker mlimit run store tvWork tvDone
throwTo myId TookTooLongException
(x, t) <- withTime work
liftIO $ killThread theirId
liftIO $ store t n x Success
return True
go keepGoing
hPutStrLnFlush :: MonadIO m => Handle -> String -> m ()
hPutStrLnFlush h s = liftIO $ hPutStrLn h s >> hFlush h
withTime :: MonadIO m => m a -> m (a, Double)
withTime m = do
start <- liftIO getCurrentTime
x <- m
end <- liftIO getCurrentTime
t <- liftIO $ evaluate $ end `timeDiff` start
return (x, t)
timeDiff :: UTCTime -> UTCTime -> Double
timeDiff end start = fromRational $ toRational (end `diffUTCTime` start)