{-# LANGUAGE GADTs, GeneralizedNewtypeDeriving, RecordWildCards #-}
module Data.TTLHashTable (
TimeStamp,
TTLHashTable,
TTLHashTableError(..),
Settings(..),
insert,
insert_,
insertWithTTL,
insertWithTTL_,
delete,
find,
foldM,
getSettings,
getTimeStamp,
Data.TTLHashTable.mapM_,
lookup,
lookupAndRenew,
lookupMaybeExpired,
mutate,
new,
newWithSettings,
reconfigure,
removeExpired,
size) where
import Prelude hiding (lookup)
import Control.Exception (Exception)
import Control.Monad (void, forM_, when)
import Control.Monad.Except (runExcept, throwError)
import Control.Monad.Failable (Failable, failure)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Maybe (MaybeT(..), runMaybeT)
import Data.Bifunctor (first)
import Data.Bits (finiteBitSize)
import Data.Default (Default, def)
import Data.Hashable (Hashable)
import Data.IntMap.Strict (IntMap)
import Data.IORef (IORef,
atomicModifyIORef',
modifyIORef',
newIORef,
readIORef,
writeIORef)
import Data.Typeable (Typeable)
import System.Clock (Clock(..), TimeSpec(..), getTime)
import qualified Data.HashTable.Class as C
import qualified Data.HashTable.IO as H
import qualified Data.IntMap.Strict as M
data TTLHashTable h k v where
TTLHashTable :: (C.HashTable h)
=> { hashTable_ :: H.IOHashTable h k (Value v),
maxSizeRef_ :: IORef Int,
numEntriesRef_ :: IORef Int,
timeStampsRef_ :: IORef (IntMap k),
defaultTTLRef_ :: IORef Int,
gcMaxEntriesRef_ :: IORef Int }
-> TTLHashTable h k v
data Value v = Value { expiresAt :: TimeStamp,
ttl :: Int,
value :: v }
newtype TimeStamp = TimeStamp Int deriving (Num, Integral, Real, Enum, Eq, Ord)
data Settings = Settings {
maxSize :: Int,
defaultTTL :: Int,
gcMaxEntries :: Int }
data TTLHashTableError =
NotFound
| ExpiredEntry
| HashTableFull
| UnsupportedPlatform String
| HashTableTooLarge
deriving (Eq, Typeable, Show)
instance Exception TTLHashTableError
instance Default Settings where
def = Settings { maxSize = maxBound,
defaultTTL = 24 * 60 * 60 * 1000,
gcMaxEntries = maxBound
}
assertIntSize :: (Failable m) => m ()
assertIntSize =
when (finiteBitSize maxInt < 64) $
failure $ UnsupportedPlatform "Int size on this platform is < 64 bits"
where maxInt = maxBound :: Int
new :: (C.HashTable h, MonadIO m, Failable m) => m (TTLHashTable h k v)
new = newWithSettings def
newWithSettings :: (C.HashTable h, MonadIO m, Failable m) => Settings -> m (TTLHashTable h k v)
newWithSettings Settings {..} = do
assertIntSize
liftIO $ do
table <- newHT
sRef <- newIORef 0
tRef <- newIORef M.empty
msRef <- newIORef maxSize
dTTLRef <- newIORef defaultTTL
meRef <- newIORef gcMaxEntries
return TTLHashTable { hashTable_ = table,
maxSizeRef_ = msRef,
numEntriesRef_ = sRef,
timeStampsRef_ = tRef,
defaultTTLRef_ = dTTLRef,
gcMaxEntriesRef_ = meRef
}
where newHT | maxSize == maxBound = H.new
| otherwise = H.newSized maxSize
insert :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
=> TTLHashTable h k v
-> k
-> v
-> m ()
insert ht@TTLHashTable {..} k v = do
ttl <- liftIO $ readIORef defaultTTLRef_
insertWithTTL ht ttl k v
insert_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
=> TTLHashTable h k v
-> k
-> v
-> m ()
insert_ h k = void . runMaybeT . insert h k
insertWithTTL :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
=> TTLHashTable h k v
-> Int
-> k
-> v
-> m ()
insertWithTTL ht@TTLHashTable {..} ttl k v = do
numEntries <- liftIO $ readIORef numEntriesRef_
now <- getTimeStamp
let expiresAt = now + fromIntegral (ttl * 1000000)
maxSize <- liftIO $ readIORef maxSizeRef_
if numEntries < maxSize
then insert' expiresAt
else do
madeSpace <- checkOldest ht now
maybe (failure HashTableFull) (const $ insert' expiresAt) madeSpace
where insert' expiresAt = do
let value = Value expiresAt ttl v
liftIO $ do
H.insert hashTable_ k value
modifyIORef' numEntriesRef_ (+1)
modifyIORef' timeStampsRef_ $ M.insert (fromIntegral expiresAt) k
insertWithTTL_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
=> TTLHashTable h k v
-> Int
-> k
-> v
-> m ()
insertWithTTL_ h ttl k = void . runMaybeT . insertWithTTL h ttl k
checkOldest :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> TimeStamp -> m (Maybe ())
checkOldest ht@TTLHashTable {..} now =
liftIO . runMaybeT $ do
(timeStamp, k) <- MaybeT . atomicModifyIORef' timeStampsRef_ $ \timeStamps ->
case M.minViewWithKey timeStamps of
Nothing -> (timeStamps, Nothing)
Just ((timeStamp, k), timeStamps') ->
if (fromIntegral timeStamp) <= now
then (timeStamps', Just (timeStamp, k))
else (timeStamps, Nothing)
MaybeT $ mutateWith (deleteExpired $ fromIntegral timeStamp) ht k
lookup :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> k -> m v
lookup = lookup' False
lookupAndRenew :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> k -> m v
lookupAndRenew = lookup' True
lookupMaybeExpired :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> k -> m (v, TimeStamp)
lookupMaybeExpired TTLHashTable {..} k =
maybe (failure NotFound) returnValue =<< liftIO (H.lookup hashTable_ k)
where returnValue Value {..} = return (value, expiresAt)
lookup' :: (Eq k, Hashable k, MonadIO m, Failable m) => Bool -> TTLHashTable h k v -> k -> m v
lookup' False TTLHashTable {..} k = do
now <- getTimeStamp
mValue <- liftIO $ H.lookup hashTable_ k
Value {..} <- checkLookedUp mValue now
return value
lookup' True ht@TTLHashTable {..} k = do
now <- getTimeStamp
(mExpire, mValue) <- mutateWith (refreshEntry now) ht k
removeTimeStamp mExpire
Value {..} <- checkLookedUp mValue now
liftIO $ modifyIORef' timeStampsRef_ $ M.insert (fromIntegral expiresAt) k
return value
where refreshEntry _ Nothing =
(Nothing, (Nothing, Nothing))
refreshEntry now (Just v@Value{..}) =
if expiresAt > now then
let v' = Value { expiresAt = now + (fromIntegral ttl * 1000000),
ttl = ttl,
value = value }
in (Just v', (Just expiresAt, Just v'))
else
(Nothing, (Nothing, Just v))
removeTimeStamp Nothing =
return ()
removeTimeStamp (Just timeStamp) =
liftIO . modifyIORef' timeStampsRef_ $ M.delete (fromIntegral timeStamp)
checkLookedUp :: (MonadIO m, Failable m)
=> Maybe (Value v)
-> TimeStamp
-> m (Value v)
checkLookedUp Nothing _ = failure NotFound
checkLookedUp (Just v@Value {..}) now =
if expiresAt < now
then failure ExpiredEntry
else return v
find :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> k -> m (Maybe v)
find ht@TTLHashTable {..} k =
runMaybeT $ lookup ht k
delete :: (C.HashTable h, Eq k, Hashable k, MonadIO m, Failable m) =>
TTLHashTable h k v -> k -> m ()
delete ht@TTLHashTable {..} k = do
timeStamp <- deleteWith $ \v ->
(Nothing, v >>= return . expiresAt)
forM_ timeStamp $ liftIO . modifyIORef' timeStampsRef_ . M.delete . fromIntegral
where deleteWith fun = mutateWith fun ht k
deleteExpired :: TimeStamp -> Maybe (Value v) -> (Maybe (Value v), Maybe ())
deleteExpired _ Nothing =
(Nothing, Nothing)
deleteExpired timeStamp (Just v@Value {..}) =
if expiresAt == timeStamp
then (Nothing, Just ())
else (Just v, Nothing)
mutateWith :: (Eq k, Hashable k, MonadIO m, Failable m)
=> (Maybe (Value v) -> (Maybe (Value v), a))
-> TTLHashTable h k v
-> k
-> m a
mutateWith mutator TTLHashTable {..} k = do
iResult <- liftIO $ do
numEntries <- readIORef numEntriesRef_
maxSize <- readIORef maxSizeRef_
H.mutate hashTable_ k $ mutate_ numEntries maxSize
flip (either failure) iResult $ \(n, result) -> do
liftIO $ modifyIORef' numEntriesRef_ $ (+) n
return result
where mutate_ numEntries maxSize mValue =
let (mValue', result) = mutator mValue
result' = runExcept $ do
n <- case (mValue, mValue') of
(Nothing, Just _) ->
if numEntries < maxSize
then return 1
else throwError HashTableFull
(Just _, Nothing) -> return (-1)
_ -> return 0
return (n, result)
in either (\_ -> (mValue, result')) (\_ -> (mValue', result')) result'
size :: (MonadIO m) => TTLHashTable h k v -> m Int
size TTLHashTable {..} = liftIO $ readIORef numEntriesRef_
removeExpired :: (MonadIO m, Eq k, Hashable k) => TTLHashTable h k v -> m Int
removeExpired ht@TTLHashTable {..} = do
gcMaxEntries <- liftIO $ readIORef gcMaxEntriesRef_
removeExpired' ht gcMaxEntries
removeExpired' :: (MonadIO m, Eq k, Hashable k) => TTLHashTable h k v -> Int -> m Int
removeExpired' ht@TTLHashTable {..} gcMaxEntries =
liftIO $ do
now <- getTimeStamp
(n, expired) <- atomicModifyIORef' timeStampsRef_ $ selectedEntries (fromIntegral now)
Prelude.mapM_ remove expired
return n
where remove (timeStamp, k) = mutateWith (deleteExpired $ fromIntegral timeStamp) ht k
selectedEntries now m =
let (old, active) = M.split now m
(selected, notYet) = splitAt gcMaxEntries $ M.toList old
(n, toReinsert) = foldl countFromList (0, M.empty) notYet
in (M.union active toReinsert, (n, selected))
countFromList (n, acc) (k, v) = (n + 1, M.insert k v acc)
getTimeStamp :: (MonadIO m) => m TimeStamp
getTimeStamp = do
(TimeSpec secs ns) <- liftIO $ getTime Monotonic
return . fromIntegral $ secs * 1000000000 + ns
foldM :: (MonadIO m) => (a -> (k, v) -> IO a) -> a -> TTLHashTable h k v -> m a
foldM f x TTLHashTable {..} = liftIO . H.foldM f' x $ hashTable_
where f' acc (k, Value {..}) = f acc (k, value)
mapM_ :: (MonadIO m) => ((k, v) -> IO a) -> TTLHashTable h k v -> m ()
mapM_ f TTLHashTable {..} = liftIO $ H.mapM_ f' hashTable_
where f' (k, Value {..}) = f (k, value)
reconfigure :: (MonadIO m, Failable m) => TTLHashTable h k v -> Settings -> m ()
reconfigure TTLHashTable {..} Settings {..} = do
numEntries <- liftIO $ readIORef numEntriesRef_
when (numEntries > maxSize) $ failure HashTableTooLarge
liftIO $ do
writeIORef maxSizeRef_ maxSize
writeIORef defaultTTLRef_ defaultTTL
writeIORef gcMaxEntriesRef_ gcMaxEntries
getSettings :: (MonadIO m) => TTLHashTable h k v -> m Settings
getSettings TTLHashTable {..} = liftIO $ do
maxSize <- readIORef maxSizeRef_
defaultTTL <- readIORef defaultTTLRef_
gcMaxEntries <- readIORef gcMaxEntriesRef_
return Settings { maxSize = maxSize,
defaultTTL = defaultTTL,
gcMaxEntries = gcMaxEntries }
mutate :: (Eq k, Hashable k, MonadIO m, Failable m)
=> TTLHashTable h k v
-> k
-> (Maybe v -> (Maybe v, a))
-> m a
mutate ht@TTLHashTable {..} k f = do
now <- getTimeStamp
defaultTTL <- liftIO $ readIORef defaultTTLRef_
mutateWith (mutate' now defaultTTL f) ht k
mutate' :: TimeStamp -> Int -> (Maybe v -> (Maybe v, a)) -> Maybe (Value v) -> (Maybe (Value v), a)
mutate' now defaultTTL f mV =
let (expiresAt, ttl, mValue) = metaFrom mV
vFrom mValue' = do
value <- mValue'
return Value {expiresAt = expiresAt, ttl = ttl, value = value }
in first vFrom $ f mValue
where metaFrom Nothing =
(now + fromIntegral (defaultTTL * 1000000), defaultTTL, Nothing)
metaFrom (Just Value {..}) =
(now + fromIntegral (ttl * 1000000), ttl, Just value)