{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveAnyClass #-}
module Data.CuckooFilter.Internal (
    
    Size(..),
    makeSize,
    Filter(..),
    empty,
    
    FingerPrint(..),
    emptyFP,
    makeFingerprint,
    
    Bucket(..),
    emptyBucket,
    Index(..),
    IndexA(..),
    IndexB(..),
    replaceInBucket,
    insertBucket,
    primaryIndex,
    secondaryIndex,
    kickedSecondaryIndex,
    
    getCell,
    setCell
) where
import Data.Aeson (ToJSON, FromJSON)
import Data.Bits (xor, (.&.), (.|.), shiftR, shiftL)
import Data.Foldable (foldl')
import qualified Data.IntMap.Strict as IM
import Data.Hashable (Hashable, hash)
import Data.Serialize (Serialize)
import Data.Word (Word32, Word8)
import GHC.Generics (Generic)
import Numeric.Natural (Natural)
newtype Size = Size Natural
    deriving (Show, Eq, Ord)
    deriving stock Generic
    deriving newtype (Serialize, ToJSON, FromJSON)
makeSize :: Natural -> Maybe Size
makeSize n
    | n == 0 = Nothing
    | otherwise = Just . Size $ fromIntegral n
class Index a where
    toIndex :: Natural -> a -> Int
newtype IndexA = IA Word32
    deriving (Show, Eq, Ord, Generic)
    deriving newtype (ToJSON, FromJSON, Hashable)
    deriving anyclass Serialize
instance Index IndexA where
    toIndex numBuckets (IA n) = fromIntegral n `mod` fromIntegral numBuckets
newtype IndexB = IB Word32
    deriving (Show, Eq, Ord, Generic)
    deriving newtype (ToJSON, FromJSON, Hashable)
    deriving anyclass Serialize
instance Index IndexB where
    toIndex numBuckets (IB n) = fromIntegral n `mod` fromIntegral numBuckets
newtype FingerPrint = FP Word8
    deriving (Show, Eq, Ord, Generic)
    deriving newtype (ToJSON, FromJSON, Hashable)
    deriving anyclass Serialize
emptyFP :: FingerPrint
emptyFP = FP 0
newtype Bucket = B Word32
    deriving (Show, Ord)
    deriving stock Generic
    deriving newtype (ToJSON, FromJSON, Eq)
    deriving anyclass Serialize
emptyBucket :: Bucket
emptyBucket = B 0
getCell ::
    Bucket
    -> Natural 
    -> FingerPrint
getCell (B bucket) cellNumber =
    FP . fromIntegral $ (bucket .&. mask) `shiftR` offset
    where
        offset = (fromIntegral cellNumber) * 8
        mask = (255 :: Word32) `shiftL` offset
setCell ::
    Bucket
    -> Natural
    -> FingerPrint
    -> Bucket
setCell (B bucket) cellNumber (FP fp) =
    B $ zeroed .|. mask
    where
        offset = (fromIntegral cellNumber) * 8
        zeroed = (bucket .|. zeroMask) `xor` zeroMask
        zeroMask = (255 :: Word32) `shiftL` offset
        mask = (fromIntegral fp :: Word32) `shiftL` offset
data Filter a = F {
    buckets :: IM.IntMap Bucket, 
    numBuckets :: !Natural, 
    size :: !Size 
    }
    deriving (Show, Eq, Generic, Serialize, ToJSON, FromJSON)
empty ::
    Size 
    -> Filter a
empty (Size s) = F {
    
    
    
    buckets = IM.empty,
    numBuckets = numBuckets,
    size = Size s
    }
    where
        numBuckets = s `div` 4
insertBucket ::
    FingerPrint
    -> Bucket
    -> Maybe Bucket
insertBucket fp bucket =
    case (a,b,c,d) of
        (True, _, _, _) -> Just $ setCell bucket 0 fp
        (_, True, _, _) -> Just $ setCell bucket 1 fp
        (_, _, True, _) -> Just $ setCell bucket 2 fp
        (_, _, _, True) -> Just $ setCell bucket 3 fp
        _ -> Nothing
    where
        
        a = emptyFP == getCell bucket 0
        b = emptyFP == getCell bucket 1
        c = emptyFP == getCell bucket 2
        d = emptyFP == getCell bucket 3
replaceInBucket ::
    FingerPrint
    -> (FingerPrint -> Bucket -> (Bool, Bool, Bool, Bool)) 
    -> Bucket 
    -> (FingerPrint, Bucket) 
replaceInBucket fp predicate bucket = let
    results = predicate fp bucket
    in case results of
        (True, _, _, _) -> (getCell bucket 0, setCell bucket 0 fp)
        (_, True, _, _) -> (getCell bucket 1, setCell bucket 1 fp)
        (_, _, True, _) -> (getCell bucket 2, setCell bucket 2 fp)
        (_, _, _, True) -> (getCell bucket 3, setCell bucket 3 fp)
        _ -> (fp, bucket)
makeFingerprint :: Hashable a =>
    a
    -> FingerPrint
makeFingerprint a = FP . max 1 $  fromIntegral (abs $ hash a) `mod` 255
primaryIndex :: Hashable a =>
    a
    -> Natural
    -> IndexA
primaryIndex a numBuckets =
    IA . fromIntegral $ hash a
secondaryIndex ::
    FingerPrint
    -> Natural
    -> IndexA
    -> IndexB
secondaryIndex fp numBuckets (IA primary) =
    IB (primary `xor` fpHash)
    where
        fpHash = fromIntegral $ hash fp
kickedSecondaryIndex ::
    FingerPrint
    -> Natural
    -> IndexB
    -> IndexB
kickedSecondaryIndex fp numBuckets (IB alt) =
    secondaryIndex fp numBuckets (IA alt)