module Ki.Internal.Scope
  ( Scope,
    scoped,
    awaitAll,
    fork,
    forkWith,
    forkWith_,
    fork_,
    forkTry,
    forkTryWith,
  )
where

import qualified Control.Concurrent
import Control.Exception
  ( Exception (fromException, toException),
    MaskingState (..),
    SomeAsyncException,
    asyncExceptionFromException,
    asyncExceptionToException,
    catch,
    pattern ErrorCall,
  )
import qualified Data.IntMap.Lazy as IntMap
import Data.Void (Void, absurd)
import GHC.Conc
  ( STM,
    TVar,
    atomically,
    enableAllocationLimit,
    labelThread,
    newTVarIO,
    readTVar,
    retry,
    setAllocationCounter,
    throwSTM,
    writeTVar,
  )
import GHC.IO (unsafeUnmask)
import Ki.Internal.ByteCount
import Ki.Internal.Counter
import Ki.Internal.Prelude
import Ki.Internal.Thread
import GHC.Conc.Sync (readTVarIO)

-- | A scope.
--
-- ==== __👉 Details__
--
-- * A scope delimits the lifetime of all threads created within it.
--
-- * A scope is only valid during the callback provided to 'Ki.scoped'.
--
-- * The thread that creates a scope is considered the parent of all threads created within it.
--
-- * All threads created within a scope can be awaited together (see 'Ki.awaitAll').
--
-- * All threads created within a scope are terminated when the scope closes.
data Scope = Scope
  { -- The MVar that a child tries to put to, in the case that it tries to propagate an exception to its parent, but
    -- gets delivered an exception from its parent concurrently (which interrupts the throw). The parent must raise
    -- exceptions in its children with asynchronous exceptions uninterruptibly masked for correctness, yet we don't want
    -- a parent in the process of tearing down to miss/ignore this exception that we're trying to propagate?
    --
    -- Why a single-celled MVar? What if two siblings are fighting to inform their parent of their death? Well, only
    -- one exception can be propagated by the parent anyway, so we wouldn't need or want both.
    Scope -> MVar SomeException
childExceptionVar :: {-# UNPACK #-} !(MVar SomeException),
    -- The set of child threads that are currently running, each keyed by a monotonically increasing int.
    Scope -> TVar (IntMap ThreadId)
childrenVar :: {-# UNPACK #-} !(TVar (IntMap ThreadId)),
    -- The counter that holds the (int) key to use for the next child thread.
    Scope -> Counter
nextChildIdCounter :: {-# UNPACK #-} !Counter,
    -- The id of the thread that created the scope, which is considered the parent of all threads created within it.
    Scope -> ThreadId
parentThreadId :: {-# UNPACK #-} !ThreadId,
    -- The number of child threads that are guaranteed to be about to start, in the sense that only the GHC scheduler
    -- can continue to delay; there's no opportunity for an async exception to strike and prevent one of these threads
    -- from starting.
    --
    -- Sentinel value: -1 means the scope is closed.
    Scope -> TVar Int
startingVar :: {-# UNPACK #-} !(TVar Int)
  }

-- Internal async exception thrown by a parent thread to its children when the scope is closing.
data ScopeClosing
  = ScopeClosing

instance Show ScopeClosing where
  show :: ScopeClosing -> String
show ScopeClosing
_ = String
"ScopeClosing"

instance Exception ScopeClosing where
  toException :: ScopeClosing -> SomeException
toException = forall e. Exception e => e -> SomeException
asyncExceptionToException
  fromException :: SomeException -> Maybe ScopeClosing
fromException = forall e. Exception e => SomeException -> Maybe e
asyncExceptionFromException

-- Trust without verifying that any 'ScopeClosed' exception, which is not exported by this module, was indeed thrown to
-- a thread by its parent. It is possible to write a program that violates this (just catch the async exception and
-- throw it to some other thread)... but who would do that?
isScopeClosingException :: SomeException -> Bool
isScopeClosingException :: SomeException -> Bool
isScopeClosingException SomeException
exception =
  case forall e. Exception e => SomeException -> Maybe e
fromException SomeException
exception of
    Just ScopeClosing
ScopeClosing -> Bool
True
    Maybe ScopeClosing
_ -> Bool
False

pattern IsScopeClosingException :: SomeException
pattern $mIsScopeClosingException :: forall {r}. SomeException -> ((# #) -> r) -> ((# #) -> r) -> r
IsScopeClosingException <- (isScopeClosingException -> True)

-- | Open a scope, perform an IO action with it, then close the scope.
--
-- ==== __👉 Details__
--
-- * The thread that creates a scope is considered the parent of all threads created within it.
--
-- * A scope is only valid during the callback provided to 'Ki.scoped'.
--
-- * When a scope closes (/i.e./ just before 'Ki.scoped' returns):
--
--     * The parent thread raises an exception in all of its living children.
--     * The parent thread blocks until those threads terminate.
scoped :: (Scope -> IO a) -> IO a
scoped :: forall a. (Scope -> IO a) -> IO a
scoped Scope -> IO a
action = do
  scope :: Scope
scope@Scope {MVar SomeException
childExceptionVar :: MVar SomeException
$sel:childExceptionVar:Scope :: Scope -> MVar SomeException
childExceptionVar, TVar (IntMap ThreadId)
childrenVar :: TVar (IntMap ThreadId)
$sel:childrenVar:Scope :: Scope -> TVar (IntMap ThreadId)
childrenVar, TVar Int
startingVar :: TVar Int
$sel:startingVar:Scope :: Scope -> TVar Int
startingVar} <- IO Scope
allocateScope

  forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
uninterruptibleMask \forall a. IO a -> IO a
restore -> do
    Either SomeException a
result <- forall e a. Exception e => IO a -> IO (Either e a)
try (forall a. IO a -> IO a
restore (Scope -> IO a
action Scope
scope))

    !IntMap ThreadId
livingChildren <- do
      IntMap ThreadId
livingChildren0 <-
        forall a. STM a -> IO a
atomically do
          -- Block until we haven't committed to starting any threads. Without this, we may create a thread concurrently
          -- with closing its scope, and not grab its thread id to throw an exception to.
          TVar Int -> STM ()
blockUntil0 TVar Int
startingVar
          -- Write the sentinel value indicating that this scope is closed, and it is an error to try to create a thread
          -- within it.
          forall a. TVar a -> a -> STM ()
writeTVar TVar Int
startingVar (-Int
1)
          -- Return the list of currently-running children to kill. Some of them may have *just* started (e.g. if we
          -- initially retried in `blockUntil0` above). That's fine - kill them all!
          forall a. TVar a -> STM a
readTVar TVar (IntMap ThreadId)
childrenVar

      -- If one of our children propagated an exception to us, then we know it's about to terminate, so we don't bother
      -- throwing an exception to it.
      forall (f :: * -> *) a. Applicative f => a -> f a
pure case Either SomeException a
result of
        Left (forall e. Exception e => SomeException -> Maybe e
fromException -> Just ThreadFailed {Int
$sel:childId:ThreadFailed :: ThreadFailed -> Int
childId :: Int
childId}) -> forall a. Int -> IntMap a -> IntMap a
IntMap.delete Int
childId IntMap ThreadId
livingChildren0
        Either SomeException a
_ -> IntMap ThreadId
livingChildren0

    -- Deliver a ScopeClosing exception to every living child.
    --
    -- This happens to throw in the order the children were created... but I think we decided this feature isn't very
    -- useful in practice, so maybe we should simplify the internals and just keep a set of children?
    forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ (forall a. IntMap a -> [a]
IntMap.elems IntMap ThreadId
livingChildren) \ThreadId
livingChild -> forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
livingChild ScopeClosing
ScopeClosing

    -- Block until all children have terminated; this relies on children respecting the async exception, which they
    -- must, for correctness. Otherwise, a thread could indeed outlive the scope in which it's created, which is
    -- definitely not structured concurrency!
    forall a. STM a -> IO a
atomically (forall a. TVar (IntMap a) -> STM ()
blockUntilEmpty TVar (IntMap ThreadId)
childrenVar)

    -- By now there are three sources of exception:
    --
    --   1) A sync or async exception thrown during the callback, captured in `result`. If applicable, we want to unwrap
    --      the `ThreadFailed` off of this, which was only used to indicate it came from one of our children.
    --
    --   2) A sync or async exception left for us in `childExceptionVar` by a child that tried to propagate it to us
    --      directly, but failed (because we killed it concurrently).
    --
    --   3) An async exception waiting in our exception queue, because we still have async exceptions uninterruptibly
    --      masked.
    --
    -- We cannot throw more than one, so throw them in that priority order.
    case Either SomeException a
result of
      Left SomeException
exception -> forall e a. Exception e => e -> IO a
throwIO (SomeException -> SomeException
unwrapThreadFailed SomeException
exception)
      Right a
value ->
        forall a. MVar a -> IO (Maybe a)
tryTakeMVar MVar SomeException
childExceptionVar forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Maybe SomeException
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
value
          Just SomeException
exception -> forall e a. Exception e => e -> IO a
throwIO SomeException
exception

-- Allocate a new scope.
allocateScope :: IO Scope
allocateScope :: IO Scope
allocateScope = do
  MVar SomeException
childExceptionVar <- forall a. IO (MVar a)
newEmptyMVar
  TVar (IntMap ThreadId)
childrenVar <- forall a. a -> IO (TVar a)
newTVarIO forall a. IntMap a
IntMap.empty
  Counter
nextChildIdCounter <- IO Counter
newCounter
  ThreadId
parentThreadId <- IO ThreadId
myThreadId
  TVar Int
startingVar <- forall a. a -> IO (TVar a)
newTVarIO Int
0
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Scope {MVar SomeException
childExceptionVar :: MVar SomeException
$sel:childExceptionVar:Scope :: MVar SomeException
childExceptionVar, TVar (IntMap ThreadId)
childrenVar :: TVar (IntMap ThreadId)
$sel:childrenVar:Scope :: TVar (IntMap ThreadId)
childrenVar, Counter
nextChildIdCounter :: Counter
$sel:nextChildIdCounter:Scope :: Counter
nextChildIdCounter, ThreadId
parentThreadId :: ThreadId
$sel:parentThreadId:Scope :: ThreadId
parentThreadId, TVar Int
startingVar :: TVar Int
$sel:startingVar:Scope :: TVar Int
startingVar}

-- Spawn a thread in a scope, providing it its child id and a function that sets the masking state to the requested
-- masking state. The given action is called with async exceptions interruptibly masked.
spawn :: Scope -> ThreadOptions -> (Int -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ThreadId
spawn :: Scope
-> ThreadOptions
-> (Int -> (forall a. IO a -> IO a) -> UnexceptionalIO ())
-> IO ThreadId
spawn
  Scope {TVar (IntMap ThreadId)
childrenVar :: TVar (IntMap ThreadId)
$sel:childrenVar:Scope :: Scope -> TVar (IntMap ThreadId)
childrenVar, Counter
nextChildIdCounter :: Counter
$sel:nextChildIdCounter:Scope :: Scope -> Counter
nextChildIdCounter, TVar Int
startingVar :: TVar Int
$sel:startingVar:Scope :: Scope -> TVar Int
startingVar}
  ThreadOptions {ThreadAffinity
$sel:affinity:ThreadOptions :: ThreadOptions -> ThreadAffinity
affinity :: ThreadAffinity
affinity, Maybe ByteCount
$sel:allocationLimit:ThreadOptions :: ThreadOptions -> Maybe ByteCount
allocationLimit :: Maybe ByteCount
allocationLimit, String
$sel:label:ThreadOptions :: ThreadOptions -> String
label :: String
label, $sel:maskingState:ThreadOptions :: ThreadOptions -> MaskingState
maskingState = MaskingState
requestedChildMaskingState}
  Int -> (forall a. IO a -> IO a) -> UnexceptionalIO ()
action = do
    -- Interruptible mask is enough so long as none of the STM operations below block.
    --
    -- Unconditionally set masking state to MaskedInterruptible, even though we might already be at MaskedInterruptible
    -- or MaskedUninterruptible, to avoid a branch on parentMaskingState.
    forall a. IO a -> IO a
interruptiblyMasked do
      -- Record the thread as being about to start. Not allowed to retry.
      forall a. STM a -> IO a
atomically do
        Int
n <- forall a. TVar a -> STM a
readTVar TVar Int
startingVar
        if Int
n forall a. Ord a => a -> a -> Bool
< Int
0
          then forall e a. Exception e => e -> STM a
throwSTM (String -> ErrorCall
ErrorCall String
"ki: scope closed")
          else forall a. TVar a -> a -> STM ()
writeTVar TVar Int
startingVar forall a b. (a -> b) -> a -> b
$! Int
n forall a. Num a => a -> a -> a
+ Int
1

      Int
childId <- Counter -> IO Int
incrCounter Counter
nextChildIdCounter

      ThreadId
childThreadId <-
        ThreadAffinity -> IO () -> IO ThreadId
forkWithAffinity ThreadAffinity
affinity do
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null String
label)) do
            ThreadId
childThreadId <- IO ThreadId
myThreadId
            ThreadId -> String -> IO ()
labelThread ThreadId
childThreadId String
label

          case Maybe ByteCount
allocationLimit of
            Maybe ByteCount
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            Just ByteCount
bytes -> do
              Int64 -> IO ()
setAllocationCounter (ByteCount -> Int64
byteCountToInt64 ByteCount
bytes)
              IO ()
enableAllocationLimit

          let -- Action that sets the masking state from the current (MaskedInterruptible) to the requested one.
              atRequestedMaskingState :: IO a -> IO a
              atRequestedMaskingState :: forall a. IO a -> IO a
atRequestedMaskingState =
                case MaskingState
requestedChildMaskingState of
                  MaskingState
Unmasked -> forall a. IO a -> IO a
unsafeUnmask
                  MaskingState
MaskedInterruptible -> forall a. a -> a
id
                  MaskingState
MaskedUninterruptible -> forall a. IO a -> IO a
uninterruptiblyMasked

          forall a. UnexceptionalIO a -> IO a
runUnexceptionalIO (Int -> (forall a. IO a -> IO a) -> UnexceptionalIO ()
action Int
childId forall a. IO a -> IO a
atRequestedMaskingState)

          forall a. STM a -> IO a
atomically (TVar (IntMap ThreadId) -> Int -> STM ()
unrecordChild TVar (IntMap ThreadId)
childrenVar Int
childId)

      -- Record the child as having started. Not allowed to retry.
      forall a. STM a -> IO a
atomically do
        Int
n <- forall a. TVar a -> STM a
readTVar TVar Int
startingVar
        forall a. TVar a -> a -> STM ()
writeTVar TVar Int
startingVar forall a b. (a -> b) -> a -> b
$! Int
n forall a. Num a => a -> a -> a
- Int
1 -- it's actually ok to go from e.g. -1 to -2 here (very unlikely)
        TVar (IntMap ThreadId) -> Int -> ThreadId -> STM ()
recordChild TVar (IntMap ThreadId)
childrenVar Int
childId ThreadId
childThreadId

      forall (f :: * -> *) a. Applicative f => a -> f a
pure ThreadId
childThreadId

-- Record our child by either:
--
--   * Flipping `Nothing` to `Just childThreadId` (common case: we record child before it unrecords itself)
--   * Flipping `Just _` to `Nothing` (uncommon case: we observe that a child already unrecorded itself)
--
-- Never retries.
recordChild :: TVar (IntMap ThreadId) -> Int -> ThreadId -> STM ()
recordChild :: TVar (IntMap ThreadId) -> Int -> ThreadId -> STM ()
recordChild TVar (IntMap ThreadId)
childrenVar Int
childId ThreadId
childThreadId = do
  IntMap ThreadId
children <- forall a. TVar a -> STM a
readTVar TVar (IntMap ThreadId)
childrenVar
  forall a. TVar a -> a -> STM ()
writeTVar TVar (IntMap ThreadId)
childrenVar forall a b. (a -> b) -> a -> b
$! forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IntMap.alter (forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a. a -> Maybe a
Just ThreadId
childThreadId) (forall a b. a -> b -> a
const forall a. Maybe a
Nothing)) Int
childId IntMap ThreadId
children

-- Unrecord a child (ourselves) by either:
--
--   * Flipping `Just childThreadId` to `Nothing` (common case: parent recorded us first)
--   * Flipping `Nothing` to `Just undefined` (uncommon case: we terminate and unrecord before parent can record us).
--
-- Never retries.
unrecordChild :: TVar (IntMap ThreadId) -> Int -> STM ()
unrecordChild :: TVar (IntMap ThreadId) -> Int -> STM ()
unrecordChild TVar (IntMap ThreadId)
childrenVar Int
childId = do
  IntMap ThreadId
children <- forall a. TVar a -> STM a
readTVar TVar (IntMap ThreadId)
childrenVar
  forall a. TVar a -> a -> STM ()
writeTVar TVar (IntMap ThreadId)
childrenVar forall a b. (a -> b) -> a -> b
$! forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IntMap.alter (forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a. a -> Maybe a
Just forall a. HasCallStack => a
undefined) (forall a b. a -> b -> a
const forall a. Maybe a
Nothing)) Int
childId IntMap ThreadId
children

-- forkIO/forkOn/forkOS, switching on affinity
forkWithAffinity :: ThreadAffinity -> IO () -> IO ThreadId
forkWithAffinity :: ThreadAffinity -> IO () -> IO ThreadId
forkWithAffinity = \case
  ThreadAffinity
Unbound -> IO () -> IO ThreadId
forkIO
  Capability Int
n -> Int -> IO () -> IO ThreadId
forkOn Int
n
  ThreadAffinity
OsThread -> IO () -> IO ThreadId
Control.Concurrent.forkOS

-- | Wait until all threads created within a scope terminate.
awaitAll :: Scope -> STM ()
awaitAll :: Scope -> STM ()
awaitAll Scope {TVar (IntMap ThreadId)
childrenVar :: TVar (IntMap ThreadId)
$sel:childrenVar:Scope :: Scope -> TVar (IntMap ThreadId)
childrenVar, TVar Int
startingVar :: TVar Int
$sel:startingVar:Scope :: Scope -> TVar Int
startingVar} = do
  forall a. TVar (IntMap a) -> STM ()
blockUntilEmpty TVar (IntMap ThreadId)
childrenVar
  TVar Int -> STM ()
blockUntil0 TVar Int
startingVar

-- Block until an IntMap becomes empty.
blockUntilEmpty :: TVar (IntMap a) -> STM ()
blockUntilEmpty :: forall a. TVar (IntMap a) -> STM ()
blockUntilEmpty TVar (IntMap a)
var = do
  IntMap a
x <- forall a. TVar a -> STM a
readTVar TVar (IntMap a)
var
  if forall a. IntMap a -> Bool
IntMap.null IntMap a
x then forall (f :: * -> *) a. Applicative f => a -> f a
pure () else forall a. STM a
retry

-- Block until a TVar becomes 0.
blockUntil0 :: TVar Int -> STM ()
blockUntil0 :: TVar Int -> STM ()
blockUntil0 TVar Int
var = do
  Int
x <- forall a. TVar a -> STM a
readTVar TVar Int
var
  if Int
x forall a. Eq a => a -> a -> Bool
== Int
0 then forall (f :: * -> *) a. Applicative f => a -> f a
pure () else forall a. STM a
retry

-- | Create a child thread to execute an action within a scope.
--
-- /Note/: The child thread does not mask asynchronous exceptions, regardless of the parent thread's masking state. To
-- create a child thread with a different initial masking state, use 'Ki.forkWith'.
fork :: Scope -> IO a -> IO (Thread a)
fork :: forall a. Scope -> IO a -> IO (Thread a)
fork Scope
scope =
  forall a. Scope -> ThreadOptions -> IO a -> IO (Thread a)
forkWith Scope
scope ThreadOptions
defaultThreadOptions

-- | Variant of 'Ki.fork' for threads that never return.
fork_ :: Scope -> IO Void -> IO ()
fork_ :: Scope -> IO Void -> IO ()
fork_ Scope
scope =
  Scope -> ThreadOptions -> IO Void -> IO ()
forkWith_ Scope
scope ThreadOptions
defaultThreadOptions

-- | Variant of 'Ki.fork' that takes an additional options argument.
forkWith :: Scope -> ThreadOptions -> IO a -> IO (Thread a)
forkWith :: forall a. Scope -> ThreadOptions -> IO a -> IO (Thread a)
forkWith Scope
scope ThreadOptions
opts IO a
action = do
  TVar (Maybe (Either SomeException a))
resultVar <- forall a. a -> IO (TVar a)
newTVarIO forall a. Maybe a
Nothing
  ThreadId
ident <-
    Scope
-> ThreadOptions
-> (Int -> (forall a. IO a -> IO a) -> UnexceptionalIO ())
-> IO ThreadId
spawn Scope
scope ThreadOptions
opts \Int
childId forall a. IO a -> IO a
masking -> do
      Either SomeException a
result <- forall a. IO a -> UnexceptionalIO (Either SomeException a)
unexceptionalTry (forall a. IO a -> IO a
masking IO a
action)
      case Either SomeException a
result of
        Left SomeException
exception ->
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
            (Bool -> Bool
not (SomeException -> Bool
isScopeClosingException SomeException
exception))
            (Scope -> Int -> SomeException -> UnexceptionalIO ()
propagateException Scope
scope Int
childId SomeException
exception)
        Right a
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      -- even put async exceptions that we propagated. this isn't totally ideal because a caller awaiting this thread
      -- would not be able to distinguish between async exceptions delivered to this thread, or itself
      forall a. IO a -> UnexceptionalIO a
UnexceptionalIO (forall a. STM a -> IO a
atomically (forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe (Either SomeException a))
resultVar (forall a. a -> Maybe a
Just Either SomeException a
result)))
  let doAwait :: STM a
