{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns, PatternGuards, MagicHash #-}
{-# OPTIONS_GHC -funbox-strict-fields #-}
-----------------------------------------------------------------------
-- | A non-blocking concurrent map from hashable keys to values.
--
-- The implementation is based on /lock-free concurrent hash tries/
-- (aka /Ctries/) as described by:
--
--    * Aleksander Prokopec, Phil Bagwell, Martin Odersky,
--      \"/Cache-Aware Lock-Free Concurent Hash Tries/\"
--    * Aleksander Prokopec, Nathan G. Bronson, Phil Bagwell, Martin
--      Odersky \"/Concurrent Tries with Efficient Non-Blocking Snapshots/\"
--
-- Operations have a worst-case complexity of /O(log n)/, with a base
-- equal to the size of the native 'Word'.
--
-----------------------------------------------------------------------

module Control.Concurrent.Map
    ( Map

      -- * Construction
    , empty

      -- * Modification
    , insert
    , delete
    , insertIfAbsent

      -- * Query
    , lookup

      -- * Lists
    , fromList
    , unsafeToList

    --, printMap
    ) where

#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>))
#endif
import Control.Monad
import Data.Atomics
import Data.Bits
import Data.Hashable (Hashable)
import qualified Data.Hashable as H
import Data.IORef
import qualified Data.List as List
import Data.Maybe
import Data.Word
import Prelude hiding (lookup)

import qualified Control.Concurrent.Map.Array as A

-----------------------------------------------------------------------

-- | A map from keys @k@ to values @v@.
newtype Map k v = Map (INode k v)

type INode k v = IORef (MainNode k v)

data MainNode k v = CNode !Bitmap !(A.Array (Branch k v))
                  | Tomb !(SNode k v)
                  | Collision ![SNode k v]

data Branch k v = INode !(INode k v)
                | SNode !(SNode k v)

data SNode k v = S !k v
    deriving (Eq, Show)

isTomb :: MainNode k v -> Bool
isTomb (Tomb _) = True
isTomb _        = False

type Bitmap = Word
type Hash   = Word
type Level  = Int

hash :: Hashable a => a -> Hash
hash = fromIntegral . H.hash


-----------------------------------------------------------------------
-- * Construction

-- | /O(1)/. Construct an empty map.
empty :: IO (Map k v)
empty = Map <$> newIORef (CNode 0 A.empty)


-----------------------------------------------------------------------
-- * Modification

-- | /O(log n)/. Associate the given value with the given key.
-- If the key is already present in the map, the old value is replaced.
-- Returns 'True' if the value was inserted, 'False' if it was replaced.
insert :: (Eq k, Hashable k) => k -> v -> Map k v -> IO Bool
insert k v (Map root) = go0
    where
        h = hash k
        go0 = go 0 undefined root
        go lev parent inode = do
            ticket <- readForCAS inode
            case peekTicket ticket of
                CNode bmp arr -> do
                    let m = mask h lev
                        i = sparseIndex bmp m
                        n = popCount bmp
                    if bmp .&. m == 0
                        then do
                            let arr' = A.insert (SNode (S k v)) i n arr
                                cn'  = CNode (bmp .|. m) arr'
                            success <- fst <$> casIORef inode ticket cn'
                            if success then return True else go0

                        else case A.index arr i of
                            SNode (S k2 v2)
                                | k == k2 -> do
                                    let arr' = A.update (SNode (S k v)) i n arr
                                        cn'  = CNode bmp arr'
                                    success <- fst <$> casIORef inode ticket cn'
                                    if success then return False else go0

                                | otherwise -> do
                                    let h2 = hash k2
                                    inode2 <- newINode h k v h2 k2 v2 (nextLevel lev)
                                    let arr' = A.update (INode inode2) i n arr
                                        cn'  = CNode bmp arr'
                                    success <- fst <$> casIORef inode ticket cn'
                                    if success then return False else go0

                            INode inode2 -> go (nextLevel lev) inode inode2

                Tomb _ -> clean parent (prevLevel lev) >> go0

                Collision arr -> do
                    let arr' = S k v : filter (\(S k2 _) -> k2 /= k) arr
                        col' = Collision arr'
                    success <- fst <$> casIORef inode ticket col'
                    if success
                        then return $ not $ any (\(S k2 _) -> k2 == k) arr
                        else go0

{-# INLINABLE insert #-}


-- | /O(log n)/. Associate the given value with the given key.
-- If the key is already present in the map, don't change the value.
-- Returns 'True' if the value was inserted, 'False' otherwise.
insertIfAbsent :: (Eq k, Hashable k) => k -> v -> Map k v -> IO Bool
insertIfAbsent k v (Map root) = go0
    where
        h = hash k
        go0 = go 0 undefined root
        go lev parent inode = do
            ticket <- readForCAS inode
            case peekTicket ticket of
                CNode bmp arr -> do
                    let m = mask h lev
                        i = sparseIndex bmp m
                        n = popCount bmp
                    if bmp .&. m == 0
                        then do
                            let arr' = A.insert (SNode (S k v)) i n arr
                                cn'  = CNode (bmp .|. m) arr'
                            success <- fst <$> casIORef inode ticket cn'
                            if success then return True else go0

                        else case A.index arr i of
                            SNode (S k2 v2)
                                | k == k2 -> return False

                                | otherwise -> do
                                    let h2 = hash k2
                                    inode2 <- newINode h k v h2 k2 v2 (nextLevel lev)
                                    let arr' = A.update (INode inode2) i n arr
                                        cn'  = CNode bmp arr'
                                    success <- fst <$> casIORef inode ticket cn'
                                    if success then return True else go0

                            INode inode2 -> go (nextLevel lev) inode inode2

                Tomb _ -> clean parent (prevLevel lev) >> go0

                Collision arr ->
                    if any (\(S k2 _) -> k2 == k) arr
                        then return False
                        else do
                            let arr' = S k v : filter (\(S k2 _) -> k2 /= k) arr
                                col' = Collision arr'
                            success <- fst <$> casIORef inode ticket col'
                            if success then return True else go0

{-# INLINABLE insertIfAbsent #-}


newINode :: Hash -> k -> v -> Hash -> k -> v -> Int -> IO (INode k v)
newINode h1 k1 v1 h2 k2 v2 lev
    | lev >= hashLength = newIORef $ Collision [S k1 v1, S k2 v2]
    | otherwise = do
        let i1 = index h1 lev
            i2 = index h2 lev
            bmp = (unsafeShiftL 1 i1) .|. (unsafeShiftL 1 i2)
        case compare i1 i2 of
            LT -> newIORef $ CNode bmp $ A.pair (SNode (S k1 v1)) (SNode (S k2 v2))
            GT -> newIORef $ CNode bmp $ A.pair (SNode (S k2 v2)) (SNode (S k1 v1))
            EQ -> do inode' <- newINode h1 k1 v1 h2 k2 v2 (nextLevel lev)
                     newIORef $ CNode bmp $ A.singleton (INode inode')


-- | /O(log n)/. Remove the given key and its associated value from the map,
-- if present.
-- Returns 'True' if the value was deleted, 'False' otherwise.
delete :: (Eq k, Hashable k) => k -> Map k v -> IO Bool
delete k (Map root) = go0
    where
        h = hash k
        go0 = go 0 undefined root
        go lev parent inode = do
            ticket <- readForCAS inode
            case peekTicket ticket of
                CNode bmp arr -> do
                    let m = mask h lev
                        i = sparseIndex bmp m
                    if bmp .&. m == 0
                        then return False  -- not found
                        else case A.index arr i of
                            SNode (S k2 _)
                                | k == k2 -> do
                                    let arr' = A.delete i (popCount bmp) arr
                                        cn'  = CNode (bmp `xor` m) arr'
                                        cn'' = contract lev cn'
                                    success <- fst <$> casIORef inode ticket cn''
                                    result <- if success then return True else go0
                                    whenM (isTomb <$> readIORef inode) $
                                        cleanParent parent inode h (prevLevel lev)
                                    return result

                                | otherwise -> return False  -- not found

                            INode inode2 -> go (nextLevel lev) inode inode2

                Tomb _ -> clean parent (prevLevel lev) >> go0

                Collision arr -> do
                    let arr' = filter (\(S k2 _) -> k2 /= k) $ arr
                        col' | [s] <- arr' = Tomb s
                             | otherwise   = Collision arr'
                    success <- fst <$> casIORef inode ticket col'
                    if success then return True else go0

{-# INLINABLE delete #-}

-----------------------------------------------------------------------
-- * Query

-- | /O(log n)/. Return the value associated with the given key, or 'Nothing'.
lookup :: (Eq k, Hashable k) => k -> Map k v -> IO (Maybe v)
lookup k (Map root) = go0
    where
        h = hash k
        go0 = go 0 undefined root
        go lev parent inode = do
            main <- readIORef inode
            case main of
                CNode bmp arr -> do
                    let m = mask h lev
                        i = sparseIndex bmp m
                    if bmp .&. m == 0
                        then return Nothing
                        else case A.index arr i of
                            INode inode2 -> go (nextLevel lev) inode inode2
                            SNode (S k2 v) | k == k2   -> return (Just v)
                                           | otherwise -> return Nothing

                Tomb _ -> clean parent (prevLevel lev) >> go0

                Collision xs -> do
                    case List.find (\(S k2 _) -> k2 == k) xs of
                        Just (S _ v) -> return (Just v)
                        _            -> return Nothing

{-# INLINABLE lookup #-}

-----------------------------------------------------------------------
-- * Internal compression operations

clean :: INode k v -> Level -> IO ()
clean inode lev = do
    ticket <- readForCAS inode
    case peekTicket ticket of
        cn@(CNode _ _) -> do
            cn' <- compress lev cn
            void $ casIORef inode ticket cn'
        _ -> return ()
{-# INLINE clean #-}

cleanParent :: INode k v -> INode k v -> Hash -> Level -> IO ()
cleanParent parent inode h lev = do
    ticket <- readForCAS parent
    case peekTicket ticket of
        cn@(CNode bmp arr) -> do
            let m = mask h lev
                i = sparseIndex bmp m
            unless (bmp .&. m == 0) $
                case A.index arr i of
                    INode inode2 | inode2 == inode ->
                        whenM (isTomb <$> readIORef inode) $ do
                            cn' <- compress lev cn
                            unlessM (fst <$> casIORef parent ticket cn') $
                                cleanParent parent inode h lev
                    _ -> return ()
        _ -> return ()

compress :: Level -> MainNode k v -> IO (MainNode k v)
compress lev (CNode bmp arr) =
    contract lev <$> CNode bmp <$> A.mapM resurrect (popCount bmp) arr
compress _ x = return x
{-# INLINE compress #-}

resurrect :: Branch k v -> IO (Branch k v)
resurrect b@(INode inode) = do
    main <- readIORef inode
    case main of
        Tomb s -> return (SNode s)
        _      -> return b
resurrect b = return b
{-# INLINE resurrect #-}

contract :: Level -> MainNode k v -> MainNode k v
contract lev (CNode bmp arr) | lev > 0
                           , popCount bmp == 1
                           , SNode s <- A.head arr
                           = Tomb s
contract _ x = x
{-# INLINE contract #-}

-----------------------------------------------------------------------
-- * Lists

-- | /O(n * log n)/. Construct a map from a list of key/value pairs.
fromList :: (Eq k, Hashable k) => [(k,v)] -> IO (Map k v)
fromList xs = empty >>= \m -> mapM_ (\(k,v) -> insert k v m) xs >> return m
{-# INLINABLE fromList #-}

-- | /O(n)/. Unsafely convert the map to a list of key/value pairs.
--
-- WARNING: 'unsafeToList' makes no atomicity guarantees. Concurrent
-- changes to the map will lead to inconsistent results.
unsafeToList :: Map k v -> IO [(k,v)]
unsafeToList (Map root) = go root
    where
        go inode = do
            main <- readIORef inode
            case main of
                CNode bmp arr -> A.foldM' go2 [] (popCount bmp) arr
                Tomb (S k v) -> return [(k,v)]
                Collision xs -> return $ map (\(S k v) -> (k,v)) xs

        go2 xs (INode inode) = go inode >>= \ys -> return (ys ++ xs)
        go2 xs (SNode (S k v)) = return $ (k,v) : xs
{-# INLINABLE unsafeToList #-}

-----------------------------------------------------------------------

whenM :: Monad m => m Bool -> m () -> m ()
whenM p s = p >>= \t -> if t then s else return ()
{-# INLINE whenM #-}

unlessM :: Monad m => m Bool -> m () -> m ()
unlessM p s = p >>= \t -> if t then return () else s
{-# INLINE unlessM #-}

-----------------------------------------------------------------------

hashLength :: Int
hashLength = finiteBitSize (undefined :: Word)

bitsPerSubkey :: Int
bitsPerSubkey = floor . logBase (2 :: Float) . fromIntegral $ hashLength

subkeyMask :: Bitmap
subkeyMask = 1 `unsafeShiftL` bitsPerSubkey - 1

index :: Hash -> Level -> Int
index h lev = fromIntegral $ (h `unsafeShiftR` lev) .&. subkeyMask
{-# INLINE index #-}

-- when or-ed with a CNode bitmap, determines if the hash is present
-- in the array at the given level of the trie
mask :: Hash -> Level -> Bitmap
mask h lev = 1 `unsafeShiftL` index h lev
{-# INLINE mask #-}

-- position in the CNode array
sparseIndex :: Bitmap -> Bitmap -> Int
sparseIndex bmp m = popCount ((m - 1) .&. bmp)
{-# INLINE sparseIndex #-}

nextLevel :: Level -> Level
nextLevel = (+) bitsPerSubkey
{-# INLINE nextLevel #-}

prevLevel :: Level -> Level
prevLevel = subtract bitsPerSubkey
{-# INLINE prevLevel #-}

-----------------------------------------------------------------------

-- TODO
--printMap :: (Show k, Show v) => Map k v -> IO ()
--printMap (Map root) = goI root
--    where
--        goI inode = putStr "(I " >> readIORef inode >>= goM >> putStr ")\n"
--        goM (CNode bmp arr) = do
--            putStr $ "(C " ++ (show bmp) ++ " ["
--            A.mapM_ (\b -> goB b >> putStr ", ") (popCount bmp) arr
--            putStr $ "] )"
--        goM (Tomb (S k v)) = putStr $ "(T " ++ (show k) ++ " " ++ (show v) ++ ")"
--        goM (Collision xs) = putStr $ "(Collision " ++ show xs ++ ")"
--        goB (INode i) = putStr "\n" >> goI i
--        goB (SNode (S k v)) = putStr $ "(" ++ (show k) ++ "," ++ (show v) ++ ")"