module Ki.Thread
  ( Thread (..),
    async,
    asyncWithUnmask,
    await,
    awaitSTM,
    awaitFor,
    fork,
    fork_,
    forkWithUnmask,
    forkWithUnmask_,
  )
where

import Control.Exception (Exception (fromException))
import qualified Ki.Context as Context
import Ki.Duration (Duration)
import Ki.Prelude
import Ki.Scope (Scope (Scope))
import qualified Ki.Scope as Scope
import Ki.Timeout (timeoutSTM)

-- | A running __thread__.
data Thread a
  = Thread !ThreadId !(STM a)
  deriving stock (a -> Thread b -> Thread a
(a -> b) -> Thread a -> Thread b
(forall a b. (a -> b) -> Thread a -> Thread b)
-> (forall a b. a -> Thread b -> Thread a) -> Functor Thread
forall a b. a -> Thread b -> Thread a
forall a b. (a -> b) -> Thread a -> Thread b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Thread b -> Thread a
$c<$ :: forall a b. a -> Thread b -> Thread a
fmap :: (a -> b) -> Thread a -> Thread b
$cfmap :: forall a b. (a -> b) -> Thread a -> Thread b
Functor, (forall x. Thread a -> Rep (Thread a) x)
-> (forall x. Rep (Thread a) x -> Thread a) -> Generic (Thread a)
forall x. Rep (Thread a) x -> Thread a
forall x. Thread a -> Rep (Thread a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Thread a) x -> Thread a
forall a x. Thread a -> Rep (Thread a) x
$cto :: forall a x. Rep (Thread a) x -> Thread a
$cfrom :: forall a x. Thread a -> Rep (Thread a) x
Generic)

instance Eq (Thread a) where
  Thread ThreadId
id1 STM a
_ == :: Thread a -> Thread a -> Bool
== Thread ThreadId
id2 STM a
_ =
    ThreadId
id1 ThreadId -> ThreadId -> Bool
forall a. Eq a => a -> a -> Bool
== ThreadId
id2

instance Ord (Thread a) where
  compare :: Thread a -> Thread a -> Ordering
compare (Thread ThreadId
id1 STM a
_) (Thread ThreadId
id2 STM a
_) =
    ThreadId -> ThreadId -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ThreadId
id1 ThreadId
id2

-- | Create a __thread__ within a __scope__.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
async :: Scope -> IO a -> IO (Thread (Either SomeException a))
async :: Scope -> IO a -> IO (Thread (Either SomeException a))
async Scope
scope IO a
action =
  Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either SomeException a))
forall a.
Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either SomeException a))
asyncWithRestore Scope
scope \forall x. IO x -> IO x
restore -> IO a -> IO a
forall x. IO x -> IO x
restore IO a
action

-- | Variant of 'async' that provides the __thread__ a function that unmasks asynchronous exceptions.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
asyncWithUnmask :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread (Either SomeException a))
asyncWithUnmask :: Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either SomeException a))
asyncWithUnmask Scope
scope (forall x. IO x -> IO x) -> IO a
action =
  Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either SomeException a))
forall a.
Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either SomeException a))
asyncWithRestore Scope
scope \forall x. IO x -> IO x
restore -> IO a -> IO a
forall x. IO x -> IO x
restore ((forall x. IO x -> IO x) -> IO a
action forall x. IO x -> IO x
unsafeUnmask)

asyncWithRestore :: forall a. Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread (Either SomeException a))
asyncWithRestore :: Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either SomeException a))
asyncWithRestore Scope
scope (forall x. IO x -> IO x) -> IO a
action = do
  TMVar (Either SomeException a)
resultVar <- IO (TMVar (Either SomeException a))
forall a. IO (TMVar a)
newEmptyTMVarIO
  ThreadId
childThreadId <- Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (Either SomeException a -> IO ())
-> IO ThreadId
forall a.
Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (Either SomeException a -> IO ())
-> IO ThreadId
Scope.scopeFork Scope
scope (forall x. IO x -> IO x) -> IO a
action (TMVar (Either SomeException a) -> Either SomeException a -> IO ()
forall a. TMVar a -> a -> IO ()
putTMVarIO TMVar (Either SomeException a)
resultVar)
  Thread (Either SomeException a)
