{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BinaryLiterals #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UnboxedSums #-}
{-# LANGUAGE UnboxedTuples #-}

-- | Implementation of static hash map data structure.
module Data.Bytes.HashMap.Word
  ( Map
  , lookup
  , fromList
  , fromTrustedList
  , fromListWith

    -- * Used for testing
  , distribution
  , distinctEntropies
  ) where

import Prelude hiding (lookup)

import Data.Bytes.Types (Bytes (Bytes))
import Data.Int (Int32)
import Data.Primitive (ByteArray (..), PrimArray (..))
import Data.Primitive.Unlifted.Array (UnliftedArray, UnliftedArray_ (UnliftedArray))
import Data.Primitive.Unlifted.Array.Primops (UnliftedArray#)
import GHC.Exts (ByteArray#, Int (I#), Int#, Word#)
import GHC.Word (Word (W#), Word32)
import System.Entropy (CryptHandle)

import qualified Data.Bytes as Bytes
import qualified Data.Bytes.Hash as Hash
import qualified Data.Bytes.HashMap as Lifted
import qualified Data.Bytes.HashMap.Internal as Lifted
import qualified Data.List as List
import qualified Data.Primitive as PM
import qualified Data.Primitive.Unlifted.Array as PM
import qualified GHC.Exts as Exts

{- | A static perfect hash table where the keys are byte arrays. This
  table cannot be updated after its creation, but all lookups have
  guaranteed O(1) worst-case cost. It consumes linear space. This
  is an excellent candidate for use with compact regions.
-}
data Map
  = Map
      !ByteArray -- top-level entropy
      !(UnliftedArray ByteArray) -- entropies
      !(PrimArray Int32) -- offset to apply to hash, could probably be 32 bits
      !(UnliftedArray ByteArray) -- keys
      !(PrimArray Word) -- values

fromLifted :: Lifted.Map Word -> Map
fromLifted :: Map Word -> Map
fromLifted (Lifted.Map ByteArray
a UnliftedArray ByteArray
b PrimArray Int32
c UnliftedArray ByteArray
d SmallArray Word
e) = ByteArray
-> UnliftedArray ByteArray
-> PrimArray Int32
-> UnliftedArray ByteArray
-> PrimArray Word
-> Map
Map ByteArray
a UnliftedArray ByteArray
b PrimArray Int32
c UnliftedArray ByteArray
d ([Item (PrimArray Word)] -> PrimArray Word
forall l. IsList l => [Item l] -> l
Exts.fromList (SmallArray Word -> [Item (SmallArray Word)]
forall l. IsList l => l -> [Item l]
Exts.toList SmallArray Word
e))

fromList :: CryptHandle -> [(Bytes, Word)] -> IO Map
fromList :: CryptHandle -> [(Bytes, Word)] -> IO Map
fromList CryptHandle
h = (Map Word -> Map) -> IO (Map Word) -> IO Map
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Map Word -> Map
fromLifted (IO (Map Word) -> IO Map)
-> ([(Bytes, Word)] -> IO (Map Word)) -> [(Bytes, Word)] -> IO Map
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CryptHandle -> [(Bytes, Word)] -> IO (Map Word)
forall v. CryptHandle -> [(Bytes, v)] -> IO (Map v)
Lifted.fromList CryptHandle
h

fromListWith ::
  -- | Source of randomness
  CryptHandle ->
  (Word -> Word -> Word) ->
  [(Bytes, Word)] ->
  IO Map
fromListWith :: CryptHandle -> (Word -> Word -> Word) -> [(Bytes, Word)] -> IO Map
fromListWith CryptHandle
h Word -> Word -> Word
c [(Bytes, Word)]
xs = (Map Word -> Map) -> IO (Map Word) -> IO Map
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Map Word -> Map
fromLifted (CryptHandle
-> (Word -> Word -> Word) -> [(Bytes, Word)] -> IO (Map Word)
forall v.
CryptHandle -> (v -> v -> v) -> [(Bytes, v)] -> IO (Map v)
Lifted.fromListWith CryptHandle
h Word -> Word -> Word
c [(Bytes, Word)]
xs)

{- | Build a map from keys that are known at compile time.
All keys must be 64 bytes or less. This uses a built-in source
of entropy and is entirely deterministic. An adversarial user
could feed this function keys that cause it to error out rather
than completing.
-}
fromTrustedList :: [(Bytes, Word)] -> Map
fromTrustedList :: [(Bytes, Word)] -> Map
fromTrustedList = Map Word -> Map
fromLifted (Map Word -> Map)
-> ([(Bytes, Word)] -> Map Word) -> [(Bytes, Word)] -> Map
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Bytes, Word)] -> Map Word
forall v. [(Bytes, v)] -> Map v
Lifted.fromTrustedList

lookup :: Bytes -> Map -> Maybe Word
{-# INLINE lookup #-}
lookup :: Bytes -> Map -> Maybe Word
lookup
  (Bytes (ByteArray ByteArray#
keyArr) (I# Int#
keyOff) (I# Int#
keyLen))
  (Map (ByteArray ByteArray#
entropyA) (UnliftedArray UnliftedArray# (Unlifted ByteArray)
entropies) (PrimArray ByteArray#
offsets) (UnliftedArray UnliftedArray# (Unlifted ByteArray)
keys) (PrimArray ByteArray#
vals)) =
    case (# ByteArray#, Int#, Int# #)
-> (# ByteArray#, UnliftedArray# ByteArray#, ByteArray#,
      UnliftedArray# ByteArray#, ByteArray# #)
-> (# (# #) | Word# #)
lookup# (# ByteArray#
keyArr, Int#
keyOff, Int#
keyLen #) (# ByteArray#
entropyA, UnliftedArray# ByteArray#
UnliftedArray# (Unlifted ByteArray)
entropies, ByteArray#
offsets, UnliftedArray# ByteArray#
UnliftedArray# (Unlifted ByteArray)
keys, ByteArray#
vals #) of
      (# (# #) | #) -> Maybe Word
forall a. Maybe a
Nothing
      (# | Word#
v #) -> Word -> Maybe Word
forall a. a -> Maybe a
Just (Word# -> Word
W# Word#
v)

lookup# ::
  (# ByteArray#, Int#, Int# #) ->
  (# ByteArray#, UnliftedArray# ByteArray#, ByteArray#, UnliftedArray# ByteArray#, ByteArray# #) ->
  (# (# #) | Word# #)
{-# NOINLINE lookup# #-}
lookup# :: (# ByteArray#, Int#, Int# #)
-> (# ByteArray#, UnliftedArray# ByteArray#, ByteArray#,
      UnliftedArray# ByteArray#, ByteArray# #)
-> (# (# #) | Word# #)
lookup# (# ByteArray#
keyArr#, Int#
keyOff#, Int#
keyLen# #) (# ByteArray#
entropyA#, UnliftedArray# ByteArray#
entropies#, ByteArray#
offsets#, UnliftedArray# ByteArray#
keys#, ByteArray#
vals# #)
  | Int
sz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (# (# #) | #)
  | ByteArray -> Int
PM.sizeofByteArray ByteArray
entropyA Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
reqEntropy = (# (# #) | #)
  | Int
ixA <- Word -> Int
w2i (Word -> Word -> Word
unsafeRem (Word32 -> Word
upW32 (ByteArray -> Bytes -> Word32
Hash.bytes ByteArray
entropyA Bytes
key)) (Int -> Word
i2w Int
sz))
  , ByteArray
entropyB <- UnliftedArray ByteArray -> Int -> ByteArray
forall a. PrimUnlifted a => UnliftedArray a -> Int -> a
PM.indexUnliftedArray UnliftedArray_ ByteArray# ByteArray
UnliftedArray ByteArray
entropies Int
ixA
  , Int
offset <- forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int32 @Int (PrimArray Int32 -> Int -> Int32
forall a. Prim a => PrimArray a -> Int -> a
PM.indexPrimArray PrimArray Int32
offsets Int
ixA) =
      case ByteArray -> ByteArray -> Int#
sameByteArray ByteArray
entropyA ByteArray
entropyB of
        Int#
1#
          | Int
ix <- Int
ixA
          , Int
offsetIx <- Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ix
          , Bytes -> ByteArray -> Bool
bytesEqualsByteArray Bytes
key (UnliftedArray ByteArray -> Int -> ByteArray
forall a. PrimUnlifted a => UnliftedArray a -> Int -> a
PM.indexUnliftedArray UnliftedArray_ ByteArray# ByteArray
UnliftedArray ByteArray
keys Int
offsetIx)
          , !(W# Word#
v) <- PrimArray Word -> Int -> Word
forall a. Prim a => PrimArray a -> Int -> a
PM.indexPrimArray PrimArray Word
vals Int
offsetIx ->
              (# | Word#
v #)
          | Bool
otherwise -> (# (# #) | #)
        Int#
_
          | ByteArray -> Int
PM.sizeofByteArray ByteArray
entropyB Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
reqEntropy
          , Int
ix <- Word -> Int
w2i (Word -> Word -> Word
unsafeRem (Word32 -> Word
upW32 (ByteArray -> Bytes -> Word32
Hash.bytes ByteArray
entropyB Bytes
key)) (Int -> Word
i2w Int
sz))
          , Int
offsetIx <- Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ix
          , Bytes -> ByteArray -> Bool
bytesEqualsByteArray Bytes
key (UnliftedArray ByteArray -> Int -> ByteArray
forall a. PrimUnlifted a => UnliftedArray a -> Int -> a
PM.indexUnliftedArray UnliftedArray_ ByteArray# ByteArray
UnliftedArray ByteArray
keys Int
offsetIx)
          , !(W# Word#
v) <- PrimArray Word -> Int -> Word
forall a. Prim a => PrimArray a -> Int -> a
PM.indexPrimArray PrimArray Word
vals Int
offsetIx ->
              (# | Word#
v #)
          | Bool
otherwise -> (# (# #) | #)
 where
  sz :: Int
sz = UnliftedArray ByteArray -> Int
forall e. UnliftedArray e -> Int
PM.sizeofUnliftedArray UnliftedArray_ ByteArray# ByteArray
UnliftedArray ByteArray
entropies
  reqEntropy :: Int
reqEntropy = Word -> Int
w2i (Word -> Word
requiredEntropy (Int -> Word
i2w (Bytes -> Int
Bytes.length Bytes
key)))
  key :: Bytes
key = ByteArray -> Int -> Int -> Bytes
Bytes (ByteArray# -> ByteArray
ByteArray ByteArray#
keyArr#) (Int# -> Int
I# Int#
keyOff#) (Int# -> Int
I# Int#
keyLen#)
  entropyA :: ByteArray
entropyA = ByteArray# -> ByteArray
ByteArray ByteArray#
entropyA#
  entropies :: UnliftedArray ByteArray
entropies = UnliftedArray# ByteArray# -> UnliftedArray_ ByteArray# ByteArray
forall (unlifted_a :: UnliftedType) a.
UnliftedArray# unlifted_a -> UnliftedArray_ unlifted_a a
UnliftedArray UnliftedArray# ByteArray#
entropies# :: UnliftedArray ByteArray
  keys :: UnliftedArray ByteArray
keys = UnliftedArray# ByteArray# -> UnliftedArray_ ByteArray# ByteArray
forall (unlifted_a :: UnliftedType) a.
UnliftedArray# unlifted_a -> UnliftedArray_ unlifted_a a
UnliftedArray UnliftedArray# ByteArray#
keys# :: UnliftedArray ByteArray
  vals :: PrimArray Word
vals = ByteArray# -> PrimArray Word
forall a. ByteArray# -> PrimArray a
PrimArray ByteArray#
vals# :: PrimArray Word
  offsets :: PrimArray Int32
offsets = ByteArray# -> PrimArray Int32
forall a. ByteArray# -> PrimArray a
PrimArray ByteArray#
offsets# :: PrimArray Int32

unsafeRem :: Word -> Word -> Word
unsafeRem :: Word -> Word -> Word
unsafeRem (W# Word#
a) (W# Word#
b) = Word# -> Word
W# (Word# -> Word# -> Word#
Exts.remWord# Word#
a Word#
b)

i2w :: Int -> Word
i2w :: Int -> Word
i2w = Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral

requiredEntropy :: Word -> Word
requiredEntropy :: Word -> Word
requiredEntropy Word
n = Word
8 Word -> Word -> Word
forall a. Num a => a -> a -> a
* Word
n Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word
8

w2i :: Word -> Int
w2i :: Word -> Int
w2i = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral

bytesEqualsByteArray :: Bytes -> ByteArray -> Bool
bytesEqualsByteArray :: Bytes -> ByteArray -> Bool
bytesEqualsByteArray (Bytes ByteArray
arr1 Int
off1 Int
len1) ByteArray
arr2
  | Int
len1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteArray -> Int
PM.sizeofByteArray ByteArray
arr2 = Bool
False
  | Bool
otherwise = ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering
compareByteArrays ByteArray
arr1 Int
off1 ByteArray
arr2 Int
0 Int
len1 Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ

compareByteArrays :: ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering
compareByteArrays :: ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering
compareByteArrays (ByteArray ByteArray#
ba1#) (I# Int#
off1#) (ByteArray ByteArray#
ba2#) (I# Int#
off2#) (I# Int#
n#) =
  Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int# -> Int
I# (ByteArray# -> Int# -> ByteArray# -> Int# -> Int# -> Int#
Exts.compareByteArrays# ByteArray#
ba1# Int#
off1# ByteArray#
ba2# Int#
off2# Int#
n#)) Int
0

upW32 :: Word32 -> Word
upW32 :: Word32 -> Word
upW32 = Word32 -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral

distribution :: Map -> [(Int, Int)]
distribution :: Map -> [(Int, Int)]
distribution (Map ByteArray
entropy UnliftedArray ByteArray
entropies PrimArray Int32
offsets UnliftedArray ByteArray
keys PrimArray Word
vals) =
  Map Word -> [(Int, Int)]
forall v. Map v -> [(Int, Int)]
Lifted.distribution
    (ByteArray
-> UnliftedArray ByteArray
-> PrimArray Int32
-> UnliftedArray ByteArray
-> SmallArray Word
-> Map Word
forall v.
ByteArray
-> UnliftedArray ByteArray
-> PrimArray Int32
-> UnliftedArray ByteArray
-> SmallArray v
-> Map v
Lifted.Map ByteArray
entropy UnliftedArray ByteArray
entropies PrimArray Int32
offsets UnliftedArray ByteArray
keys ([Item (SmallArray Word)] -> SmallArray Word
forall l. IsList l => [Item l] -> l
Exts.fromList (PrimArray Word -> [Item (PrimArray Word)]
forall l. IsList l => l -> [Item l]
Exts.toList PrimArray Word
vals)))

-- | The number of non-matching entropies being used.
distinctEntropies :: Map -> Int
distinctEntropies :: Map -> Int
distinctEntropies (Map ByteArray
entropy UnliftedArray ByteArray
entropies PrimArray Int32
_ UnliftedArray ByteArray
_ PrimArray Word
_) =
  [[ByteArray]] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
List.length ([ByteArray] -> [[ByteArray]]
forall a. Eq a => [a] -> [[a]]
List.group ([ByteArray] -> [ByteArray]
forall a. Ord a => [a] -> [a]
List.sort (ByteArray
entropy ByteArray -> [ByteArray] -> [ByteArray]
forall a. a -> [a] -> [a]
: UnliftedArray_ ByteArray# ByteArray
-> [Item (UnliftedArray_ ByteArray# ByteArray)]
forall l. IsList l => l -> [Item l]
Exts.toList UnliftedArray_ ByteArray# ByteArray
UnliftedArray ByteArray
entropies)))

sameByteArray :: ByteArray -> ByteArray -> Int#
sameByteArray :: ByteArray -> ByteArray -> Int#
sameByteArray (ByteArray ByteArray#
x) (ByteArray ByteArray#
y) =
  MutableByteArray# Any -> MutableByteArray# Any -> Int#
forall s. MutableByteArray# s -> MutableByteArray# s -> Int#
Exts.sameMutableByteArray# (ByteArray# -> MutableByteArray# Any
forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
x) (ByteArray# -> MutableByteArray# Any
forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
y)