{-# LANGUAGE OverloadedStrings          #-}
{-# LANGUAGE TypeFamilies               #-}

{-|
    'Snap.Extension.Session.Memory' exports the 'MonadSessionMemory' interface
    which allows you to keep an in-memory session object for each client session
    of a web application.
-}

module Snap.Extension.Session.Memory (
    HasMemorySessionManager(..),
    MemorySessionManager,
    memorySessionInitializer
    ) where

import Control.Applicative
import Control.Concurrent
import Control.Monad.Reader
import Data.Time.Clock
import Snap.Extension
import Snap.Extension.Session
import Snap.SessionUtil
import Snap.Types

import Data.Map (Map)
import qualified Data.Map as M

import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B

{-
    This type class lets one bundle the memory-based session manager into a
    single application state object.
-}
class HasMemorySessionManager a where
    type MemorySession a
    memorySessionMgr :: a -> MemorySessionManager (MemorySession a)

memorySessionInitializer :: NominalDiffTime
                         -> IO t
                         -> Initializer (MemorySessionManager t)
memorySessionInitializer time def =
    mkInitializer =<< liftIO (makeSessionManager time def)

instance InitializerState (MemorySessionManager t) where
    extensionId = const "Session/Memory"
    mkCleanup   = closeSessionManager
    mkReload    = const $ return ()

{-
    A SessionManager keeps track of local sessions.
-}
data MemorySessionManager obj = MemorySessionManager
    (MVar Bool) (MVar (Map SessionKey (SessionWrapper obj))) (IO obj)

{-
    A session contains the user session object and the dialogue map.
-}
data SessionWrapper obj = SessionWrapper {
    sessionClient      :: ByteString,
    sessionLastTouched :: MVar UTCTime,
    sessionObject      :: MVar obj
    }

{-|
    Determines whether a session is still valid or not.
-}
goodSession :: NominalDiffTime -> (SessionKey, SessionWrapper obj) -> IO Bool
goodSession timeout (_, session) = do
    st <- readMVar (sessionLastTouched session)
    ct <- getCurrentTime 
    return (diffUTCTime ct st <= timeout)

{-|
    Monadic while statement, for convenience.
-}
whileM :: Monad m => m Bool -> m a -> m ()
whileM cond action = do
    b <- cond
    if b then action >> whileM cond action else return ()

{-|
    Creates a new 'SessionManager' to manage a sessions in the web application.
    This also spawns the session reaper, which cleans up sessions that haven't
    been touched for a given time period.
-}
makeSessionManager :: NominalDiffTime -> IO obj -> IO (MemorySessionManager obj)
makeSessionManager timeout def = do
    oref <- newMVar True
    sref <- newMVar M.empty
    _    <- forkIO $ whileM (readMVar oref) $ do
        sess <- takeMVar sref
        good <- filterM (goodSession timeout) (M.assocs sess)
        putMVar sref (M.fromList good)
        threadDelay 5000
    return $ MemorySessionManager oref sref def

{-|
    Closes a SessionManager, which will cause it to cease accepting any
    incoming requests, and also to terminate the session reaper thread.
-}
closeSessionManager :: MemorySessionManager obj -> IO ()
closeSessionManager (MemorySessionManager oref sref _) = do
    _ <- swapMVar oref False
    _ <- swapMVar sref M.empty
    return ()

{-|
    Adds a 'Session' and associated cookie.  This always sets a new
    blank session, so should only be used when there is no session already.
-}
addSession :: MonadSnap m => MemorySessionManager obj -> m (SessionWrapper obj)
addSession (MemorySessionManager _ sref def) = do
    client <- fmap rqRemoteAddr getRequest
    (k, session) <- liftIO $ do
        smap    <- takeMVar sref
        k       <- uniqueKey smap
        ct      <- getCurrentTime
        session <- SessionWrapper client <$> newMVar ct <*> (newMVar =<< def)
        putMVar sref (M.insert k session smap)
        return (k, session)
    setCookie $ Cookie "sessionid" (B.pack $ show k) Nothing Nothing Nothing
    return session

{-
    Retrieves an existing session, if one exists.  If there is no session, the
    result is Nothing.
-}
getExistingSession :: MonadSnap m
                   => MemorySessionManager obj
                   -> m (Maybe (SessionWrapper obj))
getExistingSession (MemorySessionManager oref sref _) = do
    liftIO (readMVar oref) >>= flip unless mzero
    client  <- fmap rqRemoteAddr getRequest
    ck      <- lookupCookie "sessionid"
    case ck of
        Nothing  -> return Nothing
        Just sid -> do
            smap <- liftIO $ readMVar sref
            case M.lookup (read (B.unpack (cookieValue sid))) smap of
                Nothing -> return Nothing
                Just s  -> do
                    if client /= sessionClient s
                        then return Nothing
                        else return (Just s)

{-|
    Ensures that there is a 'Session' in place, and returns it.  Adds a blank
    one if necessary.  This also updates the last touched time for the session,
    preventing it from being removed by the reaper thread for a while.
-}
getAnySession :: MonadSnap m
              => MemorySessionManager obj
              -> m (SessionWrapper obj)
getAnySession mgr = maybe (addSession mgr) return =<< getExistingSession mgr

instance HasMemorySessionManager s => MonadSession (SnapExtend s) where
    type Session (SnapExtend s) = MemorySession s
    getSession = do
        mgr <- asks memorySessionMgr
        ses <- getAnySession mgr
        liftIO $ readMVar $ sessionObject ses

    setSession val = do
        mgr <- asks memorySessionMgr
        ses <- getAnySession mgr
        _   <- liftIO $ swapMVar (sessionLastTouched ses) =<< getCurrentTime
        _   <- liftIO $ swapMVar (sessionObject      ses) val
        return ()

    touchSession = do
        mgr <- asks memorySessionMgr
        ses <- getAnySession mgr
        _   <- liftIO $ swapMVar (sessionLastTouched ses) =<< getCurrentTime
        return ()

    clearSession = clearCookie "sessionid"