{-# LANGUAGE CPP #-}

module Network.Wai.Handler.Warp.Counter (
    Counter,
    newCounter,
    waitForZero,
    increase,
    decrease,
) where

import Control.Concurrent.STM

import Network.Wai.Handler.Warp.Imports

newtype Counter = Counter (TVar Int)

newCounter :: IO Counter
newCounter :: IO Counter
newCounter = TVar Int -> Counter
Counter (TVar Int -> Counter) -> IO (TVar Int) -> IO Counter
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (TVar Int)
forall a. a -> IO (TVar a)
newTVarIO Int
0

waitForZero :: Counter -> IO ()
waitForZero :: Counter -> IO ()
waitForZero (Counter TVar Int
ref) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    Int
x <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
ref
    Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) STM ()
forall a. STM a
retry

increase :: Counter -> IO ()
increase :: Counter -> IO ()
increase (Counter TVar Int
ref) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
ref ((Int -> Int) -> STM ()) -> (Int -> Int) -> STM ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

decrease :: Counter -> IO ()
decrease :: Counter -> IO ()
decrease (Counter TVar Int
ref) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
ref ((Int -> Int) -> STM ()) -> (Int -> Int) -> STM ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1