{-# LANGUAGE DeriveDataTypeable #-}
{-# OPTIONS_GHC -funbox-strict-fields #-}

module Distribution.Client.Compat.Semaphore
  ( QSem
  , newQSem
  , waitQSem
  , signalQSem
  ) where

import Prelude (Bool (..), Eq (..), IO, Int, Num (..), flip, return, ($), ($!))

import Control.Concurrent.STM
  ( TVar
  , atomically
  , newTVar
  , readTVar
  , retry
  , writeTVar
  )
import Control.Exception (mask_, onException)
import Control.Monad (join, unless)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NE
import Data.Typeable (Typeable)

-- | 'QSem' is a quantity semaphore in which the resource is acquired
-- and released in units of one. It provides guaranteed FIFO ordering
-- for satisfying blocked `waitQSem` calls.
data QSem = QSem !(TVar Int) !(TVar [TVar Bool]) !(TVar [TVar Bool])
  deriving (QSem -> QSem -> Bool
(QSem -> QSem -> Bool) -> (QSem -> QSem -> Bool) -> Eq QSem
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: QSem -> QSem -> Bool
== :: QSem -> QSem -> Bool
$c/= :: QSem -> QSem -> Bool
/= :: QSem -> QSem -> Bool
Eq, Typeable)

newQSem :: Int -> IO QSem
newQSem :: Int -> IO QSem
newQSem Int
i = STM QSem -> IO QSem
forall a. STM a -> IO a
atomically (STM QSem -> IO QSem) -> STM QSem -> IO QSem
forall a b. (a -> b) -> a -> b
$ do
  TVar Int
q <- Int -> STM (TVar Int)
forall a. a -> STM (TVar a)
newTVar Int
i
  TVar [TVar Bool]
b1 <- [TVar Bool] -> STM (TVar [TVar Bool])
forall a. a -> STM (TVar a)
newTVar []
  TVar [TVar Bool]
b2 <- [TVar Bool] -> STM (TVar [TVar Bool])
forall a. a -> STM (TVar a)
newTVar []
  QSem -> STM QSem
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar Int -> TVar [TVar Bool] -> TVar [TVar Bool] -> QSem
QSem TVar Int
q TVar [TVar Bool]
b1 TVar [TVar Bool]
b2)

waitQSem :: QSem -> IO ()
waitQSem :: QSem -> IO ()
waitQSem s :: QSem
s@(QSem TVar Int
q TVar [TVar Bool]
_b1 TVar [TVar Bool]
b2) =
  IO () -> IO ()
forall a. IO a -> IO a
mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO (IO ()) -> IO ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (IO (IO ()) -> IO ()) -> IO (IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ STM (IO ()) -> IO (IO ())
forall a. STM a -> IO a
atomically (STM (IO ()) -> IO (IO ())) -> STM (IO ()) -> IO (IO ())
forall a b. (a -> b) -> a -> b
$ do
    -- join, because if we need to block, we have to add a TVar to
    -- the block queue.
    -- mask_, because we need a chance to set up an exception handler
    -- after the join returns.
    Int
v <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
q
    if Int
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
      then do
        TVar Bool
b <- Bool -> STM (TVar Bool)
forall a. a -> STM (TVar a)
newTVar Bool
False
        [TVar Bool]
ys <- TVar [TVar Bool] -> STM [TVar Bool]
forall a. TVar a -> STM a
readTVar TVar [TVar Bool]
b2
        TVar [TVar Bool] -> [TVar Bool] -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar [TVar Bool]
b2 (TVar Bool
b TVar Bool -> [TVar Bool] -> [TVar Bool]
forall a. a -> [a] -> [a]
: [TVar Bool]
ys)
        IO () -> STM (IO ())
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar Bool -> IO ()
wait TVar Bool
b)
      else do
        TVar Int -> Int -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Int
q (Int -> STM ()) -> Int -> STM ()
forall a b. (a -> b) -> a -> b
$! Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
        IO () -> STM (IO ())
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
  where
    --
    -- very careful here: if we receive an exception, then we need to
    --  (a) write True into the TVar, so that another signalQSem doesn't
    --      try to wake up this thread, and
    --  (b) if the TVar is *already* True, then we need to do another
    --      signalQSem to avoid losing a unit of the resource.
    --
    -- The 'wake' function does both (a) and (b), so we can just call
    -- it here.
    --
    wait :: TVar Bool -> IO ()
wait TVar Bool
t =
      (IO () -> IO () -> IO ()) -> IO () -> IO () -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
onException (QSem -> TVar Bool -> IO ()
wake QSem
s TVar Bool
t) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          Bool
b <- TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
t
          Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
b STM ()
forall a. STM a
retry

wake :: QSem -> TVar Bool -> IO ()
wake :: QSem -> TVar Bool -> IO ()
wake QSem
s TVar Bool
x = IO (IO ()) -> IO ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (IO (IO ()) -> IO ()) -> IO (IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ STM (IO ()) -> IO (IO ())
forall a. STM a -> IO a
atomically (STM (IO ()) -> IO (IO ())) -> STM (IO ()) -> IO (IO ())
forall a b. (a -> b) -> a -> b
$ do
  Bool
b <- TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
x
  if Bool
b
    then IO () -> STM (IO ())
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (QSem -> IO ()
signalQSem QSem
s)
    else do
      TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
x Bool
True
      IO () -> STM (IO ())
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())

{-
 property we want:

   bracket waitQSem (\_ -> signalQSem) (\_ -> ...)

 never loses a unit of the resource.
-}

signalQSem :: QSem -> IO ()
signalQSem :: QSem -> IO ()
signalQSem s :: QSem
s@(QSem TVar Int
q TVar [TVar Bool]
b1 TVar [TVar Bool]
b2) =
  IO () -> IO ()
forall a. IO a -> IO a
mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO (IO ()) -> IO ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (IO (IO ()) -> IO ()) -> IO (IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ STM (IO ()) -> IO (IO ())
forall a. STM a -> IO a
atomically (STM (IO ()) -> IO (IO ())) -> STM (IO ()) -> IO (IO ())
forall a b. (a -> b) -> a -> b
$ do
    -- join, so we don't force the reverse inside the txn
    -- mask_ is needed so we don't lose a wakeup
    Int
v <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
q
    if Int
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0
      then do
        TVar Int -> Int -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Int
q (Int -> STM ()) -> Int -> STM ()
forall a b. (a -> b) -> a -> b
$! Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        IO () -> STM (IO ())
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
      else do
        [TVar Bool]
xs <- TVar [TVar Bool] -> STM [TVar Bool]
forall a. TVar a -> STM a
readTVar TVar [TVar Bool]
b1
        [TVar Bool] -> STM (IO ())
checkwake1 [TVar Bool]
xs
  where
    checkwake1 :: [TVar Bool] -> STM (IO ())
checkwake1 [] = do
      [TVar Bool]
ys <- TVar [TVar Bool] -> STM [TVar Bool]
forall a. TVar a -> STM a
readTVar TVar [TVar Bool]
b2
      [TVar Bool] -> STM (IO ())
checkwake2 [TVar Bool]
ys
    checkwake1 (TVar Bool
x : [TVar Bool]
xs) = do
      TVar [TVar Bool] -> [TVar Bool] -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar [TVar Bool]
b1 [TVar Bool]
xs
      IO () -> STM (IO ())
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (QSem -> TVar Bool -> IO ()
wake QSem
s TVar Bool
x)

    checkwake2 :: [TVar Bool] -> STM (IO ())
checkwake2 [] = do
      TVar Int -> Int -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Int
q Int
1
      IO () -> STM (IO ())
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
    checkwake2 (TVar Bool
y : [TVar Bool]
ys) = do
      let (TVar Bool
z :| [TVar Bool]
zs) = NonEmpty (TVar Bool) -> NonEmpty (TVar Bool)
forall a. NonEmpty a -> NonEmpty a
NE.reverse (TVar Bool
y TVar Bool -> [TVar Bool] -> NonEmpty (TVar Bool)
forall a. a -> [a] -> NonEmpty a
:| [TVar Bool]
ys)
      TVar [TVar Bool] -> [TVar Bool] -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar [TVar Bool]
b1 [TVar Bool]
zs
      TVar [TVar Bool] -> [TVar Bool] -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar [TVar Bool]
b2 []
      IO () -> STM (IO ())
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (QSem -> TVar Bool -> IO ()
wake QSem
s TVar Bool
z)