{-# LANGUAGE CPP           #-}
{-# LANGUAGE UnboxedTuples #-}
{-# OPTIONS_HADDOCK not-home #-}
-- |
-- Copyright: (c) 2021 Xy Ren
-- License: BSD3
-- Maintainer: xy.r@outlook.com
-- Stability: unstable
-- Portability: non-portable (GHC only)
--
-- This module contains a contention-free thread-local variable datatype.
--
-- __This is an /internal/ module and its API may change even between minor versions.__ Therefore you should be
-- extra careful if you're to depend on this module.
module Cleff.Internal.ThreadVar (ThreadVar, newThreadVar, getThreadVar) where

import           Cleff.Internal
import           Control.Monad.IO.Class (MonadIO (liftIO))
import           Data.Atomics           (atomicModifyIORefCAS_)
import           Data.IntMap.Strict     (IntMap)
import qualified Data.IntMap.Strict     as Map
import           Foreign.C.Types
import           GHC.Conc               (ThreadId (ThreadId))
import           GHC.Exts               (ThreadId#, mkWeak#)
import           GHC.IO                 (IO (IO))
import           UnliftIO.Concurrent    (myThreadId)
import           UnliftIO.IORef         (IORef, newIORef, readIORef)

-- | Get the hash for a 'ThreadId' in terms of C types (RTS function).
#if __GLASGOW_HASKELL__ >= 903
foreign import ccall unsafe "rts_getThreadId"
  getThreadId :: ThreadId# -> CULLong
#elif __GLASGOW_HASKELL__ >= 900
foreign import ccall unsafe "rts_getThreadId"
  getThreadId :: ThreadId# -> CLong
#else
foreign import ccall unsafe "rts_getThreadId"
  getThreadId :: ThreadId# -> CInt
#endif

-- | Generates a numeric hash for a 'ThreadId'. Before GHC 9.4, this function has a practical possibility of hash
-- collision on 32-bit or Windows platforms, if threads are created rapidly and thread count exceeds 2^32. After GHC
-- 9.4, this function practically won't produce collision as the hash is extended to 64-bit on all platforms.
hashThreadId :: ThreadId -> Int
hashThreadId :: ThreadId -> Int
hashThreadId (ThreadId ThreadId#
tid#) = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ThreadId# -> CInt
getThreadId ThreadId#
tid#)

-- | Attach a finalizer (an 'IO' computation) to a thread.
attachFinalizer :: ThreadId -> IO () -> IO ()
attachFinalizer :: ThreadId -> IO () -> IO ()
attachFinalizer (ThreadId ThreadId#
tid#) (IO State# RealWorld -> (# State# RealWorld, () #)
finalize#) = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO \State# RealWorld
s1 -> let
  !(# State# RealWorld
s2, Weak# ()
_ #) = ThreadId#
-> ()
-> (State# RealWorld -> (# State# RealWorld, () #))
-> State# RealWorld
-> (# State# RealWorld, Weak# () #)
forall a b c.
a
-> b
-> (State# RealWorld -> (# State# RealWorld, c #))
-> State# RealWorld
-> (# State# RealWorld, Weak# b #)
mkWeak# ThreadId#
tid# () State# RealWorld -> (# State# RealWorld, () #)
finalize# State# RealWorld
s1
  in (# State# RealWorld
s2, () #)

-- | A thread-local variable. It is designed so that any operation originating from existing threads produce no
-- contention; thread contention only occurs when multiple new threads attempt to first-time access the variable
-- at the same time.
data ThreadVar a = ThreadVar a {-# UNPACK #-} !(IORef (IntMap (IORef a)))

-- | Create a thread variable with a same initial value for each thread.
newThreadVar :: MonadIO m => a -> m (ThreadVar a)
newThreadVar :: a -> m (ThreadVar a)
newThreadVar a
x = a -> IORef (IntMap (IORef a)) -> ThreadVar a
forall a. a -> IORef (IntMap (IORef a)) -> ThreadVar a
ThreadVar a
x (IORef (IntMap (IORef a)) -> ThreadVar a)
-> m (IORef (IntMap (IORef a))) -> m (ThreadVar a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IntMap (IORef a) -> m (IORef (IntMap (IORef a)))
forall (m :: Type -> Type) a. MonadIO m => a -> m (IORef a)
newIORef IntMap (IORef a)
forall a. IntMap a
Map.empty

-- | Get the variable local to this thread, in the form of an 'IORef'. It is guaranteed that the returned 'IORef'
-- will not be read or mutated by other threads inadvertently.
getThreadVar :: MonadIO m => ThreadVar a -> m (IORef a)
getThreadVar :: ThreadVar a -> m (IORef a)
getThreadVar (ThreadVar a
x0 IORef (IntMap (IORef a))
table) = do
  ThreadId
tid <- m ThreadId
forall (m :: Type -> Type). MonadIO m => m ThreadId
myThreadId
  let thash :: Int
thash = ThreadId -> Int
hashThreadId ThreadId
tid
  Maybe (IORef a)
maybeRef <- Int -> IntMap (IORef a) -> Maybe (IORef a)
forall a. Int -> IntMap a -> Maybe a
Map.lookup Int
thash (IntMap (IORef a) -> Maybe (IORef a))
-> m (IntMap (IORef a)) -> m (Maybe (IORef a))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef (IntMap (IORef a)) -> m (IntMap (IORef a))
forall (m :: Type -> Type) a. MonadIO m => IORef a -> m a
readIORef IORef (IntMap (IORef a))
table
  case Maybe (IORef a)
maybeRef of
    Maybe (IORef a)
Nothing -> do
      IORef a
ref <- a -> m (IORef a)
forall (m :: Type -> Type) a. MonadIO m => a -> m (IORef a)
newIORef a
x0
      IO () -> m ()
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ (IORef (IntMap (IORef a))
 -> (IntMap (IORef a) -> IntMap (IORef a)) -> IO ())
-> IORef (IntMap (IORef a))
-> (IntMap (IORef a) -> IntMap (IORef a))
-> IO ()
forall a. a -> a
noinline IORef (IntMap (IORef a))
-> (IntMap (IORef a) -> IntMap (IORef a)) -> IO ()
forall t. IORef t -> (t -> t) -> IO ()
atomicModifyIORefCAS_ IORef (IntMap (IORef a))
table (Int -> IORef a -> IntMap (IORef a) -> IntMap (IORef a)
forall a. Int -> a -> IntMap a -> IntMap a
Map.insert Int
thash IORef a
ref)
      IO () -> m ()
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> IO () -> IO ()
attachFinalizer ThreadId
tid (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        (IORef (IntMap (IORef a))
 -> (IntMap (IORef a) -> IntMap (IORef a)) -> IO ())
-> IORef (IntMap (IORef a))
-> (IntMap (IORef a) -> IntMap (IORef a))
-> IO ()
forall a. a -> a
noinline IORef (IntMap (IORef a))
-> (IntMap (IORef a) -> IntMap (IORef a)) -> IO ()
forall t. IORef t -> (t -> t) -> IO ()
atomicModifyIORefCAS_ IORef (IntMap (IORef a))
table (Int -> IntMap (IORef a) -> IntMap (IORef a)
forall a. Int -> IntMap a -> IntMap a
Map.delete Int
thash)
      IORef a -> m (IORef a)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure IORef a
ref
    Just IORef a
ref -> IORef a -> m (IORef a)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure IORef a
ref
{-# INLINE getThreadVar #-}