{-# LANGUAGE BangPatterns #-}
module Control.Concurrent.STM.BTChan
        ( BTChan
        , newBTChan
        , newBTChanIO
        , writeBTChan
        , readBTChan
	, isEmptyBTChan
	, sizeOfBTChan
	, setMaxOfBTChan
	, maxOfBTChan
        ) where

import Control.Concurrent.STM
import Control.Monad (when)

-- |A 'BTChan' is a bounded 'TChan' - a FIFO channel using 'TChan' and
-- a transactional variable to limit the number of elements on the channel.
data BTChan a = BTChan {-# UNPACK #-} !Int (TChan a) (TVar  Int)

-- |An IO version of 'newBTChanIO'.  This should be useful with unsafePerformIO
-- in the same manner as 'newTVarIO' and 'newTChanIO' are used.
newBTChanIO :: Int -> IO (BTChan a)
newBTChanIO m = do
    szTV <- newTVarIO 0
    c    <- newTChanIO
    return (BTChan m c szTV)

-- | `newBTChan m` make a new bounded TChan of max size `m`.
newBTChan :: Int -> STM (BTChan a)
newBTChan m = do
        szTV <- newTVar 0
        c    <- newTChan
        return (BTChan m c szTV)

-- |Writes the value to the 'BTChan' or blocks if the channel is full.
writeBTChan :: BTChan a -> a -> STM ()
writeBTChan (BTChan mx c szTV) x = do
        sz <- readTVar szTV
	when (sz >= mx) retry
        writeTVar szTV (sz + 1) >> writeTChan c x

-- |Reads the next value from the 'BTChan'
readBTChan :: BTChan a -> STM a
readBTChan (BTChan _ c szTV) = do
        x <- readTChan c
        sz <- readTVar szTV
        let !sz' = sz - 1
        writeTVar szTV sz'
        return x

-- Put an element on the front of the queue so it will be the next item read.
unGetBTChan :: BTChan a -> a -> STM ()
unGetBTChan (BTChan m c sTV) a = do
	s <- readTVar sTV
	when (s >= m) retry
	let !s' = s+1
	writeTVar sTV s'
	unGetTChan c a

-- |Returns 'True' if the supplied 'TChan' is empty.
isEmptyBTChan :: BTChan a -> STM Bool
isEmptyBTChan (BTChan _ c _) = isEmptyTChan c

-- |Get the current number of elements in the 'BTChan'.
sizeOfBTChan :: BTChan a -> STM Int
sizeOfBTChan (BTChan _ _ sTV) = readTVar sTV

-- |@c2 = setMaxOfBTChan c1 mx@ Using the same underlying 'TChan',
-- set a new maximum number of messages, @mx@.  If the current size
-- is greater than @mx@ then no messages are dropped, but writes 
-- will block till the size goes lower than @mx@.  Using @c2@ and
-- @c1@ concurrently is possible, but @c2@ writes will block at the new
-- maximum while writes to @c1@ will block at the new, making it biased
-- against whichever writer has the channel with the smaller bound.
setMaxOfBTChan :: BTChan a -> Int -> BTChan a
setMaxOfBTChan (BTChan _ c s) m = BTChan m c s

-- |Get the bound of the `BTChan`.
maxOfBTChan :: BTChan a -> Int
maxOfBTChan (BTChan m _ _) = m