doAwait =
        forall a. TVar a -> STM a
readTVar TVar (Maybe (Either SomeException a))
resultVar forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Maybe (Either SomeException a)
Nothing -> forall a. STM a
retry
          Just (Left SomeException
exception) -> forall e a. Exception e => e -> STM a
throwSTM SomeException
exception
          Just (Right a
value) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
value
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. ThreadId -> STM a -> Thread a
makeThread ThreadId
ident STM a
doAwait)

-- | Variant of 'Ki.forkWith' for threads that never return.
forkWith_ :: Scope -> ThreadOptions -> IO Void -> IO ()
forkWith_ :: Scope -> ThreadOptions -> IO Void -> IO ()
forkWith_ Scope
scope ThreadOptions
opts IO Void
action = do
  ThreadId
_childThreadId <-
    Scope
-> ThreadOptions
-> (Int -> (forall a. IO a -> IO a) -> UnexceptionalIO ())
-> IO ThreadId
spawn Scope
scope ThreadOptions
opts \Int
childId forall a. IO a -> IO a
masking ->
      forall a b.
(SomeException -> UnexceptionalIO b)
-> (a -> UnexceptionalIO b) -> IO a -> UnexceptionalIO b
unexceptionalTryEither
        (\SomeException
exception -> forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (SomeException -> Bool
isScopeClosingException SomeException
exception)) (Scope -> Int -> SomeException -> UnexceptionalIO ()
propagateException Scope
scope Int
childId SomeException
exception))
        forall a. Void -> a
