{-# OPTIONS_HADDOCK not-home #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
module Hedgehog.Internal.Queue (
    TaskIndex(..)
  , TasksRemaining(..)

  , runTasks
  , finalizeTask

  , runActiveFinalizers
  , dequeueMVar

  , updateNumCapabilities
  ) where

import           Control.Concurrent (rtsSupportsBoundThreads)
import           Control.Concurrent.Async (forConcurrently)
import           Control.Concurrent.MVar (MVar)
import qualified Control.Concurrent.MVar as MVar
import           Control.Monad (when)
import           Control.Monad.IO.Class (MonadIO(..))

import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map

import qualified GHC.Conc as Conc

import           Hedgehog.Internal.Config


newtype TaskIndex =
  TaskIndex Int
  deriving (TaskIndex -> TaskIndex -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TaskIndex -> TaskIndex -> Bool
$c/= :: TaskIndex -> TaskIndex -> Bool
== :: TaskIndex -> TaskIndex -> Bool
$c== :: TaskIndex -> TaskIndex -> Bool
Eq, Eq TaskIndex
TaskIndex -> TaskIndex -> Bool
TaskIndex -> TaskIndex -> Ordering
TaskIndex -> TaskIndex -> TaskIndex
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TaskIndex -> TaskIndex -> TaskIndex
$cmin :: TaskIndex -> TaskIndex -> TaskIndex
max :: TaskIndex -> TaskIndex -> TaskIndex
$cmax :: TaskIndex -> TaskIndex -> TaskIndex
>= :: TaskIndex -> TaskIndex -> Bool
$c>= :: TaskIndex -> TaskIndex -> Bool
> :: TaskIndex -> TaskIndex -> Bool
$c> :: TaskIndex -> TaskIndex -> Bool
<= :: TaskIndex -> TaskIndex -> Bool
$c<= :: TaskIndex -> TaskIndex -> Bool
< :: TaskIndex -> TaskIndex -> Bool
$c< :: TaskIndex -> TaskIndex -> Bool
compare :: TaskIndex -> TaskIndex -> Ordering
$ccompare :: TaskIndex -> TaskIndex -> Ordering
Ord, Int -> TaskIndex
TaskIndex -> Int
TaskIndex -> [TaskIndex]
TaskIndex -> TaskIndex
TaskIndex -> TaskIndex -> [TaskIndex]
TaskIndex -> TaskIndex -> TaskIndex -> [TaskIndex]
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: TaskIndex -> TaskIndex -> TaskIndex -> [TaskIndex]
$cenumFromThenTo :: TaskIndex -> TaskIndex -> TaskIndex -> [TaskIndex]
enumFromTo :: TaskIndex -> TaskIndex -> [TaskIndex]
$cenumFromTo :: TaskIndex -> TaskIndex -> [TaskIndex]
enumFromThen :: TaskIndex -> TaskIndex -> [TaskIndex]
$cenumFromThen :: TaskIndex -> TaskIndex -> [TaskIndex]
enumFrom :: TaskIndex -> [TaskIndex]
$cenumFrom :: TaskIndex -> [TaskIndex]
fromEnum :: TaskIndex -> Int
$cfromEnum :: TaskIndex -> Int
toEnum :: Int -> TaskIndex
$ctoEnum :: Int -> TaskIndex
pred :: TaskIndex -> TaskIndex
$cpred :: TaskIndex -> TaskIndex
succ :: TaskIndex -> TaskIndex
$csucc :: TaskIndex -> TaskIndex
Enum, Integer -> TaskIndex
TaskIndex -> TaskIndex
TaskIndex -> TaskIndex -> TaskIndex
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> TaskIndex
$cfromInteger :: Integer -> TaskIndex
signum :: TaskIndex -> TaskIndex
$csignum :: TaskIndex -> TaskIndex
abs :: TaskIndex -> TaskIndex
$cabs :: TaskIndex -> TaskIndex
negate :: TaskIndex -> TaskIndex
$cnegate :: TaskIndex -> TaskIndex
* :: TaskIndex -> TaskIndex -> TaskIndex
$c* :: TaskIndex -> TaskIndex -> TaskIndex
- :: TaskIndex -> TaskIndex -> TaskIndex
$c- :: TaskIndex -> TaskIndex -> TaskIndex
+ :: TaskIndex -> TaskIndex -> TaskIndex
$c+ :: TaskIndex -> TaskIndex -> TaskIndex
Num)

newtype TasksRemaining =
  TasksRemaining Int

dequeueMVar ::
     MVar [(TaskIndex, a)]
  -> (TasksRemaining -> TaskIndex -> a -> IO b)
  -> IO (Maybe (TaskIndex, b))
dequeueMVar :: forall a b.
MVar [(TaskIndex, a)]
-> (TasksRemaining -> TaskIndex -> a -> IO b)
-> IO (Maybe (TaskIndex, b))
dequeueMVar MVar [(TaskIndex, a)]
mvar TasksRemaining -> TaskIndex -> a -> IO b
start =
  forall a b. MVar a -> (a -> IO (a, b)) -> IO b
MVar.modifyMVar MVar [(TaskIndex, a)]
mvar forall a b. (a -> b) -> a -> b
$ \case
    [] ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], forall a. Maybe a
