module Engine.Types.RefCounted where

import RIO

import Control.Monad.Trans.Resource (allocate_)
import GHC.IO.Exception (IOErrorType(UserError), IOException(IOError))
import UnliftIO.Resource (MonadResource)

-- | A 'RefCounted' will perform the specified action when the count reaches 0
data RefCounted = RefCounted
  { RefCounted -> IORef Int
rcCount  :: IORef Int
  , RefCounted -> IO ()
rcAction :: IO ()
  }

-- | Create a counter with a value of 1
newRefCounted :: MonadIO m => IO () -> m RefCounted
newRefCounted :: forall (m :: * -> *). MonadIO m => IO () -> m RefCounted
newRefCounted IO ()
rcAction = do
  IORef Int
rcCount <- IO (IORef Int) -> m (IORef Int)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IORef Int) -> m (IORef Int))
-> IO (IORef Int) -> m (IORef Int)
forall a b. (a -> b) -> a -> b
$ Int -> IO (IORef Int)
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef Int
1
  pure RefCounted :: IORef Int -> IO () -> RefCounted
RefCounted{IO ()
IORef Int
rcCount :: IORef Int
rcAction :: IO ()
$sel:rcAction:RefCounted :: IO ()
$sel:rcCount:RefCounted :: IORef Int
..}

-- | Decrement the value, the action will be run promptly and in
-- this thread if the counter reached 0.
releaseRefCounted :: MonadIO m => RefCounted -> m ()
releaseRefCounted :: forall (m :: * -> *). MonadIO m => RefCounted -> m ()
releaseRefCounted RefCounted{IO ()
IORef Int
rcAction :: IO ()
rcCount :: IORef Int
$sel:rcAction:RefCounted :: RefCounted -> IO ()
$sel:rcCount:RefCounted :: RefCounted -> IORef Int
..} =
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> m a) -> m b) -> m b
mask \forall a. IO a -> IO a
_ ->
    IORef Int -> (Int -> (Int, Int)) -> IO Int
forall (m :: * -> *) a b.
MonadIO m =>
IORef a -> (a -> (a, b)) -> m b
atomicModifyIORef' IORef Int
rcCount (\Int
c -> (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) IO Int -> (Int -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 ->
        IOException -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IOException -> IO ()) -> IOException -> IO ()
forall a b. (a -> b) -> a -> b
$ Maybe Handle
-> IOErrorType
-> String
-> String
-> Maybe CInt
-> Maybe String
-> IOException
IOError
          Maybe Handle
forall a. Maybe a
Nothing
          IOErrorType
UserError
          String
""
          String
"Ref counted value decremented below 0"
          Maybe CInt
forall a. Maybe a
Nothing
          Maybe String
forall a. Maybe a
Nothing

      Int
0 ->
        IO ()
rcAction

      Int
_stillReferenced ->
        () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Increment the counter by 1
takeRefCounted :: MonadIO m => RefCounted -> m ()
takeRefCounted :: forall (m :: * -> *). MonadIO m => RefCounted -> m ()
takeRefCounted RefCounted{IO ()
IORef Int
rcAction :: IO ()
rcCount :: IORef Int
$sel:rcAction:RefCounted :: RefCounted -> IO ()
$sel:rcCount:RefCounted :: RefCounted -> IORef Int
..} =
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef Int -> (Int -> (Int, ())) -> IO ()
forall (m :: * -> *) a b.
MonadIO m =>
IORef a -> (a -> (a, b)) -> m b
atomicModifyIORef' IORef Int
rcCount \Int
c -> (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, ())

-- | Hold a reference for the duration of the 'MonadResource' action
resourceTRefCount :: MonadResource f => RefCounted -> f ()
resourceTRefCount :: forall (f :: * -> *). MonadResource f => RefCounted -> f ()
resourceTRefCount RefCounted
r =
  f ReleaseKey -> f ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (f ReleaseKey -> f ()) -> f ReleaseKey -> f ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO () -> f ReleaseKey
forall (m :: * -> *) a.
MonadResource m =>
IO a -> IO () -> m ReleaseKey
allocate_ (RefCounted -> IO ()
forall (m :: * -> *). MonadIO m => RefCounted -> m ()
takeRefCounted RefCounted
r) (RefCounted -> IO ()
forall (m :: * -> *). MonadIO m => RefCounted -> m ()
releaseRefCounted RefCounted
r)