{-# LANGUAGE NamedFieldPuns, BangPatterns #-} {-# LANGUAGE RankNTypes #-} -- | A concurrent finite map represented as a single linked list. -- -- In contrast to standard maps, this one only allows lookups and insertions, -- not modifications or removals. While modifications would be fairly easy to -- add, removals would significantly complicate the logic, and aren't needed for -- the primary application -- LVars. -- -- The interface is also somewhat low-level: rather than a standard insert -- function, @tryInsert@ takes a "token" (i.e. a pointer into the linked list) -- and attempts to insert at that location (but may fail). Tokens are acquired -- through the @find@ function, which yields a token in the case that a key is -- *not* found; the token represents the location in the list where the key -- *should* go. This low-level interface is intended for use in higher-level -- data structures, e.g. SkipListMap. module Data.Concurrent.LinkedMap ( LMap(), LMList(..), newLMap, Token(), value, find, FindResult(..), tryInsert, foldlWithKey, map, reverse, head, toList, fromList, findIndex, -- * Utilities for splitting/slicing halve, halve', dropUntil ) where import Data.IORef import Data.Atomics import Control.Reagent -- AT: not yet using this, but would be nice to refactor -- to use it. import Control.Monad.IO.Class import Control.Exception (assert) import Prelude hiding (reverse, map, head) -- | A concurrent finite map, represented as a linked list data LMList k v = Node k v {-# UNPACK #-} !(IORef (LMList k v)) | Empty type LMap k v = IORef (LMList k v) -- | Create a new concurrent map newLMap :: IO (LMap k v) newLMap = newIORef Empty -- | A position in the map into which a key/value pair can be inserted data Token k v = Token { keyToInsert :: k, -- ^ what key were we looking up? value :: Maybe v, -- ^ the value at this position in the map nextRef :: IORef (LMList k v), -- ^ the reference at which to insert nextTicket :: Ticket (LMList k v) -- ^ a ticket for the old value of nextRef } -- | Either the value associated with a key, or else a token at the position -- where that key should go. data FindResult k v = Found v | NotFound (Token k v) -- | Attempt to locate a key in the map {-# INLINE find #-} find :: Ord k => LMap k v -> k -> IO (FindResult k v) find m k = findInner m Nothing where findInner m v = do nextTicket <- readForCAS m let stopHere = NotFound $ Token {keyToInsert = k, value = v, nextRef = m, nextTicket} case peekTicket nextTicket of Empty -> return stopHere Node k' v' next -> case compare k k' of LT -> return stopHere EQ -> return $ Found v' GT -> findInner next (Just v') -- | Attempt to insert a key/value pair at the given location (where the key is -- given by the token). NB: tryInsert will *always* fail after the first attempt. -- If successful, returns a (mutable!) view of the map beginning at the given key. {-# INLINE tryInsert #-} tryInsert :: Token k v -> v -> IO (Maybe (LMap k v)) tryInsert Token { keyToInsert, nextRef, nextTicket } v = do newRef <- newIORef $ peekTicket nextTicket (success, _) <- casIORef nextRef nextTicket $ Node keyToInsert v newRef return $ if success then Just nextRef else Nothing -- | Concurrently fold over all key/value pairs in the map within the given -- monad, in increasing key order. Inserts that arrive concurrently may or may -- not be included in the fold. -- -- Strict in the accumulator. foldlWithKey :: Monad m => (forall x . IO x -> m x) -> (a -> k -> v -> m a) -> a -> LMap k v -> m a foldlWithKey liftIO f !a !m = do n <- liftIO$ readIORef m case n of Empty -> return a Node k v next -> do a' <- f a k v foldlWithKey liftIO f a' next -- | Map over a snapshot of the list. Inserts that arrive concurrently may or may -- not be included. This does not affect keys, so the physical structure remains the -- same. map :: MonadIO m => (a -> b) -> LMap k a -> m (LMap k b) map fn mp = do tmp <- foldlWithKey liftIO (\ acc k v -> do r <- liftIO (newIORef acc) return$! Node k (fn v) r) Empty mp tmp' <- liftIO (newIORef tmp) -- Here we suffer a reverse to avoid blowing the stack. reverse tmp' -- | Create a new linked map that is the reverse order from the input. reverse :: MonadIO m => LMap k v -> m (LMap k v) reverse mp = liftIO . newIORef =<< loop Empty mp where loop !acc mp = do n <- liftIO$ readIORef mp case n of Empty -> return acc Node k v next -> do r <- liftIO (newIORef acc) loop (Node k v r) next head :: LMap k v -> IO (Maybe k) head lm = do x <- readIORef lm case x of Empty -> return Nothing Node k _ _ -> return $! Just k -- | Convert to a list toList :: LMap k v -> IO [(k,v)] toList lm = do x <- readIORef lm case x of Empty -> return [] Node k v tl -> do ls <- toList tl return $! (k,v) : ls -- | Convert from a list. fromList :: [(k,v)] -> IO (LMap k v) fromList ls = do let loop [] = return Empty loop ((k,v):tl) = do tl' <- loop tl ref <- newIORef tl' return $! Node k v ref lm <- loop ls newIORef lm halve' :: Ord k => Maybe k -> LMap k v -> IO (Maybe (LMap k v, LMap k v)) halve' mend lm = do lml <- readIORef lm res <- halve mend lml case res of Nothing -> return Nothing Just (len1,_len2,tailhd) -> do ls <- toList lm l' <- fromList (take len1 ls) r' <- newIORef tailhd return $! Just $! (l',r') -- | Attempt to split into two halves. -- -- This optionally takes an upper bound key, which is treated as an alternate -- end-of-list signifier. -- -- Result: If there is only one element, then return Nothing. If there are more, -- return the number of elements in the first and second halves, plus a pointer to -- the beginning of the second half. It is a contract of this function that the -- two Ints returned are non-zero. -- halve :: Ord k => Maybe k -> LMList k v -> IO (Maybe (Int, Int, LMList k v)) {-# INLINE halve #-} halve mend ls = loop 0 ls ls where isEnd Empty = True isEnd (Node k _ _) = case mend of Just end -> k >= end Nothing -> False emptCheck (0,l2,t) = return Nothing emptCheck !x = return $! Just x loop len tort hare | isEnd hare = emptCheck (len, len, tort) loop len tort@(Node _ _ next1) (Node k v next2) = do next2' <- readIORef next2 case next2' of x | isEnd x -> emptCheck (len, len+1, tort) Node _ _ next3 -> do next1' <- readIORef next1 next3' <- readIORef next3 loop (len+1) next1' next3' -- | Drop from the front of the list until the first key is equal or greater than the -- given key. dropUntil :: Ord k => k -> LMList k v -> IO (LMList k v) dropUntil _ Empty = return Empty dropUntil stop nd@(Node k v tl) | stop <= k = return nd | otherwise = do tl' <- readIORef tl dropUntil stop tl' -- | Given a pointer into the middle of the list, find how deep it is. -- findIndex :: Eq k => LMList k v -> LMList k v -> IO (Maybe Int) findIndex :: Eq k => LMList k v -> LMList k v -> IO (Maybe Int) findIndex ls1 ls2 = error "FINISHME - LinkedMap.findIndex"