-- | Concurrent queue for single reader, single writer module Control.Distributed.Process.Internal.CQueue ( CQueue , BlockSpec(..) , newCQueue , enqueue , dequeue ) where import Control.Concurrent.MVar (MVar, newMVar, takeMVar, putMVar) import Control.Concurrent.STM ( atomically , TChan , newTChan , writeTChan , readTChan , tryReadTChan ) import Control.Applicative ((<$>), (<*>)) import Control.Exception (mask, onException) import System.Timeout (timeout) -- We use a TCHan rather than a Chan so that we have a non-blocking read data CQueue a = CQueue (MVar [a]) -- Arrived (TChan a) -- Incoming newCQueue :: IO (CQueue a) newCQueue = CQueue <$> newMVar [] <*> atomically newTChan enqueue :: CQueue a -> a -> IO () enqueue (CQueue _arrived incoming) a = atomically $ writeTChan incoming a data BlockSpec = NonBlocking | Blocking | Timeout Int -- | Dequeue an element -- -- The timeout (if any) is applied only to waiting for incoming messages, not -- to checking messages that have already arrived dequeue :: forall a b. CQueue a -- ^ Queue -> BlockSpec -- ^ Blocking behaviour -> [a -> Maybe b] -- ^ List of matches -> IO (Maybe b) -- ^ 'Nothing' only on timeout dequeue (CQueue arrived incoming) blockSpec matches = go where go :: IO (Maybe b) go = mask $ \restore -> do arr <- takeMVar arrived -- We first check the arrived messages. If we get interrupted during this -- search, we just put the MVar back (we haven't read from the Chan yet) (arr', mb) <- onException (restore (checkArrived [] arr)) (putMVar arrived arr) case (mb, blockSpec) of (Just b, _) -> do putMVar arrived arr' return (Just b) (Nothing, NonBlocking) -> checkNonBlocking arr' (Nothing, Blocking) -> Just <$> checkBlocking arr' (Nothing, Timeout n) -> timeout n $ checkBlocking arr' -- We reverse the accumulator on return only if we find a match checkArrived :: [a] -> [a] -> IO ([a], Maybe b) checkArrived acc [] = return (acc, Nothing) checkArrived acc (x:xs) = case check x of Just y -> return (reverse acc ++ xs, Just y) Nothing -> checkArrived (x:acc) xs -- If we call checkBlocking there may or may not be a timeout checkBlocking :: [a] -> IO b checkBlocking acc = do -- readTChan is a blocking call, and hence is interruptable. If it is -- interrupted, we put the value of the accumulator in 'arrived' -- (as opposed to the original value), so that no messages get lost -- (hence the low-level structure using mask rather than modifyMVar) x <- onException (atomically $ readTChan incoming) (putMVar arrived $ reverse acc) case check x of Nothing -> checkBlocking (x:acc) Just y -> putMVar arrived (reverse acc) >> return y -- checkNonBlocking is only called if there is no timeout checkNonBlocking :: [a] -> IO (Maybe b) checkNonBlocking acc = do -- tryReadTChan is *not* interruptible mx <- atomically $ tryReadTChan incoming case mx of Nothing -> putMVar arrived (reverse acc) >> return Nothing Just x -> case check x of Nothing -> checkNonBlocking (x:acc) Just y -> putMVar arrived (reverse acc) >> return (Just y) check :: a -> Maybe b check = checkMatches matches checkMatches :: [a -> Maybe b] -> a -> Maybe b checkMatches [] _ = Nothing checkMatches (m:ms) a = case m a of Nothing -> checkMatches ms a Just b -> Just b