{-# LANGUAGE TupleSections #-}

-- | Thread pool implementation. The three names correspond to the following
--   priority levels (highest to lowest):
--
-- * 'addPoolException' - things that probably result in a build error,
--   so kick them off quickly.
--
-- * 'addPoolResume' - things that started, blocked, and may have open
--   resources in their closure.
--
-- * 'addPoolStart' - rules that haven't yet started.
--
-- * 'addPoolBatch' - rules that might batch if other rules start first.
module General.Pool(
    Pool, runPool,
    addPool, PoolPriority(..),
    increasePool, keepAlivePool
    ) where

import Control.Concurrent.Extra
import System.Time.Extra
import Control.Exception
import Control.Monad.Extra
import General.Timing
import General.Thread
import qualified Data.Heap as Heap
import qualified Data.HashSet as Set
import Data.IORef.Extra
import System.Random


---------------------------------------------------------------------
-- THREAD POOL

{-
Must keep a list of active threads, so can raise exceptions in a timely manner
If any worker throws an exception, must signal to all the other workers
-}

data S = S
    {S -> Bool
alive :: !Bool -- True until there's an exception, after which don't spawn more tasks
    ,S -> HashSet Thread
threads :: !(Set.HashSet Thread) -- IMPORTANT: Must be strict or we leak thread stacks
    ,S -> Int
threadsLimit :: {-# UNPACK #-} !Int -- user supplied thread limit, Set.size threads <= threadsLimit
    ,S -> Int
threadsCount :: {-# UNPACK #-} !Int -- Set.size threads, but in O(1)
    ,S -> Int
threadsMax :: {-# UNPACK #-} !Int -- high water mark of Set.size threads (accounting only)
    ,S -> Int
threadsSum :: {-# UNPACK #-} !Int -- number of threads we have been through (accounting only)
    ,S -> IO Int
rand :: IO Int -- operation to give us the next random Int
    ,S -> Heap (Entry (PoolPriority, Int) (IO ()))
todo :: !(Heap.Heap (Heap.Entry (PoolPriority, Int) (IO ()))) -- operations waiting a thread
    }


emptyS :: Int -> Bool -> IO S
emptyS :: Int -> Bool -> IO S
emptyS Int
n Bool
deterministic = do
    IO Int
rand <- if Bool -> Bool
not Bool
deterministic then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a (m :: * -> *). (Random a, MonadIO m) => m a
randomIO else do
        IORef Int
ref <- forall a. a -> IO (IORef a)
newIORef Int
0
        -- no need to be thread-safe - if two threads race they were basically the same time anyway
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do Int
i <- forall a. IORef a -> IO a
readIORef IORef Int
ref; forall a. IORef a -> a -> IO ()
writeIORef' IORef Int
ref (Int
iforall a. Num a => a -> a -> a
+Int
1); forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Bool
-> HashSet Thread
-> Int
-> Int
-> Int
-> Int
-> IO Int
-> Heap (Entry (PoolPriority, Int) (IO ()))
-> S
S Bool
True forall a. HashSet a
Set.empty Int
n Int
0 Int
0 Int
0 IO Int
rand forall a. Heap a
Heap.empty


data Pool = Pool
    !(Var S) -- Current state, 'alive' = False to say we are aborting
    !(Barrier (Either SomeException S)) -- Barrier to signal that we are finished


withPool :: Pool -> (S -> IO (S, IO ())) -> IO ()
withPool :: Pool -> (S -> IO (S, IO ())) -> IO ()
withPool (Pool Var S
var Barrier (Either SomeException S)
_) S -> IO (S, IO ())
f = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall a b. Var a -> (a -> IO (a, b)) -> IO b
modifyVar Var S
var forall a b. (a -> b) -> a -> b
$ \S
s ->
    if S -> Bool
alive S
s then S -> IO (S, IO ())
f S
s else forall (f :: * -> *) a. Applicative f => a -> f a
pure (S
s, forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

withPool_ :: Pool -> (S -> IO S) -> IO ()
withPool_ :: Pool -> (S -> IO S) -> IO ()
withPool_ Pool
pool S -> IO S
act = Pool -> (S -> IO (S, IO ())) -> IO ()
withPool Pool
pool forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (, forall (f :: * -> *) a. Applicative f => a -> f a
pure()) forall b c a. (b -> c) -> (a -> b) -> a -> c
. S -> IO S
act


worker :: Pool -> IO ()
worker :: Pool -> IO ()
worker Pool
pool = Pool -> (S -> IO (S, IO ())) -> IO ()
withPool Pool
pool forall a b. (a -> b) -> a -> b
$ \S
s -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case forall a. Heap a -> Maybe (a, Heap a)
Heap.uncons forall a b. (a -> b) -> a -> b
$ S -> Heap (Entry (PoolPriority, Int) (IO ()))
todo S
s of
    Maybe
  (Entry (PoolPriority, Int) (IO ()),
   Heap (Entry (PoolPriority, Int) (IO ())))
Nothing -> (S
s, forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
    Just (Heap.Entry (PoolPriority, Int)
_ IO ()
now, Heap (Entry (PoolPriority, Int) (IO ()))
todo2) -> (S
s{todo :: Heap (Entry (PoolPriority, Int) (IO ()))
todo = Heap (Entry (PoolPriority, Int) (IO ()))
todo2}, IO ()
now forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Pool -> IO ()
worker Pool
pool)

-- | Given a pool, and a function that breaks the S invariants, restore them.
--   They are only allowed to touch threadsLimit or todo.
--   Assumes only requires spawning a most one job (e.g. can't increase the pool by more than one at a time)
step :: Pool -> (S -> IO S) -> IO ()
-- mask_ is so we don't spawn and not record it
step :: Pool -> (S -> IO S) -> IO ()
step pool :: Pool
pool@(Pool Var S
_ Barrier (Either SomeException S)
done) S -> IO S
op = forall a. IO a -> IO a
mask_ forall a b. (a -> b) -> a -> b
$ Pool -> (S -> IO S) -> IO ()
withPool_ Pool
pool forall a b. (a -> b) -> a -> b
$ \S
s -> do
    S
s <- S -> IO S
op S
s
    case forall a. Heap a -> Maybe (a, Heap a)
Heap.uncons forall a b. (a -> b) -> a -> b
$ S -> Heap (Entry (PoolPriority, Int) (IO ()))
todo S
s of
        Just (Heap.Entry (PoolPriority, Int)
_ IO ()
now, Heap (Entry (PoolPriority, Int) (IO ()))
todo2) | S -> Int
threadsCount S
s forall a. Ord a => a -> a -> Bool
< S -> Int
threadsLimit S
s -> do
            -- spawn a new worker
            Thread
t <- forall a.
IO a -> (Thread -> Either SomeException a -> IO ()) -> IO Thread
newThreadFinally (IO ()
now forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Pool -> IO ()
worker Pool
pool) forall a b. (a -> b) -> a -> b
$ \Thread
t Either SomeException ()
res -> case Either SomeException ()
res of
                Left SomeException
e -> Pool -> (S -> IO S) -> IO ()
withPool_ Pool
pool forall a b. (a -> b) -> a -> b
$ \S
s -> do
                    forall a. Partial => Barrier a -> a -> IO ()
signalBarrier Barrier (Either SomeException S)
done forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left SomeException
e
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Thread -> S -> S
remThread Thread
t S
s){alive :: Bool
alive = Bool
False}
                Right ()
_ ->
                    Pool -> (S -> IO S) -> IO ()
step Pool
pool forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Thread -> S -> S
remThread Thread
t
            forall (f :: * -> *) a. Applicative f => a -> f a
pure (Thread -> S -> S
addThread Thread
t S
s){todo :: Heap (Entry (PoolPriority, Int) (IO ()))
todo = Heap (Entry (PoolPriority, Int) (IO ()))
todo2}
        Maybe
  (Entry (PoolPriority, Int) (IO ()),
   Heap (Entry (PoolPriority, Int) (IO ())))
Nothing | S -> Int
threadsCount S
s forall a. Eq a => a -> a -> Bool
== Int
0 -> do
            forall a. Partial => Barrier a -> a -> IO ()
signalBarrier Barrier (Either SomeException S)
done forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right S
s
            forall (f :: * -> *) a. Applicative f => a -> f a
pure S
s{alive :: Bool
alive = Bool
False}
        Maybe
  (Entry (PoolPriority, Int) (IO ()),
   Heap (Entry (PoolPriority, Int) (IO ())))
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure S
s
    where
        addThread :: Thread -> S -> S
addThread Thread
t S
s = S
s{threads :: HashSet Thread
threads = forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
Set.insert Thread
t forall a b. (a -> b) -> a -> b
$ S -> HashSet Thread
threads S
s, threadsCount :: Int
threadsCount = S -> Int
threadsCount S
s forall a. Num a => a -> a -> a
+ Int
1
                         ,threadsSum :: Int
threadsSum = S -> Int
threadsSum S
s forall a. Num a => a -> a -> a
+ Int
1, threadsMax :: Int
threadsMax = S -> Int
threadsMax S
s forall a. Ord a => a -> a -> a
`max` (S -> Int
threadsCount S
s forall a. Num a => a -> a -> a
+ Int
1)}
        remThread :: Thread -> S -> S
remThread Thread
t S
s = S
s{threads :: HashSet Thread
threads = forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
Set.delete Thread
t forall a b. (a -> b) -> a -> b
$ S -> HashSet Thread
threads S
s, threadsCount :: Int
threadsCount = S -> Int
threadsCount S
s forall a. Num a => a -> a -> a
- Int
1}


-- | Add a new task to the pool. See the top of the module for the relative ordering
--   and semantics.
addPool :: PoolPriority -> Pool -> IO a -> IO ()
addPool :: forall a. PoolPriority -> Pool -> IO a -> IO ()
addPool PoolPriority
priority Pool
pool IO a
act = Pool -> (S -> IO S) -> IO ()
step Pool
pool forall a b. (a -> b) -> a -> b
$ \S
s -> do
    Int
i <- S -> IO Int
rand S
s
    forall (f :: * -> *) a. Applicative f => a -> f a
pure S
s{todo :: Heap (Entry (PoolPriority, Int) (IO ()))
todo = forall a. Ord a => a -> Heap a -> Heap a
Heap.insert (forall p a. p -> a -> Entry p a
Heap.Entry (PoolPriority
priority, Int
i) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Functor f => f a -> f ()
void IO a
act) forall a b. (a -> b) -> a -> b
$ S -> Heap (Entry (PoolPriority, Int) (IO ()))
todo S
s}


data PoolPriority
    = PoolException
    | PoolResume
    | PoolStart
    | PoolBatch
    | PoolDeprioritize Double
      deriving (PoolPriority -> PoolPriority -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PoolPriority -> PoolPriority -> Bool
$c/= :: PoolPriority -> PoolPriority -> Bool
== :: PoolPriority -> PoolPriority -> Bool
$c== :: PoolPriority -> PoolPriority -> Bool
Eq,Eq PoolPriority
PoolPriority -> PoolPriority -> Bool
PoolPriority -> PoolPriority -> Ordering
PoolPriority -> PoolPriority -> PoolPriority
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: PoolPriority -> PoolPriority -> PoolPriority
$cmin :: PoolPriority -> PoolPriority -> PoolPriority
max :: PoolPriority -> PoolPriority -> PoolPriority
$cmax :: PoolPriority -> PoolPriority -> PoolPriority
>= :: PoolPriority -> PoolPriority -> Bool
$c>= :: PoolPriority -> PoolPriority -> Bool
> :: PoolPriority -> PoolPriority -> Bool
$c> :: PoolPriority -> PoolPriority -> Bool
<= :: PoolPriority -> PoolPriority -> Bool
$c<= :: PoolPriority -> PoolPriority -> Bool
< :: PoolPriority -> PoolPriority -> Bool
$c< :: PoolPriority -> PoolPriority -> Bool
compare :: PoolPriority -> PoolPriority -> Ordering
$ccompare :: PoolPriority -> PoolPriority -> Ordering
Ord)

-- | Temporarily increase the pool by 1 thread. Call the cleanup action to restore the value.
--   After calling cleanup you should requeue onto a new thread.
increasePool :: Pool -> IO (IO ())
increasePool :: Pool -> IO (IO ())
increasePool Pool
pool = do
    Pool -> (S -> IO S) -> IO ()
step Pool
pool forall a b. (a -> b) -> a -> b
$ \S
s -> forall (f :: * -> *) a. Applicative f => a -> f a
pure S
s{threadsLimit :: Int
threadsLimit = S -> Int
threadsLimit S
s forall a. Num a => a -> a -> a
+ Int
1}
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Pool -> (S -> IO S) -> IO ()
step Pool
pool forall a b. (a -> b) -> a -> b
$ \S
s -> forall (f :: * -> *) a. Applicative f => a -> f a
pure S
s{threadsLimit :: Int
threadsLimit = S -> Int
threadsLimit S
s forall a. Num a => a -> a -> a
- Int
1}


-- | Make sure the pool cannot run out of tasks (and thus everything finishes) until after the cancel is called.
--   Ensures that a pool that will requeue in time doesn't go idle.
keepAlivePool :: Pool -> IO (IO ())
keepAlivePool :: Pool -> IO (IO ())
keepAlivePool Pool
pool = do
    Barrier ()
bar <- forall a. IO (Barrier a)
newBarrier
    forall a. PoolPriority -> Pool -> IO a -> IO ()
addPool PoolPriority
PoolResume Pool
pool forall a b. (a -> b) -> a -> b
$ do
        IO ()
cancel <- Pool -> IO (IO ())
increasePool Pool
pool
        forall a. Barrier a -> IO a
waitBarrier Barrier ()
bar
        IO ()
cancel
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Partial => Barrier a -> a -> IO ()
signalBarrier Barrier ()
bar ()


-- | Run all the tasks in the pool on the given number of works.
--   If any thread throws an exception, the exception will be reraised.
runPool :: Bool -> Int -> (Pool -> IO ()) -> IO () -- run all tasks in the pool
runPool :: Bool -> Int -> (Pool -> IO ()) -> IO ()
runPool Bool
deterministic Int
n Pool -> IO ()
act = do
    Var S
s <- forall a. a -> IO (Var a)
newVar forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Int -> Bool -> IO S
emptyS Int
n Bool
deterministic
    Barrier (Either SomeException S)
done <- forall a. IO (Barrier a)
newBarrier
    let pool :: Pool
pool = Var S -> Barrier (Either SomeException S) -> Pool
Pool Var S
s Barrier (Either SomeException S)
done

    -- if someone kills our thread, make sure we kill our child threads
    let cleanup :: IO ()
cleanup = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall a b. Var a -> (a -> IO (a, b)) -> IO b
modifyVar Var S
s forall a b. (a -> b) -> a -> b
$ \S
s -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (S
s{alive :: Bool
alive=Bool
False}, [Thread] -> IO ()
stopThreads forall a b. (a -> b) -> a -> b
$ forall a. HashSet a -> [a]
Set.toList forall a b. (a -> b) -> a -> b
$ S -> HashSet Thread
threads S
s)

    let ghc10793 :: IO b
ghc10793 = do
            -- if this thread dies because it is blocked on an MVar there's a chance we have
            -- a better error in the done barrier, and GHC raised the exception wrongly, see:
            -- https://ghc.haskell.org/trac/ghc/ticket/10793
            Double -> IO ()
sleep Double
1 -- give it a little bit of time for the finally to run
                    -- no big deal, since the blocked indefinitely takes a while to fire anyway
            Maybe (Either SomeException S)
res <- forall a. Barrier a -> IO (Maybe a)
waitBarrierMaybe Barrier (Either SomeException S)
done
            case Maybe (Either SomeException S)
res of
                Just (Left SomeException
e) -> forall e a. Exception e => e -> IO a
throwIO SomeException
e
                Maybe (Either SomeException S)
_ -> forall e a. Exception e => e -> IO a
throwIO BlockedIndefinitelyOnMVar
BlockedIndefinitelyOnMVar
    forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. IO a -> IO b -> IO a
finally IO ()
cleanup forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (\BlockedIndefinitelyOnMVar
BlockedIndefinitelyOnMVar -> forall {b}. IO b
ghc10793) forall a b. (a -> b) -> a -> b
$ do
        forall a. PoolPriority -> Pool -> IO a -> IO ()
addPool PoolPriority
PoolStart Pool
pool forall a b. (a -> b) -> a -> b
$ Pool -> IO ()
act Pool
pool
        Either SomeException S
res <- forall a. Barrier a -> IO a
waitBarrier Barrier (Either SomeException S)
done
        case Either SomeException S
res of
            Left SomeException
e -> forall e a. Exception e => e -> IO a
throwIO SomeException
e
            Right S
s -> String -> IO ()
addTiming forall a b. (a -> b) -> a -> b
$ String
"Pool finished (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (S -> Int
threadsSum S
s) forall a. [a] -> [a] -> [a]
++ String
" threads, " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (S -> Int
threadsMax S
s) forall a. [a] -> [a] -> [a]
++ String
" max)"