{-# LANGUAGE NamedFieldPuns, BangPatterns #-}

-- | A concurrent finite map represented as a single linked list.  
--
-- In contrast to standard maps, this one only allows lookups and insertions,
-- not modifications or removals.  While modifications would be fairly easy to
-- add, removals would significantly complicate the logic, and aren't needed for
-- the primary application -- LVars.
--
-- The interface is also somewhat low-level: rather than a standard insert
-- function, @tryInsert@ takes a "token" (i.e. a pointer into the linked list)
-- and attempts to insert at that location (but may fail).  Tokens are acquired
-- through the @find@ function, which yields a token in the case that a key is
-- *not* found; the token represents the location in the list where the key
-- *should* go.  This low-level interface is intended for use in higher-level
-- data structures, e.g. SkipListMap.

module Data.Concurrent.LinkedMap (
  LMap(), newLMap, Token(), value, find, FindResult(..), tryInsert,
  foldlWithKey, map, reverse)
where
  
import Data.IORef
import Data.Atomics  
import Control.Reagent -- AT: not yet using this, but would be nice to refactor
                       -- to use it.
import Control.Monad.IO.Class
import Prelude hiding (reverse, map)

-- | A concurrent finite map, represented as a linked list
data LMList k v = 
    Node k v {-# UNPACK #-} !(IORef (LMList k v))
  | Empty 

type LMap k v = IORef (LMList k v)

-- | Create a new concurrent map
newLMap :: IO (LMap k v)
newLMap = newIORef Empty
  
-- | A position in the map into which a key/value pair can be inserted          
data Token k v = Token {
  keyToInsert :: k,                   -- ^ what key were we looking up?
  value       :: Maybe v,             -- ^ the value at this position in the map
  nextRef     :: IORef (LMList k v),  -- ^ the reference at which to insert
  nextTicket  :: Ticket (LMList k v)  -- ^ a ticket for the old value of nextRef
}

-- | Either the value associated with a key, or else a token at the position
-- where that key should go.
data FindResult k v =
    Found v
  | NotFound (Token k v)

-- | Attempt to locate a key in the map
{-# INLINE find #-}
find :: Ord k => LMap k v -> k -> IO (FindResult k v)
find m k = findInner m Nothing 
  where 
    findInner m v = do
      nextTicket <- readForCAS m
      let stopHere = NotFound $ Token {keyToInsert = k, value = v, nextRef = m, nextTicket}
      case peekTicket nextTicket of
        Empty -> return stopHere
        Node k' v' next -> 
          case compare k k' of
            LT -> return stopHere
            EQ -> return $ Found v'
            GT -> findInner next (Just v')
      
-- | Attempt to insert a key/value pair at the given location (where the key is
-- given by the token).  NB: tryInsert will *always* fail after the first attempt.
-- If successful, returns a (mutable!) view of the map beginning at the given key.            
{-# INLINE tryInsert #-}            
tryInsert :: Token k v -> v -> IO (Maybe (LMap k v))
tryInsert Token { keyToInsert, nextRef, nextTicket } v = do
  newRef <- newIORef $ peekTicket nextTicket
  (success, _) <- casIORef nextRef nextTicket $ Node keyToInsert v newRef
  return $ if success then Just nextRef else Nothing

-- | Concurrently fold over all key/value pairs in the map within the given
-- monad, in increasing key order.  Inserts that arrive concurrently may or may
-- not be included in the fold.
foldlWithKey :: MonadIO m => (a -> k -> v -> m a) -> a -> LMap k v -> m a
foldlWithKey f a m = do
  n <- liftIO $ readIORef m
  case n of
    Empty -> return a
    Node k v next -> do
      a' <- f a k v
      foldlWithKey f a' next


-- | Map over a snapshot of the list.  Inserts that arrive concurrently may or may
-- not be included.  This does not affect keys, so the physical structure remains the
-- same.
map :: MonadIO m => (a -> b) -> LMap k a -> m (LMap k b)
map fn mp = do 
 tmp <- foldlWithKey (\ acc k v -> do
                      r <- liftIO (newIORef acc)
                      return$! Node k (fn v) r)
                     Empty mp
 tmp' <- liftIO (newIORef tmp)
 -- Here we suffer a reverse to avoid blowing the stack. 
 reverse tmp'

-- | Create a new linked map that is the reverse order from the input.
reverse :: MonadIO m => LMap k v -> m (LMap k v)
reverse mp = liftIO . newIORef =<< loop Empty mp
  where
    loop !acc mp = do
      n <- liftIO$ readIORef mp
      case n of
        Empty -> return acc
        Node k v next -> do
          r <- liftIO (newIORef acc)
          loop (Node k v r) next