-> IO (Thread (Either SomeException a))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ThreadId
-> STM (Either SomeException a) -> Thread (Either SomeException a)
forall a. ThreadId -> STM a -> Thread a
Thread ThreadId
childThreadId (TMVar (Either SomeException a) -> STM (Either SomeException a)
forall a. TMVar a -> STM a
readTMVar TMVar (Either SomeException a)
resultVar))

-- | Wait for a __thread__ to finish.
await :: Thread a -> IO a
await :: Thread a -> IO a
await =
  STM a -> IO a
forall a. STM a -> IO a
atomically (STM a -> IO a) -> (Thread a -> STM a) -> Thread a -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Thread a -> STM a
forall a. Thread a -> STM a
awaitSTM

-- | @STM@ variant of 'await'.
awaitSTM :: Thread a -> STM a
awaitSTM :: Thread a -> STM a
awaitSTM (Thread ThreadId
_ STM a
action) =
  STM a
action

-- | Variant of 'await' that gives up after the given duration.
awaitFor :: Thread a -> Duration -> IO (Maybe a)
awaitFor :: Thread a -> Duration -> IO (Maybe a)
awaitFor Thread a
thread Duration
duration =
  Duration -> STM (IO (Maybe a)) -> IO (Maybe a) -> IO (Maybe a)