absurd
        (forall a. IO a -> IO a
masking IO Void
action)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Like 'Ki.fork', but the child thread does not propagate exceptions that are both:
--
-- * Synchronous (/i.e./ not an instance of 'SomeAsyncException').
-- * An instance of @e@.
forkTry :: forall e a. Exception e => Scope -> IO a -> IO (Thread (Either e a))
forkTry :: forall e a.
Exception e =>
Scope -> IO a -> IO (Thread (Either e a))
forkTry Scope
scope =
  forall e a.
Exception e =>
Scope -> ThreadOptions -> IO a -> IO (Thread (Either e a))
forkTryWith Scope
scope ThreadOptions
defaultThreadOptions

-- | Variant of 'Ki.forkTry' that takes an additional options argument.
forkTryWith :: forall e a. Exception e => Scope -> ThreadOptions -> IO a -> IO (Thread (Either e a))
forkTryWith :: forall e a.
Exception e =>
Scope -> ThreadOptions -> IO a -> IO (Thread (Either e a))
forkTryWith Scope
scope ThreadOptions
opts IO a
action = do
  TVar (Maybe (Either SomeException a))
resultVar <- forall a. a -> IO (TVar a)
newTVarIO forall a. Maybe a
Nothing
  ThreadId
childThreadId <-
    Scope