Nothing)
    (TaskIndex
ix, a
x) : [(TaskIndex, a)]
xs -> do
      b
y <- TasksRemaining -> TaskIndex -> a -> IO b
start (Int -> TasksRemaining
TasksRemaining forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [(TaskIndex, a)]
xs) TaskIndex
ix a
x
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(TaskIndex, a)]
xs, forall a. a -> Maybe a
Just (TaskIndex
ix, b
y))

runTasks ::
     WorkerCount
  -> [a]
  -> (TasksRemaining -> TaskIndex -> a -> IO b)
  -> (b -> IO ())
  -> (b -> IO ())
  -> (b -> IO c)
  -> IO [c]
runTasks :: forall a b c.
WorkerCount
-> [a]
-> (TasksRemaining -> TaskIndex -> a -> IO b)
-> (b -> IO ())
-> (b -> IO ())
-> (b -> IO c)
-> IO [c]
runTasks WorkerCount
n [a]
tasks TasksRemaining -> TaskIndex -> a -> IO b
start b -> IO ()
finish b -> IO ()
finalize b -> IO c
runTask = do
  MVar [(TaskIndex, a)]
qvar <- forall a. a -> IO (MVar a)
MVar.newMVar (forall a b. [a] -> [b] -> [(a, b)]
zip [TaskIndex
0..] [a]
tasks)
  MVar (TaskIndex, Map TaskIndex (IO ()))
fvar <- forall a. a -> IO (MVar a)
MVar.newMVar (-TaskIndex
1, forall k a. Map k a
Map.empty)

  let
    worker :: [c] -> IO [c]
worker [c]
rs = do
      Maybe (TaskIndex, b)
mx <- forall a b.
MVar [(TaskIndex, a)]
-> (TasksRemaining -> TaskIndex -> a -> IO b)
-> IO (Maybe (TaskIndex, b))
dequeueMVar MVar [(TaskIndex, a)]
qvar TasksRemaining -> TaskIndex -> a -> IO b
start
      case Maybe (TaskIndex, b)
mx of
        Maybe (TaskIndex, b)
Nothing ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure [c]
rs
        Just (TaskIndex
ix, b
x) -> do
          c
r <- b -> IO c
runTask b
x
          b -> IO ()
finish b
x
          forall (m :: * -> *).
MonadIO m =>
MVar (TaskIndex, Map TaskIndex (IO ()))
-> TaskIndex -> IO () -> m ()
finalizeTask MVar (TaskIndex, Map TaskIndex (IO ()))
fvar TaskIndex
ix (b -> IO ()
finalize b
x)
          [c] -> IO [c]
worker (c
r forall a. a -> [a] -> [a]
: [c]
rs)

  -- FIXME ensure all workers have finished running
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b.
Traversable t =>
t a -> (a -> IO b) -> IO (t b)
forConcurrently [WorkerCount
1..forall a. Ord a => a -> a -> a
max WorkerCount
1 WorkerCount
n] forall a b. (a -> b) -> a -> b
$ \WorkerCount
_ix ->
    [c] -> IO [c]
