{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} -- | KazuraQueue is the fast queue implementation inspired by unagi-chan. module Control.Concurrent.KazuraQueue ( Queue , newQueue , readQueue , readQueueWithoutMask , tryReadQueue , tryReadQueueWithoutMask , writeQueue , writeQueueWithoutMask , lengthQueue , lengthQueue' ) where import Control.Concurrent.WVar (WCached, WTicket, WVar) import qualified Control.Concurrent.WVar as WVar import qualified Control.Concurrent as CC import Control.Concurrent.MVar (MVar) import qualified Control.Concurrent.MVar as MVar import qualified Control.Exception as E import qualified Control.Monad as M import Control.Monad.Primitive (RealWorld) import qualified Data.Atomics as Atm import qualified Data.Atomics.Counter as Atm import Data.Bits ((.&.)) import qualified Data.Bits as Bits import Data.IORef (IORef) import qualified Data.IORef as Ref import qualified Data.Primitive.Array as Arr -------------------------------- -- constants and its utilities {-# INLINE bufferLength #-} bufferLength :: Int bufferLength = 64 {-# INLINE logBufferLength #-} logBufferLength :: Int logBufferLength = 6 {-# INLINE divModBufferLength #-} divModBufferLength :: Int -> (Int,Int) divModBufferLength n = d `seq` m `seq` (d,m) where d = n `Bits.unsafeShiftR` logBufferLength m = n .&. (bufferLength - 1) -------------------------------- -- Queue -- | Type of a Queue. /a/ is the type of an item in the Queue. data Queue a = Queue { queueWriteStream :: {-# UNPACK #-} !(IORef (Stream a)) , queueWriteCounter :: {-# UNPACK #-} !Atm.AtomicCounter , queueReadStream :: {-# UNPACK #-} !(IORef (Stream a)) , queueReadState :: {-# UNPACK #-} !(WVar (ReadState a)) , queueNoneTicket :: !(Atm.Ticket (Item a)) } data ReadState a = ReadState { rsCounter :: {-# UNPACK #-} !Atm.AtomicCounter , rsLimit :: {-# UNPACK #-} !StreamIndex } type Buffer a = Arr.MutableArray RealWorld (Item a) type BufferSource a = IO (Buffer a) data Item a = Item a | None | Wait {-# UNPACK #-} !(MVar a) | Done data Stream a = Stream { streamBuffer :: {-# UNPACK #-} !(Buffer a) , streamNext :: {-# UNPACK #-} !(IORef (NextStream a)) , streamOffset :: {-# UNPACK #-} !StreamIndex } data NextStream a = NextStream {-# UNPACK #-} !(Stream a) | NextSource !(BufferSource a) type StreamIndex = Int type BufferIndex = Int ------------------------------ newBufferSource :: IO (BufferSource a) newBufferSource = do arr <- Arr.newArray bufferLength None return (Arr.cloneMutableArray arr 0 bufferLength) newReadState :: StreamIndex -> IO (WVar (ReadState a)) newReadState strIdx = do rcounter <- Atm.newCounter strIdx WVar.newWVar ReadState { rsCounter = rcounter , rsLimit = strIdx } -- | Create a new empty 'Queue'. newQueue :: IO (Queue a) newQueue = do bufSrc <- newBufferSource buf <- bufSrc noneTicket <- Atm.readArrayElem buf 0 next <- Ref.newIORef $ NextSource bufSrc let stream = Stream buf next initialOffset wstream <- Ref.newIORef stream wcounter <- Atm.newCounter initialIndex rstream <- Ref.newIORef stream rsvar <- newReadState initialIndex return Queue { queueWriteStream = wstream , queueWriteCounter = wcounter , queueReadStream = rstream , queueReadState = rsvar , queueNoneTicket = noneTicket } where -- for test of counter overflow initialOffset = maxBound - 3 initialIndex = initialOffset - 1 ---------------------------------------------------------- {-# INLINE waitItem #-} waitItem :: Buffer a -> BufferIndex -> IO () waitItem buf bufIdx = do ticket <- Atm.readArrayElem buf bufIdx case Atm.peekTicket ticket of None -> do mv <- MVar.newEmptyMVar (_ret, next) <- Atm.casArrayElem buf bufIdx ticket $! Wait mv case Atm.peekTicket next of None -> error "impossible case waitItem" Wait mv' -> M.void $ MVar.readMVar mv' _ -> return () Wait mv -> M.void $ MVar.readMVar mv _ -> return () {-# INLINE writeItem #-} writeItem :: Buffer a -> BufferIndex -> Atm.Ticket (Item a) -> a -> IO () writeItem buf bufIdx ticket a = do (suc, next) <- Atm.casArrayElem buf bufIdx ticket (Item a) M.unless suc $ case Atm.peekTicket next of Wait mv -> do Arr.writeArray buf bufIdx $ Item a MVar.putMVar mv a _ -> error "impossible case writeItem" ---------------------------------------------------------- -- | Read an item from the 'Queue'. {-# INLINE readQueue #-} readQueue :: Queue a -> IO a readQueue = E.mask_ . readQueueWithoutMask -- | Non-masked version of 'readQueue'. -- It is not safe for asynchronous exception. {-# INLINE readQueueWithoutMask #-} readQueueWithoutMask :: Queue a -> IO a readQueueWithoutMask queue@(Queue _ _ _ rsvar _) = WVar.cacheWVar rsvar >>= readQueueRaw queue readQueueRaw :: Queue a -> WCached (ReadState a) -> IO a readQueueRaw queue rswc0 = do rstr0 <- Ref.readIORef rstrRef strIdx <- Atm.incrCounter 1 rcounter if rlimit0 `gteIndex` strIdx then readStream rstrRef rstr0 strIdx else do rswt1 <- extendReadStreamWithLock rstr0 rswc0 True True let rswc1 = rswc0 { WVar.cachedTicket = rswt1 } readQueueRaw queue rswc1 where rstrRef = queueReadStream queue rswt0 = WVar.cachedTicket rswc0 (ReadState rcounter rlimit0) = WVar.readWTicket rswt0 -- | Try to read an item from the 'Queue'. It never blocks. -- -- Note: It decreases "length" of 'Queue' even when it returns Nothing. -- In such case, "length" will be lower than 0. {-# INLINE tryReadQueue #-} tryReadQueue :: Queue a -> IO (Maybe a) tryReadQueue = E.mask_ . tryReadQueueWithoutMask -- | Non-masked version of 'tryReadQueue'. -- It is not safe for asynchronous exception. {-# INLINE tryReadQueueWithoutMask #-} tryReadQueueWithoutMask :: Queue a -> IO (Maybe a) tryReadQueueWithoutMask queue@(Queue _ _ _ rsvar _) = WVar.cacheWVar rsvar >>= tryReadQueueRaw queue tryReadQueueRaw :: Queue a -> WCached (ReadState a) -> IO (Maybe a) tryReadQueueRaw queue rswc0 = do rstr0 <- Ref.readIORef rstrRef strIdx <- Atm.incrCounter 1 rcounter if rlimit0 `gteIndex` strIdx then Just <$> readStream rstrRef rstr0 strIdx else do rswt1 <- extendReadStreamWithLock rstr0 rswc0 False False let rswc1 = rswc0 { WVar.cachedTicket = rswt1 } (ReadState _ rlimit1) = WVar.readWTicket rswt1 if rlimit1 /= rlimit0 then tryReadQueueRaw queue rswc1 else return Nothing where rstrRef = queueReadStream queue rswt0 = WVar.cachedTicket rswc0 (ReadState rcounter rlimit0) = WVar.readWTicket rswt0 {-# INLINE readStream #-} readStream :: IORef (Stream a) -> Stream a -> StreamIndex -> IO a readStream rstrRef rstr0 strIdx = do (bufIdx, rstr1) <- targetStream rstr0 strIdx M.when (bufIdx == 0) $ Ref.writeIORef rstrRef rstr1 let buf = streamBuffer rstr1 item <- Arr.readArray buf bufIdx Arr.writeArray buf bufIdx Done case item of Item a -> return a _ -> error "impossible case readQueue" extendReadStreamWithLock :: Stream a -> WCached (ReadState a) -> Bool -> Bool -> IO (WTicket (ReadState a)) extendReadStreamWithLock rstr0 rswc0 waitLock waitWrite = do (suc, rswt1) <- WVar.tryTakeWCached rswc0 let rstate1 = WVar.readWTicket rswt1 if suc then do rstate2 <- extendReadStream rstate1 rstr0 waitWrite `E.onException` WVar.putWCached rswc0 rstate1 WVar.putWCached rswc0 rstate2 else do let rswc1 = rswc0 { WVar.cachedTicket = rswt1 } if waitLock then WVar.readFreshWCached rswc1 else do rswc2 <- WVar.recacheWCached rswc1 return $ WVar.cachedTicket rswc2 {-# INLINE extendReadStream #-} extendReadStream :: ReadState a -> Stream a -> Bool -> IO (ReadState a) extendReadStream rstate0 rstr0 waitWrite = do (rlimitNext1, rstr1) <- searchStreamReadLimit rstr0 rlimitNext0 if rlimitNext0 /= rlimitNext1 then newRState rlimitNext1 else if waitWrite then do let (Stream buf1 _ offset1) = rstr1 bufIdx1 = rlimitNext1 - offset1 waitItem buf1 bufIdx1 (rlimitNext2, _) <- searchStreamReadLimit rstr1 rlimitNext1 newRState rlimitNext2 else return rstate0 where rlimit0 = rsLimit rstate0 rlimitNext0 = rlimit0 + 1 newRState rlimitNext = do rcounter <- Atm.newCounter rlimit0 return rstate0 { rsCounter = rcounter , rsLimit = rlimitNext - 1 } -- | Write an item to the 'Queue'. -- The item is evaluated (WHNF) before actual queueing. writeQueue :: Queue a -> a -> IO () writeQueue queue = E.mask_ . writeQueueRaw queue -- | Non-masked version of 'writeQueue'. -- It is not safe for asynchronous exception. {-# INLINE writeQueueRaw #-} writeQueueWithoutMask :: Queue a -> a -> IO () writeQueueWithoutMask = writeQueueRaw writeQueueRaw :: Queue a -> a -> IO () writeQueueRaw (Queue wstrRef wcounter _ _ noneTicket) a = do wstr0 <- Ref.readIORef wstrRef strIdx <- Atm.incrCounter 1 wcounter (bufIdx, wstr1) <- targetStream wstr0 strIdx writeItem (streamBuffer wstr1) bufIdx noneTicket a M.when (bufIdx == 0) $ Ref.writeIORef wstrRef wstr1 {-# INLINE targetStream #-} targetStream :: Stream a -> StreamIndex -> IO (BufferIndex, Stream a) targetStream str0@(Stream _ _ offset) strIdx = do let (strNum, bufIdx) = divModBufferLength $ strIdx - offset str1 <- getStream strNum bufIdx str0 return (bufIdx, str1) where {-# INLINE getStream #-} getStream 0 _ strA = return strA getStream n bufIdx strA = do strB <- waitNextStream strA bufIdx getStream (n-1) bufIdx strB {-# NOINLINE waitNextStream #-} waitNextStream :: Stream a -> Int -> IO (Stream a) waitNextStream (Stream _ nextStrRef offset) = go where {-# INLINE go #-} go wait = do ticket <- Atm.readForCAS nextStrRef case Atm.peekTicket ticket of NextStream strNext -> return strNext nextSrc@(NextSource bufSrc) | wait > 0 -> do CC.yield go (wait - 1) | otherwise -> do newBuf <- bufSrc newNext <- Ref.newIORef nextSrc let nextStrCand = NextStream Stream { streamBuffer = newBuf , streamNext = newNext , streamOffset = offset + bufferLength } (_, next) <- Atm.casIORef nextStrRef ticket nextStrCand case Atm.peekTicket next of NextStream nextStr -> return nextStr NextSource _ -> go 1 -- | Search 'Stream' and return 'StreamIndex' and its 'Stream' -- of the oldest unavailable Item. {-# INLINE searchStreamReadLimit #-} searchStreamReadLimit :: Stream a -> StreamIndex -> IO (StreamIndex, Stream a) searchStreamReadLimit baseStr strIdx = go (strIdx - streamOffset baseStr) baseStr where {-# INLINE go #-} go bufIdx stream@(Stream buf _ offset) = do ret <- searchBufferReadLimit buf bufIdx case ret of Just retBufIdx -> return (offset + retBufIdx, stream) Nothing -> waitNextStream stream 0 >>= go 0 -- | Search 'Buffer' and return 'BufferIndex' -- of the oldest unavailable Item. -- If all Item in the Buffer is ready, return Nothing. {-# INLINE searchBufferReadLimit #-} searchBufferReadLimit :: Buffer a -> BufferIndex -> IO (Maybe BufferIndex) searchBufferReadLimit buf = go where {-# INLINE go #-} go bufIdx | idxIsOutOfBuf = return Nothing | otherwise = do item <- Arr.readArray buf bufIdx case item of None -> return $ Just bufIdx Wait _ -> return $ Just bufIdx _ -> go $ bufIdx + 1 where idxIsOutOfBuf = bufIdx >= bufferLength -- | Get the length of the items in the 'Queue'. -- -- Caution: It returns the value which is lower than 0 -- when the Queue is empty and some threads are waiting for new value. lengthQueue :: Queue a -> IO Int lengthQueue (Queue _ wcounter _ rsvar _) = do rs <- WVar.readWVar rsvar wcount <- Atm.readCounter wcounter rcount <- Atm.readCounter $ rsCounter rs return $ wcount - rcount -- | Non-minus version of 'lengthQueue'. lengthQueue' :: Queue a -> IO Int lengthQueue' queue = f <$> lengthQueue queue where f i | i > 0 = i | otherwise = 0 {-# INLINE gteIndex #-} gteIndex :: Int -> Int -> Bool gteIndex a b | a - b < 0 = False | otherwise = True