{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BinaryLiterals #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UnboxedSums #-}
{-# LANGUAGE UnboxedTuples #-}
module Data.Bytes.HashMap
( Map
, lookup
, fromList
, fromTrustedList
, fromListWith
, HashMapException(..)
, distribution
, distinctEntropies
) where
import Prelude hiding (lookup)
import Control.Exception (Exception,throw)
import Control.Monad (when)
import Control.Monad.ST (ST,stToIO,runST)
import Control.Monad.Trans.Except (ExceptT(ExceptT),runExceptT)
import Data.Bits ((.&.),complement)
import Data.Bytes.HashMap.Internal (Map(Map))
import Data.Bytes.Types (Bytes(Bytes))
import Data.Foldable (for_,foldlM)
import Data.Int (Int32)
import Data.Ord (Down(Down))
import Data.Primitive (ByteArray(..),PrimArray(..))
import Data.Primitive.SmallArray (SmallArray(..))
import Data.Primitive.Unlifted.Array (UnliftedArray(..))
import Data.STRef (STRef,newSTRef,writeSTRef,readSTRef)
import Foreign.Ptr (plusPtr)
import GHC.Exts (Ptr(Ptr),Int(I#),SmallArray#,ByteArray#,ArrayArray#,Int#)
import GHC.Exts (RealWorld)
import GHC.IO (ioToST)
import GHC.Word (Word(W#),Word32,Word8)
import System.Entropy (CryptHandle,hGetEntropy)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Unsafe as ByteString
import qualified Data.Bytes as Bytes
import qualified Data.Bytes.Hash as Hash
import qualified Data.List as List
import qualified Data.Primitive as PM
import qualified Data.Primitive.Ptr as PM
import qualified Data.Primitive.Unlifted.Array as PM
import qualified GHC.Exts as Exts
fromList :: CryptHandle -> [(Bytes,v)] -> IO (Map v)
fromList h = fromListWith h const
fromTrustedList :: [(Bytes,v)] -> Map v
fromTrustedList xs = runST $ do
ref <- newSTRef 0
fromListWithGen ref askForEntropyST const xs
lookup :: Bytes -> Map v -> Maybe v
{-# inline lookup #-}
lookup
(Bytes (ByteArray keyArr) (I# keyOff) (I# keyLen))
(Map (ByteArray entropyA) (UnliftedArray entropies) (PrimArray offsets) (UnliftedArray keys) (SmallArray vals)) =
case lookup# (# keyArr,keyOff,keyLen #) (# entropyA,entropies,offsets,keys,vals #) of
(# (# #) | #) -> Nothing
(# | v #) -> Just v
lookup# ::
(# ByteArray#, Int#, Int# #)
-> (# ByteArray#, ArrayArray#, ByteArray#, ArrayArray#, SmallArray# v #)
-> (# (# #) | v #)
{-# noinline lookup# #-}
lookup# (# keyArr#, keyOff#, keyLen# #) (# entropyA#, entropies#, offsets#, keys#, vals# #)
| sz == 0 = (# (# #) | #)
| PM.sizeofByteArray entropyA < reqEntropy = (# (# #) | #)
| ixA <- w2i (unsafeRem (upW32 (Hash.bytes entropyA key)) (i2w sz))
, entropyB <- PM.indexUnliftedArray entropies ixA
, offset <- fromIntegral @Int32 @Int (PM.indexPrimArray offsets ixA) =
case sameByteArray entropyA entropyB of
1# | ix <- ixA
, offsetIx <- offset + ix
, bytesEqualsByteArray key (PM.indexUnliftedArray keys offsetIx)
, !(# v #) <- PM.indexSmallArray## vals offsetIx -> (# | v #)
| otherwise -> (# (# #) | #)
_ | PM.sizeofByteArray entropyB >= reqEntropy
, ix <- w2i (unsafeRem (upW32 (Hash.bytes entropyB key)) (i2w sz))
, offsetIx <- offset + ix
, bytesEqualsByteArray key (PM.indexUnliftedArray keys offsetIx)
, !(# v #) <- PM.indexSmallArray## vals offsetIx -> (# | v #)
| otherwise -> (# (# #) | #)
where
sz = PM.sizeofUnliftedArray entropies
reqEntropy = w2i (requiredEntropy (i2w (Bytes.length key)))
key = Bytes (ByteArray keyArr#) (I# keyOff#) (I# keyLen#)
entropyA = ByteArray entropyA#
entropies = UnliftedArray entropies# :: UnliftedArray ByteArray
keys = UnliftedArray keys# :: UnliftedArray ByteArray
vals = SmallArray vals#
offsets = PrimArray offsets# :: PrimArray Int32
unsafeRem :: Word -> Word -> Word
unsafeRem (W# a) (W# b) = W# (Exts.remWord# a b)
fromListWithGen :: forall a s v.
a
-> (a -> Int -> ST s ByteArray)
-> (v -> v -> v)
-> [(Bytes,v)]
-> ST s (Map v)
fromListWithGen h ask combine xs
| count == 0 = pure (Map mempty mempty mempty mempty mempty)
| otherwise = do
let maxLen' = w2i $ requiredEntropy $ i2w $
List.foldl' (\acc (b,_) -> max (Bytes.length b) acc) 0 xs'
allowedCollisions = ceiling (sqrt (fromIntegral @Int @Double (count + 1))) :: Int
entropyA <- findInitialEntropy h ask maxLen' count allowedCollisions xs'
let groups :: [[(Word,(Bytes,v))]]
groups = List.sortOn (Down . List.length @[])
(List.groupBy (\(x,_) (y,_) -> x == y)
(List.sortOn fst
(List.map
(\(b,v) -> (rem (fromIntegral @Word32 @Word (Hash.bytes entropyA b)) (i2w count), (b,v)))
xs'
)
)
)
used <- PM.newSmallArray count False
keys <- PM.newUnliftedArray count (mempty :: ByteArray)
values <- PM.newSmallArray count (errorThunk @v)
entropies <- PM.newUnliftedArray count (mempty :: ByteArray)
offsets <- PM.newPrimArray count
PM.setPrimArray offsets 0 count (0 :: Int32)
let {-# SCC goB #-}
goB :: [ByteArray] -> [[(Word,(Bytes,v))]] -> ST s ()
goB !_ [] = pure ()
goB !cache (x : ps) = case x of
[(w,(key,val))] -> do
let ix = w2i (unsafeHeadFst x)
j <- findUnused used
PM.writeUnliftedArray entropies ix entropyA
PM.writePrimArray offsets ix (fromIntegral @Int @Int32 (j - fromIntegral w))
PM.writeSmallArray used j True
PM.writeUnliftedArray keys j (Bytes.toByteArray key)
PM.writeSmallArray values j val
goB cache ps
_ -> do
let ix = w2i (unsafeHeadFst x)
keyVals = map snd x
!maxGroupLen = List.foldl' (\acc (b,_) -> max (Bytes.length b) acc) 0 keyVals
reqEntrSz = w2i (requiredEntropy (i2w maxGroupLen))
e <- runExceptT $ for_ (entropyA : cache) $ \entropy -> if PM.sizeofByteArray entropy >= reqEntrSz
then ExceptT $ attempt entropy ix keyVals >>= \case
True -> pure (Left ())
False -> pure (Right ())
else ExceptT (pure (Right ()))
case e of
Left () -> goB cache ps
Right () -> do
goD cache 100000 ix keyVals reqEntrSz ps
updateSlots :: ByteArray -> Int -> [(Bytes,v)] -> ST s ()
updateSlots !entropy !ix keyVals = do
PM.writeUnliftedArray entropies ix entropy
for_ keyVals $ \(key,val) -> do
let j = fromIntegral @Word @Int (rem (fromIntegral @Word32 @Word (Hash.bytes entropy key)) (i2w count))
PM.writeSmallArray used j True
PM.writeUnliftedArray keys j (Bytes.toByteArray key)
PM.writeSmallArray values j val
attempt :: ByteArray -> Int -> [(Bytes,v)] -> ST s Bool
attempt !entropy !ix keyVals = do
tmpUsed <- PM.cloneSmallMutableArray used 0 count
allGood <- foldlM
(\good (key,_) -> if good
then do
let j = fromIntegral @Word @Int (rem (upW32 (Hash.bytes entropy key)) (i2w count))
PM.readSmallArray tmpUsed j >>= \case
True -> pure False
False -> do
PM.writeSmallArray tmpUsed j True
pure True
else pure False
) True keyVals
if allGood
then do
updateSlots entropy ix keyVals
pure True
else pure False
{-# SCC goD #-}
goD :: [ByteArray] -> Int -> Int -> [(Bytes,v)] -> Int -> [[(Word,(Bytes,v))]] -> ST s ()
goD !cache !counter !ix keyVals !entropySz zs = do
entropy <- ask h entropySz
attempt entropy ix keyVals >>= \case
True -> goB (entropy : cache) zs
False -> case counter of
0 -> throw $ HashMapException count
(map fst keyVals)
((fmap.fmap.fmap) fst groups)
_ -> goD cache (counter - 1) ix keyVals entropySz zs
goB [] groups
vals' <- PM.unsafeFreezeSmallArray values
keys' <- PM.unsafeFreezeUnliftedArray keys
entropies' <- PM.unsafeFreezeUnliftedArray entropies
offsets' <- PM.unsafeFreezePrimArray offsets
pure (Map entropyA entropies' offsets' keys' vals')
where
xs' :: [(Bytes,v)]
xs' = map
(\rs ->
( unsafeHeadFst rs
, List.foldl1' combine (map snd rs)
)
) (List.groupBy (\(x,_) (y,_) -> x == y) (List.sortOn fst xs))
count = List.length @[] xs' :: Int
findUnused :: PM.SmallMutableArray s Bool -> ST s Int
findUnused xs = go 0
where
len = PM.sizeofSmallMutableArray xs
go !ix = if ix < len
then do
PM.readSmallArray xs ix >>= \case
True -> go (ix + 1)
False -> pure ix
else error "findUnused: could not find unused slot"
fromListWith :: forall v.
CryptHandle
-> (v -> v -> v)
-> [(Bytes,v)]
-> IO (Map v)
fromListWith h combine xs = stToIO
(fromListWithGen h askForEntropy combine xs)
findInitialEntropy :: forall s a v.
a
-> (a -> Int -> ST s ByteArray)
-> Int
-> Int
-> Int
-> [(Bytes,v)]
-> ST s ByteArray
{-# SCC findInitialEntropy #-}
findInitialEntropy !h ask !maxLen' !count !allowedCollisions xs = go 40
where
go :: Int -> ST s ByteArray
go !ix = do
entropy <- ask h maxLen'
let maxCollisions = List.foldl'
(\acc zs -> max acc (List.length @[] zs))
0
(List.group
(List.sort
(map (\(b,_) -> rem (fromIntegral @Word32 @Word (Hash.bytes entropy b)) (i2w count)) xs)
)
)
if maxCollisions <= allowedCollisions
then pure entropy
else case ix of
0 -> throw (HashMapException (-1) [] [])
_ -> go (ix - 1)
askForEntropyST :: STRef s Int -> Int -> ST s ByteArray
askForEntropyST !ref !n = do
counter <- readSTRef ref
writeSTRef ref $! mod (counter + 1) 8192
let (_,r) = divMod n 8
if | r /= 0 -> error "bytehash: askForEntropyST, request does not divide 8"
| n > 8192 -> error "bytehash: askForEntropyST, requested more than 8192"
| otherwise -> do
dst <- PM.newPrimArray n
PM.copyPtrToMutablePrimArray dst 0
(plusPtr Hash.entropy counter :: Ptr Word8) n
PM.PrimArray x <- PM.unsafeFreezePrimArray dst
pure (ByteArray x)
askForEntropy :: CryptHandle -> Int -> ST RealWorld ByteArray
askForEntropy !h !n = ioToST $ do
entropy <- hGetEntropy h n
when (ByteString.length entropy /= n)
(fail "bytehash: askForEntropy failed, blame entropy")
dst <- PM.newByteArray n
ByteString.unsafeUseAsCStringLen entropy $ \(ptr, len) -> do
let !(PM.MutableByteArray primDst) = dst
PM.copyPtrToMutablePrimArray (PM.MutablePrimArray primDst) 0 ptr len
PM.unsafeFreezeByteArray dst
requiredEntropy :: Word -> Word
requiredEntropy n = 8 * n + 8
errorThunk :: a
errorThunk = error "Data.Bytes.HashMap: mistake"
unsafeHeadFst :: [(a,b)] -> a
unsafeHeadFst ((x,_) : _) = x
unsafeHeadFst [] = error "Data.Bytes.HashMap: bad use of unsafeHeadFst"
w2i :: Word -> Int
w2i = fromIntegral
i2w :: Int -> Word
i2w = fromIntegral
upW32 :: Word32 -> Word
upW32 = fromIntegral
bytesEqualsByteArray :: Bytes -> ByteArray -> Bool
bytesEqualsByteArray (Bytes arr1 off1 len1) arr2
| len1 /= PM.sizeofByteArray arr2 = False
| otherwise = compareByteArrays arr1 off1 arr2 0 len1 == EQ
compareByteArrays :: ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering
compareByteArrays (ByteArray ba1#) (I# off1#) (ByteArray ba2#) (I# off2#) (I# n#) =
compare (I# (Exts.compareByteArrays# ba1# off1# ba2# off2# n#)) 0
data HashMapException = HashMapException !Int [Bytes] [[(Word,Bytes)]]
deriving stock (Show,Eq)
deriving anyclass (Exception)
distribution :: Map v -> [(Int,Int)]
distribution (Map entropy entropies _ keys _) =
let counts = runST $ do
let sz = PM.sizeofUnliftedArray entropies
dst <- PM.newPrimArray sz
PM.setPrimArray dst 0 sz (0 :: Int)
let go !ix = case ix of
(-1) -> pure ()
_ -> do
let key = PM.indexUnliftedArray keys ix
let ixA = w2i (unsafeRem (upW32 (Hash.byteArray entropy key)) (i2w sz))
old <- PM.readPrimArray dst ixA
PM.writePrimArray dst ixA (old + 1)
go (ix - 1)
go (sz - 1)
PM.unsafeFreezePrimArray dst
in List.sort $ List.map
( \xs -> case xs of
[] -> errorWithoutStackTrace "bytehash: distribution impl error"
y : _ -> (y,List.length xs)
) (List.group (List.sort (Exts.toList counts)))
distinctEntropies :: Map v -> Int
distinctEntropies (Map entropy entropies _ _ _) =
List.length (List.group (List.sort (entropy : Exts.toList entropies)))
sameByteArray :: ByteArray -> ByteArray -> Int#
sameByteArray (ByteArray x) (ByteArray y) =
Exts.sameMutableByteArray# (Exts.unsafeCoerce# x) (Exts.unsafeCoerce# y)