{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase   #-}
{-|
Module: Supervisors
Description: Montior a pool of threads.

This module exposes a 'Supervisor' construct, which can be used to safely
spawn threads while guaranteeing that:

* When the supervisor is killed, all of the threads it supervises will be
  killed.
* Child threads can terminate in any order, and memory usage will always
  be proportional to the number of *live* supervised threads.
-}
module Supervisors
    ( Supervisor
    , withSupervisor
    , supervise
    , superviseSTM
    ) where

import Control.Concurrent.STM

import Control.Concurrent       (ThreadId, forkIO, myThreadId, throwTo)
import Control.Concurrent.Async (withAsync)
import Control.Exception.Safe
    ( Exception
    , SomeException
    , bracket
    , bracket_
    , finally
    , toException
    , withException
    )
import Control.Monad            (forever, void)

import qualified Data.Set as S

-- | A handle for a supervisor, which montiors a pool of threads.
data Supervisor = Supervisor
    { Supervisor -> TVar (Either SomeException (Set ThreadId))
stateVar :: TVar (Either SomeException (S.Set ThreadId))
    , Supervisor -> TQueue (IO ())
runQ     :: TQueue (IO ())
    }

-- | Start a new supervisor, and return it.
newSupervisor :: IO Supervisor
newSupervisor :: IO Supervisor
newSupervisor = do
    TVar (Either SomeException (Set ThreadId))
stateVar <- Either SomeException (Set ThreadId)
-> IO (TVar (Either SomeException (Set ThreadId)))
forall a. a -> IO (TVar a)
newTVarIO (Either SomeException (Set ThreadId)
 -> IO (TVar (Either SomeException (Set ThreadId))))
-> Either SomeException (Set ThreadId)
-> IO (TVar (Either SomeException (Set ThreadId)))
forall a b. (a -> b) -> a -> b
$ Set ThreadId -> Either SomeException (Set ThreadId)
forall a b. b -> Either a b
Right Set ThreadId
forall a. Set a
S.empty
    TQueue (IO ())
runQ <- IO (TQueue (IO ()))
forall a. IO (TQueue a)
newTQueueIO
    let sup :: Supervisor
sup = Supervisor :: TVar (Either SomeException (Set ThreadId))
-> TQueue (IO ()) -> Supervisor
Supervisor
            { stateVar :: TVar (Either SomeException (Set ThreadId))
stateVar = TVar (Either SomeException (Set ThreadId))
stateVar
            , runQ :: TQueue (IO ())
runQ = TQueue (IO ())
runQ
            }
    Supervisor -> IO Supervisor
forall (f :: * -> *) a. Applicative f => a -> f a
pure Supervisor
sup

-- | Run the logic associated with the supervisor. This never returns until
-- the supervisor receives an (asynchronous) exception. When it does return,
-- all of the supervised threads will be killed.
runSupervisor :: Supervisor -> IO ()
runSupervisor :: Supervisor -> IO ()
runSupervisor sup :: Supervisor
sup@Supervisor{runQ :: Supervisor -> TQueue (IO ())
runQ=TQueue (IO ())
q} =
    IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (STM (IO ()) -> IO (IO ())
forall a. STM a -> IO a
atomically (TQueue (IO ()) -> STM (IO ())
forall a. TQueue a -> STM a
readTQueue TQueue (IO ())
q) IO (IO ()) -> (IO () -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Supervisor -> IO () -> IO ()
supervise Supervisor
sup)
    IO () -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) e a b.
(MonadMask m, Exception e) =>
m a -> (e -> m b) -> m a
`withException`
    \SomeException
e -> Supervisor -> SomeException -> IO ()
forall e. Exception e => Supervisor -> e -> IO ()
throwKids Supervisor
sup (SomeException
e :: SomeException)

-- | Run an IO action with access to a supervisor. Threads spawned using the
-- supervisor will be killed when the action returns.
withSupervisor :: (Supervisor -> IO a) -> IO a
withSupervisor :: (Supervisor -> IO a) -> IO a
withSupervisor Supervisor -> IO a
f = do
    Supervisor
sup <- IO Supervisor
newSupervisor
    IO () -> (Async () -> IO a) -> IO a
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (Supervisor -> IO ()
runSupervisor Supervisor
sup) ((Async () -> IO a) -> IO a) -> (Async () -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ IO a -> Async () -> IO a
forall a b. a -> b -> a
const (Supervisor -> IO a
f Supervisor
sup)

-- | Throw an exception to all of a supervisor's children, using 'throwTo'.
throwKids :: Exception e => Supervisor -> e -> IO ()
throwKids :: Supervisor -> e -> IO ()
throwKids Supervisor{stateVar :: Supervisor -> TVar (Either SomeException (Set ThreadId))
stateVar=TVar (Either SomeException (Set ThreadId))
stateVar} e
exn =
    IO (Set ThreadId)
-> (Set ThreadId -> IO ()) -> (Set ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
        (STM (Set ThreadId) -> IO (Set ThreadId)
forall a. STM a -> IO a
atomically (STM (Set ThreadId) -> IO (Set ThreadId))
-> STM (Set ThreadId) -> IO (Set ThreadId)
forall a b. (a -> b) -> a -> b
$ TVar (Either SomeException (Set ThreadId))
-> STM (Either SomeException (Set ThreadId))
forall a. TVar a -> STM a
readTVar TVar (Either SomeException (Set ThreadId))
stateVar STM (Either SomeException (Set ThreadId))
-> (Either SomeException (Set ThreadId) -> STM (Set ThreadId))
-> STM (Set ThreadId)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Left SomeException
_ ->
                Set ThreadId -> STM (Set ThreadId)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Set ThreadId
forall a. Set a
S.empty
            Right Set ThreadId
kids -> do
                TVar (Either SomeException (Set ThreadId))
-> Either SomeException (Set ThreadId) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Either SomeException (Set ThreadId))
stateVar (Either SomeException (Set ThreadId) -> STM ())
-> Either SomeException (Set ThreadId) -> STM ()
forall a b. (a -> b) -> a -> b
$ SomeException -> Either SomeException (Set ThreadId)
forall a b. a -> Either a b
Left (e -> SomeException
forall e. Exception e => e -> SomeException
toException e
exn)
                Set ThreadId -> STM (Set ThreadId)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Set ThreadId
kids)
        ((ThreadId -> IO () -> IO ()) -> IO () -> Set ThreadId -> IO ()
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
            -- important: chain these together with `finally`,
            -- rather than (>>=) or friends, so that if one
            -- throws an exception we still run the others.
            (\ThreadId
kid IO ()
old -> ThreadId -> e -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
kid e
exn IO () -> IO () -> IO ()
forall (m :: * -> *) a b. MonadMask m => m a -> m b -> m a
`finally` IO ()
old)
            (() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
        )
        (\Set ThreadId
_ -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

-- | Launch the IO action in a thread, monitored by the 'Supervisor'. If the
-- supervisor receives an exception, the exception will also be raised in the
-- child thread.
supervise :: Supervisor -> IO () -> IO ()
supervise :: Supervisor -> IO () -> IO ()
supervise Supervisor{stateVar :: Supervisor -> TVar (Either SomeException (Set ThreadId))
stateVar=TVar (Either SomeException (Set ThreadId))
stateVar} IO ()
task =
    IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO () -> IO () -> IO ()
forall (m :: * -> *) a b c. MonadMask m => m a -> m b -> m c -> m c
bracket_ IO ()
addMe IO ()
removeMe IO ()
task
  where
    -- | Add our ThreadId to the supervisor.
    addMe :: IO ()
addMe = do
        ThreadId
me <- IO ThreadId
myThreadId
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Either SomeException (Set ThreadId)
supState <- TVar (Either SomeException (Set ThreadId))
-> STM (Either SomeException (Set ThreadId))
forall a. TVar a -> STM a
readTVar TVar (Either SomeException (Set ThreadId))
stateVar
            case Either SomeException (Set ThreadId)
supState of
                Left SomeException
e ->
                    SomeException -> STM ()
forall e a. Exception e => e -> STM a
throwSTM SomeException
e
                Right Set ThreadId
kids -> do
                    let !newKids :: Set ThreadId
newKids = ThreadId -> Set ThreadId -> Set ThreadId
forall a. Ord a => a -> Set a -> Set a
S.insert ThreadId
me Set ThreadId
kids
                    TVar (Either SomeException (Set ThreadId))
-> Either SomeException (Set ThreadId) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Either SomeException (Set ThreadId))
stateVar (Either SomeException (Set ThreadId) -> STM ())
-> Either SomeException (Set ThreadId) -> STM ()
forall a b. (a -> b) -> a -> b
$ Set ThreadId -> Either SomeException (Set ThreadId)
forall a b. b -> Either a b
Right Set ThreadId
newKids
    -- | Remove our ThreadId from the supervisor, so we don't leak it.
    removeMe :: IO ()
removeMe = do
        ThreadId
me <- IO ThreadId
myThreadId
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (Either SomeException (Set ThreadId))
-> (Either SomeException (Set ThreadId)
    -> Either SomeException (Set ThreadId))
-> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (Either SomeException (Set ThreadId))
stateVar ((Either SomeException (Set ThreadId)
  -> Either SomeException (Set ThreadId))
 -> STM ())
-> (Either SomeException (Set ThreadId)
    -> Either SomeException (Set ThreadId))
-> STM ()
forall a b. (a -> b) -> a -> b
$ \case
            state :: Either SomeException (Set ThreadId)
state@(Left SomeException
_) ->
                -- The supervisor is already stopped; we don't need to
                -- do anything.
                Either SomeException (Set ThreadId)
state
            Right Set ThreadId
kids ->
                -- We need to remove ourselves from the list of children;
                -- if we don't, we'll leak our ThreadId until the supervisor
                -- exits.
                --
                -- The use of $! here is very important, because even though
                -- modifyTVar' is strict, it only does whnf, so it would leave
                -- the state only evaluated as far as @Right (S.delete me kids)@;
                -- in that case we would still leak @me@.
                Set ThreadId -> Either SomeException (Set ThreadId)
forall a b. b -> Either a b
Right (Set ThreadId -> Either SomeException (Set ThreadId))
-> Set ThreadId -> Either SomeException (Set ThreadId)
forall a b. (a -> b) -> a -> b
$! ThreadId -> Set ThreadId -> Set ThreadId
forall a. Ord a => a -> Set a -> Set a
S.delete ThreadId
me Set ThreadId
kids

-- | Like 'supervise', but can be used from inside 'STM'. The thread will be
-- spawned if and only if the transaction commits.
superviseSTM :: Supervisor -> IO () -> STM ()
superviseSTM :: Supervisor -> IO () -> STM ()
superviseSTM Supervisor{runQ :: Supervisor -> TQueue (IO ())
runQ=TQueue (IO ())
q} = TQueue (IO ()) -> IO () -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue TQueue (IO ())
q