{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UnliftedFFITypes #-}
-- |
-- Module      : Control.Prim.Concurrent
-- Copyright   : (c) Alexey Kuleshevich 2020
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <alexey@kuleshevi.ch>
-- Stability   : experimental
-- Portability : non-portable
--
module Control.Prim.Concurrent
  ( GHC.ThreadId(..)
  , fork
  , forkFinally
  , forkOn
  , forkOnFinally
  , forkOS
  , killThread
  , yield

  , threadDelay
  , timeout
  , timeout_

  , myThreadId
  , threadIdToCInt
  , threadStatus
  , labelThread
  , isCurrentThreadBound
  , threadCapability
  , getNumCapabilities
  , setNumCapabilities
  -- * Sparks
  , spark
  , numSparks
  , runSparks
  -- * Single threaded RTS
  , delay
  , waitRead
  , waitWrite
  , module Control.Prim.Monad
  ) where

import qualified Control.Exception as GHC
import qualified Control.Concurrent as GHC
import Control.Prim.Exception
import Control.Prim.Monad
import Foreign.Prim
import qualified GHC.Conc as GHC
import qualified System.Timeout as GHC

spark :: MonadPrim s m => a -> m a
spark :: a -> m a
spark a
a = (State# s -> (# State# s, a #)) -> m a
forall s (m :: * -> *) a.
MonadPrim s m =>
(State# s -> (# State# s, a #)) -> m a
prim (a -> State# s -> (# State# s, a #)
forall a d. a -> State# d -> (# State# d, a #)
spark# a
a)

numSparks :: MonadPrim s m => m Int
numSparks :: m Int
numSparks =
  (State# s -> (# State# s, Int #)) -> m Int
forall s (m :: * -> *) a.
MonadPrim s m =>
(State# s -> (# State# s, a #)) -> m a
prim ((State# s -> (# State# s, Int #)) -> m Int)
-> (State# s -> (# State# s, Int #)) -> m Int
forall a b. (a -> b) -> a -> b
$ \State# s
s ->
    case State# s -> (# State# s, Int# #)
forall d. State# d -> (# State# d, Int# #)
numSparks# State# s
s of
      (# State# s
s', Int#
n# #) -> (# State# s
s', Int# -> Int
I# Int#
n# #)

runSparks :: MonadPrim s m => m ()
runSparks :: m ()
runSparks = (State# s -> State# s) -> m ()
forall s (m :: * -> *).
MonadPrim s m =>
(State# s -> State# s) -> m ()
prim_ State# s -> State# s
forall d. State# d -> State# d
loop
  where
    loop :: State# d -> State# d
loop State# d
s =
      case State# d -> (# State# d, Int#, Any #)
forall d a. State# d -> (# State# d, Int#, a #)
getSpark# State# d
s of
        (# State# d
s', Int#
n#, Any
p #) ->
          if Int# -> Bool
isTrue# (Int#
n# Int# -> Int# -> Int#
==# Int#
0#)
            then State# d
s'
            else Any
p Any -> State# d -> State# d
`seq` State# d -> State# d
loop State# d
s'

-- | Wrapper for `delay#`. Sleep specified number of microseconds. Not designed for
-- threaded runtime: __Errors when compiled with @-threaded@__
delay :: MonadPrim s m => Int -> m ()
delay :: Int -> m ()
delay (I# Int#
i#) = (State# s -> State# s) -> m ()
forall s (m :: * -> *).
MonadPrim s m =>
(State# s -> State# s) -> m ()
prim_ (Int# -> State# s -> State# s
forall d. Int# -> State# d -> State# d
delay# Int#
i#)

-- | Wrapper for `waitRead#`. Block and wait for input to become available on the
-- `Fd`. Not designed for threaded runtime: __Errors out when compiled with @-threaded@__
waitRead :: MonadPrim s m => Fd -> m ()
waitRead :: Fd -> m ()
waitRead !Fd
fd =
  case Fd -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Fd
fd of
    I# Int#
i# -> (State# s -> State# s) -> m ()
forall s (m :: * -> *).
MonadPrim s m =>
(State# s -> State# s) -> m ()
prim_ (Int# -> State# s -> State# s
forall d. Int# -> State# d -> State# d
waitRead# Int#
i#)


-- | Wrapper for `waitWrite#`. Block and wait until output is possible on the `Fd`.
-- Not designed for threaded runtime: __Errors out when compiled with @-threaded@__
waitWrite :: MonadPrim s m => Fd -> m ()
waitWrite :: Fd -> m ()
waitWrite !Fd
fd =
  case Fd -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Fd
fd of
    I# Int#
i# -> (State# s -> State# s) -> m ()
forall s (m :: * -> *).
MonadPrim s m =>
(State# s -> State# s) -> m ()
prim_ (Int# -> State# s -> State# s
forall d. Int# -> State# d -> State# d
waitWrite# Int#
i#)

-- | Wrapper around `fork#`. Unlike `Control.Concurrent.forkIO` it does not install
-- any exception handlers on the action, so you need make sure to do it yourself.
fork :: MonadUnliftPrim RW m => m () -> m GHC.ThreadId
fork :: m () -> m ThreadId
fork m ()
action =
  m ()
-> ((State# RealWorld -> (# State# RealWorld, () #))
    -> State# RealWorld -> (# State# RealWorld, ThreadId #))
-> m ThreadId
forall s (m :: * -> *) a b.
MonadUnliftPrim s m =>
m a
-> ((State# s -> (# State# s, a #))
    -> State# s -> (# State# s, b #))
-> m b
runInPrimBase m ()
action (((State# RealWorld -> (# State# RealWorld, () #))
  -> State# RealWorld -> (# State# RealWorld, ThreadId #))
 -> m ThreadId)
-> ((State# RealWorld -> (# State# RealWorld, () #))
    -> State# RealWorld -> (# State# RealWorld, ThreadId #))
-> m ThreadId
forall a b. (a -> b) -> a -> b
$ \State# RealWorld -> (# State# RealWorld, () #)
action# State# RealWorld
s ->
    case IO () -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
forall a.
a -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
fork# ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO State# RealWorld -> (# State# RealWorld, () #)
action#) State# RealWorld
s of
      (# State# RealWorld
s', ThreadId#
tid# #) -> (# State# RealWorld
s', ThreadId# -> ThreadId
GHC.ThreadId ThreadId#
tid# #)

-- | Spawn a thread and run an action in it. Any exception raised by the new thread will
-- be passed to the supplied exception handler, which itself will be run in a masked state
forkFinally :: MonadUnliftPrim RW m => m a -> (Either SomeException a -> m ()) -> m GHC.ThreadId
forkFinally :: m a -> (Either SomeException a -> m ()) -> m ThreadId
forkFinally m a
action Either SomeException a -> m ()
handler =
  ((forall b. m b -> m b) -> m ThreadId) -> m ThreadId
forall a (m :: * -> *) s.
MonadUnliftPrim s m =>
((forall b. m b -> m b) -> m a) -> m a
mask (((forall b. m b -> m b) -> m ThreadId) -> m ThreadId)
-> ((forall b. m b -> m b) -> m ThreadId) -> m ThreadId
forall a b. (a -> b) -> a -> b
$ \forall b. m b -> m b
restore -> m () -> m ThreadId
forall (m :: * -> *).
MonadUnliftPrim RealWorld m =>
m () -> m ThreadId
fork (m () -> m ThreadId) -> m () -> m ThreadId
forall a b. (a -> b) -> a -> b
$ m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadUnliftPrim RealWorld m =>
m a -> m (Either SomeException a)
tryAny (m a -> m a
forall b. m b -> m b
restore m a
action) m (Either SomeException a)
-> (Either SomeException a -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either SomeException a -> m ()
handler

-- | Wrapper around `forkOn#`. Unlike `Control.Concurrent.forkOn` it does not install any
-- exception handlers on the action, so you need make sure to do it yourself.
forkOn :: MonadUnliftPrim RW m => Int -> m () -> m GHC.ThreadId
forkOn :: Int -> m () -> m ThreadId
forkOn (I# Int#
cap#) m ()
action =
  m ()
-> ((State# RealWorld -> (# State# RealWorld, () #))
    -> State# RealWorld -> (# State# RealWorld, ThreadId #))
-> m ThreadId
forall s (m :: * -> *) a b.
MonadUnliftPrim s m =>
m a
-> ((State# s -> (# State# s, a #))
    -> State# s -> (# State# s, b #))
-> m b
runInPrimBase m ()
action (((State# RealWorld -> (# State# RealWorld, () #))
  -> State# RealWorld -> (# State# RealWorld, ThreadId #))
 -> m ThreadId)
-> ((State# RealWorld -> (# State# RealWorld, () #))
    -> State# RealWorld -> (# State# RealWorld, ThreadId #))
-> m ThreadId
forall a b. (a -> b) -> a -> b
$ \State# RealWorld -> (# State# RealWorld, () #)
action# State# RealWorld
s ->
    case Int#
-> IO () -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
forall a.
Int# -> a -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
forkOn# Int#
cap# ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO State# RealWorld -> (# State# RealWorld, () #)
action#) State# RealWorld
s of
      (# State# RealWorld
s', ThreadId#
tid# #) -> (# State# RealWorld
s', ThreadId# -> ThreadId
GHC.ThreadId ThreadId#
tid# #)

forkOnFinally ::
     MonadUnliftPrim RW m
  => Int
  -> m a
  -> (Either SomeException a -> m ())
  -> m GHC.ThreadId
forkOnFinally :: Int -> m a -> (Either SomeException a -> m ()) -> m ThreadId
forkOnFinally Int
cap m a
action Either SomeException a -> m ()
handler =
  ((forall b. m b -> m b) -> m ThreadId) -> m ThreadId
forall a (m :: * -> *) s.
MonadUnliftPrim s m =>
((forall b. m b -> m b) -> m a) -> m a
mask (((forall b. m b -> m b) -> m ThreadId) -> m ThreadId)
-> ((forall b. m b -> m b) -> m ThreadId) -> m ThreadId
forall a b. (a -> b) -> a -> b
$ \forall b. m b -> m b
restore -> Int -> m () -> m ThreadId
forall (m :: * -> *).
MonadUnliftPrim RealWorld m =>
Int -> m () -> m ThreadId
forkOn Int
cap (m () -> m ThreadId) -> m () -> m ThreadId
forall a b. (a -> b) -> a -> b
$ m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadUnliftPrim RealWorld m =>
m a -> m (Either SomeException a)
tryAny (m a -> m a
forall b. m b -> m b
restore m a
action) m (Either SomeException a)
-> (Either SomeException a -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either SomeException a -> m ()
handler


forkOS :: MonadUnliftPrim RW m => m () -> m GHC.ThreadId
forkOS :: m () -> m ThreadId
forkOS m ()
action = ((forall a. m a -> IO a) -> IO ThreadId) -> m ThreadId
forall (m :: * -> *) b.
MonadUnliftPrim RealWorld m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. m a -> IO a) -> IO ThreadId) -> m ThreadId)
-> ((forall a. m a -> IO a) -> IO ThreadId) -> m ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> IO a
run -> IO () -> IO ThreadId
GHC.forkOS (m () -> IO ()
forall a. m a -> IO a
run m ()
action)



-- | Wrapper around `killThread#`, which throws `GHC.ThreadKilled` exception in the target
-- thread. Use `throwTo` if you want a different exception to be thrown.
killThread :: MonadPrim RW m => GHC.ThreadId -> m ()
killThread :: ThreadId -> m ()
killThread !ThreadId
tid = ThreadId -> AsyncException -> m ()
forall s (m :: * -> *) e.
(MonadPrim s m, Exception e) =>
ThreadId -> e -> m ()
throwTo ThreadId
tid AsyncException
GHC.ThreadKilled

-- | Lifted version of `GHC.threadDelay`
threadDelay :: MonadPrim RW m => Int -> m ()
threadDelay :: Int -> m ()
threadDelay = IO () -> m ()
forall (m :: * -> *) a. MonadPrim RealWorld m => IO a -> m a
liftIO (IO () -> m ()) -> (Int -> IO ()) -> Int -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IO ()
GHC.threadDelay

-- | Lifted version of `GHC.timeout`
--
-- @since 0.3.0
timeout :: MonadUnliftPrim RW m => Int -> m a -> m (Maybe a)
timeout :: Int -> m a -> m (Maybe a)
timeout !Int
n !m a
action = ((forall a. m a -> IO a) -> IO (Maybe a)) -> m (Maybe a)
forall (m :: * -> *) b.
MonadUnliftPrim RealWorld m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. m a -> IO a) -> IO (Maybe a)) -> m (Maybe a))
-> ((forall a. m a -> IO a) -> IO (Maybe a)) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> IO a
run -> Int -> IO a -> IO (Maybe a)
forall a. Int -> IO a -> IO (Maybe a)
GHC.timeout Int
n (m a -> IO a
forall a. m a -> IO a
run m a
action)

-- | Same as `timeout`, but ignores the outcome
--
-- @since 0.3.0
timeout_ :: MonadUnliftPrim RW m => Int -> m a -> m ()
timeout_ :: Int -> m a -> m ()
timeout_ Int
n = m (Maybe a) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (Maybe a) -> m ()) -> (m a -> m (Maybe a)) -> m a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> m a -> m (Maybe a)
forall (m :: * -> *) a.
MonadUnliftPrim RealWorld m =>
Int -> m a -> m (Maybe a)
timeout Int
n



-- | Just like `Control.Concurrent.yield` this is a Wrapper around `yield#` primop ,
-- except that this version works for any state token. It is safe to use within `ST`
-- because it can't affect the result of computation, just the order of evaluation with
-- respect to other threads, which is not relevant for the state thread monad anyways.
--
-- @since 0.3.0
yield :: forall m s. MonadPrim s m => m ()
yield :: m ()
yield = (State# s -> State# s) -> m ()
forall s (m :: * -> *).
MonadPrim s m =>
(State# s -> State# s) -> m ()
prim_ ((State# RealWorld -> State# RealWorld) -> State# s -> State# s
unsafeCoerce# State# RealWorld -> State# RealWorld
yield# :: State# s -> State# s)

-- | Wrapper around `myThreadId#`.
myThreadId :: MonadPrim RW m => m GHC.ThreadId
myThreadId :: m ThreadId
myThreadId =
  (State# RealWorld -> (# State# RealWorld, ThreadId #))
-> m ThreadId
forall s (m :: * -> *) a.
MonadPrim s m =>
(State# s -> (# State# s, a #)) -> m a
prim ((State# RealWorld -> (# State# RealWorld, ThreadId #))
 -> m ThreadId)
-> (State# RealWorld -> (# State# RealWorld, ThreadId #))
-> m ThreadId
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
    case State# RealWorld -> (# State# RealWorld, ThreadId# #)
myThreadId# State# RealWorld
s of
      (# State# RealWorld
s', ThreadId#
tid# #) -> (# State# RealWorld
s', ThreadId# -> ThreadId
GHC.ThreadId ThreadId#
tid# #)

-- | Pointer should refer to UTF8 encoded string of bytes
labelThread :: MonadPrim RW m => GHC.ThreadId -> Ptr a -> m ()
labelThread :: ThreadId -> Ptr a -> m ()
labelThread (GHC.ThreadId ThreadId#
tid#) (Ptr Addr#
addr#) = (State# RealWorld -> State# RealWorld) -> m ()
forall s (m :: * -> *).
MonadPrim s m =>
(State# s -> State# s) -> m ()
prim_ (ThreadId# -> Addr# -> State# RealWorld -> State# RealWorld
labelThread# ThreadId#
tid# Addr#
addr#)

-- | Check if current thread was spawned with `forkOn#`
--
-- @since 0.3.0
isCurrentThreadBound :: MonadPrim RW m => m Bool
isCurrentThreadBound :: m Bool
isCurrentThreadBound =
  (State# RealWorld -> (# State# RealWorld, Bool #)) -> m Bool
forall s (m :: * -> *) a.
MonadPrim s m =>
(State# s -> (# State# s, a #)) -> m a
prim ((State# RealWorld -> (# State# RealWorld, Bool #)) -> m Bool)
-> (State# RealWorld -> (# State# RealWorld, Bool #)) -> m Bool
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
    case State# RealWorld -> (# State# RealWorld, Int# #)
isCurrentThreadBound# State# RealWorld
s of
      (# State# RealWorld
s', Int#
bool# #) -> (# State# RealWorld
s', Int# -> Bool
isTrue# Int#
bool# #)

threadStatus :: MonadPrim RW m => GHC.ThreadId -> m GHC.ThreadStatus
threadStatus :: ThreadId -> m ThreadStatus
threadStatus = IO ThreadStatus -> m ThreadStatus
forall s (n :: * -> *) (m :: * -> *) a.
(MonadPrimBase s n, MonadPrim s m) =>
n a -> m a
liftPrimBase (IO ThreadStatus -> m ThreadStatus)
-> (ThreadId -> IO ThreadStatus) -> ThreadId -> m ThreadStatus
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ThreadId -> IO ThreadStatus
GHC.threadStatus

threadCapability :: MonadPrim RW m => GHC.ThreadId -> m (Int, Bool)
threadCapability :: ThreadId -> m (Int, Bool)
threadCapability = IO (Int, Bool) -> m (Int, Bool)
forall s (n :: * -> *) (m :: * -> *) a.
(MonadPrimBase s n, MonadPrim s m) =>
n a -> m a
liftPrimBase (IO (Int, Bool) -> m (Int, Bool))
-> (ThreadId -> IO (Int, Bool)) -> ThreadId -> m (Int, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ThreadId -> IO (Int, Bool)
GHC.threadCapability

getNumCapabilities :: MonadPrim RW m => m Int
getNumCapabilities :: m Int
getNumCapabilities = IO Int -> m Int
forall s (n :: * -> *) (m :: * -> *) a.
(MonadPrimBase s n, MonadPrim s m) =>
n a -> m a
liftPrimBase IO Int
GHC.getNumCapabilities

setNumCapabilities :: MonadPrim RW m => Int -> m ()
setNumCapabilities :: Int -> m ()
setNumCapabilities = IO () -> m ()
forall s (n :: * -> *) (m :: * -> *) a.
(MonadPrimBase s n, MonadPrim s m) =>
n a -> m a
liftPrimBase (IO () -> m ()) -> (Int -> IO ()) -> Int -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IO ()
GHC.setNumCapabilities



-- | Something that is not exported from @base@: convert a `GHC.ThreadId` to a regular
-- integral type.
--
-- @since 0.0.0
threadIdToCInt :: GHC.ThreadId -> CInt
threadIdToCInt :: ThreadId -> CInt
threadIdToCInt ThreadId
tid = ThreadId# -> CInt
getThreadId (ThreadId -> ThreadId#
id2TSO ThreadId
tid)

id2TSO :: GHC.ThreadId -> ThreadId#
id2TSO :: ThreadId -> ThreadId#
id2TSO (GHC.ThreadId ThreadId#
t) = ThreadId#
t

-- Relevant ticket: https://gitlab.haskell.org/ghc/ghc/-/issues/8281
foreign import ccall unsafe "rts_getThreadId" getThreadId :: ThreadId# -> CInt