{-# LANGUAGE GADTs, RecordWildCards #-}
{- |
Module: Data.TTLHashTable
Description: Adds TTL entry expiration to the excellent mutable hash tables from the hashtables package
Copyright: (c) Erick Gonzalez, 2019
License: BSD3
Maintainer: erick@codemonkeylabs.de

This library extends fast mutable hashtables so that entries added can be expired after a given TTL (time to live). This TTL can be specified as a default property of the table or on a per entry basis.

-}
module Data.TTLHashTable (
-- * How to use this module:
-- |
-- Import one of the hash table modules from the hashtables package.. i.e. Basic, Cuckoo, etc
-- and "wrap" them in a TTLHashTable:
--
-- @
-- import Data.HashTable.ST.Basic as Basic
--
-- type HashTable k v = TTLHashTable Basic.HashTable k v
--
-- @
--
-- You can then use the functions in this module with this hashtable type. Note that the
-- functions in this module which can fail offer a flexible error handling strategy by virtue of
-- working in the context of a 'Failable' monad. So for example, if the function is used directly
-- in the IO monad and a failure occurs it would then result in an exception being thrown. However
-- if the context supports the possibiliy of failure like a 'MaybeT' or 'ExceptT'
-- transformer, it would then instead return something like @IO Nothing@ or @Left NotFound@
-- respectively (depending on the actual failure of course).
--
-- None of the functions in this module are thread safe, just as the underlying mutable
-- hash tables in the ST monad aren't as well. If concurrent threads need to operate on the same
-- table, you need to provide external means of synchronization to guarantee exclusive access
-- to the table
                          TTLHashTable,
                          TTLHashTableError(..),
                          Settings(..),
                          insert,
                          insert_,
                          insertWithTTL,
                          insertWithTTL_,
                          delete,
                          find,
                          foldM,
                          getSettings,
                          Data.TTLHashTable.mapM_,
                          lookup,
                          new,
                          newWithSettings,
                          reconfigure,
                          removeExpired,
                          size) where

import Prelude                 hiding (lookup)
import Control.Exception              (Exception)
import Control.Monad                  (void, when)
import Control.Monad.Failable         (Failable, failure)
import Control.Monad.IO.Class         (MonadIO, liftIO)
import Control.Monad.Trans.Maybe      (MaybeT(..), runMaybeT)
import Data.Bits                      (finiteBitSize)
import Data.Default                   (Default, def)
import Data.Hashable                  (Hashable)
import Data.IntMap.Strict             (IntMap)
import Data.IORef                     (IORef,
                                       atomicModifyIORef',
                                       modifyIORef',
                                       newIORef,
                                       readIORef,
                                       writeIORef)
import Data.Typeable                  (Typeable)
import System.Clock                   (Clock(..), TimeSpec(..), getTime)

import qualified Data.HashTable.Class as C
import qualified Data.HashTable.IO    as H
import qualified Data.IntMap.Strict   as M

-- | The TTL hash table type, parameterized on the type of table, key and value.
data TTLHashTable h k v where
    TTLHashTable :: (C.HashTable h)
                    => { hashTable_        :: H.IOHashTable h k (Value v),
                        maxSizeRef_       :: IORef Int,
                        numEntriesRef_    :: IORef Int,
                        timeStampsRef_    :: IORef (IntMap k),
                        renewUponReadRef_ :: IORef Bool,
                        defaultTTLRef_    :: IORef Int,
                        gcMaxEntriesRef_  :: IORef Int }
                    -> TTLHashTable h k v

data Value v = Value { expiresAt :: Int,
                       ttl       :: Int,
                       value     :: v }

-- | The 'Settings' type allows for specifying how the hash table should behave.
data Settings = Settings {
                           -- | Maximum size of the hash table. Once reached, insertion of keys
                           -- will fail. Defaults to @maxBound@
                           maxSize       :: Int,
                           -- | Whether a succesful lookup of an entry means the TTL of the entry
                           -- should be restarted. Default is 'False'
                           renewUponRead :: Bool,
                           -- | Default TTL value in milliseconds to be used for an entry if none
                           -- is specified at insertion time
                           defaultTTL    :: Int,
                           -- | Maximum number of entries that can be garbage collected in one
                           -- single call to removeExpired. This setting is provided so that
                           -- the possibility of long running garbage collection can be managed
                           -- by the user of the library. Default is @maxBound@
                           gcMaxEntries  :: Int }

-- | Exception type used to report failures (depending on calling context)
data TTLHashTableError =
    NotFound      -- ^ The entry was not found in the table
  | ExpiredEntry  -- ^ The entry did exist but is no longer valid
  | HashTableFull -- ^ The maximum size for the table has been reached
  | UnsupportedPlatform String -- ^ The platform is not supported
  | HashTableTooLarge -- ^ The hash table is too large for the provided settings
    deriving (Eq, Typeable, Show)

instance Exception TTLHashTableError

instance Default Settings where
    def = Settings { maxSize       = maxBound,
                     renewUponRead = False,
                     defaultTTL    = 365 * 24 * 60 * 60 * 1000, -- 1 year in milliseconds
                     gcMaxEntries  = maxBound
                   }

assertIntSize :: (Failable m) => m ()
assertIntSize =
    when (finiteBitSize maxInt < 64) $
      failure $ UnsupportedPlatform "Int size on this platform is < 64 bits"
          where maxInt   = maxBound :: Int


-- | Creates a new hash table with default settings
new :: (C.HashTable h, MonadIO m, Failable m) => m (TTLHashTable h k v)
new = newWithSettings def

-- | Creates a new hash table with the specified settings. Use the 'Default' instance of 'Settings'
-- and then fine tune parameters as needed. I.e:
-- @
-- newWithSettings def { maxSize = 64 }
-- @
newWithSettings :: (C.HashTable h, MonadIO m, Failable m) => Settings -> m (TTLHashTable h k v)
newWithSettings Settings {..} = do
    assertIntSize
    liftIO $ do
      table   <- newHT
      sRef    <- newIORef 0
      tRef    <- newIORef M.empty
      msRef   <- newIORef maxSize
      rurRef  <- newIORef renewUponRead
      dTTLRef <- newIORef defaultTTL
      meRef   <- newIORef gcMaxEntries
      return TTLHashTable { hashTable_        = table,
                            maxSizeRef_       = msRef,
                            numEntriesRef_    = sRef,
                            timeStampsRef_    = tRef,
                            renewUponReadRef_ = rurRef,
                            defaultTTLRef_    = dTTLRef,
                            gcMaxEntriesRef_  = meRef
                          }
          where newHT | maxSize == maxBound = H.new
                      | otherwise           = H.newSized maxSize

-- | Insert a new entry into the hash table. Take note of the fact that __this function can fail__
-- for example if table has reached maxSize entries for example. Failure is signaled depending on
-- the calling 'Failable' context. So for example if called in pure IO, it would throw a regular
-- IO exception (of type 'TTLHashTableError'). For this reason,
-- __you probably  want to call this function in a 'MaybeT' or 'ExceptT' monad__
insert :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
          => TTLHashTable h k v
          -> k
          -> v
          -> m ()
insert ht@TTLHashTable {..} k v = do
  ttl <- liftIO $ readIORef defaultTTLRef_
  insertWithTTL ht ttl k v

-- | Just like 'insert' but doesn't result in a failure if the insertion doesn't succeed.
-- It just saves you from ignoring the return code returned from 'insert' manually
-- (or catching and ignoring the exception in the case of IO)
insert_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
          => TTLHashTable h k v
          -> k
          -> v
          -> m ()
insert_ h k = void . runMaybeT . insert h k

-- | like 'insert' but an entry specific TTL in milliseconds can be provided.
insertWithTTL :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
          => TTLHashTable h k v
          -> Int
          -> k
          -> v
          -> m ()
insertWithTTL ht@TTLHashTable {..} ttl k v = do
  numEntries <- liftIO $ readIORef numEntriesRef_
  now        <- getTimeStamp
  let expiresAt = now + ttl * 1000000 -- to nanoseconds
  maxSize <- liftIO $ readIORef maxSizeRef_
  if numEntries < maxSize
    then insert' expiresAt
    else do
      madeSpace <- checkOldest ht now
      maybe (failure HashTableFull) (const $ insert' expiresAt) madeSpace
          where insert' expiresAt = do
                  let value = Value expiresAt ttl v
                  liftIO $ do
                    H.insert hashTable_ k value
                    modifyIORef' numEntriesRef_ (+1)
                    modifyIORef' timeStampsRef_ $ M.insert expiresAt k

-- | like 'insertWithTTL' but ignores insertion failure
insertWithTTL_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
          => TTLHashTable h k v
          -> Int
          -> k
          -> v
          -> m ()
insertWithTTL_ h ttl k = void . runMaybeT . insertWithTTL h ttl k


checkOldest :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> Int -> m (Maybe ())
checkOldest ht@TTLHashTable {..} now =
    liftIO . runMaybeT $ do
      (timeStamp, k) <- MaybeT . atomicModifyIORef' timeStampsRef_ $ \timeStamps ->
                         case M.minViewWithKey timeStamps of
                           Nothing -> (timeStamps, Nothing)
                           Just ((timeStamp, k), timeStamps') ->
                             if timeStamp <= now
                               then (timeStamps', Just (timeStamp, k))
                               else (timeStamps, Nothing)
      MaybeT $ mutateWith (deleteExpired timeStamp) ht k

-- | Lookup a key in the hash table. If called straight in the IO monad it would throw a
-- 'NotFound' exception, but if called under @MaybeT IO@ or @ExceptT SomeException IO@ it would
-- return @IO Nothing@ or @IO (Left NotFound)@ respectively. So you probably want to
-- __execute this function in one of these transformer monads__
lookup :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> k -> m v
lookup ht@TTLHashTable {..} k = do
  renewUponRead <- liftIO $ readIORef renewUponReadRef_
  lookup' ht renewUponRead k

lookup' :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> Bool -> k -> m v
lookup' ht@TTLHashTable {..} False k = do
        now        <- getTimeStamp
        mValue     <- liftIO $ H.lookup hashTable_ k
        Value {..} <- checkLookedUp ht k mValue now
        return value
lookup' ht@TTLHashTable {..} True k = do
        now               <- getTimeStamp
        (mExpire, mValue) <- mutateWith (refreshEntry now) ht k
        removeTimeStamp mExpire
        Value {..}        <- checkLookedUp ht k mValue now
        liftIO $ modifyIORef' timeStampsRef_ $ M.insert expiresAt k
        return value
    where refreshEntry _ Nothing =
              ((Nothing, 0), (Nothing, Nothing))
          refreshEntry now (Just v@Value{..}) =
              if expiresAt > now then
                  let v' = Value { expiresAt = now + ttl * 1000000, -- nanoseconds
                                   ttl = ttl,
                                   value = value }
                  in ((Just v', 0), (Just expiresAt, Just v'))
              else
                  ((Nothing, 1), (Nothing, Just v))
          removeTimeStamp Nothing =
              return ()
          removeTimeStamp (Just timeStamp) =
              liftIO . modifyIORef' timeStampsRef_ $ M.delete timeStamp

checkLookedUp :: (Eq k, Hashable k, MonadIO m, C.HashTable h, Failable m)
                 => TTLHashTable h k v
                 -> k
                 -> Maybe (Value v)
                 -> Int
                 -> m (Value v)
checkLookedUp _ _ Nothing _               = failure NotFound
checkLookedUp ht k (Just v@Value {..}) now =
  if expiresAt < now
    then do
      delete ht k
      failure ExpiredEntry
    else
      return v

-- | A lookup function which simply returns 'Maybe' wrapped in the calling 'MonadIO'
-- context, to accomodate the more conventional users
find :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> k -> m (Maybe v)
find ht@TTLHashTable {..} k = do
  runMaybeT $ lookup ht k

-- | delete an entry from the hash table.
delete :: (C.HashTable h, Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> k -> m ()
delete = mutateWith simpleDeletion
    where simpleDeletion Nothing           = ((Nothing, 0), ())
          simpleDeletion (Just Value {..}) = ((Nothing, 1), ())

deleteExpired :: Int -> Maybe (Value v) -> ((Maybe (Value v), Int), Maybe ())
deleteExpired _ Nothing =
  ((Nothing, 0), Nothing)
deleteExpired timeStamp (Just v@Value {..}) =
  if expiresAt == timeStamp
    then ((Nothing, 1), Just ())
    else ((Just v, 0), Nothing)

mutateWith :: (Eq k, Hashable k, MonadIO m)
              => (Maybe (Value v) -> ((Maybe (Value v), Int), a))
              -> TTLHashTable h k v
              -> k
              -> m a
mutateWith mutator TTLHashTable {..} k =
    liftIO $ do
      (n, result) <- H.mutate hashTable_ k mutate'
      modifyIORef' numEntriesRef_ $ flip (-) n
      return result
          where mutate' mValue =
                  let ((mValue', n), result) = mutator mValue
                  in (mValue', (n, result))

-- | Report the current number of entries in the table, including those who have expired but
-- haven't been garbage collected yet
size :: (MonadIO m) => TTLHashTable h k v -> m Int
size TTLHashTable {..} = liftIO $ readIORef numEntriesRef_

-- | Run garbage collection of expired entries in the table. It returns the number of expired
-- entries left yet to be removed from the table, if the 'gcMaxEntries' limit was reached before
-- finishing cleaning up all old entries. Note that this function as well as all other operations
-- in a hash table are __not__ thread safe. If concurrent threads need to operate on the table,
-- some concurrency primitive must be used to guarantee exclusive access.
removeExpired :: (MonadIO m, Eq k, Hashable k) => TTLHashTable h k v -> m Int
removeExpired ht@TTLHashTable {..} = do
  gcMaxEntries <- liftIO $ readIORef gcMaxEntriesRef_
  removeExpired' ht gcMaxEntries

removeExpired' :: (MonadIO m, Eq k, Hashable k) => TTLHashTable h k v -> Int -> m Int
removeExpired' ht@TTLHashTable {..} gcMaxEntries =
  liftIO $ do
    now          <- getTimeStamp
    (n, expired) <- atomicModifyIORef' timeStampsRef_ $ selectedEntries now
    Prelude.mapM_ remove expired
    return n
        where remove (timeStamp, k) = mutateWith (deleteExpired timeStamp) ht k
              selectedEntries now m =
                  let (old, active)      = M.split now m
                      (selected, notYet) = splitAt gcMaxEntries $ M.toList old
                      (n, toReinsert)    = foldl countFromList (0, M.empty) notYet
                  in (M.union active toReinsert, (n, selected))
              countFromList (n, acc) (k, v) = (n + 1, M.insert k v acc)

-- | Returns a timestamp value in nanoseconds
getTimeStamp :: (MonadIO m) => m Int
getTimeStamp = do
  (TimeSpec secs ns) <- liftIO $ getTime Monotonic
  return . fromIntegral $ (secs * 1000000000 + ns)

-- | A strict fold in IO over the @(key, value)@ records in a hash table
foldM :: (MonadIO m) => (a -> (k, v) -> IO a) -> a -> TTLHashTable h k v -> m a
foldM f x TTLHashTable {..} = liftIO . H.foldM f' x $ hashTable_
    where f' acc (k, Value {..}) = f acc (k, value)

-- | A side-effecting map over the @(key, value)@ records in a hash table
mapM_ :: (MonadIO m) => ((k, v) -> IO a) -> TTLHashTable h k v -> m ()
mapM_ f TTLHashTable {..} = liftIO $ H.mapM_ f' hashTable_
    where f' (k, Value {..}) = f (k, value)

-- | Provide a new set of settings for a given hash table
reconfigure :: (MonadIO m, Failable m) => TTLHashTable h k v -> Settings -> m ()
reconfigure TTLHashTable {..} Settings {..} = do
  numEntries <- liftIO $ readIORef numEntriesRef_
  when (numEntries > maxSize) $ failure HashTableTooLarge
  liftIO $ do
    writeIORef maxSizeRef_ maxSize
    writeIORef renewUponReadRef_ renewUponRead
    writeIORef defaultTTLRef_ defaultTTL
    writeIORef gcMaxEntriesRef_ gcMaxEntries

getSettings :: (MonadIO m) => TTLHashTable h k v -> m Settings
getSettings TTLHashTable {..} = liftIO $ do
  maxSize       <- readIORef maxSizeRef_
  renewUponRead <- readIORef renewUponReadRef_
  defaultTTL    <- readIORef defaultTTLRef_
  gcMaxEntries  <- readIORef gcMaxEntriesRef_
  return Settings { maxSize = maxSize,
                    renewUponRead = renewUponRead,
                    defaultTTL = defaultTTL,
                    gcMaxEntries = gcMaxEntries }