{-# LANGUAGE BangPatterns, CPP, MagicHash, UnboxedTuples, ScopedTypeVariables #-}
module Data.Concurrent.Queue.MichaelScott
(
LinkedQueue(), newQ, nullQ, pushL, tryPopR,
)
where
import Data.IORef (readIORef, newIORef)
import System.IO (stderr)
#ifdef DEBUG
import Data.ByteString.Char8 (hPutStrLn, pack)
#endif
import GHC.IORef(IORef(IORef))
import GHC.STRef(STRef(STRef))
import qualified Data.Concurrent.Deque.Class as C
import Data.Atomics (readForCAS, casIORef, Ticket, peekTicket)
import GHC.Base hiding ((==#), sameMutVar#)
import GHC.Exts hiding ((==#), sameMutVar#)
import qualified GHC.Exts as Exts
(==#) :: Int# -> Int# -> Bool
==# :: Int# -> Int# -> Bool
(==#) Int#
x Int#
y = case Int#
x Int# -> Int# -> Int#
Exts.==# Int#
y of { Int#
0# -> Bool
False; Int#
_ -> Bool
True }
sameMutVar# :: MutVar# s a -> MutVar# s a -> Bool
sameMutVar# :: forall s a. MutVar# s a -> MutVar# s a -> Bool
sameMutVar# MutVar# s a
x MutVar# s a
y = case forall d a. MutVar# d a -> MutVar# d a -> Int#
Exts.sameMutVar# MutVar# s a
x MutVar# s a
y of { Int#
0# -> Bool
False; Int#
_ -> Bool
True }
data LinkedQueue a = LQ
{ forall a. LinkedQueue a -> IORef (Pair a)
head :: {-# UNPACK #-} !(IORef (Pair a))
, forall a. LinkedQueue a -> IORef (Pair a)
tail :: {-# UNPACK #-} !(IORef (Pair a))
}
data Pair a = Null | Cons a {-# UNPACK #-}!(IORef (Pair a))
{-# INLINE pairEq #-}
pairEq :: Pair a -> Pair a -> Bool
pairEq :: forall a. Pair a -> Pair a -> Bool
pairEq Pair a
Null Pair a
Null = Bool
True
pairEq (Cons a
_ (IORef (STRef MutVar# RealWorld (Pair a)
mv1)))
(Cons a
_ (IORef (STRef MutVar# RealWorld (Pair a)
mv2))) = forall s a. MutVar# s a -> MutVar# s a -> Bool
sameMutVar# MutVar# RealWorld (Pair a)
mv1 MutVar# RealWorld (Pair a)
mv2
pairEq Pair a
_ Pair a
_ = Bool
False
pushL :: forall a . LinkedQueue a -> a -> IO ()
pushL :: forall a. LinkedQueue a -> a -> IO ()
pushL q :: LinkedQueue a
q@(LQ IORef (Pair a)
headPtr IORef (Pair a)
tailPtr) a
val = do
IORef (Pair a)
r <- forall a. a -> IO (IORef a)
newIORef forall a. Pair a
Null
let newp :: Pair a
newp = forall a. a -> IORef (Pair a) -> Pair a
Cons a
val IORef (Pair a)
r
loop :: IO ()
loop :: IO ()
loop = do
Ticket (Pair a)
tailTicket <- forall a. IORef a -> IO (Ticket a)
readForCAS IORef (Pair a)
tailPtr
case forall a. Ticket a -> a
peekTicket Ticket (Pair a)
tailTicket of
Pair a
Null -> forall a. HasCallStack => [Char] -> a
error [Char]
"push: LinkedQueue invariants broken. Internal error."
Cons a
_ IORef (Pair a)
nextPtr -> do
Ticket (Pair a)
nextTicket <- forall a. IORef a -> IO (Ticket a)
readForCAS IORef (Pair a)
nextPtr
#ifdef RECHECK_ASSUMPTIONS
(tailTicket', tail') <- readForCAS tailPtr
if not (pairEq tail tail') then loop
else case next of
#else
case forall a. Ticket a -> a
peekTicket Ticket (Pair a)
nextTicket of
#endif
Pair a
Null -> do (Bool
b,Ticket (Pair a)
newtick) <- forall a. IORef a -> Ticket a -> a -> IO (Bool, Ticket a)
casIORef IORef (Pair a)
nextPtr Ticket (Pair a)
nextTicket Pair a
newp
case Bool
b of
Bool
True -> do
(Bool, Ticket (Pair a))
_ <- forall a. IORef a -> Ticket a -> a -> IO (Bool, Ticket a)
casIORef IORef (Pair a)
tailPtr Ticket (Pair a)
tailTicket Pair a
newp
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Bool
False -> IO ()
loop
nxt :: Pair a
nxt@(Cons a
_ IORef (Pair a)
_) -> do
(Bool, Ticket (Pair a))
_ <- forall a. IORef a -> Ticket a -> a -> IO (Bool, Ticket a)
casIORef IORef (Pair a)
tailPtr Ticket (Pair a)
tailTicket Pair a
nxt
IO ()
loop
IO ()
loop
checkInvariant :: String -> LinkedQueue a -> IO ()
checkInvariant :: forall a. [Char] -> LinkedQueue a -> IO ()
checkInvariant [Char]
s (LQ IORef (Pair a)
headPtr IORef (Pair a)
tailPtr) =
do Pair a
head <- forall a. IORef a -> IO a
readIORef IORef (Pair a)
headPtr
Pair a
tail <- forall a. IORef a -> IO a
readIORef IORef (Pair a)
tailPtr
if (Bool -> Bool
not (forall a. Pair a -> Pair a -> Bool
pairEq Pair a
head Pair a
tail))
then case Pair a
head of
Pair a
Null -> forall a. HasCallStack => [Char] -> a
error ([Char]
s forall a. [a] -> [a] -> [a]
++ [Char]
" checkInvariant: LinkedQueue invariants broken. Internal error.")
Cons a
_ IORef (Pair a)
next -> do
Pair a
next' <- forall a. IORef a -> IO a
readIORef IORef (Pair a)
next
case Pair a
next' of
Pair a
Null -> forall a. HasCallStack => [Char] -> a
error ([Char]
s forall a. [a] -> [a] -> [a]
++ [Char]
" checkInvariant: next' should not be null")
Pair a
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
else forall (m :: * -> *) a. Monad m => a -> m a
return ()
tryPopR :: forall a . LinkedQueue a -> IO (Maybe a)
tryPopR :: forall a. LinkedQueue a -> IO (Maybe a)
tryPopR q :: LinkedQueue a
q@(LQ IORef (Pair a)
headPtr IORef (Pair a)
tailPtr) = Int -> IO (Maybe a)
loop Int
0
where
loop :: Int -> IO (Maybe a)
#ifdef DEBUG
loop 25 = do hPutStrLn stderr (pack "tryPopR: tried ~25 times!!"); loop 26
loop 50 = do hPutStrLn stderr (pack "tryPopR: tried ~50 times!!"); loop 51
loop 100 = do hPutStrLn stderr (pack "tryPopR: tried ~100 times!!"); loop 101
loop 1000 = do hPutStrLn stderr (pack "tryPopR: tried ~1000 times!!"); loop 1001
#endif
loop :: Int -> IO (Maybe a)
loop !Int
tries = do
Ticket (Pair a)
headTicket <- forall a. IORef a -> IO (Ticket a)
readForCAS IORef (Pair a)
headPtr
Ticket (Pair a)
tailTicket <- forall a. IORef a -> IO (Ticket a)
readForCAS IORef (Pair a)
tailPtr
case forall a. Ticket a -> a
peekTicket Ticket (Pair a)
headTicket of
Pair a
Null -> forall a. HasCallStack => [Char] -> a
error [Char]
"tryPopR: LinkedQueue invariants broken. Internal error."
head :: Pair a
head@(Cons a
_ IORef (Pair a)
next) -> do
Ticket (Pair a)
nextTicket' <- forall a. IORef a -> IO (Ticket a)
readForCAS IORef (Pair a)
next
#ifdef RECHECK_ASSUMPTIONS
head' <- readIORef headPtr
if not (pairEq head head') then loop (tries+1) else do
#else
let head' :: Pair a
head' = Pair a
head
do
#endif
if forall a. Pair a -> Pair a -> Bool
pairEq Pair a
head (forall a. Ticket a -> a
peekTicket Ticket (Pair a)
tailTicket) then do
case forall a. Ticket a -> a
peekTicket Ticket (Pair a)
nextTicket' of
Pair a
Null -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
next' :: Pair a
next'@(Cons a
_ IORef (Pair a)
_) -> do
forall a. IORef a -> Ticket a -> a -> IO (Bool, Ticket a)
casIORef IORef (Pair a)
tailPtr Ticket (Pair a)
tailTicket Pair a
next'
Int -> IO (Maybe a)
loop (Int
triesforall a. Num a => a -> a -> a
+Int
1)
else do
case forall a. Ticket a -> a
peekTicket Ticket (Pair a)
nextTicket' of
Pair a
Null -> forall a. HasCallStack => [Char] -> a
error [Char]
"tryPop: Internal error. Next should not be null if head/=tail."
next' :: Pair a
next'@(Cons a
value IORef (Pair a)
_) -> do
(Bool
b,Ticket (Pair a)
_) <- forall a. IORef a -> Ticket a -> a -> IO (Bool, Ticket a)
casIORef IORef (Pair a)
headPtr Ticket (Pair a)
headTicket Pair a
next'
case Bool
b of
Bool
True -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just a
value)
Bool
False -> Int -> IO (Maybe a)
loop (Int
triesforall a. Num a => a -> a -> a
+Int
1)
newQ :: IO (LinkedQueue a)
newQ :: forall a. IO (LinkedQueue a)
newQ = do
IORef (Pair a)
r <- forall a. a -> IO (IORef a)
newIORef forall a. Pair a
Null
let newp :: Pair a
newp = forall a. a -> IORef (Pair a) -> Pair a
Cons (forall a. HasCallStack => [Char] -> a
error [Char]
"LinkedQueue: Used uninitialized magic value.") IORef (Pair a)
r
IORef (Pair a)
hd <- forall a. a -> IO (IORef a)
newIORef Pair a
newp
IORef (Pair a)
tl <- forall a. a -> IO (IORef a)
newIORef Pair a
newp
forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. IORef (Pair a) -> IORef (Pair a) -> LinkedQueue a
LQ IORef (Pair a)
hd IORef (Pair a)
tl)
nullQ :: LinkedQueue a -> IO Bool
nullQ :: forall a. LinkedQueue a -> IO Bool
nullQ (LQ IORef (Pair a)
headPtr IORef (Pair a)
tailPtr) = do
Pair a
head <- forall a. IORef a -> IO a
readIORef IORef (Pair a)
headPtr
Pair a
tail <- forall a. IORef a -> IO a
readIORef IORef (Pair a)
tailPtr
forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. Pair a -> Pair a -> Bool
pairEq Pair a
head Pair a
tail)
instance C.DequeClass LinkedQueue where
newQ :: forall a. IO (LinkedQueue a)
newQ = forall a. IO (LinkedQueue a)
newQ
nullQ :: forall a. LinkedQueue a -> IO Bool
nullQ = forall a. LinkedQueue a -> IO Bool
nullQ
pushL :: forall a. LinkedQueue a -> a -> IO ()
pushL = forall a. LinkedQueue a -> a -> IO ()
pushL
tryPopR :: forall a. LinkedQueue a -> IO (Maybe a)
tryPopR = forall a. LinkedQueue a -> IO (Maybe a)
tryPopR
leftThreadSafe :: forall elt. LinkedQueue elt -> Bool
leftThreadSafe LinkedQueue elt
_ = Bool
True
rightThreadSafe :: forall elt. LinkedQueue elt -> Bool
rightThreadSafe LinkedQueue elt
_ = Bool
True