{-# LANGUAGE Haskell2010         #-}
{-# LANGUAGE DeriveDataTypeable  #-}

{-|
    An implementation of a mutable map with randomly generated keys and
    values that time out and must be cleaned up when not in use.  This
    pattern occurs in a number of situations related to session
    management, so it is useful to have the code in one place.
-}
module Snap.Snaplet.TypedSession.SessionMap (
    SessionMap,
    new,
    close,
    insert,
    update,
    delete,
    lookup,
    touch
    ) where

import Prelude hiding (lookup, catch)

import Control.Concurrent
import Control.Exception
import Control.Monad
import Data.ByteString (ByteString)
import Data.Time
import Data.Typeable
import Snap.Snaplet.Session (RNG, mkRNG, randomToken)

import qualified Data.HashTable.IO as HT
import qualified Data.PSQueue      as PQ


{-|
    A SessionMap is a mutable map with ByteString keys and values of a
    user-defined type.  The values in the map are associated with
    timeouts, after which they will automatically be deleted.

    A SessionMap is in one of two states: active or closed.  When first
    created, the map is active.  Values can be inserted, queried, and
    removed, and will be automatically removed once time has expired.
    When the map is closed, it changes to the closed state.  Attempts
    to use it will result in exceptions, and all data and resources
    associated with value timeout are freed.  There is no way to reopen
    a closed SessionMap.
-}
data SessionMap v = SessionMap
    !(MVar (Maybe (PQ.PSQ ByteString UTCTime)))
        -- In addition to holding the priority queue, this also acts
        -- as the lock for the entire data structure.  The hash table
        -- should not be modified except when holding this lock.
    !(HT.BasicHashTable ByteString v)
    !ThreadId
    !RNG
    !Int


{-|
    This exception type is created specifically for the purpose of
    poking the watcher thread when new values are inserted into the
    map.
-}
data PokeWatcher = PokeWatcher deriving (Typeable, Show)
instance Exception PokeWatcher


{-|
    Creates a new random map, which is initially empty but active.
-}
new :: Int -> IO (SessionMap v)
new to = do
    q   <- newMVar (Just PQ.empty)
    ht  <- HT.new
    w   <- forkIO (doWatcher q ht)
    gen <- mkRNG
    return (SessionMap q ht w gen to)


{-|
    The watcher thread.  This thread loops, checking for expired
    elements and removing them from the map.  When there are no expired
    elements, it sleeps until the next expiration time.  It should be
    thrown a PokeWatcher exception when the expiration queue is changed
    in order to get it to notice earlier expiration times.
-}
doWatcher :: MVar (Maybe (PQ.PSQ ByteString UTCTime))
          -> HT.BasicHashTable ByteString v
          -> IO ()
doWatcher q ht = mask $ \_ -> do
    nxt <- modifyMVar q (cleanExpired ht)
    case nxt of
        Nothing  -> return ()
        Just del -> do threadDelay del
                           `catch` \PokeWatcher -> return ()
                       doWatcher q ht


{-
    The inner portion of the watcher thread, which removes expired
    elements and returns the amopunt of time to wait before cleaning
    again.
-}
cleanExpired :: HT.BasicHashTable ByteString v
             -> Maybe (PQ.PSQ ByteString UTCTime)
             -> IO (Maybe (PQ.PSQ ByteString UTCTime), Maybe Int)
cleanExpired _  Nothing  = return (Nothing, Nothing)
cleanExpired ht (Just q) = case PQ.minView q of
    Nothing -> return (Just q, Just maxBound)
    Just (k PQ.:-> e, q') -> do
        t <- getCurrentTime
        if e <= t then do
            HT.delete ht k
            cleanExpired ht (Just q')
        else return (Just q, Just (round (1000000 * diffUTCTime e t)))


{-|
    Performs some action with the MVar held for the priority queue.
    That MVar serves as a lock for the entire data structure, ensuring
    thread safety.
-}
withOpenMap :: MVar (Maybe (PQ.PSQ ByteString UTCTime))
            -> (PQ.PSQ ByteString UTCTime -> IO (Maybe (PQ.PSQ ByteString UTCTime), a))
            -> IO a
withOpenMap qq f = modifyMVar qq go
    where go Nothing  = error "Session map is already closed"
          go (Just q) = f q


{-|
    Closes a SessionMap, moving it to the closed state.  Once in the
    closed state, and data contained in the map is lost and resources
    are freed, and any attempt to use the map will result in an error.
    A closed SessionMap cannot be reopened.
-}
close :: SessionMap v -> IO ()
close (SessionMap qq ht w _ _) = withOpenMap qq $ \_ -> do
    throwTo w PokeWatcher
    vals <- HT.toList ht
    forM_ vals $ \(k,_) -> HT.delete ht k
    return (Nothing, ())


{-|
    Inserts a new value into the map, choosing and returning a new
    unused session key in the process.
-}
insert :: SessionMap v -> v -> IO ByteString
insert m@(SessionMap qq ht _ gen _) v = withOpenMap qq $ \q -> do
    k       <- uniqueKey gen ht
    (q', _) <- update' m k v q
    return (q', k)


{-|
    Chooses a new session key that is unused by the map.  While this
    technically could loop arbitrarily long, in practice the space of
    keys is so large that it's extremely unlikely to run into a
    conflict at all.
-}
uniqueKey :: RNG -> HT.BasicHashTable ByteString v -> IO ByteString
uniqueKey gen ht = do
    k <- randomToken 40 gen
    maybe (return k) (const (uniqueKey gen ht)) =<< HT.lookup ht k


{-|
    Replaces the value associated with a key in the session map.  The
    expiration time for the key is reset at the same time.
-}
update :: SessionMap v -> ByteString -> v -> IO ()
update m@(SessionMap q _ _ _ _) k v =
    withOpenMap q (update' m k v)


update' :: SessionMap v -> ByteString -> v
        -> PQ.PSQ ByteString UTCTime
        -> IO (Maybe (PQ.PSQ ByteString UTCTime), ())
update' m@(SessionMap _ ht _ _ _) k v q = do
    HT.insert ht k v
    touch' m k q


{-|
    Removes a value from the map.
-}
delete :: SessionMap v -> ByteString -> IO ()
delete (SessionMap qq ht _ _ _) k = withOpenMap qq $ \q -> do
    HT.delete ht k
    return (Just (PQ.delete k q), ())


{-|
    Looks up a session key in the map, and returns the associated value.
    This does NOT reset the expiration time.  If that's what you want,
    see 'touch'.
-}
lookup :: SessionMap v -> ByteString -> IO (Maybe v)
lookup (SessionMap qq ht _ _ _) k = withOpenMap qq $ \q -> do
    v <- HT.lookup ht k
    return (Just q, v)


{-|
    Resets the expiration time in a key in the map.
-}
touch :: SessionMap v -> ByteString -> IO ()
touch m@(SessionMap q _ _ _ _) k =
    withOpenMap q (touch' m k)


touch' :: SessionMap v -> ByteString
       -> PQ.PSQ ByteString UTCTime
       -> IO (Maybe (PQ.PSQ ByteString UTCTime), ())
touch' (SessionMap _ ht w _ to) k q = do
    throwTo w PokeWatcher
    t <- getCurrentTime
    let p = addUTCTime (fromIntegral to) t
    ans <- HT.lookup ht k
    case ans of
        Nothing -> return (Just q, ())
        Just _  -> return (Just (PQ.insert k p q), ())