-- | Threadsafe, shared, atomic counters
--
-- This is based on "Data.Atomics.Counter".
module UnliftIO.MessageBox.Util.Fresh
  ( fresh,
    incrementAndGet,
    newCounterVar,
    HasCounterVar (getCounterVar),
    CounterVar (),
  )
where

import Control.Monad.Reader (MonadReader, asks)
import Data.Atomics.Counter
  ( AtomicCounter,
    incrCounter,
    newCounter,
  )
import Data.Coerce (Coercible, coerce)
import UnliftIO (MonadIO (..))

-- | A threadsafe atomic a

-- | Atomically increment and get the value of the 'Counter'
-- for type @a@ that must be present in the @env@.
{-# INLINE fresh #-}
fresh ::
  forall a env m.
  ( MonadReader env m,
    MonadIO m,
    HasCounterVar a env,
    Coercible a Int
  ) =>
  m a
fresh :: m a
fresh =
  (env -> CounterVar a) -> m (CounterVar a)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (forall env. HasCounterVar a env => env -> CounterVar a
forall k (a :: k) env. HasCounterVar a env => env -> CounterVar a
getCounterVar @a) m (CounterVar a) -> (CounterVar a -> m a) -> m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CounterVar a -> m a
forall a (m :: * -> *).
(MonadIO m, Coercible a Int) =>
CounterVar a -> m a
incrementAndGet

-- | Atomically increment and get the value of the 'Counter'
-- for type @a@ that must be present in the @env@.
{-# INLINE incrementAndGet #-}
incrementAndGet ::
  forall a m.
  ( MonadIO m,
    Coercible a Int
  ) =>
  CounterVar a -> m a
incrementAndGet :: CounterVar a -> m a
incrementAndGet (MkCounterVar !AtomicCounter
atomicCounter) =
  Int -> a
coerce (Int -> a) -> m Int -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Int -> m Int
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Int -> AtomicCounter -> IO Int
incrCounter Int
1 AtomicCounter
atomicCounter)


-- | Create a new 'CounterVar' starting at @0@.
{-# INLINE newCounterVar #-}
newCounterVar ::
  forall a m.
  MonadIO m =>
  m (CounterVar a)
newCounterVar :: m (CounterVar a)
newCounterVar =
  AtomicCounter -> CounterVar a
forall k (a :: k). AtomicCounter -> CounterVar a
MkCounterVar (AtomicCounter -> CounterVar a)
-> m AtomicCounter -> m (CounterVar a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO AtomicCounter -> m AtomicCounter
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Int -> IO AtomicCounter
newCounter Int
0)

-- | An 'AtomicCounter'.
newtype CounterVar a = MkCounterVar AtomicCounter

-- | A type class for @MonadReader@ based
-- applications.
class HasCounterVar a env | env -> a where
  getCounterVar :: env -> CounterVar a

instance HasCounterVar t (CounterVar t) where
  getCounterVar :: CounterVar t -> CounterVar t
getCounterVar = CounterVar t -> CounterVar t
forall a. a -> a
id