{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Control.Scheduler
(
Comp(..)
, Scheduler(..)
, withScheduler
, withScheduler_
, traverseConcurrently
, traverseConcurrently_
, traverse_
) where
import Control.Concurrent
import Control.Exception
import Control.Scheduler.Computation
import Control.Scheduler.Queue
import Control.Monad
import Control.Monad.IO.Unlift
import Data.Atomics (atomicModifyIORefCAS, atomicModifyIORefCAS_)
import Data.Foldable as F (foldl')
import Data.IORef
import Data.Traversable
data Jobs m a = Jobs
{ jobsNumWorkers :: {-# UNPACK #-} !Int
, jobsQueue :: !(JQueue m a)
, jobsCountRef :: !(IORef Int)
}
data Scheduler m a = Scheduler
{ numWorkers :: {-# UNPACK #-} !Int
, scheduleWork :: m a -> m ()
}
traverse_ :: (Applicative f, Foldable t) => (a -> f ()) -> t a -> f ()
traverse_ f = F.foldl' (\c a -> c *> f a) (pure ())
traverseConcurrently :: (MonadUnliftIO m, Traversable t) => Comp -> (a -> m b) -> t a -> m (t b)
traverseConcurrently comp f xs = do
ys <- withScheduler comp $ \s -> traverse_ (scheduleWork s . f) xs
pure $ transList ys xs
transList :: Traversable t => [a] -> t b -> t a
transList xs' = snd . mapAccumL withR xs'
where
withR (x:xs) _ = (xs, x)
withR _ _ = error "Impossible<traverseConcurrently> - Mismatched sizes"
traverseConcurrently_ :: (MonadUnliftIO m, Foldable t) => Comp -> (a -> m b) -> t a -> m ()
traverseConcurrently_ comp f xs = withScheduler_ comp $ \s -> traverse_ (scheduleWork s . f) xs
scheduleJobs :: MonadIO m => Jobs m a -> m a -> m ()
scheduleJobs = scheduleJobsWith mkJob
scheduleJobs_ :: MonadIO m => Jobs m a -> m b -> m ()
scheduleJobs_ = scheduleJobsWith (return . Job_ . void)
scheduleJobsWith :: MonadIO m => (m b -> m (Job m a)) -> Jobs m a -> m b -> m ()
scheduleJobsWith mkJob' Jobs {jobsQueue, jobsCountRef, jobsNumWorkers} action = do
liftIO $ atomicModifyIORefCAS_ jobsCountRef (+ 1)
job <-
mkJob' $ do
res <- action
res `seq` dropCounterOnZero jobsCountRef $ retireWorkersN jobsQueue jobsNumWorkers
return res
pushJQueue jobsQueue job
retireWorkersN :: MonadIO m => JQueue m a -> Int -> m ()
retireWorkersN jobsQueue n = traverse_ (pushJQueue jobsQueue) $ replicate n Retire
dropCounterOnZero :: MonadIO m => IORef Int -> m () -> m ()
dropCounterOnZero counterRef onZero = do
jc <-
liftIO $ atomicModifyIORefCAS
counterRef
(\ !i' ->
let !i = i' - 1
in (i, i))
when (jc == 0) onZero
runWorker :: MonadIO m =>
JQueue m a
-> m ()
-> m ()
runWorker jQueue onRetire = go
where
go =
popJQueue jQueue >>= \case
Just job -> job >> go
Nothing -> onRetire
withScheduler ::
MonadUnliftIO m
=> Comp
-> (Scheduler m a -> m b)
-> m [a]
withScheduler comp = withSchedulerInternal comp scheduleJobs flushResults
withScheduler_ ::
MonadUnliftIO m
=> Comp
-> (Scheduler m a -> m b)
-> m ()
withScheduler_ comp = withSchedulerInternal comp scheduleJobs_ (const (pure ()))
withSchedulerInternal ::
MonadUnliftIO m
=> Comp
-> (Jobs m a -> m a -> m ())
-> (JQueue m a -> m c)
-> (Scheduler m a -> m b)
-> m c
withSchedulerInternal comp submitWork collect onScheduler = do
jobsNumWorkers <-
case comp of
Seq -> return 1
Par -> liftIO getNumCapabilities
ParOn ws -> return $ length ws
ParN 0 -> liftIO getNumCapabilities
ParN n -> return $ fromIntegral n
sWorkersCounterRef <- liftIO $ newIORef jobsNumWorkers
jobsQueue <- newJQueue
jobsCountRef <- liftIO $ newIORef 0
workDoneMVar <- liftIO newEmptyMVar
let jobs = Jobs {..}
scheduler = Scheduler {numWorkers = jobsNumWorkers, scheduleWork = submitWork jobs}
onRetire = dropCounterOnZero sWorkersCounterRef $ liftIO (putMVar workDoneMVar Nothing)
_ <- onScheduler scheduler
jc <- liftIO $ readIORef jobsCountRef
when (jc == 0) $ scheduleJobs_ jobs (pure ())
let spawnWorkersWith fork ws =
withRunInIO $ \run ->
forM ws $ \w ->
fork w $ \unmask ->
catch
(unmask $ run $ runWorker jobsQueue onRetire)
(run . handleWorkerException jobsQueue workDoneMVar jobsNumWorkers)
{-# INLINE spawnWorkersWith #-}
spawnWorkers =
case comp of
Seq -> return []
Par -> spawnWorkersWith forkOnWithUnmask [1 .. jobsNumWorkers]
ParOn ws -> spawnWorkersWith forkOnWithUnmask ws
ParN _ -> spawnWorkersWith (\_ -> forkIOWithUnmask) [1 .. jobsNumWorkers]
{-# INLINE spawnWorkers #-}
doWork = do
when (comp == Seq) $ runWorker jobsQueue onRetire
mExc <- liftIO $ readMVar workDoneMVar
case mExc of
Nothing -> collect jobsQueue
Just (WorkerException exc) -> liftIO $ throwIO exc
{-# INLINE doWork #-}
safeBracketOnError
spawnWorkers
(liftIO . traverse_ (`throwTo` SomeAsyncException WorkerTerminateException))
(const doWork)
handleWorkerException ::
MonadIO m => JQueue m a -> MVar (Maybe WorkerException) -> Int -> SomeException -> m ()
handleWorkerException jQueue workDoneMVar nWorkers exc =
case asyncExceptionFromException exc of
Just WorkerTerminateException -> return ()
_ -> do
_ <- liftIO $ tryPutMVar workDoneMVar $ Just $ WorkerException exc
retireWorkersN jQueue (nWorkers - 1)
newtype WorkerException =
WorkerException SomeException
deriving (Show)
instance Exception WorkerException where
displayException workerExc =
case workerExc of
WorkerException exc ->
"A worker handled a job that ended with exception: " ++ displayException exc
data WorkerTerminateException =
WorkerTerminateException
deriving (Show)
instance Exception WorkerTerminateException where
displayException WorkerTerminateException = "A worker was terminated by the scheduler"
safeBracketOnError :: MonadUnliftIO m => m a -> (a -> m b) -> (a -> m c) -> m c
safeBracketOnError before after thing = withRunInIO $ \run -> mask $ \restore -> do
x <- run before
res1 <- try $ restore $ run $ thing x
case res1 of
Left (e1 :: SomeException) -> do
_ :: Either SomeException b <-
try $ uninterruptibleMask_ $ run $ after x
throwIO e1
Right y -> return y