{-# LANGUAGE BangPatterns #-} {-# LANGUAGE BinaryLiterals #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE UnboxedSums #-} {-# LANGUAGE UnboxedTuples #-} -- | Implementation of static hash map data structure. module Data.Bytes.HashMap ( Map , lookup , fromList , fromTrustedList , fromListWith -- * Used for testing , HashMapException(..) , distribution , distinctEntropies ) where -- Implementation notes. This module uses a variant of the technique -- described in http://stevehanov.ca/blog/?id=119 with the big difference -- being that we do not throw away the keys. You can only throw away the -- keys in very specific problem domains where you somehow control -- everything that is going to be looked up. -- -- General implementation thoughts. It would be really nice to figure -- out how to parallelize hashing. We currently go one byte at a time. -- Processing more bytes at a time would cut down on memory loads. -- However, doing more than one byte at a time is tricky. When you -- get to the end of a string, you end up having to do some extra -- finagling to make sure you do not read past the end. I have tried -- to do this in the past, and it is difficult to do it correctly. -- -- Other thought: Using a random 64-bit word for each byte is pretty -- heavy handed. 64-bit words give us 32-bit hashes, but in most cases, -- we are not building maps that are that big. We really only need -- 16-bit hashes most of the time (maps with less than 64K values). -- Switching to 32-bit words would save space. Plus, if we did this, -- we could also use SSE _mm_add_epi32 and _mm_add_epi32 to process -- four bytes at a time. import Prelude hiding (lookup) import Control.Exception (Exception,throw) import Control.Monad (when) import Control.Monad.ST (ST,stToIO,runST) import Control.Monad.Trans.Except (ExceptT(ExceptT),runExceptT) import Data.Bits ((.&.),complement) import Data.Bytes.HashMap.Internal (Map(Map)) import Data.Bytes.Types (Bytes(Bytes)) import Data.Foldable (for_,foldlM) import Data.Int (Int32) import Data.Ord (Down(Down)) import Data.Primitive (ByteArray(..),PrimArray(..)) import Data.Primitive.SmallArray (SmallArray(..)) import Data.Primitive.Unlifted.Array (UnliftedArray(..)) import Data.STRef (STRef,newSTRef,writeSTRef,readSTRef) import Foreign.Ptr (plusPtr) import GHC.Exts (Ptr(Ptr),Int(I#),SmallArray#,ByteArray#,ArrayArray#,Int#) import GHC.Exts (RealWorld) import GHC.IO (ioToST) import GHC.Word (Word(W#),Word32,Word8) import System.Entropy (CryptHandle,hGetEntropy) import qualified Data.ByteString as ByteString import qualified Data.ByteString.Unsafe as ByteString import qualified Data.Bytes as Bytes import qualified Data.Bytes.Hash as Hash import qualified Data.List as List import qualified Data.Primitive as PM import qualified Data.Primitive.Ptr as PM import qualified Data.Primitive.Unlifted.Array as PM import qualified GHC.Exts as Exts -- | Build a static hash map. This may be used on input that comes -- from an adversarial user. It always produces a perfect hash map. fromList :: CryptHandle -> [(Bytes,v)] -> IO (Map v) fromList h = fromListWith h const -- | Build a map from keys that are known at compile time. -- All keys must be 64 bytes or less. This uses a built-in source -- of entropy and is entirely deterministic. An adversarial user -- could feed this function keys that cause it to error out rather -- than completing. fromTrustedList :: [(Bytes,v)] -> Map v fromTrustedList xs = runST $ do ref <- newSTRef 0 fromListWithGen ref askForEntropyST const xs -- | Returns the value associated with the key in the map. lookup :: Bytes -> Map v -> Maybe v {-# inline lookup #-} lookup (Bytes (ByteArray keyArr) (I# keyOff) (I# keyLen)) (Map (ByteArray entropyA) (UnliftedArray entropies) (PrimArray offsets) (UnliftedArray keys) (SmallArray vals)) = case lookup# (# keyArr,keyOff,keyLen #) (# entropyA,entropies,offsets,keys,vals #) of (# (# #) | #) -> Nothing (# | v #) -> Just v -- One compelling optimization done here is that we use sameByteArray -- to check if the sources of entropy are pointer-wise equal. This is -- a very inexpensive check, and it ends up being true close to 50% -- of the time. If it is true, we can avoid hashing a second time. -- which avoids reading from a place in memory that is essentially -- random. One way to further improve the performance of this library -- would be to try to get doubleton buckets to use entropyA by searching -- for a suitable offset. lookup# :: (# ByteArray#, Int#, Int# #) -> (# ByteArray#, ArrayArray#, ByteArray#, ArrayArray#, SmallArray# v #) -> (# (# #) | v #) {-# noinline lookup# #-} lookup# (# keyArr#, keyOff#, keyLen# #) (# entropyA#, entropies#, offsets#, keys#, vals# #) | sz == 0 = (# (# #) | #) | PM.sizeofByteArray entropyA < reqEntropy = (# (# #) | #) | ixA <- w2i (unsafeRem (upW32 (Hash.bytes entropyA key)) (i2w sz)) , entropyB <- PM.indexUnliftedArray entropies ixA , offset <- fromIntegral @Int32 @Int (PM.indexPrimArray offsets ixA) = case sameByteArray entropyA entropyB of 1# | ix <- ixA , offsetIx <- offset + ix , bytesEqualsByteArray key (PM.indexUnliftedArray keys offsetIx) , !(# v #) <- PM.indexSmallArray## vals offsetIx -> (# | v #) | otherwise -> (# (# #) | #) _ | PM.sizeofByteArray entropyB >= reqEntropy , ix <- w2i (unsafeRem (upW32 (Hash.bytes entropyB key)) (i2w sz)) , offsetIx <- offset + ix , bytesEqualsByteArray key (PM.indexUnliftedArray keys offsetIx) , !(# v #) <- PM.indexSmallArray## vals offsetIx -> (# | v #) | otherwise -> (# (# #) | #) where sz = PM.sizeofUnliftedArray entropies reqEntropy = w2i (requiredEntropy (i2w (Bytes.length key))) key = Bytes (ByteArray keyArr#) (I# keyOff#) (I# keyLen#) entropyA = ByteArray entropyA# entropies = UnliftedArray entropies# :: UnliftedArray ByteArray keys = UnliftedArray keys# :: UnliftedArray ByteArray vals = SmallArray vals# offsets = PrimArray offsets# :: PrimArray Int32 unsafeRem :: Word -> Word -> Word unsafeRem (W# a) (W# b) = W# (Exts.remWord# a b) fromListWithGen :: forall a s v. a -- ^ Source of randomness -> (a -> Int -> ST s ByteArray) -> (v -> v -> v) -> [(Bytes,v)] -> ST s (Map v) fromListWithGen h ask combine xs | count == 0 = pure (Map mempty mempty mempty mempty mempty) | otherwise = do let maxLen' = w2i $ requiredEntropy $ i2w $ List.foldl' (\acc (b,_) -> max (Bytes.length b) acc) 0 xs' allowedCollisions = ceiling (sqrt (fromIntegral @Int @Double (count + 1))) :: Int entropyA <- findInitialEntropy h ask maxLen' count allowedCollisions xs' let groups :: [[(Word,(Bytes,v))]] groups = List.sortOn (Down . List.length @[]) (List.groupBy (\(x,_) (y,_) -> x == y) (List.sortOn fst (List.map (\(b,v) -> (rem (fromIntegral @Word32 @Word (Hash.bytes entropyA b)) (i2w count), (b,v))) xs' ) ) ) used <- PM.newSmallArray count False keys <- PM.newUnliftedArray count (mempty :: ByteArray) values <- PM.newSmallArray count (errorThunk @v) entropies <- PM.newUnliftedArray count (mempty :: ByteArray) offsets <- PM.newPrimArray count PM.setPrimArray offsets 0 count (0 :: Int32) let {-# SCC goB #-} goB :: [ByteArray] -> [[(Word,(Bytes,v))]] -> ST s () goB !_ [] = pure () goB !cache (x : ps) = case x of [(w,(key,val))] -> do -- Space optimization for singleton buckets. If only one key -- hashed to a bucket, we just use entropyA as the entropy -- since it is guaranteed to be big enough. Then we use the -- offset field to correct the hash. This avoid creating any -- additional entropy byte arrays for singleton buckets. -- Technically, it should be possible to do this for some -- of the doubletons as well. It is just a little more -- difficult. let ix = w2i (unsafeHeadFst x) j <- findUnused used PM.writeUnliftedArray entropies ix entropyA PM.writePrimArray offsets ix (fromIntegral @Int @Int32 (j - fromIntegral w)) PM.writeSmallArray used j True PM.writeUnliftedArray keys j (Bytes.toByteArray key) PM.writeSmallArray values j val goB cache ps _ -> do let ix = w2i (unsafeHeadFst x) keyVals = map snd x !maxGroupLen = List.foldl' (\acc (b,_) -> max (Bytes.length b) acc) 0 keyVals reqEntrSz = w2i (requiredEntropy (i2w maxGroupLen)) -- As a space optimization, we try out all options from the cache. -- If we can reuse random bytes that were used for a different key, -- we can save a lot of space. Reuse is frequently possible. e <- runExceptT $ for_ (entropyA : cache) $ \entropy -> if PM.sizeofByteArray entropy >= reqEntrSz then ExceptT $ attempt entropy ix keyVals >>= \case True -> pure (Left ()) False -> pure (Right ()) else ExceptT (pure (Right ())) case e of Left () -> goB cache ps Right () -> do goD cache 100000 ix keyVals reqEntrSz ps updateSlots :: ByteArray -> Int -> [(Bytes,v)] -> ST s () updateSlots !entropy !ix keyVals = do PM.writeUnliftedArray entropies ix entropy for_ keyVals $ \(key,val) -> do let j = fromIntegral @Word @Int (rem (fromIntegral @Word32 @Word (Hash.bytes entropy key)) (i2w count)) PM.writeSmallArray used j True PM.writeUnliftedArray keys j (Bytes.toByteArray key) PM.writeSmallArray values j val attempt :: ByteArray -> Int -> [(Bytes,v)] -> ST s Bool attempt !entropy !ix keyVals = do tmpUsed <- PM.cloneSmallMutableArray used 0 count allGood <- foldlM (\good (key,_) -> if good then do let j = fromIntegral @Word @Int (rem (upW32 (Hash.bytes entropy key)) (i2w count)) PM.readSmallArray tmpUsed j >>= \case True -> pure False False -> do PM.writeSmallArray tmpUsed j True pure True else pure False ) True keyVals if allGood then do updateSlots entropy ix keyVals pure True else pure False {-# SCC goD #-} goD :: [ByteArray] -> Int -> Int -> [(Bytes,v)] -> Int -> [[(Word,(Bytes,v))]] -> ST s () goD !cache !counter !ix keyVals !entropySz zs = do entropy <- ask h entropySz attempt entropy ix keyVals >>= \case True -> goB (entropy : cache) zs False -> case counter of 0 -> throw $ HashMapException count (map fst keyVals) ((fmap.fmap.fmap) fst groups) _ -> goD cache (counter - 1) ix keyVals entropySz zs -- Notice that we do not start out with entropyA. We manually cons that -- onto the top every time, so that if it can get reused, it does. We -- would rather it get reused than anything else since there is an -- optimization in the lookup function that avoids computing the hash -- twice if this entropy gets used. goB [] groups vals' <- PM.unsafeFreezeSmallArray values keys' <- PM.unsafeFreezeUnliftedArray keys entropies' <- PM.unsafeFreezeUnliftedArray entropies offsets' <- PM.unsafeFreezePrimArray offsets pure (Map entropyA entropies' offsets' keys' vals') where -- Combine duplicates upfront. xs' :: [(Bytes,v)] xs' = map (\rs -> ( unsafeHeadFst rs , List.foldl1' combine (map snd rs) ) ) (List.groupBy (\(x,_) (y,_) -> x == y) (List.sortOn fst xs)) count = List.length @[] xs' :: Int findUnused :: PM.SmallMutableArray s Bool -> ST s Int findUnused xs = go 0 where len = PM.sizeofSmallMutableArray xs go !ix = if ix < len then do PM.readSmallArray xs ix >>= \case True -> go (ix + 1) False -> pure ix else error "findUnused: could not find unused slot" fromListWith :: forall v. CryptHandle -- ^ Source of randomness -> (v -> v -> v) -> [(Bytes,v)] -> IO (Map v) fromListWith h combine xs = stToIO (fromListWithGen h askForEntropy combine xs) findInitialEntropy :: forall s a v. a -> (a -> Int -> ST s ByteArray) -> Int -> Int -> Int -> [(Bytes,v)] -> ST s ByteArray {-# SCC findInitialEntropy #-} findInitialEntropy !h ask !maxLen' !count !allowedCollisions xs = go 40 where go :: Int -> ST s ByteArray go !ix = do entropy <- ask h maxLen' let maxCollisions = List.foldl' (\acc zs -> max acc (List.length @[] zs)) 0 (List.group (List.sort (map (\(b,_) -> rem (fromIntegral @Word32 @Word (Hash.bytes entropy b)) (i2w count)) xs) ) ) if maxCollisions <= allowedCollisions then pure entropy else case ix of 0 -> throw (HashMapException (-1) [] []) _ -> go (ix - 1) askForEntropyST :: STRef s Int -> Int -> ST s ByteArray askForEntropyST !ref !n = do counter <- readSTRef ref writeSTRef ref $! mod (counter + 1) 8192 let (_,r) = divMod n 8 if | r /= 0 -> error "bytehash: askForEntropyST, request does not divide 8" | n > 8192 -> error "bytehash: askForEntropyST, requested more than 8192" | otherwise -> do dst <- PM.newPrimArray n PM.copyPtrToMutablePrimArray dst 0 (plusPtr Hash.entropy counter :: Ptr Word8) n PM.PrimArray x <- PM.unsafeFreezePrimArray dst pure (ByteArray x) askForEntropy :: CryptHandle -> Int -> ST RealWorld ByteArray askForEntropy !h !n = ioToST $ do entropy <- hGetEntropy h n when (ByteString.length entropy /= n) (fail "bytehash: askForEntropy failed, blame entropy") dst <- PM.newByteArray n ByteString.unsafeUseAsCStringLen entropy $ \(ptr, len) -> do let !(PM.MutableByteArray primDst) = dst PM.copyPtrToMutablePrimArray (PM.MutablePrimArray primDst) 0 ptr len PM.unsafeFreezeByteArray dst requiredEntropy :: Word -> Word requiredEntropy n = 8 * n + 8 errorThunk :: a errorThunk = error "Data.Bytes.HashMap: mistake" unsafeHeadFst :: [(a,b)] -> a unsafeHeadFst ((x,_) : _) = x unsafeHeadFst [] = error "Data.Bytes.HashMap: bad use of unsafeHeadFst" w2i :: Word -> Int w2i = fromIntegral i2w :: Int -> Word i2w = fromIntegral upW32 :: Word32 -> Word upW32 = fromIntegral bytesEqualsByteArray :: Bytes -> ByteArray -> Bool bytesEqualsByteArray (Bytes arr1 off1 len1) arr2 | len1 /= PM.sizeofByteArray arr2 = False | otherwise = compareByteArrays arr1 off1 arr2 0 len1 == EQ compareByteArrays :: ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering compareByteArrays (ByteArray ba1#) (I# off1#) (ByteArray ba2#) (I# off2#) (I# n#) = compare (I# (Exts.compareByteArrays# ba1# off1# ba2# off2# n#)) 0 data HashMapException = HashMapException !Int [Bytes] [[(Word,Bytes)]] deriving stock (Show,Eq) deriving anyclass (Exception) -- | For each slot, gives the number of keys that hash to it -- after the first hash function has been applied. distribution :: Map v -> [(Int,Int)] distribution (Map entropy entropies _ keys _) = let counts = runST $ do let sz = PM.sizeofUnliftedArray entropies dst <- PM.newPrimArray sz PM.setPrimArray dst 0 sz (0 :: Int) let go !ix = case ix of (-1) -> pure () _ -> do let key = PM.indexUnliftedArray keys ix let ixA = w2i (unsafeRem (upW32 (Hash.byteArray entropy key)) (i2w sz)) old <- PM.readPrimArray dst ixA PM.writePrimArray dst ixA (old + 1) go (ix - 1) go (sz - 1) PM.unsafeFreezePrimArray dst in List.sort $ List.map ( \xs -> case xs of [] -> errorWithoutStackTrace "bytehash: distribution impl error" y : _ -> (y,List.length xs) ) (List.group (List.sort (Exts.toList counts))) -- | The number of non-matching entropies being used. distinctEntropies :: Map v -> Int distinctEntropies (Map entropy entropies _ _ _) = List.length (List.group (List.sort (entropy : Exts.toList entropies))) sameByteArray :: ByteArray -> ByteArray -> Int# sameByteArray (ByteArray x) (ByteArray y) = Exts.sameMutableByteArray# (Exts.unsafeCoerce# x) (Exts.unsafeCoerce# y)