module Control.Concurrent.STM.BTChan
( BTChan
, newBTChan
, newBTChanIO
, writeBTChan
, readBTChan
, tryWriteBTChan
, tryReadBTChan
, unGetBTChan
, isEmptyBTChan
, sizeOfBTChan
, setMaxOfBTChan
, maxOfBTChan
) where
import Control.Concurrent.STM
import Control.Monad (when, liftM)
data BTChan a =
BTChan { maxSize :: !Int
, channel :: (TChan a)
, readSize :: (TVar Int)
, writeSize :: (TVar Int)
}
newBTChanIO :: Int -> IO (BTChan a)
newBTChanIO m = do
c <- newTChanIO
rs <- newTVarIO 0
ws <- newTVarIO 0
return $ BTChan m c rs ws
newBTChan :: Int -> STM (BTChan a)
newBTChan m = do
c <- newTChan
rs <- newTVar 0
ws <- newTVar 0
return $ BTChan m c rs ws
writeBTChan :: BTChan a -> a -> STM ()
writeBTChan (BTChan mx c rdTV wrTV) x = do
sz <- readTVar wrTV
if (sz >= mx)
then do
rsz <- readTVar rdTV
let !newWR = sz + rsz
when (newWR >= mx) retry
writeTVar wrTV (newWR+1)
writeTVar rdTV 0
writeTChan c x
else do
writeTVar wrTV (sz + 1)
writeTChan c x
tryWriteBTChan :: BTChan a -> a -> STM Bool
tryWriteBTChan (BTChan mx c rdTV wrTV) x = do
sz <- readTVar wrTV
if (sz >= mx)
then do
rsz <- readTVar rdTV
let !newWR = sz + rsz
if (newWR >= mx)
then return False
else do writeTVar wrTV (newWR + 1)
writeTVar rdTV 0
writeTChan c x
return True
else do
writeTVar wrTV (sz + 1)
writeTChan c x
return True
readBTChan :: BTChan a -> STM a
readBTChan (BTChan _ c rdTV wrTV) = do
x <- readTChan c
sz <- readTVar rdTV
let !sz' = sz 1
writeTVar rdTV sz'
return x
tryReadBTChan :: BTChan a -> STM (Maybe a)
tryReadBTChan bt = do
e <- isEmptyBTChan bt
if e then return Nothing else liftM Just (readBTChan bt)
unGetBTChan :: BTChan a -> a -> STM ()
unGetBTChan (BTChan m c rdTV wrTV) a = do
sz <- readTVar wrTV
if (sz >= m)
then do
rsz <- readTVar rdTV
let !newWR = sz + rsz
when (newWR >= m) retry
writeTVar wrTV (newWR + 1)
writeTVar rdTV 0
unGetTChan c a
else do
let !s' = sz + 1
writeTVar wrTV s'
unGetTChan c a
isEmptyBTChan :: BTChan a -> STM Bool
isEmptyBTChan (BTChan _ c _ _) = isEmptyTChan c
sizeOfBTChan :: BTChan a -> STM Int
sizeOfBTChan (BTChan _ _ rdTV wrTV) = do
w <- readTVar wrTV
r <- readTVar rdTV
return $ w + r
setMaxOfBTChan :: BTChan a -> Int -> BTChan a
setMaxOfBTChan (BTChan _ c rd wr) m = BTChan m c rd wr
maxOfBTChan :: BTChan a -> Int
maxOfBTChan (BTChan m _ _ _) = m