-> ThreadOptions
-> (Int -> (forall a. IO a -> IO a) -> UnexceptionalIO ())
-> IO ThreadId
spawn Scope
scope ThreadOptions
opts \Int
childId forall a. IO a -> IO a
masking -> do
      Either SomeException a
result <- forall a. IO a -> UnexceptionalIO (Either SomeException a)
unexceptionalTry (forall a. IO a -> IO a
masking IO a
action)
      case Either SomeException a
result of
        Left SomeException
exception -> do
          let shouldPropagate :: Bool
shouldPropagate =
                if SomeException -> Bool
isScopeClosingException SomeException
exception
                  then Bool
False
                  else case forall e. Exception e => SomeException -> Maybe e
fromException @e SomeException
exception of
                    Maybe e
Nothing -> Bool
True
                    -- if the user calls `forkTry @MyAsyncException`, we still want to propagate the async exception
                    Just e
_ -> SomeException -> Bool
isAsyncException SomeException
exception
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
shouldPropagate (Scope -> Int -> SomeException -> UnexceptionalIO ()
propagateException Scope
scope Int
childId SomeException
exception)
        Right a
_value -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      forall a. IO a -> UnexceptionalIO a
UnexceptionalIO (forall a. STM a -> IO a
atomically (forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe (Either SomeException a))
resultVar (forall a. a -> Maybe a
Just Either SomeException a
result)))
  let doAwait :: STM (Either e a)
