module Data.HashTable.ST.Cuckoo
( HashTable
, new
, newSized
, delete
, lookup
, insert
, mapM_
, foldM
) where
import Control.Monad hiding (foldM, mapM_)
import Control.Monad.ST
import Data.Hashable hiding (hash)
import qualified Data.Hashable as H
import Data.Int
import Data.Maybe
import Data.Primitive.Array
import Data.STRef
import GHC.Exts
import Prelude hiding ( lookup, read, mapM_ )
import qualified Data.HashTable.Class as C
import Data.HashTable.Internal.CheapPseudoRandomBitStream
import Data.HashTable.Internal.CacheLine
import qualified Data.HashTable.Internal.IntArray as U
import Data.HashTable.Internal.Utils
#ifdef DEBUG
import System.IO
#endif
newtype HashTable s k v = HT (STRef s (HashTable_ s k v))
data HashTable_ s k v = HashTable
{ _size :: !Int
, _rng :: !(BitStream s)
, _hashes :: !(U.IntArray s)
, _keys :: !(MutableArray s k)
, _values :: !(MutableArray s v)
, _maxAttempts :: !Int
}
instance C.HashTable HashTable where
new = new
newSized = newSized
insert = insert
delete = delete
lookup = lookup
foldM = foldM
mapM_ = mapM_
computeOverhead = computeOverhead
instance Show (HashTable s k v) where
show _ = "<HashTable>"
new :: ST s (HashTable s k v)
new = newSizedReal 2 >>= newRef
newSized :: Int -> ST s (HashTable s k v)
newSized n = do
let n' = (n + numWordsInCacheLine 1) `div` numWordsInCacheLine
let k = nextBestPrime $ ceiling $ fromIntegral n' / maxLoad
newSizedReal k >>= newRef
insert :: (Eq k, Hashable k) => HashTable s k v -> k -> v -> ST s ()
insert ht !k !v = readRef ht >>= \h -> insert' h k v >>= writeRef ht
computeOverhead :: HashTable s k v -> ST s Double
computeOverhead htRef = readRef htRef >>= work
where
work (HashTable sz _ _ _ _ _) = do
nFilled <- foldM f 0 htRef
let oh = totSz
+ 2 * (totSz nFilled)
+ 12
return $! fromIntegral (oh::Int) / fromIntegral nFilled
where
totSz = numWordsInCacheLine * sz
f !a _ = return $! a+1
delete :: (Hashable k, Eq k) =>
HashTable s k v
-> k
-> ST s ()
delete htRef k = readRef htRef >>= go
where
go ht@(HashTable sz _ _ _ _ _) = do
_ <- delete' ht False k b1 b2 h1 h2
return ()
where
h1 = hash1 k
h2 = hash2 k
b1 = whichLine h1 sz
b2 = whichLine h2 sz
lookup :: (Eq k, Hashable k) =>
HashTable s k v
-> k
-> ST s (Maybe v)
lookup htRef k = do
ht <- readRef htRef
lookup' ht k
lookup' :: (Eq k, Hashable k) =>
HashTable_ s k v
-> k
-> ST s (Maybe v)
lookup' (HashTable sz _ hashes keys values _) !k = do
idx1 <- searchOne keys hashes k b1 h1
if idx1 >= 0
then do
v <- readArray values idx1
return $! Just v
else do
idx2 <- searchOne keys hashes k b2 h2
if idx2 >= 0
then do
v <- readArray values idx2
return $! Just v
else
return Nothing
where
h1 = hash1 k
h2 = hash2 k
b1 = whichLine h1 sz
b2 = whichLine h2 sz
searchOne :: (Eq k) =>
MutableArray s k
-> U.IntArray s
-> k
-> Int
-> Int
-> ST s Int
searchOne !keys !hashes !k = go
where
go !b !h = do
debug $ "searchOne: go " ++ show b ++ " " ++ show h
idx <- cacheLineSearch hashes b h
debug $ "searchOne: cacheLineSearch returned " ++ show idx
case idx of
1 -> return (1)
_ -> do
k' <- readArray keys idx
if k == k'
then return idx
else do
let !idx' = idx + 1
if isCacheLineAligned idx'
then return (1)
else go idx' h
foldM :: (a -> (k,v) -> ST s a)
-> a
-> HashTable s k v
-> ST s a
foldM f seed0 htRef = readRef htRef >>= foldMWork f seed0
foldMWork :: (a -> (k,v) -> ST s a)
-> a
-> HashTable_ s k v
-> ST s a
foldMWork f seed0 (HashTable sz _ hashes keys values _) = go 0 seed0
where
totSz = numWordsInCacheLine * sz
go !i !seed | i >= totSz = return seed
| otherwise = do
h <- U.readArray hashes i
if h /= emptyMarker
then do
k <- readArray keys i
v <- readArray values i
!seed' <- f seed (k,v)
go (i+1) seed'
else
go (i+1) seed
mapM_ :: ((k,v) -> ST s a)
-> HashTable s k v
-> ST s ()
mapM_ f htRef = readRef htRef >>= mapMWork f
mapMWork :: ((k,v) -> ST s a)
-> HashTable_ s k v
-> ST s ()
mapMWork f (HashTable sz _ hashes keys values _) = go 0
where
totSz = numWordsInCacheLine * sz
go !i | i >= totSz = return ()
| otherwise = do
h <- U.readArray hashes i
if h /= emptyMarker
then do
k <- readArray keys i
v <- readArray values i
_ <- f (k,v)
go (i+1)
else
go (i+1)
newSizedReal :: Int -> ST s (HashTable_ s k v)
newSizedReal nbuckets = do
let !ntotal = nbuckets * numWordsInCacheLine
let !maxAttempts = 12 + (log2 $ toEnum nbuckets)
debug $ "creating cuckoo hash table with " ++
show nbuckets ++ " buckets having " ++
show ntotal ++ " total slots"
rng <- newBitStream
hashes <- U.newArray ntotal
keys <- newArray ntotal undefined
values <- newArray ntotal undefined
return $! HashTable nbuckets rng hashes keys values maxAttempts
insert' :: (Eq k, Hashable k) =>
HashTable_ s k v
-> k
-> v
-> ST s (HashTable_ s k v)
insert' ht k v = do
debug "insert': begin"
mbX <- updateOrFail ht k v
z <- maybe (return ht)
(\(k',v') -> grow ht k' v')
mbX
debug "insert': end"
return z
updateOrFail :: (Eq k, Hashable k) =>
HashTable_ s k v
-> k
-> v
-> ST s (Maybe (k,v))
updateOrFail ht@(HashTable sz _ hashes keys values _) k v = do
debug $ "updateOrFail: begin: sz = " ++ show sz
debug $ " h1=" ++ show h1 ++ ", h2=" ++ show h2
++ ", b1=" ++ show b1 ++ ", b2=" ++ show b2
(didx, hashCode) <- delete' ht True k b1 b2 h1 h2
debug $ "delete' returned (" ++ show didx ++ "," ++ show hashCode ++ ")"
if didx >= 0
then do
U.writeArray hashes didx hashCode
writeArray keys didx k
writeArray values didx v
return Nothing
else cuckoo
where
h1 = hash1 k
h2 = hash2 k
b1 = whichLine h1 sz
b2 = whichLine h2 sz
cuckoo = do
debug "cuckoo: calling cuckooOrFail"
result <- cuckooOrFail ht h1 h2 b1 b2 k v
debug $ "cuckoo: cuckooOrFail returned " ++
(if isJust result then "Just _" else "Nothing")
maybe (return Nothing)
(return . Just)
result
delete' :: (Hashable k, Eq k) =>
HashTable_ s k v
-> Bool
-> k
-> Int
-> Int
-> Int
-> Int
-> ST s (Int, Int)
delete' (HashTable _ _ hashes keys values _) !updating !k b1 b2 h1 h2 = do
debug $ "delete' b1=" ++ show b1
++ " b2=" ++ show b2
++ " h1=" ++ show h1
++ " h2=" ++ show h2
prefetchWrite hashes b2
idx1 <- searchOne keys hashes k b1 h1
if idx1 < 0
then do
idx2 <- searchOne keys hashes k b2 h2
if idx2 < 0
then if updating
then do
debug $ "delete': looking for empty element"
idxE1 <- cacheLineSearch hashes b1 emptyMarker
debug $ "delete': idxE1 was " ++ show idxE1
if idxE1 >= 0
then return (idxE1, h1)
else do
idxE2 <- cacheLineSearch hashes b2 emptyMarker
debug $ "delete': idxE2 was " ++ show idxE1
if idxE2 >= 0
then return (idxE2, h2)
else return (1, 1)
else return (1,1)
else deleteIt idx2 h2
else deleteIt idx1 h1
where
deleteIt !idx !h = do
if not updating
then do
U.writeArray hashes idx emptyMarker
writeArray keys idx undefined
writeArray values idx undefined
else return ()
return $! (idx, h)
cuckooOrFail :: (Hashable k, Eq k) =>
HashTable_ s k v
-> Int
-> Int
-> Int
-> Int
-> k
-> v
-> ST s (Maybe (k,v))
cuckooOrFail (HashTable sz rng hashes keys values maxAttempts0)
!h1_0 !h2_0 !b1_0 !b2_0 !k0 !v0 = do
debug $ "cuckooOrFail h1_0=" ++ show h1_0
++ " h2_0=" ++ show h2_0
++ " b1_0=" ++ show b1_0
++ " b2_0=" ++ show b2_0
!lineChoice <- getNextBit rng
debug $ "chose line " ++ show lineChoice
let (!b, !h) = if lineChoice == 0 then (b1_0, h1_0) else (b2_0, h2_0)
go b h k0 v0 maxAttempts0
where
randomIdx !b = do
!z <- getNBits cacheLineIntBits rng
return $! b + z
bumpIdx !idx !h !k !v = do
debug $ "bumpIdx idx=" ++ show idx ++ " h=" ++ show h
!h' <- U.readArray hashes idx
debug $ "bumpIdx: h' was " ++ show h'
!k' <- readArray keys idx
v' <- readArray values idx
U.writeArray hashes idx h
writeArray keys idx k
writeArray values idx v
debug $ "bumped key with h'=" ++ show h'
return $! (h', k', v')
otherHash h k = if h2 == h then h1 else h2
where
h1 = hash1 k
h2 = hash2 k
tryWrite !b !h k v maxAttempts = do
debug $ "tryWrite b=" ++ show b ++ " h=" ++ show h
idx <- cacheLineSearch hashes b emptyMarker
debug $ "cacheLineSearch returned " ++ show idx
if idx >= 0
then do
U.writeArray hashes idx h
writeArray keys idx k
writeArray values idx v
return Nothing
else go b h k v $! maxAttempts 1
go !b !h !k v !maxAttempts | maxAttempts == 0 = return $! Just (k,v)
| otherwise = do
idx <- randomIdx b
(!h0', !k', v') <- bumpIdx idx h k v
let !h' = otherHash h0' k'
let !b' = whichLine h' sz
tryWrite b' h' k' v' maxAttempts
grow :: (Eq k, Hashable k) =>
HashTable_ s k v
-> k
-> v
-> ST s (HashTable_ s k v)
grow (HashTable sz _ hashes keys values _) k0 v0 = do
newHt <- grow' $! bumpSize sz
mbR <- updateOrFail newHt k0 v0
maybe (return newHt)
(\_ -> grow' $ bumpSize $ _size newHt)
mbR
where
grow' newSz = do
debug $ "growing table, oldsz = " ++ show sz ++
", newsz=" ++ show newSz
newHt <- newSizedReal newSz
rehash newSz newHt
rehash !newSz !newHt = go 0
where
totSz = numWordsInCacheLine * sz
go !i | i >= totSz = return newHt
| otherwise = do
h <- U.readArray hashes i
if (h /= emptyMarker)
then do
k <- readArray keys i
v <- readArray values i
mbR <- updateOrFail newHt k v
maybe (go $ i + 1)
(\_ -> grow' $ bumpSize newSz)
mbR
else go $ i + 1
hashPrime :: Int
hashPrime = if wordSize == 32 then hashPrime32 else hashPrime64
where
hashPrime32 = 0xedf2a025
hashPrime64 = 0x3971ca9c8b3722e9
hash1 :: Hashable k => k -> Int
hash1 = hashF H.hash
hash2 :: Hashable k => k -> Int
hash2 = hashF (H.hashWithSalt hashPrime)
hashF :: (k -> Int) -> k -> Int
hashF f k = out
where
!(I# h#) = f k
!m# = maskw# h# 0#
!nm# = not# m#
!r# = ((int2Word# 1#) `and#` m#) `or#` (int2Word# h# `and#` nm#)
!out = I# (word2Int# r#)
emptyMarker :: Int
emptyMarker = 0
maxLoad :: Double
maxLoad = 0.88
debug :: String -> ST s ()
#ifdef DEBUG
debug s = unsafeIOToST (putStrLn s >> hFlush stdout)
#else
debug _ = return ()
#endif
whichLine :: Int -> Int -> Int
whichLine !h !sz = whichBucket h sz `iShiftL` cacheLineIntBits
newRef :: HashTable_ s k v -> ST s (HashTable s k v)
newRef = liftM HT . newSTRef
writeRef :: HashTable s k v -> HashTable_ s k v -> ST s ()
writeRef (HT ref) ht = writeSTRef ref ht
readRef :: HashTable s k v -> ST s (HashTable_ s k v)
readRef (HT ref) = readSTRef ref