{-# LANGUAGE LambdaCase #-}
module Control.Distributed.Fork.Utils where
import Control.Monad (forM, unless)
import Control.Concurrent.STM (newTVar, atomically, retry, readTVar, modifyTVar, writeTVar)
import Control.Monad.Trans.State (evalStateT, get, put)
import Control.Monad.IO.Class (liftIO)
import Control.Concurrent.Async (async, wait, waitBoth)
import Control.Exception (throwIO)
import Data.Function (fix)
import Control.Concurrent (threadDelay)
import qualified System.Console.Terminal.Size as TS
import Control.Distributed.Fork.Internal (Handle (..))
import Control.Distributed.Fork
mapConcurrentlyWithProgress :: Backend
-> Closure (Dict (Serializable a))
-> [Closure (IO a)]
-> IO [a]
mapConcurrentlyWithProgress backend dict xs = do
st <- atomically $ newTVar (True, 0::Int, 0::Int, 0::Int, 0::Int)
handles <- mapM (fork backend dict) xs
asyncs <- forM handles $ \(Handle tv) ->
async . flip evalStateT (0, 0, 0, 0) . fix $ \recurse -> do
oldState <- get
(newState, result) <- liftIO $ atomically $ do
(n, r) <- readTVar tv >>= return . \case
ExecutorPending (ExecutorWaiting _) -> ((1, 0, 0, 0), Nothing)
ExecutorPending (ExecutorSubmitted _) -> ((0, 1, 0, 0), Nothing)
ExecutorPending (ExecutorStarted _) -> ((0, 0, 1, 0), Nothing)
ExecutorFinished fr -> ((0, 0, 0, 1), Just fr)
unless (oldState /= n) retry
return (n, r)
put newState
liftIO . atomically $
modifyTVar st $ \case
(_, s1, s2, s3, s4) -> case (oldState, newState) of
((o1, o2, o3, o4), (n1, n2, n3, n4)) ->
(True, s1 - o1 + n1, s2 - o2 + n2, s3 - o3 + n3, s4 - o4 + n4)
case result of
Nothing -> recurse
Just (ExecutorFailed err) -> liftIO . throwIO $ ExecutorFailedException err
Just (ExecutorSucceeded x) -> return x
result <- async $ mapM wait asyncs
termWidth <- maybe 40 TS.width <$> TS.size :: IO Int
let total = length xs
ratio = fromIntegral total / fromIntegral (termWidth - 2) :: Double
pbar <- async . fix $ \recurse -> do
(waiting, submitted, started, finished) <- atomically $
readTVar st >>= \case
(False, _, _, _, _) -> retry
(True, a, b, c, d) -> do
writeTVar st (False, a, b, c, d)
return (a, b, c, d)
let p c n = replicate (truncate (fromIntegral n / ratio)) c
putStr . concat $
[ "\r"
, "["
, p '#' finished
, p ':' started
, p '.' submitted
, p ' ' waiting
, "]"
]
if finished < total
then threadDelay 10000 >> recurse
else putStrLn ""
fst <$> waitBoth result pbar