{-# LANGUAGE BangPatterns #-}
module Control.Concurrent.STM.BTChan
        ( BTChan
        , newBTChan
        , newBTChanIO
        , writeBTChan
        , readBTChan
        , tryWriteBTChan
        , tryReadBTChan
        , unGetBTChan
        , isEmptyBTChan
        , sizeOfBTChan
        , setMaxOfBTChan
        , maxOfBTChan
        ) where

import Control.Concurrent.STM
import Control.Monad (when, liftM)
import Control.Applicative

-- |A 'BTChan' is a bounded 'TChan' - a FIFO channel using 'TChan' and
-- a transactional variable to limit the number of elements on the channel.
data BTChan a =
	BTChan  { maxSize	:: {-# UNPACK #-} !Int
		, channel	:: (TChan a)
		, readSize	:: (TVar Int)
		, writeSize	:: (TVar  Int)
	}

-- |An IO version of 'newBTChanIO'.  This should be useful with unsafePerformIO
-- in the same manner as 'newTVarIO' and 'newTChanIO' are used.
newBTChanIO :: Int -> IO (BTChan a)
newBTChanIO m = BTChan m <$> newTChanIO <*> newTVarIO 0 <*> newTVarIO 0

-- |@newBTChan m@ make a new bounded TChan of max size @m@.
newBTChan :: Int -> STM (BTChan a)
newBTChan m = BTChan m <$> newTChan <*> newTVar 0 <*> newTVar 0

-- |Writes the value to the 'BTChan' or blocks if the channel is full.
writeBTChan :: BTChan a -> a -> STM ()
writeBTChan (BTChan mx c rdTV wrTV) x = do
        sz <- readTVar wrTV
        if (sz >= mx)
	    then do
		rsz <- readTVar rdTV
		let !newWR = sz + rsz
		when (newWR >= mx) retry
		writeTVar wrTV newWR
		writeTVar rdTV 0
		writeTChan c x
	    else do
        	writeTVar wrTV (sz + 1)
        	writeTChan c x

-- |A non-blocking write that returns 'True' if the write succeeded, 'False' otherwise.
tryWriteBTChan :: BTChan a -> a -> STM Bool
tryWriteBTChan (BTChan mx c rdTV wrTV) x = do
        sz <- readTVar wrTV
        if (sz >= mx)
            then do
                rsz <- readTVar rdTV
                let !newWR = sz + rsz
                if (newWR >= mx)
			then return False
                	else do writeTVar wrTV newWR
		                writeTVar rdTV 0
        		        writeTChan c x
				return True
            else do
                writeTVar wrTV (sz + 1)
                writeTChan c x
		return True

-- |Reads the next value from the 'BTChan'
readBTChan :: BTChan a -> STM a
readBTChan (BTChan _ c rdTV wrTV) = do
        x <- readTChan c
        sz <- readTVar rdTV
        let !sz' = sz - 1
        writeTVar rdTV sz'
        return x

-- |A non-blocking read that returns 'Just a' on success and 'Nothing'
-- when the channel is empty.
tryReadBTChan :: BTChan a -> STM (Maybe a)
tryReadBTChan bt = do
        e <- isEmptyBTChan bt
        if e then return Nothing else liftM Just (readBTChan bt)

-- Put an element on the front of the queue so it will be the next item read.
unGetBTChan :: BTChan a -> a -> STM ()
unGetBTChan (BTChan m c rdTV wrTV) a = do
        sz <- readTVar wrTV
        if (sz >= m)
	    then do
                rsz <- readTVar rdTV
                let !newWR = sz + rsz
                when (newWR >= m) retry
                writeTVar wrTV newWR
                writeTVar rdTV 0
                unGetTChan c a
	    else do
		let !s' = sz + 1
        	writeTVar wrTV s'
        	unGetTChan c a

-- |Returns 'True' if the supplied 'TChan' is empty.
isEmptyBTChan :: BTChan a -> STM Bool
isEmptyBTChan (BTChan _ c _ _) = isEmptyTChan c

-- |Get the current number of elements in the 'BTChan'.
sizeOfBTChan :: BTChan a -> STM Int
sizeOfBTChan (BTChan _ _ rdTV wrTV) = (+) <$> readTVar wrTV <*> readTVar rdTV

-- |@let c2 = setMaxOfBTChan c1 mx@ Using the same underlying 'TChan',
-- set a new maximum number of messages, @mx@.  If the current size
-- is greater than @mx@ then no messages are dropped, but writes 
-- will block till the size goes lower than @mx@.  Using @c2@ and
-- @c1@ concurrently is possible, but @c2@ writes will block at the new
-- maximum while writes to @c1@ will block at the new, making it biased
-- against whichever writer has the channel with the smaller bound.
setMaxOfBTChan :: BTChan a -> Int -> BTChan a
setMaxOfBTChan (BTChan _ c rd wr) m = BTChan m c rd wr

-- |Get the bound of the `BTChan`.
maxOfBTChan :: BTChan a -> Int
maxOfBTChan (BTChan m _ _ _) = m