{-# LANGUAGE NamedFieldPuns #-}
-- | Module: Lifetimes.Rc
-- Description: Support for working with reference-counted resources.
--
-- Rather than associating a resource with one lifetime, a reference counted
-- resource associates each *reference* with a lifetime, and is released when
-- all references have expired.
module Lifetimes.Rc
    ( Rc
    , addRef
    , refCounted
    ) where

import Control.Concurrent.STM
import Lifetimes
import Zhp

-- | A resource which is managed by reference counting.
data Rc a = Rc
    { forall a. Rc a -> TVar Int
count   :: TVar Int
    , forall a. Rc a -> a
value   :: a
    , forall a. Rc a -> IO ()
cleanup :: IO ()
    }

-- | Acquire a new reference.
addRef :: Rc a -> Acquire a
addRef :: forall a. Rc a -> Acquire a
addRef Rc a
rc =
    forall a. IO a -> (a -> IO ()) -> Acquire a
mkAcquire
        (forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. Rc a -> STM a
incRef Rc a
rc)
        (\a
_ -> forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. Rc a -> STM (IO ())
decRef Rc a
rc)

resourceToRc :: Resource a -> STM (Rc a)
resourceToRc :: forall a. Resource a -> STM (Rc a)
resourceToRc Resource a
res = do
    a
value <- forall (m :: * -> *) a. MonadSTM m => Resource a -> m a
mustGetResource Resource a
res
    IO ()
cleanup <- forall (m :: * -> *) a. MonadSTM m => Resource a -> m (IO ())
detach Resource a
res
    TVar Int
count <- forall a. a -> STM (TVar a)
newTVar Int
1
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Rc { TVar Int
count :: TVar Int
count :: TVar Int
count, IO ()
cleanup :: IO ()
cleanup :: IO ()
cleanup, a
value :: a
value :: a
value }


-- | Acquire a resource using refcounting. Takes an 'Acquire' for the underlying
-- resource, and returns one that acquires an initial reference to it. Additional
-- references may be created using 'addRef', and the underlying resource will be
-- kept alive until all resources are released.
refCounted :: Acquire a -> Acquire (Rc a)
refCounted :: forall a. Acquire a -> Acquire (Rc a)
refCounted Acquire a
acq = do
    Lifetime
lt <- Acquire Lifetime
currentLifetime
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. (Lifetime -> IO a) -> IO a
withLifetime forall a b. (a -> b) -> a -> b
$ \Lifetime
tmpLt -> do
        Resource a
res <- forall a. Lifetime -> Acquire a -> IO (Resource a)
acquire Lifetime
tmpLt Acquire a
acq
        forall a. Lifetime -> Acquire a -> IO a
acquireValue Lifetime
lt forall a b. (a -> b) -> a -> b
$ forall a. IO a -> (a -> IO ()) -> Acquire a
mkAcquire
            (forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. Resource a -> STM (Rc a)
resourceToRc Resource a
res)
            (forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. STM a -> IO a
atomically forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Rc a -> STM (IO ())
decRef)


incRef :: Rc a -> STM a
incRef :: forall a. Rc a -> STM a
incRef Rc{TVar Int
count :: TVar Int
count :: forall a. Rc a -> TVar Int
count, a
value :: a
value :: forall a. Rc a -> a
value} = do
    forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
count forall a. Enum a => a -> a
succ
    forall (f :: * -> *) a. Applicative f => a -> f a
pure a
value

decRef :: Rc a -> STM (IO ())
decRef :: forall a. Rc a -> STM (IO ())
decRef Rc{TVar Int
count :: TVar Int
count :: forall a. Rc a -> TVar Int
count, IO ()
cleanup :: IO ()
cleanup :: forall a. Rc a -> IO ()
cleanup} = do
    forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
count forall a. Enum a => a -> a
pred
    Int
c <- forall a. TVar a -> STM a
readTVar TVar Int
count
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case Int
c of
        Int
0 -> IO ()
cleanup
        Int
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()