{-# LANGUAGE Haskell2010 #-} {-# LANGUAGE DeriveDataTypeable #-} {-| An implementation of a mutable map with randomly generated keys and values that time out and must be cleaned up when not in use. This pattern occurs in a number of situations related to session management, so it is useful to have the code in one place. -} module Snap.Snaplet.TypedSession.SessionMap ( SessionMap, new, close, insert, update, delete, lookup, touch ) where import Prelude hiding (lookup, catch) import Control.Concurrent import Control.Exception import Control.Monad import Data.ByteString (ByteString) import Data.Time import Data.Typeable import Snap.Snaplet.Session (RNG, mkRNG, randomToken) import qualified Data.HashTable.IO as HT import qualified Data.PSQueue as PQ {-| A SessionMap is a mutable map with ByteString keys and values of a user-defined type. The values in the map are associated with timeouts, after which they will automatically be deleted. A SessionMap is in one of two states: active or closed. When first created, the map is active. Values can be inserted, queried, and removed, and will be automatically removed once time has expired. When the map is closed, it changes to the closed state. Attempts to use it will result in exceptions, and all data and resources associated with value timeout are freed. There is no way to reopen a closed SessionMap. -} data SessionMap v = SessionMap !(MVar (Maybe (PQ.PSQ ByteString UTCTime))) -- In addition to holding the priority queue, this also acts -- as the lock for the entire data structure. The hash table -- should not be modified except when holding this lock. !(HT.BasicHashTable ByteString v) !ThreadId !RNG !Int {-| This exception type is created specifically for the purpose of poking the watcher thread when new values are inserted into the map. -} data PokeWatcher = PokeWatcher deriving (Typeable, Show) instance Exception PokeWatcher {-| Creates a new random map, which is initially empty but active. -} new :: Int -> IO (SessionMap v) new to = do q <- newMVar (Just PQ.empty) ht <- HT.new w <- forkIO (doWatcher q ht) gen <- mkRNG return (SessionMap q ht w gen to) {-| The watcher thread. This thread loops, checking for expired elements and removing them from the map. When there are no expired elements, it sleeps until the next expiration time. It should be thrown a PokeWatcher exception when the expiration queue is changed in order to get it to notice earlier expiration times. -} doWatcher :: MVar (Maybe (PQ.PSQ ByteString UTCTime)) -> HT.BasicHashTable ByteString v -> IO () doWatcher q ht = mask $ \_ -> do nxt <- modifyMVar q (cleanExpired ht) case nxt of Nothing -> return () Just del -> do threadDelay del `catch` \PokeWatcher -> return () doWatcher q ht {- The inner portion of the watcher thread, which removes expired elements and returns the amopunt of time to wait before cleaning again. -} cleanExpired :: HT.BasicHashTable ByteString v -> Maybe (PQ.PSQ ByteString UTCTime) -> IO (Maybe (PQ.PSQ ByteString UTCTime), Maybe Int) cleanExpired _ Nothing = return (Nothing, Nothing) cleanExpired ht (Just q) = case PQ.minView q of Nothing -> return (Just q, Just maxBound) Just (k PQ.:-> e, q') -> do t <- getCurrentTime if e <= t then do HT.delete ht k cleanExpired ht (Just q') else return (Just q, Just (round (1000000 * diffUTCTime e t))) {-| Performs some action with the MVar held for the priority queue. That MVar serves as a lock for the entire data structure, ensuring thread safety. -} withOpenMap :: MVar (Maybe (PQ.PSQ ByteString UTCTime)) -> (PQ.PSQ ByteString UTCTime -> IO (Maybe (PQ.PSQ ByteString UTCTime), a)) -> IO a withOpenMap qq f = modifyMVar qq go where go Nothing = error "Session map is already closed" go (Just q) = f q {-| Closes a SessionMap, moving it to the closed state. Once in the closed state, and data contained in the map is lost and resources are freed, and any attempt to use the map will result in an error. A closed SessionMap cannot be reopened. -} close :: SessionMap v -> IO () close (SessionMap qq ht w _ _) = withOpenMap qq $ \_ -> do throwTo w PokeWatcher vals <- HT.toList ht forM_ vals $ \(k,_) -> HT.delete ht k return (Nothing, ()) {-| Inserts a new value into the map, choosing and returning a new unused session key in the process. -} insert :: SessionMap v -> v -> IO ByteString insert m@(SessionMap qq ht _ gen _) v = withOpenMap qq $ \q -> do k <- uniqueKey gen ht (q', _) <- update' m k v q return (q', k) {-| Chooses a new session key that is unused by the map. While this technically could loop arbitrarily long, in practice the space of keys is so large that it's extremely unlikely to run into a conflict at all. -} uniqueKey :: RNG -> HT.BasicHashTable ByteString v -> IO ByteString uniqueKey gen ht = do k <- randomToken 40 gen maybe (return k) (const (uniqueKey gen ht)) =<< HT.lookup ht k {-| Replaces the value associated with a key in the session map. The expiration time for the key is reset at the same time. -} update :: SessionMap v -> ByteString -> v -> IO () update m@(SessionMap q _ _ _ _) k v = withOpenMap q (update' m k v) update' :: SessionMap v -> ByteString -> v -> PQ.PSQ ByteString UTCTime -> IO (Maybe (PQ.PSQ ByteString UTCTime), ()) update' m@(SessionMap _ ht _ _ _) k v q = do HT.insert ht k v touch' m k q {-| Removes a value from the map. -} delete :: SessionMap v -> ByteString -> IO () delete (SessionMap qq ht _ _ _) k = withOpenMap qq $ \q -> do HT.delete ht k return (Just (PQ.delete k q), ()) {-| Looks up a session key in the map, and returns the associated value. This does NOT reset the expiration time. If that's what you want, see 'touch'. -} lookup :: SessionMap v -> ByteString -> IO (Maybe v) lookup (SessionMap qq ht _ _ _) k = withOpenMap qq $ \q -> do v <- HT.lookup ht k return (Just q, v) {-| Resets the expiration time in a key in the map. -} touch :: SessionMap v -> ByteString -> IO () touch m@(SessionMap q _ _ _ _) k = withOpenMap q (touch' m k) touch' :: SessionMap v -> ByteString -> PQ.PSQ ByteString UTCTime -> IO (Maybe (PQ.PSQ ByteString UTCTime), ()) touch' (SessionMap _ ht w _ to) k q = do throwTo w PokeWatcher t <- getCurrentTime let p = addUTCTime (fromIntegral to) t ans <- HT.lookup ht k case ans of Nothing -> return (Just q, ()) Just _ -> return (Just (PQ.insert k p q), ())