module Semaphore (Semaphore, getSemaphore, putPos, getPos, getNeg, putNeg) where

import Control.Concurrent.MVar
import System.IO.Unsafe
import Control.Concurrent

-- special semaphore which can be used to distiguish between two different phases

data Semaphore = Sem (MVar Int) (MVar Int) (MVar [MVar ()])

{-# NOINLINE getSemaphore #-} 
getSemaphore :: Semaphore
getSemaphore = unsafePerformIO $ do
          counter <- newMVar 0
          request <- newMVar 0
          waiting <- newMVar []
          return $ Sem counter request waiting

getPos :: Semaphore -> IO ()
getPos s@(Sem counter request event) = do 
          incr request
          waitForCounter s

waitForCounter s@(Sem counter request event) = do
          c <- takeMVar counter
                   
          if (c < 0) 
             then do {   
                         putMVar counter c;
                         wait event;
                         waitForCounter s
                        }
             else do let next = c+1
                     seq next (putMVar counter next)  


putPos :: Semaphore -> IO ()
putPos (Sem counter request event) = do
               decr counter
               decr request
               signal event
               
getNeg :: Semaphore -> IO ()
getNeg s@(Sem counter request event) = do 
          c <- takeMVar counter
          r <- takeMVar request
          if (c > 0 || r > 0)
             then do putMVar counter c
                     putMVar request r
                     wait event
                     getNeg s         
             else  do putMVar counter (c-1)
                      putMVar request r   

putNeg :: Semaphore -> IO ()
putNeg s@(Sem counter request event) = do 
          incr counter
          signal event

wait :: MVar [MVar ()] -> IO ()
wait list = do 
          l <- takeMVar list
          trigger <- newMVar ()
          takeMVar trigger
          putMVar list (trigger:l)                
          takeMVar trigger

signal :: MVar [MVar ()] -> IO ()
signal list = do 
          l <- takeMVar list
          mapM_ (\trigger -> putMVar trigger ()) l
          putMVar list []

decr :: MVar Int -> IO ()
decr mvar = do
               c <- takeMVar mvar 
               putMVar mvar (c-1)
incr :: MVar Int -> IO ()
incr mvar = do
               c <- takeMVar mvar 
               putMVar mvar (c+1)