doAwait =
        forall a. TVar a -> STM a
readTVar TVar (Maybe (Either SomeException a))
resultVar forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Maybe (Either SomeException a)
Nothing -> forall a. STM a
retry
          Just (Left SomeException
exception) ->
            case forall e. Exception e => SomeException -> Maybe e
fromException @e SomeException
exception of
              Maybe e
Nothing -> forall e a. Exception e => e -> STM a
throwSTM SomeException
exception
              Just e
expectedException -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left e
expectedException)
          Just (Right a
value) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right a
value)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. ThreadId -> STM a -> Thread a
makeThread ThreadId
childThreadId STM (Either e a)
doAwait)
  where
    isAsyncException :: SomeException -> Bool
    isAsyncException :: SomeException -> Bool
isAsyncException SomeException
exception =
      case forall e. Exception e => SomeException -> Maybe e
fromException @SomeAsyncException SomeException
exception of
        Maybe SomeAsyncException
Nothing -> Bool
False
        Just SomeAsyncException
_ -> Bool
True

-- We have a non-`ScopeClosing` exception to propagate to our parent.
--
-- If our scope has already begun closing (`startingVar` is -1), then either...
--
--   (A) We already received a `ScopeClosing`, but then ended up trying to propagate an exception anyway, because we
--   threw a synchronous exception (or were hit by a different asynchronous exception) during our teardown procedure.
--
--   or
--
--   (B) We will receive a `ScopeClosing` imminently, because our parent has *just* finished setting `startingVar` to
--   -1, and will proceed to throw ScopeClosing to all of its children.
--
-- If (A), our parent has asynchronous exceptions masked, so we must inform it of our exception via `childExceptionVar`
-- rather than throwTo. If (B), either mechanism would work. And because we don't if we're in case (A) or (B), we just
-- `childExceptionVar`.
--
-- And if our scope has not already begun closing (`startingVar` is not -1), then we ought to throw our exception to it.
-- But that might fail due to either...
--
--   (C) Our parent concurrently closing the scope and sending us a `ScopeClosing`; because it has asynchronous
--   exceptions uninterruptibly masked and we only have asynchronous exception *synchronously* masked, its `throwTo`
--   will return `()`, and ours will throw that `ScopeClosing` asynchronous exception. In this case, since we now know
--   our parent is tearing down and has asynchronous exceptions masked, we again inform it via `childExceptionVar`.
--
--   (D) Some *other* non-`ScopeClosing` asynchronous exception is raised here. This is truly odd: maybe it's a heap
--   overflow exception from the GHC runtime? Maybe some other thread has smuggled our `ThreadId` out and has manually
--   thrown us an exception for some reason? Either way, because we already have an exception that we are trying to
--   propagate, we just scoot these freaky exceptions under the rug.
--
-- Precondition: interruptibly masked
propagateException :: Scope -> Int -> SomeException -> UnexceptionalIO ()
propagateException :: Scope -> Int -> SomeException -> UnexceptionalIO ()
propagateException Scope {MVar SomeException
childExceptionVar :: MVar SomeException
$sel:childExceptionVar:Scope :: Scope -> MVar SomeException
childExceptionVar, ThreadId
parentThreadId :: ThreadId
$sel:parentThreadId:Scope :: Scope -> ThreadId
parentThreadId, TVar Int
startingVar :: TVar Int
$sel:startingVar:Scope :: Scope -> TVar Int
startingVar} Int
childId SomeException
exception =
  forall a. IO a -> UnexceptionalIO a
