module Snap.Extension.Session.Memory (
HasMemorySessionManager(..),
MemorySessionManager,
memorySessionInitializer
) where
import Control.Applicative
import Control.Concurrent
import Control.Monad.Reader
import Data.Time.Clock
import Snap.Extension
import Snap.Extension.Session
import Snap.SessionUtil
import Snap.Types
import Data.Map (Map)
import qualified Data.Map as M
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
class HasMemorySessionManager a where
type MemorySession a
memorySessionMgr :: a -> MemorySessionManager (MemorySession a)
memorySessionInitializer :: NominalDiffTime
-> IO t
-> Initializer (MemorySessionManager t)
memorySessionInitializer time def =
mkInitializer =<< liftIO (makeSessionManager time def)
instance InitializerState (MemorySessionManager t) where
extensionId = const "Session/Memory"
mkCleanup = closeSessionManager
mkReload = const $ return ()
data MemorySessionManager obj = MemorySessionManager
(MVar Bool) (MVar (Map SessionKey (SessionWrapper obj))) (IO obj)
data SessionWrapper obj = SessionWrapper {
sessionClient :: ByteString,
sessionLastTouched :: MVar UTCTime,
sessionObject :: MVar obj
}
goodSession :: NominalDiffTime -> (SessionKey, SessionWrapper obj) -> IO Bool
goodSession timeout (_, session) = do
st <- readMVar (sessionLastTouched session)
ct <- getCurrentTime
return (diffUTCTime ct st <= timeout)
whileM :: Monad m => m Bool -> m a -> m ()
whileM cond action = do
b <- cond
if b then action >> whileM cond action else return ()
makeSessionManager :: NominalDiffTime -> IO obj -> IO (MemorySessionManager obj)
makeSessionManager timeout def = do
oref <- newMVar True
sref <- newMVar M.empty
_ <- forkIO $ whileM (readMVar oref) $ do
sess <- takeMVar sref
good <- filterM (goodSession timeout) (M.assocs sess)
putMVar sref (M.fromList good)
threadDelay 5000
return $ MemorySessionManager oref sref def
closeSessionManager :: MemorySessionManager obj -> IO ()
closeSessionManager (MemorySessionManager oref sref _) = do
_ <- swapMVar oref False
_ <- swapMVar sref M.empty
return ()
addSession :: MonadSnap m => MemorySessionManager obj -> m (SessionWrapper obj)
addSession (MemorySessionManager _ sref def) = do
client <- fmap rqRemoteAddr getRequest
(k, session) <- liftIO $ do
smap <- takeMVar sref
k <- uniqueKey smap
ct <- getCurrentTime
session <- SessionWrapper client <$> newMVar ct <*> (newMVar =<< def)
putMVar sref (M.insert k session smap)
return (k, session)
setCookie $ Cookie "sessionid" (B.pack $ show k) Nothing Nothing Nothing
return session
getExistingSession :: MonadSnap m
=> MemorySessionManager obj
-> m (Maybe (SessionWrapper obj))
getExistingSession (MemorySessionManager oref sref _) = do
liftIO (readMVar oref) >>= flip unless mzero
client <- fmap rqRemoteAddr getRequest
ck <- lookupCookie "sessionid"
case ck of
Nothing -> return Nothing
Just sid -> do
smap <- liftIO $ readMVar sref
case M.lookup (read (B.unpack (cookieValue sid))) smap of
Nothing -> return Nothing
Just s -> do
if client /= sessionClient s
then return Nothing
else return (Just s)
getAnySession :: MonadSnap m
=> MemorySessionManager obj
-> m (SessionWrapper obj)
getAnySession mgr = maybe (addSession mgr) return =<< getExistingSession mgr
instance HasMemorySessionManager s => MonadSession (SnapExtend s) where
type Session (SnapExtend s) = MemorySession s
getSession = do
mgr <- asks memorySessionMgr
ses <- getAnySession mgr
liftIO $ readMVar $ sessionObject ses
setSession val = do
mgr <- asks memorySessionMgr
ses <- getAnySession mgr
_ <- liftIO $ swapMVar (sessionLastTouched ses) =<< getCurrentTime
_ <- liftIO $ swapMVar (sessionObject ses) val
return ()
touchSession = do
mgr <- asks memorySessionMgr
ses <- getAnySession mgr
_ <- liftIO $ swapMVar (sessionLastTouched ses) =<< getCurrentTime
return ()
clearSession = clearCookie "sessionid"