module Control.Concurrent.LightSwitch
    ( LightSwitch
    , lockLightSwitch
    , newLightSwitch
    , unlockLightSwitch
    , withLightSwitch
    ) where

import Control.Concurrent.Util (withQSem)

import Control.Applicative ((<$>), (<*>))
import Control.Concurrent.QSem (newQSem, QSem, signalQSem, waitQSem)
import Control.Exception (bracket_)
import Control.Monad (when)
import Data.IORef (IORef, newIORef, readIORef, writeIORef)

data LightSwitch = LightSwitch
                 { counter   :: IORef Int
                 , mutex     :: QSem
                 , semaphore :: QSem
                 }

newLightSwitch :: QSem -> IO LightSwitch
newLightSwitch = ((LightSwitch <$> newIORef 0 <*> newQSem 1) <*>) . return

mutateIORef :: (a -> a) -> IORef a -> IO a
mutateIORef f r = ((>>) . writeIORef r <*> return) . f =<< readIORef r

lockLightSwitch :: LightSwitch -> IO ()
lockLightSwitch s = withQSem (mutex s) $ do
                      c <- mutateIORef (+ 1) $ counter s
                      when (c == 1) . waitQSem $ semaphore s

unlockLightSwitch :: LightSwitch -> IO ()
unlockLightSwitch s = withQSem (mutex s) $ do
                        c <- mutateIORef (\x -> x - 1) $ counter s
                        when (c == 0) . signalQSem $ semaphore s

withLightSwitch :: LightSwitch -> IO () -> IO ()
withLightSwitch = bracket_ . lockLightSwitch <*> unlockLightSwitch