worker []

runActiveFinalizers ::
     MonadIO m
  => MVar (TaskIndex, Map TaskIndex (IO ()))
  -> m ()
runActiveFinalizers :: forall (m :: * -> *).
MonadIO m =>
MVar (TaskIndex, Map TaskIndex (IO ())) -> m ()
runActiveFinalizers MVar (TaskIndex, Map TaskIndex (IO ()))
mvar =
  forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
    Bool
again <-
      forall a b. MVar a -> (a -> IO (a, b)) -> IO b
MVar.modifyMVar MVar (TaskIndex, Map TaskIndex (IO ()))
mvar forall a b. (a -> b) -> a -> b
$ \original :: (TaskIndex, Map TaskIndex (IO ()))
original@(TaskIndex
minIx, Map TaskIndex (IO ())
finalizers0) ->
        case forall k a. Map k a -> Maybe ((k, a), Map k a)
Map.minViewWithKey Map TaskIndex (IO ())
finalizers0 of
          Maybe ((TaskIndex, IO ()), Map TaskIndex (IO ()))
Nothing ->
            forall (f :: * -> *) a. Applicative f => a -> f a
pure ((TaskIndex, Map TaskIndex (IO ()))
original, Bool
False)

          Just ((TaskIndex
ix, IO ()
finalize), Map TaskIndex (IO ())
finalizers) ->
            if TaskIndex
ix forall a. Eq a => a -> a -> Bool
== TaskIndex
minIx forall a. Num a => a -> a -> a
+ TaskIndex
1 then do
              IO ()
finalize
              forall (f :: * -> *) a. Applicative f => a -> f a
pure ((TaskIndex
ix, Map TaskIndex (IO ())
finalizers), Bool
True)
            else
              forall (f :: * -> *) a. Applicative f => a -> f a
pure ((TaskIndex, Map TaskIndex (IO ()))
original, Bool
False)

    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
again forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *).
MonadIO m =>
MVar (TaskIndex, Map TaskIndex (IO ())) -> m ()
runActiveFinalizers MVar (TaskIndex, Map TaskIndex (IO ()))
mvar

finalizeTask ::
     MonadIO m
  => MVar (TaskIndex, Map TaskIndex (IO ()))
  -> TaskIndex
  -> IO ()
  -> m ()
finalizeTask :: forall (m :: * -> *).
MonadIO m =>
MVar (TaskIndex, Map TaskIndex (IO ()))
-> TaskIndex -> IO () -> m ()
finalizeTask MVar (TaskIndex, Map TaskIndex (IO ()))
mvar TaskIndex
ix IO ()
finalize = do
  forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. MVar a -> (a -> IO a) -> IO ()
MVar.modifyMVar_ MVar (TaskIndex, Map TaskIndex (IO ()))
mvar forall a b. (a -> b) -> a -> b
$ \(TaskIndex
minIx, Map TaskIndex (IO ())
finalizers) ->
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (TaskIndex
minIx, forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert TaskIndex
ix IO ()
finalize Map TaskIndex (IO ())
finalizers)
  forall (m :: * -> *).
MonadIO m =>
MVar (TaskIndex, Map TaskIndex (IO ())) -> m ()
runActiveFinalizers MVar (TaskIndex, Map TaskIndex (IO ()))
mvar

-- | Update the number of capabilities but never set it lower than it already
--   is.
--
updateNumCapabilities :: WorkerCount -> IO ()
updateNumCapabilities :: WorkerCount -> IO ()
updateNumCapabilities (WorkerCount Int
n) = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
rtsSupportsBoundThreads forall a b. (a -> b) -> a -> b
$ do
  Int
ncaps <- IO Int
Conc.getNumCapabilities
  Int -> IO ()
Conc.setNumCapabilities (forall a. Ord a => a -> a -> a
max Int
n Int
ncaps)