{-# LANGUAGE GADTs, RecordWildCards #-}
module Data.TTLHashTable (
TTLHashTable,
TTLHashTableError(..),
Settings(..),
insert,
insert_,
insertWithTTL,
insertWithTTL_,
delete,
find,
lookup,
new,
newWithSettings,
removeExpired,
size) where
import Prelude hiding (lookup)
import Control.Exception (Exception)
import Control.Monad (void)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Failable (Failable, failure)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Maybe (MaybeT(..), runMaybeT)
import Data.Default (Default, def)
import Data.Hashable (Hashable)
import Data.IntMap.Strict (IntMap)
import Data.IORef (IORef,
atomicModifyIORef',
modifyIORef',
newIORef,
readIORef)
import Data.Tuple (swap)
import Data.Typeable (Typeable)
import System.Clock (Clock(Monotonic), TimeSpec(..), getTime)
import System.Mem.Weak (Weak, deRefWeak, mkWeak)
import qualified Data.HashTable.Class as C
import qualified Data.HashTable.IO as H
import qualified Data.IntMap.Strict as M
data TTLHashTable h k v where
TTLHashTable :: (C.HashTable h)
=> { hashTable_ :: H.IOHashTable h k (Value v),
maxSize_ :: Int,
numEntriesRef_ :: IORef Int,
timeStampsRef_ :: IORef (IntMap (Weak k)),
renewUponRead_ :: Bool,
defaultTTL_ :: Int }
-> TTLHashTable h k v
data Value v = Value { expiresAt :: Int,
ttl :: Int,
value :: v }
data Settings = Settings {
maxSize :: Int,
renewUponRead :: Bool,
defaultTTL :: Int }
data TTLHashTableError = NotFound
| ExpiredEntry
| HashTableFull
deriving (Eq, Typeable, Show)
instance Exception TTLHashTableError
instance Default Settings where
def = Settings { maxSize = maxBound,
renewUponRead = False,
defaultTTL = 365 * 24 * 60 * 60 * 1000
}
new :: (C.HashTable h, MonadIO m) => m (TTLHashTable h k v)
new = newWithSettings def
newWithSettings :: (C.HashTable h, MonadIO m) => Settings -> m (TTLHashTable h k v)
newWithSettings Settings {..} =
liftIO $ do
table <- newHT
sRef <- newIORef 0
tRef <- newIORef M.empty
return TTLHashTable { hashTable_ = table,
maxSize_ = maxSize,
numEntriesRef_ = sRef,
timeStampsRef_ = tRef,
renewUponRead_ = renewUponRead,
defaultTTL_ = defaultTTL
}
where newHT | maxSize == maxBound = H.new
| otherwise = H.newSized maxSize
insert :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
=> TTLHashTable h k v
-> k
-> v
-> m ()
insert ht@TTLHashTable {..} = insertWithTTL ht defaultTTL_
insert_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
=> TTLHashTable h k v
-> k
-> v
-> m ()
insert_ h k = void . runMaybeT . insert h k
insertWithTTL :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
=> TTLHashTable h k v
-> Int
-> k
-> v
-> m ()
insertWithTTL ht@TTLHashTable {..} ttl k v = do
numEntries <- liftIO $ readIORef numEntriesRef_
now <- getTimeStamp
let expiresAt = now + ttl
if numEntries < maxSize_
then insert' expiresAt
else do
madeSpace <- checkOldest ht now
maybe (failure HashTableFull) (const $ insert' expiresAt) madeSpace
where insert' expiresAt = do
let value = Value expiresAt ttl v
liftIO $ do
H.insert hashTable_ k value
modifyIORef' numEntriesRef_ (+1)
wPtr <- mkWeak value k . Just . modifyIORef' timeStampsRef_ $ M.delete expiresAt
modifyIORef' timeStampsRef_ $ M.insert expiresAt wPtr
insertWithTTL_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
=> TTLHashTable h k v
-> Int
-> k
-> v
-> m ()
insertWithTTL_ h ttl k = void . runMaybeT . insertWithTTL h ttl k
checkOldest :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> Int -> m (Maybe ())
checkOldest ht@TTLHashTable {..} now =
liftIO . runMaybeT $ do
wPtr <- MaybeT . atomicModifyIORef' timeStampsRef_ $ \timeStamps ->
case M.minViewWithKey timeStamps of
Nothing -> (timeStamps, Nothing)
Just ((timeStamp, wPtr), timeStamps') ->
if timeStamp <= now
then (timeStamps', Just wPtr)
else (timeStamps, Nothing)
k <- MaybeT $ deRefWeak wPtr
lift $ delete ht k
lookup :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> k -> m v
lookup ht@TTLHashTable {..} k
| not renewUponRead_ = do
now <- getTimeStamp
mValue <- liftIO $ H.lookup hashTable_ k
Value {..} <- checkLookedUp ht k mValue now
return value
| otherwise = do
now <- getTimeStamp
mValue <- liftIO $ H.mutate hashTable_ k $ refreshEntry now
v@Value {..} <- checkLookedUp ht k mValue now
liftIO $ do
wPtr <- mkWeak v k $ Just . modifyIORef' timeStampsRef_ $ M.delete expiresAt
modifyIORef' timeStampsRef_ $ M.insert expiresAt wPtr
return value
where refreshEntry _ Nothing =
(Nothing, Nothing)
refreshEntry now (Just v@Value{..}) =
if expiresAt > now then
let v' = Value { expiresAt = now + ttl, ttl = ttl, value = value }
in (Just v', Just v')
else (Nothing, Just v)
checkLookedUp :: (Eq k, Hashable k, MonadIO m, C.HashTable h, Failable m)
=> TTLHashTable h k v
-> k
-> Maybe (Value v)
-> Int
-> m (Value v)
checkLookedUp _ _ Nothing _ = failure NotFound
checkLookedUp ht k (Just v@Value {..}) now =
if expiresAt < now
then do
delete ht k
failure ExpiredEntry
else
return v
find :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> k -> m (Maybe v)
find ht@TTLHashTable {..} k = do
runMaybeT $ lookup ht k
delete :: (C.HashTable h, Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> k -> m ()
delete TTLHashTable {..} k =
liftIO $ do
n <- H.mutate hashTable_ k delete'
modifyIORef' numEntriesRef_ $ (flip (-) n)
where delete' Nothing = (Nothing, 0)
delete' (Just Value {..}) = (Nothing, 1)
size :: (MonadIO m) => TTLHashTable h k v -> m Int
size TTLHashTable {..} = liftIO $ readIORef numEntriesRef_
removeExpired :: (MonadIO m, Eq k, Hashable k) => TTLHashTable h k v -> m ()
removeExpired ht@TTLHashTable {..} =
liftIO $ do
now <- getTimeStamp
expired <- atomicModifyIORef' timeStampsRef_ $ swap . M.split now
mapM_ remove expired
where remove wPtr = runMaybeT $ do
k <- MaybeT $ deRefWeak wPtr
lift $ delete ht k
getTimeStamp :: (MonadIO m) => m Int
getTimeStamp = do
(TimeSpec secs ns) <- liftIO $ getTime Monotonic
return . fromIntegral $ (secs * 1000000000 + ns) `div` 1000000