{-# LANGUAGE CPP #-}

module Data.Interned.Extended.HashTableBased
  ( Id
  , Cache(..)
  , freshCache
  , cacheSize
  , resetCache

#ifdef PROFILE_CACHES
  , getMetrics
#endif

  , Interned(..)
  , intern
  ) where

import Data.Hashable
import qualified Data.HashTable.IO as HT
import Data.IORef
import GHC.IO ( unsafeDupablePerformIO )

import Data.HashTable.Extended

#ifdef PROFILE_CACHES
import Data.Memoization.Metrics ( CacheMetrics(CacheMetrics) )
#endif

----------------------------------------------------------------------------------------------------------

--------------------
------- Caches
--------------------

type Id = Int

-- | Tried using the BasicHashtable size function to remove need for this IORef
-- ( see https://github.com/gregorycollins/hashtables/pull/68 ), but it was slower
data Cache t = Cache { Cache t -> IORef Id
fresh :: !(IORef Id)
                     , Cache t -> CuckooHashTable (Description t) t
content :: !(HT.CuckooHashTable (Description t) t)
#ifdef PROFILE_CACHES
                     , queryCount :: !(IORef Int)
                     , missCount  :: !(IORef Int)
#endif
                     }

freshCache :: IO (Cache t)
freshCache :: IO (Cache t)
freshCache = IORef Id -> HashTable RealWorld (Description t) t -> Cache t
forall t. IORef Id -> CuckooHashTable (Description t) t -> Cache t
Cache (IORef Id -> HashTable RealWorld (Description t) t -> Cache t)
-> IO (IORef Id)
-> IO (HashTable RealWorld (Description t) t -> Cache t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Id -> IO (IORef Id)
forall a. a -> IO (IORef a)
newIORef Id
0
                   IO (HashTable RealWorld (Description t) t -> Cache t)
-> IO (HashTable RealWorld (Description t) t) -> IO (Cache t)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (HashTable RealWorld (Description t) t)
forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
HT.new
#ifdef PROFILE_CACHES
                   <*> newIORef 0
                   <*> newIORef 0
#endif

cacheSize :: Cache t -> IO Int
cacheSize :: Cache t -> IO Id
cacheSize Cache {fresh :: forall t. Cache t -> IORef Id
fresh = IORef Id
refI} = IORef Id -> IO Id
forall a. IORef a -> IO a
readIORef IORef Id
refI

resetCache :: (Interned t) => Cache t -> IO ()
resetCache :: Cache t -> IO ()
resetCache _c :: Cache t
_c@(Cache {fresh :: forall t. Cache t -> IORef Id
fresh=IORef Id
refI, content :: forall t. Cache t -> CuckooHashTable (Description t) t
content=CuckooHashTable (Description t) t
ht}) = do
  IORef Id -> Id -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Id
refI Id
0
  AnyHashTable -> IO ()
resetHashTable (CuckooHashTable (Description t) t -> AnyHashTable
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> AnyHashTable
AnyHashTable CuckooHashTable (Description t) t
ht)
#ifdef PROFILE_CACHES
  writeIORef (queryCount _c) 0
  writeIORef (missCount  _c) 0
#endif

bumpQueryCount :: Cache t -> IO ()
#ifdef PROFILE_CACHES
bumpQueryCount Cache {queryCount = ref} = modifyIORef ref (+1)
#else
bumpQueryCount :: Cache t -> IO ()
bumpQueryCount Cache t
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
{-# INLINE bumpQueryCount #-}

bumpMissCount :: Cache t -> IO ()
#ifdef PROFILE_CACHES
bumpMissCount Cache {missCount = ref} = modifyIORef ref (+1)
#else
bumpMissCount :: Cache t -> IO ()
bumpMissCount Cache t
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
{-# INLINE bumpMissCount #-}


#ifdef PROFILE_CACHES
getMetrics :: Cache t -> IO CacheMetrics
getMetrics Cache {queryCount = qc, missCount = mc} = CacheMetrics <$> readIORef qc <*> readIORef mc
#endif

--------------------
------- Interning
--------------------

class ( Eq (Description t)
      , Hashable (Description t)
      ) => Interned t where
  data Description t
  type Uninterned t
  describe :: Uninterned t -> Description t
  identify :: Id -> Uninterned t -> t
  cache        :: Cache t

intern :: Interned t => Uninterned t -> t
intern :: Uninterned t -> t
intern !Uninterned t
bt = IO t -> t
forall a. IO a -> a
unsafeDupablePerformIO (IO t -> t) -> IO t -> t
forall a b. (a -> b) -> a -> b
$ do
    let c :: Cache t
c    = Cache t
forall t. Interned t => Cache t
cache
    let refI :: IORef Id
refI = Cache t -> IORef Id
forall t. Cache t -> IORef Id
fresh Cache t
c
    let ht :: CuckooHashTable (Description t) t
ht   = Cache t -> CuckooHashTable (Description t) t
forall t. Cache t -> CuckooHashTable (Description t) t
content Cache t
c
    Cache t -> IO ()
forall t. Cache t -> IO ()
bumpQueryCount Cache t
c
    Maybe t
v <- CuckooHashTable (Description t) t -> Description t -> IO (Maybe t)
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> IO (Maybe v)
HT.lookup CuckooHashTable (Description t) t
ht Description t
dt
    case Maybe t
v of
      Maybe t
Nothing -> do Cache t -> IO ()
forall t. Cache t -> IO ()
bumpMissCount Cache t
c
                    Id
i <- IORef Id -> (Id -> (Id, Id)) -> IO Id
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Id
refI (\Id
i -> (Id
i Id -> Id -> Id
forall a. Num a => a -> a -> a
+ Id
1, Id
i))
                    let t :: t
t = Id -> Uninterned t -> t
forall t. Interned t => Id -> Uninterned t -> t
identify Id
i Uninterned t
bt
                    CuckooHashTable (Description t) t -> Description t -> t -> IO ()
forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> v -> IO ()
HT.insert CuckooHashTable (Description t) t
ht Description t
dt t
t
                    t -> IO t
forall (m :: * -> *) a. Monad m => a -> m a
return t
t
      Just t
t  -> t -> IO t
forall (m :: * -> *) a. Monad m => a -> m a
return t
t
  where
  !dt :: Description t
dt = Uninterned t -> Description t
forall t. Interned t => Uninterned t -> Description t
describe Uninterned t
bt