module Control.Concurrent.STM.Map
( Map
, empty
, insert
, delete
, unsafeDelete
, lookup
, phantomLookup
, member
, fromList
, unsafeToList
) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>))
#endif
import Control.Concurrent.STM
import Control.Monad
import Data.Atomics
import Data.IORef
import Data.Maybe
import GHC.Conc.Sync (unsafeIOToSTM)
import Prelude hiding (lookup)
import Data.SparseArray
newtype Map k v = Map (INode k v)
type INode k v = IORef (Node k v)
data Node k v = Array !(SparseArray (Branch k v))
| List ![Leaf k v]
| Tomb !(Leaf k v)
data Branch k v = I !(INode k v)
| L !(Leaf k v)
data Leaf k v = Leaf !k !(TVar (Maybe v))
empty :: STM (Map k v)
empty = unsafeIOToSTM $ Map <$> newIORef (Array emptyArray)
insert :: (Eq k, Hashable k) => k -> v -> Map k v -> STM ()
insert k v m = do var <- getTVar k m
writeTVar var (Just v)
lookup :: (Eq k, Hashable k) => k -> Map k v -> STM (Maybe v)
lookup k m = do var <- getTVar k m
readTVar var
delete :: (Eq k, Hashable k) => k -> Map k v -> STM ()
delete k m = do var <- getTVar k m
writeTVar var Nothing
member :: (Eq k, Hashable k) => k -> Map k v -> STM Bool
member k m = do
v <- lookup k m
case v of
Nothing -> return False
Just _ -> return True
getTVar :: (Eq k, Hashable k) => k -> Map k v -> STM (TVar (Maybe v))
getTVar k (Map root) = go root 0 undefined
where
h = hash k
go inode level parent = do
ticket <- unsafeIOToSTM $ readForCAS inode
case peekTicket ticket of
Array a -> case arrayLookup level h a of
Just (I inode2) -> go inode2 (down level) inode
Just (L leaf2@(Leaf k2 var))
| k == k2 -> return var
| otherwise -> cas inode ticket (growTrie level a (hash k2) leaf2)
Nothing -> cas inode ticket (insertLeaf level a)
List xs -> case listLookup k xs of
Just var -> return var
Nothing -> cas inode ticket (return . List . (:xs))
Tomb _ -> unsafeIOToSTM (clean parent (up level)) >> go root 0 undefined
cas inode ticket f = do
var <- newTVar Nothing
node <- f (Leaf k var)
(ok,_) <- unsafeIOToSTM $ casIORef inode ticket node
if ok then return var
else go root 0 undefined
insertLeaf level a leaf = do
let a' = arrayInsert level h (L leaf) a
return (Array a')
growTrie level a h2 leaf2 leaf1 = do
inode2 <- unsafeIOToSTM $ combineLeaves (down level) h leaf1 h2 leaf2
let a' = arrayUpdate level h (I inode2) a
return (Array a')
combineLeaves level h1 leaf1 h2 leaf2
| level >= lastLevel = newIORef (List [leaf1, leaf2])
| otherwise =
case mkPair level h (L leaf1) h2 (L leaf2) of
Just pair -> newIORef (Array pair)
Nothing -> do
inode <- combineLeaves (down level) h1 leaf1 h2 leaf2
let a = mkSingleton level h (I inode)
newIORef (Array a)
phantomLookup :: (Eq k, Hashable k) => k -> Map k v -> STM (Maybe v)
phantomLookup k (Map root) = go root 0 undefined
where
h = hash k
go inode level parent = do
node <- unsafeIOToSTM $ readIORef inode
case node of
Array a -> case arrayLookup level h a of
Just (I inode2) -> go inode2 (down level) inode
Just (L (Leaf k2 var))
| k == k2 -> readTVar var
| otherwise -> return Nothing
Nothing -> return Nothing
List xs -> case listLookup k xs of
Just var -> readTVar var
Nothing -> return Nothing
Tomb _ -> unsafeIOToSTM (clean parent (up level)) >> go root 0 undefined
unsafeDelete :: (Eq k, Hashable k) => k -> Map k v -> IO ()
unsafeDelete k m@(Map root) = do
ok <- go root 0 undefined
unless ok (unsafeDelete k m)
where
h = hash k
go inode level parent = do
ticket <- readForCAS inode
case peekTicket ticket of
Array a -> do
ok <- case arrayLookup level h a of
Just (I inode2) -> go inode2 (down level) inode
Just (L (Leaf k2 _))
| k == k2 -> casArrayDelete inode ticket level a
| otherwise -> return True
Nothing -> return True
when ok (compressIfPossible level inode parent)
return ok
List xs -> casListDelete inode ticket xs
Tomb _ -> clean parent (up level) >> return False
compressIfPossible level inode parent = do
n <- readIORef inode
case n of
Tomb _ -> cleanParent parent inode h (up level)
_ -> return ()
casArrayDelete inode ticket level a = do
let a' = arrayDelete level h a
n = contract level (Array a')
(ok,_) <- casIORef inode ticket n
return ok
casListDelete inode ticket xs = do
let xs' = listDelete k xs
n | [l] <- xs' = Tomb l
| otherwise = List xs'
(ok,_) <- casIORef inode ticket n
return ok
clean :: INode k v -> Level -> IO ()
clean inode level = do
ticket <- readForCAS inode
case peekTicket ticket of
n@(Array _) -> do
n' <- compress level n
void $ casIORef inode ticket n'
_ -> return ()
cleanParent :: INode k v -> INode k v -> Hash -> Level -> IO ()
cleanParent parent inode h level = do
ticket <- readForCAS parent
case peekTicket ticket of
n@(Array a) -> case arrayLookup level h a of
Just (I inode2) | inode2 == inode -> do
n2 <- readIORef inode
case n2 of
Tomb _ -> do
n' <- compress level n
(ok,_) <- casIORef parent ticket n'
unless ok $ cleanParent parent inode h level
_ -> return ()
_ -> return ()
_ -> return ()
compress :: Level -> Node k v -> IO (Node k v)
compress level (Array a) = contract level . Array <$> arrayMapM resurrect a
compress _ n = return n
resurrect :: Branch k v -> IO (Branch k v)
resurrect b@(I inode) = do n <- readIORef inode
case n of
Tomb leaf -> return (L leaf)
_ -> return b
resurrect b = return b
contract :: Level -> Node k v -> Node k v
contract level (Array a) | level > 0
, Just (L leaf) <- arrayToMaybe a
= Tomb leaf
contract _ n = n
listLookup :: Eq k => k -> [Leaf k v] -> Maybe (TVar (Maybe v))
listLookup k1 = go
where
go [] = Nothing
go (Leaf k2 var : xs) | k1 == k2 = Just var
| otherwise = go xs
listDelete :: Eq k => k -> [Leaf k v] -> [Leaf k v]
listDelete k1 = go
where
go [] = []
go (x@(Leaf k2 _):xs) | k1 == k2 = xs
| otherwise = x : go xs
fromList :: (Eq k, Hashable k) => [(k,v)] -> IO (Map k v)
fromList xs = do
m <- atomically empty
forM_ xs $ \(k,v) -> atomically (insert k v m)
return m
unsafeToList :: Map k v -> IO [(k,v)]
unsafeToList (Map root) = go root
where
go inode = do
node <- readIORef inode
case node of
Array a -> arrayFoldM' go2 [] a
List xs -> foldM go3 [] xs
Tomb leaf -> go3 [] leaf
go2 xs (I inode) = go inode >>= \ys -> return (ys ++ xs)
go2 xs (L leaf) = go3 xs leaf
go3 xs (Leaf k var) = do
v <- readTVarIO var
case v of
Nothing -> return xs
Just v' -> return $ (k,v') : xs