forall a. Duration -> STM (IO a) -> IO a -> IO a
timeoutSTM Duration
duration (Maybe a -> IO (Maybe a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe a -> IO (Maybe a)) -> (a -> Maybe a) -> a -> IO (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just (a -> IO (Maybe a)) -> STM a -> STM (IO (Maybe a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Thread a -> STM a
forall a. Thread a -> STM a
awaitSTM Thread a
thread) (Maybe a -> IO (Maybe a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing)

-- | Create a __thread__ within a __scope__.
--
-- If the __thread__ throws an exception, the exception is immediately propagated up the call tree to the __thread__
-- that opened its __scope__.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
fork :: Scope -> IO a -> IO (Thread a)
fork :: Scope -> IO a -> IO (Thread a)
fork Scope
scope IO a
action =
  Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forall a.
Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithRestore Scope
scope \forall x. IO x -> IO x
restore -> IO a -> IO a
forall x. IO x -> IO x
restore IO a
action

-- | Variant of 'fork' that does not return a handle to the created __thread__.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
fork_ :: Scope -> IO () -> IO ()
fork_ :: Scope -> IO () -> IO ()
fork_ Scope
scope IO ()
action =
  Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithRestore_ Scope
scope \forall x. IO x -> IO x
restore -> IO () -> IO ()
forall x. IO x -> IO x
restore IO ()
action

-- | Variant of 'fork' that provides the __thread__ a function that unmasks asynchronous exceptions.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
forkWithUnmask :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithUnmask :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithUnmask Scope
scope (forall x. IO x -> IO x) -> IO a
action =
  Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forall a.
Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithRestore Scope
scope \forall x. IO x -> IO x
restore -> IO a -> IO a
forall x. IO x -> IO x
restore ((forall x. IO x -> IO x) -> IO a
action forall x. IO x -> IO x
unsafeUnmask)

-- | Variant of 'forkWithUnmask' that does not return a handle to the created __thread__.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
forkWithUnmask_ :: Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithUnmask_ :: Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithUnmask_ Scope
scope (forall x. IO x -> IO x) -> IO ()
action =
  Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithRestore_ Scope
scope \forall x. IO x -> IO x
restore -> IO () -> IO ()
forall x. IO x -> IO x
restore ((forall x. IO x -> IO x) -> IO ()
action forall x. IO x -> IO x
unsafeUnmask)

forkWithRestore :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithRestore :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithRestore Scope
scope (forall x. IO x -> IO x) -> IO a
action = do
  ThreadId
parentThreadId <- IO ThreadId
myThreadId
  TMVar a
resultVar <- IO (TMVar a)
forall a. IO (TMVar a)
newEmptyTMVarIO
  ThreadId
childThreadId <-
    Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (Either SomeException a -> IO ())
-> IO ThreadId
forall a.
Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (Either SomeException a -> IO ())
-> IO ThreadId
Scope.scopeFork Scope
scope (forall x. IO x -> IO x) -> IO a
action \case
      Left SomeException
exception ->
        -- Intentionally don't fill the result var.
        --
        -- Prior to 0.2.0, we did put a 'Left exception' in the result var, so that if another thread awaited it, we'd
        -- promptly deliver them the exception that brought this thread down. However, that exception was *wrapped* in
        -- a 'ThreadFailed' exception, so the caller could distinguish between async exceptions *delivered to them* and
        -- async exceptions coming *synchronously* out of the call to 'await'.
        --
        -- At some point I reasoned that if one is following some basic structured concurrency guidelines, and not doing
        -- weird/complicated things like passing threads around, then it is likely that a failed forked thread is just
        -- about to propagate its exception to all callers of 'await' (presumably, its direct parent).
        --
        -- Might GHC deliver a BlockedIndefinitelyOnSTM in the meantime, though?
        Scope -> ThreadId -> SomeException -> IO ()
maybePropagateException Scope
scope ThreadId
parentThreadId SomeException
exception
      Right a
result -> TMVar a -> a -> IO ()
forall a. TMVar a -> a -> IO ()
putTMVarIO TMVar a
resultVar a
result
  Thread a -> IO (Thread a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ThreadId -> STM a -> Thread a
forall a. ThreadId -> STM a -> Thread a
Thread ThreadId
childThreadId (TMVar a -> STM a
forall a. TMVar a -> STM a
readTMVar TMVar a
resultVar))

forkWithRestore_ :: Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithRestore_ :: Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithRestore_ Scope
scope (forall x. IO x -> IO x) -> IO ()
action = do
  ThreadId
parentThreadId <- IO ThreadId
myThreadId
  ThreadId
_childThreadId <- Scope
-> ((forall x. IO x -> IO x) -> IO ())
-> (Either SomeException () -> IO ())
-> IO ThreadId
forall a.
Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (Either SomeException a -> IO ())
-> IO ThreadId
Scope.scopeFork Scope
scope (forall x. IO x -> IO x) -> IO ()
action ((SomeException -> IO ()) -> Either SomeException () -> IO ()
forall a b. (a -> IO b) -> Either a b -> IO b
onLeft (Scope -> ThreadId -> SomeException -> IO ()
maybePropagateException Scope
scope ThreadId
parentThreadId))
  () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

maybePropagateException :: Scope -> ThreadId -> SomeException -> IO ()
maybePropagateException :: Scope -> ThreadId -> SomeException -> IO ()
maybePropagateException Scope {TVar Bool
$sel:closedVar:Scope :: Scope -> TVar Bool
closedVar :: TVar Bool
closedVar, Context
$sel:context:Scope :: Scope -> Context
context :: Context
context} ThreadId
parentThreadId SomeException
exception =
  IO Bool -> IO () -> IO ()
whenM IO Bool
shouldPropagateException (ThreadId -> ThreadFailed -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
parentThreadId (SomeException -> ThreadFailed
Scope.ThreadFailed SomeException
exception))
  where
    shouldPropagateException :: IO Bool
    shouldPropagateException :: IO Bool
shouldPropagateException =
      case SomeException -> Maybe ScopeClosing
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
exception of
        -- Our scope is (presumably) closing, so don't propagate this exception that presumably just came from our
        -- parent. But if our scope's closedVar isn't True, that means this 'ScopeClosing' definitely came from
        -- somewhere else...
        Just ScopeClosing
Scope.ScopeClosing -> Bool -> Bool
not (Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar Bool -> IO Bool
forall a. TVar a -> IO a
readTVarIO TVar Bool
closedVar
        Maybe ScopeClosing
Nothing ->
          case SomeException -> Maybe CancelToken
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
exception of
            -- We (presumably) are honoring our own cancellation request, so don't propagate that either.
            -- It's a bit complicated looking because we *do* want to throw this token if we (somehow) threw it
            -- "inappropriately" in the sense that it wasn't ours to throw - it was smuggled from elsewhere.
            Just CancelToken
token -> STM Bool -> IO Bool
forall a. STM a -> IO a
atomically ((CancelToken -> CancelToken -> Bool
forall a. Eq a => a -> a -> Bool
/= CancelToken
token) (CancelToken -> Bool) -> STM CancelToken -> STM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> STM CancelToken
Context.contextCancelTokenSTM Context
context STM Bool -> STM Bool -> STM Bool
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Bool -> STM Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True)
            Maybe CancelToken
Nothing -> Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True