UnexceptionalIO (forall a. TVar a -> IO a
readTVarIO TVar Int
startingVar) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    -1 -> UnexceptionalIO ()
tryPutChildExceptionVar -- (A) / (B)
    Int
_ -> UnexceptionalIO ()
loop
  where
    loop :: UnexceptionalIO ()
    loop :: UnexceptionalIO ()
loop =
      forall a. IO a -> UnexceptionalIO (Either SomeException a)
unexceptionalTry (forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
parentThreadId ThreadFailed {Int
childId :: Int
$sel:childId:ThreadFailed :: Int
childId, SomeException
$sel:exception:ThreadFailed :: SomeException
exception :: SomeException
exception}) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Left SomeException
IsScopeClosingException -> UnexceptionalIO ()
tryPutChildExceptionVar -- (C)
        Left SomeException
_ -> UnexceptionalIO ()
loop -- (D)
        Right ()
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    tryPutChildExceptionVar :: UnexceptionalIO ()
    tryPutChildExceptionVar :: UnexceptionalIO ()
tryPutChildExceptionVar =
      forall a. IO a -> UnexceptionalIO a
UnexceptionalIO (forall (f :: * -> *) a. Functor f => f a -> f ()
void (forall a. MVar a -> a -> IO Bool
tryPutMVar MVar SomeException
childExceptionVar SomeException
exception))


