module Data.NonBlocking.LockFree.Treiber(TreiberStack(), TreiberStackIO, TreiberStackSTM, newTreiberStack, pushTreiberStack, popTreiberStack) where
import Control.Concurrent.STM (STM())
import Control.Concurrent.STM.TVar (TVar())
import Control.Monad(join, when)
import Control.Monad.Loops(whileM_)
import Control.Monad.Ref(MonadAtomicRef, newRef, readRef, writeRef, atomicModifyRef)
import Data.IORef(IORef)
import GHC.Exts (Int(I#))
import GHC.Prim (reallyUnsafePtrEquality#)
type TreiberStackIO a = TreiberStack IORef a
type TreiberStackSTM a = TreiberStack TVar a
data TreiberStack r a = TreiberStack (r (TreiberElem r a))
data TreiberElem r a = TreiberElem a (r (TreiberElem r a)) | End
instance Eq (TreiberElem r a) where
End == End = True
(TreiberElem x rest1) == (TreiberElem y rest2) = (ptrEq rest1 rest2)
newTreiberStack :: (MonadAtomicRef r m) => m (TreiberStack r a)
newTreiberStack = do
ref <- newRef End
return (TreiberStack ref)
pushTreiberStack :: (MonadAtomicRef r m) => TreiberStack r a -> a -> m ()
pushTreiberStack (TreiberStack x) v = do
let partConstr = TreiberElem v
b <- newRef False
whileM_ (readRef b >>= return . not) $ do
z <- readRef x
res <- newRef z
suc <- cas x z (partConstr res)
writeRef b suc
popTreiberStack :: (MonadAtomicRef r m) => TreiberStack r a -> m (Maybe a)
popTreiberStack (TreiberStack x) = do
b <- newRef False
ret <- newRef Nothing
whileM_ (readRef b >>= return . not) $ do
y <- readRef x
case y of
End -> writeRef b True
(TreiberElem elem rest) -> do
z <- readRef rest
suc <- cas x y z
t <- readRef x
writeRef b suc
when suc $ writeRef ret (Just elem)
readRef ret
cas :: (MonadAtomicRef r m, Eq a) => r a -> a -> a -> m Bool
cas ref comp rep = atomicModifyRef ref (\val -> let b = val == comp in (if b then rep else val, b))
ptrEq :: a -> a -> Bool
ptrEq !x !y = I# (reallyUnsafePtrEquality# x y) == 1