{-# LANGUAGE Safe #-}

-- | Bounded quantity semaphores.
module Control.Concurrent.BQSem
  ( BQSem
  , newBQSem
  , waitBQSem
  , signalBQSem
  , getBQSemQuantity
    ) where

import Control.Concurrent.QSem
import Control.Concurrent.MVar
import Control.Exception (mask, onException)
import Control.Monad (unless)

-- | Bounded quantity semaphore in which the resource is acquired and released in units of one,
--   but with a maximum amount of units available at any given time.
data BQSem = BQSem
  { -- | Underlying unbounded quantity semaphore.
    BQSem -> QSem
unboundedQSem :: QSem
    -- | Maximum number of units.
  , BQSem -> Int
bqsemBound :: Int
    -- | Counter of current units.
  , BQSem -> MVar Int
bqsemCounter :: MVar Int
    }

-- | Build a new 'BQSem' with supplied initial and maximum supply.
--   An exception is thrown in any of the following cases:
--
--   * Initial supply is negative.
--   * Maximum supply is less than 1.
--   * Initial supply exceeds maximum.
--
newBQSem
  :: Int -- ^ Initial unit supply.
  -> Int -- ^ Maximum unit supply.
  -> IO BQSem
newBQSem :: Int -> Int -> IO BQSem
newBQSem Int
n0 Int
m = do
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
n0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
m) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"newBQSem: Initial quantity must be less or equal than maximum."
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"newBQSem: Maximum quantity must be at least 1."
  QSem
qsem <- Int -> IO QSem
newQSem Int
n0
  MVar Int
counter <- Int -> IO (MVar Int)
forall a. a -> IO (MVar a)
newMVar Int
n0
  BQSem -> IO BQSem
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BQSem -> IO BQSem) -> BQSem -> IO BQSem
forall a b. (a -> b) -> a -> b
$ BQSem
    { unboundedQSem :: QSem
unboundedQSem = QSem
qsem
    , bqsemBound :: Int
bqsemBound = Int
m
    , bqsemCounter :: MVar Int
bqsemCounter = MVar Int
counter
      }

-- | Wait for a unit to become available.
waitBQSem :: BQSem -> IO ()
waitBQSem :: BQSem -> IO ()
waitBQSem BQSem
bqsem =
  ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO ()) -> IO ())
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
    IO () -> IO ()
forall a. IO a -> IO a
restore (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ QSem -> IO ()
waitQSem (QSem -> IO ()) -> QSem -> IO ()
forall a b. (a -> b) -> a -> b
$ BQSem -> QSem
unboundedQSem BQSem
bqsem
    let counter :: MVar Int
counter = BQSem -> MVar Int
bqsemCounter BQSem
bqsem
    MVar Int -> IO Int
forall a. MVar a -> IO a
takeMVar MVar Int
counter IO Int -> (Int -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
      \Int
n -> MVar Int -> Int -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar Int
counter (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$! Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

-- | Make a new unit available, unless the maximum number of units has been reached,
--   in which case it does nothing (it doesn't block).
signalBQSem :: BQSem -> IO ()
signalBQSem :: BQSem -> IO ()
signalBQSem BQSem
bqsem =
  ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO ()) -> IO ())
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
    let counter :: MVar Int
counter = BQSem -> MVar Int
bqsemCounter BQSem
bqsem
    Int
n <- MVar Int -> IO Int
forall a. MVar a -> IO a
takeMVar MVar Int
counter
    if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== BQSem -> Int
bqsemBound BQSem
bqsem
       then MVar Int -> Int -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar Int
counter Int
n
       else do IO () -> IO ()
forall a. IO a -> IO a
restore (QSem -> IO ()
signalQSem (QSem -> IO ()) -> QSem -> IO ()
forall a b. (a -> b) -> a -> b
$ BQSem -> QSem
unboundedQSem BQSem
bqsem) IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`onException` MVar Int -> Int -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar Int
counter Int
n
               MVar Int -> Int -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar Int
counter (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$! Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

-- | Get current supply quantity.
getBQSemQuantity :: BQSem -> IO Int
getBQSemQuantity :: BQSem -> IO Int
getBQSemQuantity = MVar Int -> IO Int
forall a. MVar a -> IO a
readMVar (MVar Int -> IO Int) -> (BQSem -> MVar Int) -> BQSem -> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BQSem -> MVar Int
bqsemCounter