{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
module Control.Concurrent.KazuraQueue
( Queue
, newQueue
, readQueue
, readQueueWithoutMask
, tryReadQueue
, tryReadQueueWithoutMask
, writeQueue
, writeQueueWithoutMask
, lengthQueue
, lengthQueue'
) where
import Control.Concurrent.WVar (WCached, WTicket, WVar)
import qualified Control.Concurrent.WVar as WVar
import qualified Control.Concurrent as CC
import Control.Concurrent.MVar (MVar)
import qualified Control.Concurrent.MVar as MVar
import qualified Control.Exception as E
import qualified Control.Monad as M
import Control.Monad.Primitive (RealWorld)
import qualified Data.Atomics as Atm
import qualified Data.Atomics.Counter as Atm
import Data.Bits ((.&.))
import qualified Data.Bits as Bits
import Data.IORef (IORef)
import qualified Data.IORef as Ref
import qualified Data.Primitive.Array as Arr
{-# INLINE bufferLength #-}
bufferLength :: Int
bufferLength = 64
{-# INLINE logBufferLength #-}
logBufferLength :: Int
logBufferLength = 6
{-# INLINE divModBufferLength #-}
divModBufferLength :: Int -> (Int,Int)
divModBufferLength n = d `seq` m `seq` (d,m)
where
d = n `Bits.unsafeShiftR` logBufferLength
m = n .&. (bufferLength - 1)
data Queue a = Queue
{ queueWriteStream :: {-# UNPACK #-} !(IORef (Stream a))
, queueWriteCounter :: {-# UNPACK #-} !Atm.AtomicCounter
, queueReadStream :: {-# UNPACK #-} !(IORef (Stream a))
, queueReadState :: {-# UNPACK #-} !(WVar (ReadState a))
, queueNoneTicket :: !(Atm.Ticket (Item a))
}
data ReadState a = ReadState
{ rsCounter :: {-# UNPACK #-} !Atm.AtomicCounter
, rsLimit :: {-# UNPACK #-} !StreamIndex
}
type Buffer a = Arr.MutableArray RealWorld (Item a)
type BufferSource a = IO (Buffer a)
data Item a =
Item a
| None
| Wait {-# UNPACK #-} !(MVar a)
| Done
data Stream a = Stream
{ streamBuffer :: {-# UNPACK #-} !(Buffer a)
, streamNext :: {-# UNPACK #-} !(IORef (NextStream a))
, streamOffset :: {-# UNPACK #-} !StreamIndex
}
data NextStream a =
NextStream {-# UNPACK #-} !(Stream a)
| NextSource !(BufferSource a)
type StreamIndex = Int
type BufferIndex = Int
newBufferSource :: IO (BufferSource a)
newBufferSource = do
arr <- Arr.newArray bufferLength None
return (Arr.cloneMutableArray arr 0 bufferLength)
newReadState :: StreamIndex -> IO (WVar (ReadState a))
newReadState strIdx = do
rcounter <- Atm.newCounter strIdx
WVar.newWVar ReadState
{ rsCounter = rcounter
, rsLimit = strIdx
}
newQueue :: IO (Queue a)
newQueue = do
bufSrc <- newBufferSource
buf <- bufSrc
noneTicket <- Atm.readArrayElem buf 0
next <- Ref.newIORef $ NextSource bufSrc
let stream = Stream buf next initialOffset
wstream <- Ref.newIORef stream
wcounter <- Atm.newCounter initialIndex
rstream <- Ref.newIORef stream
rsvar <- newReadState initialIndex
return Queue
{ queueWriteStream = wstream
, queueWriteCounter = wcounter
, queueReadStream = rstream
, queueReadState = rsvar
, queueNoneTicket = noneTicket
}
where
initialOffset = maxBound - 3
initialIndex = initialOffset - 1
{-# INLINE waitItem #-}
waitItem :: Buffer a -> BufferIndex -> IO ()
waitItem buf bufIdx = do
ticket <- Atm.readArrayElem buf bufIdx
case Atm.peekTicket ticket of
None -> do
mv <- MVar.newEmptyMVar
(_ret, next) <- Atm.casArrayElem buf bufIdx ticket $! Wait mv
case Atm.peekTicket next of
None -> error "impossible case waitItem"
Wait mv' -> M.void $ MVar.readMVar mv'
_ -> return ()
Wait mv -> M.void $ MVar.readMVar mv
_ -> return ()
{-# INLINE writeItem #-}
writeItem :: Buffer a -> BufferIndex -> Atm.Ticket (Item a) -> a -> IO ()
writeItem buf bufIdx ticket a = do
(suc, next) <- Atm.casArrayElem buf bufIdx ticket (Item a)
M.unless suc $ case Atm.peekTicket next of
Wait mv -> do
Arr.writeArray buf bufIdx $ Item a
MVar.putMVar mv a
_ -> error "impossible case writeItem"
{-# INLINE readQueue #-}
readQueue :: Queue a -> IO a
readQueue = E.mask_ . readQueueWithoutMask
{-# INLINE readQueueWithoutMask #-}
readQueueWithoutMask :: Queue a -> IO a
readQueueWithoutMask queue@(Queue _ _ _ rsvar _) =
WVar.cacheWVar rsvar >>= readQueueRaw queue
readQueueRaw :: Queue a -> WCached (ReadState a) -> IO a
readQueueRaw queue rswc0 = do
rstr0 <- Ref.readIORef rstrRef
strIdx <- Atm.incrCounter 1 rcounter
if rlimit0 `gteIndex` strIdx
then readStream rstrRef rstr0 strIdx
else do
rswt1 <- extendReadStreamWithLock rstr0 rswc0 True True
let rswc1 = rswc0 { WVar.cachedTicket = rswt1 }
readQueueRaw queue rswc1
where
rstrRef = queueReadStream queue
rswt0 = WVar.cachedTicket rswc0
(ReadState rcounter rlimit0) = WVar.readWTicket rswt0
{-# INLINE tryReadQueue #-}
tryReadQueue :: Queue a -> IO (Maybe a)
tryReadQueue = E.mask_ . tryReadQueueWithoutMask
{-# INLINE tryReadQueueWithoutMask #-}
tryReadQueueWithoutMask :: Queue a -> IO (Maybe a)
tryReadQueueWithoutMask queue@(Queue _ _ _ rsvar _) =
WVar.cacheWVar rsvar >>= tryReadQueueRaw queue
tryReadQueueRaw :: Queue a -> WCached (ReadState a) -> IO (Maybe a)
tryReadQueueRaw queue rswc0 = do
rstr0 <- Ref.readIORef rstrRef
strIdx <- Atm.incrCounter 1 rcounter
if rlimit0 `gteIndex` strIdx
then Just <$> readStream rstrRef rstr0 strIdx
else do
rswt1 <- extendReadStreamWithLock rstr0 rswc0 False False
let rswc1 = rswc0 { WVar.cachedTicket = rswt1 }
(ReadState _ rlimit1) = WVar.readWTicket rswt1
if rlimit1 /= rlimit0
then tryReadQueueRaw queue rswc1
else return Nothing
where
rstrRef = queueReadStream queue
rswt0 = WVar.cachedTicket rswc0
(ReadState rcounter rlimit0) = WVar.readWTicket rswt0
{-# INLINE readStream #-}
readStream :: IORef (Stream a) -> Stream a -> StreamIndex -> IO a
readStream rstrRef rstr0 strIdx = do
(bufIdx, rstr1) <- targetStream rstr0 strIdx
M.when (bufIdx == 0) $ Ref.writeIORef rstrRef rstr1
let buf = streamBuffer rstr1
item <- Arr.readArray buf bufIdx
Arr.writeArray buf bufIdx Done
case item of
Item a -> return a
_ -> error "impossible case readQueue"
extendReadStreamWithLock ::
Stream a
-> WCached (ReadState a)
-> Bool
-> Bool
-> IO (WTicket (ReadState a))
extendReadStreamWithLock rstr0 rswc0 waitLock waitWrite = do
(suc, rswt1) <- WVar.tryTakeWCached rswc0
let rstate1 = WVar.readWTicket rswt1
if suc
then do
rstate2 <- extendReadStream rstate1 rstr0 waitWrite
`E.onException` WVar.putWCached rswc0 rstate1
WVar.putWCached rswc0 rstate2
else do
let rswc1 = rswc0 { WVar.cachedTicket = rswt1 }
if waitLock
then WVar.readFreshWCached rswc1
else do
rswc2 <- WVar.recacheWCached rswc1
return $ WVar.cachedTicket rswc2
{-# INLINE extendReadStream #-}
extendReadStream :: ReadState a -> Stream a -> Bool -> IO (ReadState a)
extendReadStream rstate0 rstr0 waitWrite = do
(rlimitNext1, rstr1) <- searchStreamReadLimit rstr0 rlimitNext0
if rlimitNext0 /= rlimitNext1
then newRState rlimitNext1
else if waitWrite
then do
let (Stream buf1 _ offset1) = rstr1
bufIdx1 = rlimitNext1 - offset1
waitItem buf1 bufIdx1
(rlimitNext2, _) <- searchStreamReadLimit rstr1 rlimitNext1
newRState rlimitNext2
else return rstate0
where
rlimit0 = rsLimit rstate0
rlimitNext0 = rlimit0 + 1
newRState rlimitNext = do
rcounter <- Atm.newCounter rlimit0
return rstate0
{ rsCounter = rcounter
, rsLimit = rlimitNext - 1
}
writeQueue :: Queue a -> a -> IO ()
writeQueue queue = E.mask_ . writeQueueRaw queue
{-# INLINE writeQueueRaw #-}
writeQueueWithoutMask :: Queue a -> a -> IO ()
writeQueueWithoutMask = writeQueueRaw
writeQueueRaw :: Queue a -> a -> IO ()
writeQueueRaw (Queue wstrRef wcounter _ _ noneTicket) a = do
wstr0 <- Ref.readIORef wstrRef
strIdx <- Atm.incrCounter 1 wcounter
(bufIdx, wstr1) <- targetStream wstr0 strIdx
writeItem (streamBuffer wstr1) bufIdx noneTicket a
M.when (bufIdx == 0) $ Ref.writeIORef wstrRef wstr1
{-# INLINE targetStream #-}
targetStream :: Stream a -> StreamIndex -> IO (BufferIndex, Stream a)
targetStream str0@(Stream _ _ offset) strIdx = do
let (strNum, bufIdx) = divModBufferLength $ strIdx - offset
str1 <- getStream strNum bufIdx str0
return (bufIdx, str1)
where
{-# INLINE getStream #-}
getStream 0 _ strA = return strA
getStream n bufIdx strA = do
strB <- waitNextStream strA bufIdx
getStream (n-1) bufIdx strB
{-# NOINLINE waitNextStream #-}
waitNextStream :: Stream a -> Int -> IO (Stream a)
waitNextStream (Stream _ nextStrRef offset) = go
where
{-# INLINE go #-}
go wait = do
ticket <- Atm.readForCAS nextStrRef
case Atm.peekTicket ticket of
NextStream strNext -> return strNext
nextSrc@(NextSource bufSrc)
| wait > 0 -> do
CC.yield
go (wait - 1)
| otherwise -> do
newBuf <- bufSrc
newNext <- Ref.newIORef nextSrc
let nextStrCand = NextStream Stream
{ streamBuffer = newBuf
, streamNext = newNext
, streamOffset = offset + bufferLength
}
(_, next) <- Atm.casIORef nextStrRef ticket nextStrCand
case Atm.peekTicket next of
NextStream nextStr -> return nextStr
NextSource _ -> go 1
{-# INLINE searchStreamReadLimit #-}
searchStreamReadLimit :: Stream a -> StreamIndex -> IO (StreamIndex, Stream a)
searchStreamReadLimit baseStr strIdx =
go (strIdx - streamOffset baseStr) baseStr
where
{-# INLINE go #-}
go bufIdx stream@(Stream buf _ offset) = do
ret <- searchBufferReadLimit buf bufIdx
case ret of
Just retBufIdx -> return (offset + retBufIdx, stream)
Nothing -> waitNextStream stream 0 >>= go 0
{-# INLINE searchBufferReadLimit #-}
searchBufferReadLimit :: Buffer a -> BufferIndex -> IO (Maybe BufferIndex)
searchBufferReadLimit buf = go
where
{-# INLINE go #-}
go bufIdx
| idxIsOutOfBuf = return Nothing
| otherwise = do
item <- Arr.readArray buf bufIdx
case item of
None -> return $ Just bufIdx
Wait _ -> return $ Just bufIdx
_ -> go $ bufIdx + 1
where
idxIsOutOfBuf = bufIdx >= bufferLength
lengthQueue :: Queue a -> IO Int
lengthQueue (Queue _ wcounter _ rsvar _) = do
rs <- WVar.readWVar rsvar
wcount <- Atm.readCounter wcounter
rcount <- Atm.readCounter $ rsCounter rs
return $ wcount - rcount
lengthQueue' :: Queue a -> IO Int
lengthQueue' queue = f <$> lengthQueue queue
where
f i | i > 0 = i
| otherwise = 0
{-# INLINE gteIndex #-}
gteIndex :: Int -> Int -> Bool
gteIndex a b | a - b < 0 = False
| otherwise = True