{-# LANGUAGE GADTs, RecordWildCards #-}
{- |
Module: Data.TTLHashTable
Description: Adds TTL entry expiration to the excellent mutable hash tables from the hashtables package
Copyright: (c) Erick Gonzalez, 2019
License: BSD3
Maintainer: erick@codemonkeylabs.de

This library extends fast mutable hashtables so that entries added can be expired after a given TTL (time to live). This TTL can be specified as a default property of the table or on a per entry basis.

-}
module Data.TTLHashTable (
-- * How to use this module:
-- |
-- Import one of the hash table modules from the hashtables package.. i.e. Basic, Cuckoo, etc
-- and "wrap" them in a TTLHashTable:
--
-- @
-- import Data.HashTable.ST.Basic as Basic
--
-- type HashTable k v = TTLHashTable Basic.HashTable k v
--
-- @
--
-- You can then use the functions in this module with this hashtable type. Note that the
-- functions in this module which can fail offer a flexible error handling strategy by virtue of
-- working in the context of a 'Failable' monad. So for example, if the function is used directly
-- in the IO monad and a failure occurs it would then result in an exception being thrown. However
-- if the context supports the possibiliy of failure like a 'MaybeT' or 'ExceptT'
-- transformer, it would then instead return something like @IO Nothing@ or @Left NotFound@
-- respectively (depending on the actual failure of course).
--
-- None of the functions in this module are thread safe, just as the underlying mutable
-- hash tables in the ST monad aren't as well. If concurrent threads need to operate on the same
-- table, you need to provide external means of synchronization to guarantee exclusive access
-- to the table
                          TTLHashTable,
                          TTLHashTableError(..),
                          Settings(..),
                          insert,
                          insert_,
                          insertWithTTL,
                          insertWithTTL_,
                          delete,
                          find,
                          lookup,
                          new,
                          newWithSettings,
                          removeExpired,
                          size) where

import Prelude                 hiding (lookup)
import Control.Exception              (Exception)
import Control.Monad                  (void)
import Control.Monad.Trans.Class      (lift)
import Control.Monad.Failable         (Failable, failure)
import Control.Monad.IO.Class         (MonadIO, liftIO)
import Control.Monad.Trans.Maybe      (MaybeT(..), runMaybeT)
import Data.Default                   (Default, def)
import Data.Hashable                  (Hashable)
import Data.IntMap.Strict             (IntMap)
import Data.IORef                     (IORef,
                                       atomicModifyIORef',
                                       modifyIORef',
                                       newIORef,
                                       readIORef)
import Data.Tuple                     (swap)
import Data.Typeable                  (Typeable)
import System.Clock                   (Clock(Monotonic), TimeSpec(..), getTime)
import System.Mem.Weak                (Weak, deRefWeak, mkWeak)

import qualified Data.HashTable.Class as C
import qualified Data.HashTable.IO    as H
import qualified Data.IntMap.Strict   as M

-- | The TTL hash table type, parameterized on the type of table, key and value.
data TTLHashTable h k v where
    TTLHashTable :: (C.HashTable h)
                    => { hashTable_     :: H.IOHashTable h k (Value v),
                        maxSize_       :: Int,
                        numEntriesRef_ :: IORef Int,
                        timeStampsRef_ :: IORef (IntMap (Weak k)),
                        renewUponRead_ :: Bool,
                        defaultTTL_    :: Int }
                    -> TTLHashTable h k v

data Value v = Value { expiresAt :: Int,
                       ttl       :: Int,
                       value     :: v }

-- | The 'Settings' type allows for specifying how the hash table should behave.
data Settings = Settings {
                           -- | Maximum size of the hash table. Once reached, insertion of keys
                           -- will fail. Defaults to 'maxBound'
                           maxSize       :: Int,
                           -- | Whether a succesful lookup of an entry means the TTL of the entry
                           -- should be restarted. Default is 'False'
                           renewUponRead :: Bool,
                           -- | Default TTL value to be used for an entry if none is specified
                           -- at insertion time
                           defaultTTL    :: Int }

-- | Exception type used to report failures (depending on calling context)
data TTLHashTableError = NotFound      -- ^ The entry was not found in the table
                       | ExpiredEntry  -- ^ The entry did exist but is no longer valid
                       | HashTableFull -- ^ The maximum size for the table has been reached
                         deriving (Eq, Typeable, Show)

instance Exception TTLHashTableError

instance Default Settings where
    def = Settings { maxSize       = maxBound,
                     renewUponRead = False,
                     defaultTTL    = 365 * 24 * 60 * 60 * 1000 -- 1 year in milliseconds
                   }

-- | Creates a new hash table with default settings
new :: (C.HashTable h, MonadIO m) => m (TTLHashTable h k v)
new = newWithSettings def

-- | Creates a new hash table with the specified settings. Use the 'Default' instance of 'Settings'
-- and then fine tune parameters as needed. I.e:
-- @
-- newWithSettings def { maxSize = 64 }
-- @
newWithSettings :: (C.HashTable h, MonadIO m) => Settings -> m (TTLHashTable h k v)
newWithSettings Settings {..} =
    liftIO $ do
      table <- newHT
      sRef  <- newIORef 0
      tRef  <- newIORef M.empty
      return TTLHashTable { hashTable_     = table,
                            maxSize_       = maxSize,
                            numEntriesRef_ = sRef,
                            timeStampsRef_ = tRef,
                            renewUponRead_ = renewUponRead,
                            defaultTTL_    = defaultTTL
                          }
          where newHT | maxSize == maxBound = H.new
                      | otherwise           = H.newSized maxSize

-- | Insert a new entry into the hash table. Take note of the fact that __this function can fail__
-- for example if table has reached maxSize entries for example. Failure is signaled depending on
-- the calling 'Failable' context. So for example if called in pure IO, it would throw a regular
-- IO exception (of type 'TTLHashTableError'). For this reason,
-- __you probably  want to call this function in a 'MaybeT' or 'ExceptT' monad__
insert :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
          => TTLHashTable h k v
          -> k
          -> v
          -> m ()
insert ht@TTLHashTable {..} = insertWithTTL ht defaultTTL_

-- | Just like 'insert' but doesn't result in a failure if the insertion doesn't succeed.
-- It just saves you from ignoring the return code returned from 'insert' manually
-- (or catching and ignoring the exception in the case of IO)
insert_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
          => TTLHashTable h k v
          -> k
          -> v
          -> m ()
insert_ h k = void . runMaybeT . insert h k

-- | like 'insert' but an entry specific TTL can be provided
insertWithTTL :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
          => TTLHashTable h k v
          -> Int
          -> k
          -> v
          -> m ()
insertWithTTL ht@TTLHashTable {..} ttl k v = do
  numEntries <- liftIO $ readIORef numEntriesRef_
  now        <- getTimeStamp
  let expiresAt = now + ttl
  if numEntries < maxSize_
    then insert' expiresAt
    else do
      madeSpace <- checkOldest ht now
      maybe (failure HashTableFull) (const $ insert' expiresAt) madeSpace
          where insert' expiresAt = do
                  let value = Value expiresAt ttl v
                  liftIO $ do
                    H.insert hashTable_ k value
                    modifyIORef' numEntriesRef_ (+1)
                    wPtr <- mkWeak value k . Just . modifyIORef' timeStampsRef_ $ M.delete expiresAt
                    modifyIORef' timeStampsRef_ $ M.insert expiresAt wPtr

-- | like 'insertWithTTL' but ignores insertion failure
insertWithTTL_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
          => TTLHashTable h k v
          -> Int
          -> k
          -> v
          -> m ()
insertWithTTL_ h ttl k = void . runMaybeT . insertWithTTL h ttl k


checkOldest :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> Int -> m (Maybe ())
checkOldest ht@TTLHashTable {..} now =
    liftIO . runMaybeT $ do
      wPtr <- MaybeT . atomicModifyIORef' timeStampsRef_ $ \timeStamps ->
               case M.minViewWithKey timeStamps of
                 Nothing -> (timeStamps, Nothing)
                 Just ((timeStamp, wPtr), timeStamps') ->
                   if timeStamp <= now
                     then (timeStamps', Just wPtr)
                     else (timeStamps, Nothing)
      k <- MaybeT $ deRefWeak wPtr
      lift $ delete ht k

-- | Lookup a key in the hash table. If called straight in the IO monad it would throw a
-- 'NotFound' exception, but if called under @MaybeT IO@ or @ExceptT SomeException IO@ it would
-- return @IO Nothing@ or @IO (Left NotFound)@ respectively. So you probably want to
-- __execute this function in one of these transformer monads__
lookup :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> k -> m v
lookup ht@TTLHashTable {..} k
    | not renewUponRead_ = do
        now        <- getTimeStamp
        mValue     <- liftIO $ H.lookup hashTable_ k
        Value {..} <- checkLookedUp ht k mValue now
        return value
    | otherwise = do
        now          <- getTimeStamp
        mValue       <- liftIO $ H.mutate hashTable_ k $ refreshEntry now
        v@Value {..} <- checkLookedUp ht k mValue now
        liftIO $ do
          wPtr <- mkWeak v k $ Just . modifyIORef' timeStampsRef_ $ M.delete expiresAt
          modifyIORef' timeStampsRef_ $ M.insert expiresAt wPtr
          return value
    where refreshEntry _ Nothing =
              (Nothing, Nothing)
          refreshEntry now (Just v@Value{..}) =
              if expiresAt > now then
                  let v' = Value { expiresAt = now + ttl, ttl = ttl, value = value }
                  in (Just v', Just v')
              else (Nothing, Just v)

checkLookedUp :: (Eq k, Hashable k, MonadIO m, C.HashTable h, Failable m)
                 => TTLHashTable h k v
                 -> k
                 -> Maybe (Value v)
                 -> Int
                 -> m (Value v)
checkLookedUp _ _ Nothing _               = failure NotFound
checkLookedUp ht k (Just v@Value {..}) now =
  if expiresAt < now
    then do
      delete ht k
      failure ExpiredEntry
    else
      return v

-- | A lookup function which simply returns 'Maybe' wrapped in the calling 'MonadIO'
-- context, to accomodate the more conventional users
find :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> k -> m (Maybe v)
find ht@TTLHashTable {..} k = do
  runMaybeT $ lookup ht k

-- | delete an entry from the hash table.
delete :: (C.HashTable h, Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> k -> m ()
delete TTLHashTable {..} k =
    liftIO $ do
      n <- H.mutate hashTable_ k delete'
      modifyIORef' numEntriesRef_ $ (flip (-) n)
    where delete' Nothing           = (Nothing, 0)
          delete' (Just Value {..}) = (Nothing, 1)

-- | Report the current number of entries in the table, including those who have expired but
-- haven't been garbage collected yet
size :: (MonadIO m) => TTLHashTable h k v -> m Int
size TTLHashTable {..} = liftIO $ readIORef numEntriesRef_

-- | Run garbage collection of expired entries in the table. Note that this function as well
-- as all other operations in a hash table are __not__ thread safe. If concurrent threads
-- need to operate on the table, some concurrency primitive must be used to guarantee exclusive
-- access.
removeExpired :: (MonadIO m, Eq k, Hashable k) => TTLHashTable h k v -> m ()
removeExpired ht@TTLHashTable {..} =
  liftIO $ do
    now     <- getTimeStamp
    expired <- atomicModifyIORef' timeStampsRef_ $ swap . M.split now
    mapM_ remove expired
        where remove wPtr = runMaybeT $ do
                k <- MaybeT $ deRefWeak wPtr
                lift $ delete ht k

-- | Returns a timestamp value in milliseconds
getTimeStamp :: (MonadIO m) => m Int
getTimeStamp = do
  (TimeSpec secs ns) <- liftIO $ getTime Monotonic
  return . fromIntegral $ (secs * 1000000000 + ns) `div` 1000000