{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}

{-|
    This is the server-side memory-backed implementation of typed
    sessions.  It has the advantage of being able to store arbitrary
    data structures including functions and other non-serializable data.
    As a result, though, it is limited to a single server-side process
    since it's not possible to migrate arbitrary data between nodes.
    Load balancing with this snaplet requires "sticky sessions" or a
    similar technique to ensure that a given client always reaches the
    same server-side node.
-}
module Snap.Snaplet.TypedSession.Memory (
    MemorySessionManager,
    initMemorySessions,
    module Snap.Snaplet.TypedSession
    ) where

import Control.Monad.State
import Data.ByteString (ByteString)
import Snap.Core
import Snap.Snaplet
import Snap.Snaplet.Session (getSecureCookie, setSecureCookie)
import Snap.Snaplet.TypedSession
import Web.ClientSession

import Snap.Snaplet.TypedSession.SessionMap (SessionMap)
import qualified Snap.Snaplet.TypedSession.SessionMap as SM

{-|
    A typed session manager that stores session data by session key in
    a server-side map.
-}
data MemorySessionManager t = MemorySessionManager {
    memorySessionCache   :: !(Maybe (ByteString, t)),
    memorySessionName    :: !ByteString,
    memorySessionKey     :: !Key,
    memorySessionDefault :: !(IO t),
    memorySessionTimeout :: !Int,
    memorySessionData    :: !(SessionMap t)
    }


{-|
    Initializer for the memory-backed typed session snaplet.
-}
initMemorySessions :: FilePath   -- ^ Location of an encryption key
                   -> ByteString -- ^ Name for the session ID cookie
                   -> Int        -- ^ Session timeout in seconds
                   -> IO t       -- ^ Initializer for new sessions
                   -> SnapletInit b (MemorySessionManager t)
initMemorySessions fp name timeout defaulter =
    makeSnaplet "TypedSession.Memory"
                "Typed sessions stored in server-side memory"
                Nothing $ liftIO $ do
        key <- getKey fp
        sm  <- SM.new timeout
        return $! MemorySessionManager Nothing name key defaulter timeout sm
        -- TODO: Maybe wrap routes with touchSession?


getSessionImpl :: Handler b (MemorySessionManager t) t
getSessionImpl = do
    mgr <- get
    case memorySessionCache mgr of
        Just (_, val) -> return val
        Nothing       -> do
            msid <- getSecureCookie (memorySessionName mgr)
                                    (memorySessionKey  mgr)
                                    (Just (memorySessionTimeout mgr))
            case msid of
                Nothing  -> newSession
                Just sid -> do
                    mval <- liftIO (SM.lookup (memorySessionData mgr) sid)
                    case mval of
                        Nothing  -> newSession
                        Just val -> do
                            put (mgr { memorySessionCache = Just (sid,val) })
                            return val


newSession :: Handler b (MemorySessionManager t) t
newSession = do
    mgr <- get
    val <- liftIO (memorySessionDefault mgr)
    sid <- liftIO (SM.insert (memorySessionData mgr) val)
    put (mgr { memorySessionCache = Just (sid,val) })
    setSecureCookie (memorySessionName mgr)
                    (memorySessionKey  mgr)
                    (Just (memorySessionTimeout mgr))
                    sid
    return val


setSessionImpl :: t -> Handler b (MemorySessionManager t) ()
setSessionImpl val = do
    mgr  <- get
    msid <- getSessionId
    sid <- case msid of
        Nothing  -> liftIO (SM.insert (memorySessionData mgr) val)
        Just sid -> do
            liftIO (SM.update (memorySessionData mgr) sid val)
            return sid
    put (mgr { memorySessionCache = Just (sid, val) })
    setSecureCookie (memorySessionName mgr)
                    (memorySessionKey  mgr)
                    (Just (memorySessionTimeout mgr))
                    sid


getSessionId :: Handler b (MemorySessionManager t) (Maybe ByteString)
getSessionId = do
    mgr <- get
    case memorySessionCache mgr of
        Just (sid, _) -> return (Just sid)
        Nothing       -> do
            msid <- getSecureCookie (memorySessionName mgr)
                                    (memorySessionKey  mgr)
                                    (Just (memorySessionTimeout mgr))
            case msid of
                Nothing  -> return Nothing
                Just sid -> return (Just sid)


touchSessionImpl :: Handler b (MemorySessionManager t) ()
touchSessionImpl = do
    msid <- getSessionId
    case msid of
        Nothing  -> return ()
        Just sid -> do
            mgr <- get
            liftIO (SM.touch (memorySessionData mgr) sid)
            setSecureCookie (memorySessionName mgr)
                            (memorySessionKey  mgr)
                            (Just (memorySessionTimeout mgr))
                            sid


clearSessionImpl :: Handler b (MemorySessionManager t) ()
clearSessionImpl = do
    mgr <- get
    put (mgr { memorySessionCache = Nothing })

    msid <- getSessionId
    case msid of
        Nothing  -> return ()
        Just sid -> do
            liftIO (SM.delete (memorySessionData mgr) sid)
            expireCookie (memorySessionName mgr) Nothing


instance HasTypedSession (MemorySessionManager t) t where
    getSession   = getSessionImpl
    setSession   = setSessionImpl
    touchSession = touchSessionImpl
    clearSession = clearSessionImpl