{-# LANGUAGE MultiParamTypeClasses, ScopedTypeVariables #-}
module Data.HashTable.Internal
where
import Control.Concurrent.STM
import Control.Concurrent
import Control.Concurrent.Async
import Data.IORef
import Data.Atomics
import Control.Exception
import Control.Monad
import Data.Hashable
import System.Random
import Data.Maybe
import qualified Data.List as L
import Data.Vector(Vector,(!))
import qualified Data.Vector as V
import Prelude hiding (lookup,readList)
data MigrationStatus = NotStarted | Ongoing | Finished
deriving(Show,Eq)
data Chain k v = Chain
{ _itemsTV :: TVar [(k,v)]
, _migrationStatusTV :: TVar MigrationStatus
}
deriving (Eq)
newChainIO :: IO (Chain k v)
newChainIO =
Chain <$> newTVarIO []
<*> newTVarIO NotStarted
data HashTable k v = HashTable
{ _chainsVecTV :: TVar (Vector (Chain k v))
, _totalLoad :: IORef Int
, _config :: Config k
}
data Config k = Config
{ _scaleFactor :: Float
, _threshold :: Float
, _numResizeWorkers :: Int
, _hashFunc :: k -> Int
}
instance Show (Config k) where
show cfg = "Config " ++ show (_scaleFactor cfg)
++ show (_threshold cfg)
++ show (_numResizeWorkers cfg)
mkDefaultConfig :: Hashable k => IO (Config k)
mkDefaultConfig = do
numCPUs <- getNumCapabilities
salt <- randomIO :: IO Int
return $ Config
{ _scaleFactor = 2.0
, _threshold = 0.75
, _numResizeWorkers = numCPUs
, _hashFunc = hashWithSalt salt
}
new :: (Eq k) => Int
-> Config k -> IO (HashTable k v)
new size config = do
chainsVec <- V.replicateM size newChainIO
HashTable <$> newTVarIO chainsVec
<*> newIORef 0
<*> return config
newWithDefaults :: (Eq k,Hashable k) => Int
-> IO (HashTable k v)
newWithDefaults size = mkDefaultConfig >>= new size
{-# INLINABLE readSizeIO #-}
readSizeIO :: HashTable k v -> IO Int
readSizeIO ht = do
V.length <$> readTVarIO (_chainsVecTV ht)
{-# INLINABLE readSize #-}
readSize :: HashTable k v -> STM Int
readSize ht = do
V.length <$> readTVar (_chainsVecTV ht)
resize :: (Eq k)
=> HashTable k v -> IO ()
resize ht = do
chainsVec <- readTVarIO $ _chainsVecTV ht
let size1 = V.length chainsVec
alreadyResizing <- do
hasStarted <- atomically $ do
migrating <- readTVar (_migrationStatusTV $ chainsVec ! 0)
if migrating `elem` [Ongoing,Finished] then
return True
else do
writeTVar (_migrationStatusTV $ chainsVec ! 0) Ongoing
return False
size2 <- readSizeIO ht
return (hasStarted || (size1 /= size2))
unless alreadyResizing $ do
let oldSize = V.length chainsVec
let numWorkers = _numResizeWorkers $ _config ht
let numSlices = min numWorkers
(max 1 (oldSize `div` numWorkers))
let sliceLength = oldSize `div` numSlices
let restLength = oldSize - ((numSlices-1)*sliceLength)
let vecSlices = [ V.unsafeSlice
(i*sliceLength)
(if i==numSlices-1 then restLength
else sliceLength)
chainsVec
| i <- [0..numSlices-1]]
let (scale :: Float) = _scaleFactor (_config ht)
let (newSize::Int) = round $ (fromIntegral oldSize) * scale
newVec <- V.replicateM newSize newChainIO
forConcurrently_ vecSlices (V.mapM_ (migrate newVec newSize))
`catch` (\(e::AssertionFailed) -> do
debug "ERROR in resize; this should never happen..."
throw e)
debug "finished copying over nodes..."
atomically $ writeTVar (_chainsVecTV ht) newVec
debug "replaced old array with new one..."
forConcurrently_ vecSlices $
V.mapM_ (\chain ->
atomically $ writeTVar (_migrationStatusTV chain) Finished)
debug "woke up blocked threads..."
where
migrate newVec newSize chain = do
atomically $ writeTVar (_migrationStatusTV chain) Ongoing
listOfNodes <- readTVarIO (_itemsTV chain)
sequence_ [ do let newIndex = (_hashFunc (_config ht) k) `mod` newSize
let newChain = newVec ! newIndex
newList <- readTVarIO (_itemsTV newChain)
atomically $
writeTVar (_itemsTV newChain) ((k,v):newList)
| (k,v) <- listOfNodes ]
{-# INLINABLE lookup #-}
lookup :: (Eq k)
=> HashTable k v
-> k
-> IO (Maybe v)
lookup htable k = do
chain <- readChainForKeyIO htable k
list <- readTVarIO (_itemsTV chain)
return $ L.lookup k list
type STMAction k v a = TVar [(k,v)] -> STM a
genericModify :: (Eq k)
=> HashTable k v
-> k
-> STMAction k v a
-> IO a
genericModify htable k stmAction = do
chain <- readChainForKeyIO htable k
result <- atomically $ do
migrationStatus <- readTVar (_migrationStatusTV chain)
case migrationStatus of
Ongoing -> retry
Finished -> return Nothing
NotStarted ->
Just <$> stmAction (_itemsTV chain)
case result of
Nothing -> genericModify htable k stmAction
Just v -> return v
insert :: (Eq k)
=> HashTable k v
-> k
-> v
-> IO Bool
insert htable k v = do
result <- genericModify htable k $ \tvar -> do
list <- readTVar tvar
case L.lookup k list of
Nothing -> do
writeTVar tvar ((k,v):list)
return True
Just _ -> do
writeTVar tvar ((k,v) : deleteFirstKey k list)
return False
when result $
atomicallyChangeLoad htable 1
return result
add :: (Eq k)
=> HashTable k v
-> k
-> v
-> IO Bool
add htable k v = do
result <- genericModify htable k $ \tvar -> do
list <- readTVar tvar
case L.lookup k list of
Nothing -> do
writeTVar tvar ((k,v):list)
return True
Just _ -> return False
when result $
atomicallyChangeLoad htable 1
return result
update :: (Eq k)
=> HashTable k v
-> k
-> v
-> IO Bool
update htable k v =
genericModify htable k $ \tvar -> do
list <- readTVar tvar
case L.lookup k list of
Nothing -> do
return False
Just _ -> do
writeTVar tvar ((k,v) : deleteFirstKey k list)
return True
modify :: (Eq k)
=> HashTable k v
-> k
-> (v -> v)
-> IO (Maybe v)
modify htable k f =
genericModify htable k $ \tvar -> do
list <- readTVar tvar
case L.lookup k list of
Nothing -> do
return Nothing
Just v -> do
writeTVar tvar ((k,f v) : deleteFirstKey k list)
return $ Just v
swapValues :: (Eq k)
=> HashTable k v
-> k
-> v
-> IO v
swapValues htable k v = do
result <- modify htable k (const v)
case result of
Nothing -> throw $ AssertionFailed "Data.HashTable.swapValues: key not in hash table."
Just v' -> return v'
delete :: (Eq k)
=> HashTable k v
-> k
-> IO Bool
delete htable k = do
result <- genericModify htable k $ \tvar -> do
list <- readTVar tvar
case L.lookup k list of
Nothing ->
return False
Just _ -> do
writeTVar tvar $ deleteFirstKey k list
return True
when result $
atomicallyChangeLoad htable (-1)
return result
atomicallyChangeLoad :: (Eq k)
=> HashTable k v
-> Int
-> IO ()
atomicallyChangeLoad htable incr = do
totalLoad <- atomicModifyIORefCAS (_totalLoad htable) $
\l -> (l+incr,l+incr)
size <- readSizeIO htable
when ((fromIntegral totalLoad / fromIntegral size)
>= _threshold (_config htable)) $ do
chain0 <- readChainForIndexIO htable 0
migrationStatus <- readTVarIO (_migrationStatusTV chain0)
when (migrationStatus == NotStarted) $
void $ forkIO (resize htable)
readLoad :: HashTable k v -> IO Int
readLoad htable = readIORef (_totalLoad htable)
readAssocs :: (Eq k)
=> HashTable k v -> STM [(k,v)]
readAssocs htable = do
chainsVec <- readTVar $ _chainsVecTV htable
let len = V.length chainsVec
let getItemsForChain k = do
chain <- readChainForIndex htable k
readTVar (_itemsTV chain)
msum <$> mapM getItemsForChain [0..len-1]
readAssocsIO :: (Eq k)
=> HashTable k v -> IO [(k,v)]
readAssocsIO htable = do
chainsVec <- readTVarIO $ _chainsVecTV htable
let len = V.length chainsVec
let getItemsForChain k = do
chain <- readChainForIndexIO htable k
readTVarIO (_itemsTV chain)
msum <$> mapM getItemsForChain [0..len-1]
{-# INLINABLE deleteFirstKey #-}
deleteFirstKey :: Eq a => a -> [(a,b)] -> [(a,b)]
deleteFirstKey _ [] = []
deleteFirstKey x (y:ys) = if x == fst y then ys else y : deleteFirstKey x ys
{-# INLINABLE readChainForKeyIO #-}
readChainForKeyIO :: HashTable k v -> k -> IO (Chain k v)
readChainForKeyIO htable k = do
chainsVec <- readTVarIO $ _chainsVecTV htable
let size = V.length chainsVec
let index = (_hashFunc (_config htable) k) `mod` size
return $ chainsVec ! index
{-# INLINABLE readChainForIndexIO #-}
readChainForIndexIO :: HashTable k v -> Int -> IO (Chain k v)
readChainForIndexIO htable idx = do
chainsVec <- readTVarIO $ _chainsVecTV htable
return $ chainsVec ! idx
{-# INLINABLE readChainForIndex #-}
readChainForIndex :: HashTable k v -> Int -> STM (Chain k v)
readChainForIndex htable idx = do
chainsVec <- readTVar $ _chainsVecTV htable
return $ chainsVec ! idx
{-# INLINABLE debug #-}
debug :: Show a => a -> IO ()
debug _ = return ()