module Control.Concurrent.STM.BTChan
( BTChan
, newBTChan
, newBTChanIO
, writeBTChan
, readBTChan
, tryWriteBTChan
, tryReadBTChan
, unGetBTChan
, isEmptyBTChan
, sizeOfBTChan
, setMaxOfBTChan
, maxOfBTChan
) where
import Control.Concurrent.STM
import Control.Monad (when, liftM)
import Control.Applicative
data BTChan a =
BTChan { maxSize :: !Int
, channel :: (TChan a)
, readSize :: (TVar Int)
, writeSize :: (TVar Int)
}
newBTChanIO :: Int -> IO (BTChan a)
newBTChanIO m = BTChan m <$> newTChanIO <*> newTVarIO 0 <*> newTVarIO 0
newBTChan :: Int -> STM (BTChan a)
newBTChan m = BTChan m <$> newTChan <*> newTVar 0 <*> newTVar 0
writeBTChan :: BTChan a -> a -> STM ()
writeBTChan (BTChan mx c rdTV wrTV) x = do
sz <- readTVar wrTV
if (sz >= mx)
then do
rsz <- readTVar rdTV
let !newWR = sz + rsz
when (newWR >= mx) retry
writeTVar wrTV newWR
writeTVar rdTV 0
writeTChan c x
else do
writeTVar wrTV (sz + 1)
writeTChan c x
tryWriteBTChan :: BTChan a -> a -> STM Bool
tryWriteBTChan (BTChan mx c rdTV wrTV) x = do
sz <- readTVar wrTV
if (sz >= mx)
then do
rsz <- readTVar rdTV
let !newWR = sz + rsz
if (newWR >= mx)
then return False
else do writeTVar wrTV newWR
writeTVar rdTV 0
writeTChan c x
return True
else do
writeTVar wrTV (sz + 1)
writeTChan c x
return True
readBTChan :: BTChan a -> STM a
readBTChan (BTChan _ c rdTV wrTV) = do
x <- readTChan c
sz <- readTVar rdTV
let !sz' = sz 1
writeTVar rdTV sz'
return x
tryReadBTChan :: BTChan a -> STM (Maybe a)
tryReadBTChan bt = do
e <- isEmptyBTChan bt
if e then return Nothing else liftM Just (readBTChan bt)
unGetBTChan :: BTChan a -> a -> STM ()
unGetBTChan (BTChan m c rdTV wrTV) a = do
sz <- readTVar wrTV
if (sz >= m)
then do
rsz <- readTVar rdTV
let !newWR = sz + rsz
when (newWR >= m) retry
writeTVar wrTV newWR
writeTVar rdTV 0
unGetTChan c a
else do
let !s' = sz + 1
writeTVar wrTV s'
unGetTChan c a
isEmptyBTChan :: BTChan a -> STM Bool
isEmptyBTChan (BTChan _ c _ _) = isEmptyTChan c
sizeOfBTChan :: BTChan a -> STM Int
sizeOfBTChan (BTChan _ _ rdTV wrTV) = (+) <$> readTVar wrTV <*> readTVar rdTV
setMaxOfBTChan :: BTChan a -> Int -> BTChan a
setMaxOfBTChan (BTChan _ c rd wr) m = BTChan m c rd wr
maxOfBTChan :: BTChan a -> Int
maxOfBTChan (BTChan m _ _ _) = m