-- A little promise that this IO action cannot throw an exception.
--
-- Yeah it's verbose, and maybe not that necessary, but the code that bothers to use it really does require
-- un-exceptiony IO actions for correctness, so here we are.
newtype UnexceptionalIO a = UnexceptionalIO
  {forall a. UnexceptionalIO a -> IO a
runUnexceptionalIO :: IO a}
  deriving newtype (Functor UnexceptionalIO
forall a. a -> UnexceptionalIO a
forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO a
forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
forall a b.
UnexceptionalIO (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
forall a b c.
(a -> b -> c)
-> UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO a
$c<* :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO a
*> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
$c*> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
liftA2 :: forall a b c.
(a -> b -> c)
-> UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO c
$cliftA2 :: forall a b c.
(a -> b -> c)
-> UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO c
<*> :: forall a b.
UnexceptionalIO (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
$c<*> :: forall a b.
UnexceptionalIO (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
pure :: forall a. a -> UnexceptionalIO a
$cpure :: forall a. a -> UnexceptionalIO a
Applicative, forall a b. a -> UnexceptionalIO b -> UnexceptionalIO a
forall a b. (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> UnexceptionalIO b -> UnexceptionalIO a
$c<$ :: forall a b. a -> UnexceptionalIO b -> UnexceptionalIO a
fmap :: forall a b. (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
$cfmap :: forall a b. (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
Functor, Applicative UnexceptionalIO
forall a. a -> UnexceptionalIO a
forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
forall a b.
UnexceptionalIO a -> (a -> UnexceptionalIO b) -> UnexceptionalIO b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> UnexceptionalIO a
$creturn :: forall a. a -> UnexceptionalIO a
>> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
$c>> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
>>= :: forall a b.
UnexceptionalIO a -> (a -> UnexceptionalIO b) -> UnexceptionalIO b
$c>>= :: forall a b.
UnexceptionalIO a -> (a -> UnexceptionalIO b) -> UnexceptionalIO b
Monad)

unexceptionalTry :: forall a. IO a -> UnexceptionalIO (Either SomeException a)
unexceptionalTry :: forall a. IO a -> UnexceptionalIO (Either SomeException a)
unexceptionalTry =
  coerce :: forall a b. Coercible a b => a -> b
coerce @(IO a -> IO (Either SomeException a)) forall e a. Exception e => IO a -> IO (Either e a)
try

-- Like try, but with continuations. Also, catches all exceptions, because that's the only flavor we need.
unexceptionalTryEither ::
  forall a b.
  (SomeException -> UnexceptionalIO b) ->
  (a -> UnexceptionalIO b) ->
  IO a ->
  UnexceptionalIO b
unexceptionalTryEither :: forall a b.
(SomeException -> UnexceptionalIO b)
-> (a -> UnexceptionalIO b) -> IO a -> UnexceptionalIO b
unexceptionalTryEither SomeException -> UnexceptionalIO b
onFailure a -> UnexceptionalIO b
onSuccess IO a
action =
  forall a. IO a -> UnexceptionalIO a
UnexceptionalIO do
    forall (m :: * -> *) a. Monad m => m (m a) -> m a
join do
      forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
        (coerce :: forall a b. Coercible a b => a -> b
coerce @_ @(a -> IO b) a -> UnexceptionalIO b
onSuccess forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
action)
        (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce @_ @(SomeException -> IO b) SomeException -> UnexceptionalIO b
onFailure)