{-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts, OverloadedStrings, DoAndIfThenElse, RankNTypes #-} module Web.Spock.Internal.SessionManager ( createSessionManager, withSessionManager , SessionId, Session(..), SessionManager(..) ) where import Web.Spock.Internal.Types import Web.Spock.Internal.CoreAction import Web.Spock.Internal.Wire import Web.Spock.Internal.Util import Web.Spock.Internal.Cookies import qualified Web.Spock.Internal.SessionVault as SV #if MIN_VERSION_base(4,8,0) #else import Control.Applicative #endif import Control.Concurrent import Control.Concurrent.STM import Control.Exception import Control.Monad import Control.Monad.Trans import Data.Time #if MIN_VERSION_time(1,5,0) #else import System.Locale (defaultTimeLocale) #endif import qualified Crypto.Random as CR import qualified Data.ByteString as BS import qualified Data.ByteString.Base64 as B64 import qualified Data.HashMap.Strict as HM import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Vault.Lazy as V import qualified Network.Wai as Wai withSessionManager :: SessionCfg sess -> (SessionManager conn sess st -> IO a) -> IO a withSessionManager sessCfg = bracket (createSessionManager sessCfg) sm_closeSessionManager createSessionManager :: SessionCfg sess -> IO (SessionManager conn sess st) createSessionManager cfg = do pool <- CR.createEntropyPool oldSess <- loadSessions cacheHM <- atomically $ do mapV <- SV.newSessionVault forM_ oldSess $ \v -> SV.storeSession v mapV return mapV vaultKey <- V.newKey housekeepThread <- forkIO (forever (housekeepSessions cfg cacheHM storeSessions)) return SessionManager { sm_getSessionId = getSessionIdImpl vaultKey cacheHM , sm_regenerateSessionId = regenerateSessionIdImpl vaultKey cacheHM pool cfg , sm_readSession = readSessionImpl vaultKey cacheHM , sm_writeSession = writeSessionImpl vaultKey cacheHM , sm_modifySession = modifySessionImpl vaultKey cacheHM , sm_clearAllSessions = clearAllSessionsImpl cacheHM , sm_mapSessions = mapAllSessionsImpl cacheHM , sm_middleware = sessionMiddleware pool cfg vaultKey cacheHM , sm_addSafeAction = addSafeActionImpl pool vaultKey cacheHM , sm_lookupSafeAction = lookupSafeActionImpl vaultKey cacheHM , sm_removeSafeAction = removeSafeActionImpl vaultKey cacheHM , sm_closeSessionManager = killThread housekeepThread } where (loadSessions, storeSessions) = case sc_persistCfg cfg of Nothing -> ( return [] , const $ return () ) Just spc -> ( do sessions <- spc_load spc return (map genSession sessions) , spc_store spc . map mkSerializable . HM.elems ) mkSerializable sess = (sess_id sess, sess_validUntil sess, sess_data sess) genSession (sid, validUntil, theData) = Session { sess_id = sid , sess_validUntil = validUntil , sess_data = theData , sess_safeActions = SafeActionStore HM.empty HM.empty } regenerateSessionIdImpl :: V.Key SessionId -> SV.SessionVault (Session conn sess st) -> CR.EntropyPool -> SessionCfg sess -> SpockActionCtx ctx conn sess st () regenerateSessionIdImpl vK sessionRef entropyPool cfg = do sess <- readSessionBase vK sessionRef liftIO $ deleteSessionImpl sessionRef (sess_id sess) newSession <- liftIO $ newSessionImpl entropyPool cfg sessionRef (sess_data sess) now <- liftIO getCurrentTime setRawMultiHeader MultiHeaderSetCookie $ makeSessionIdCookie cfg newSession now modifyVault $ V.insert vK (sess_id newSession) getSessionIdImpl :: V.Key SessionId -> SV.SessionVault (Session conn sess st) -> SpockActionCtx ctx conn sess st SessionId getSessionIdImpl vK sessionRef = do sess <- readSessionBase vK sessionRef return $ sess_id sess modifySessionBase :: V.Key SessionId -> SV.SessionVault (Session conn sess st) -> (Session conn sess st -> (Session conn sess st, a)) -> SpockActionCtx ctx conn sess st a modifySessionBase vK sessionRef modFun = do mValue <- queryVault vK case mValue of Nothing -> error "(3) Internal Spock Session Error. Please report this bug!" Just sid -> liftIO $ atomically $ do mSession <- SV.loadSession sid sessionRef case mSession of Nothing -> fail "Internal Spock Session Error: Unknown SessionId" Just session -> do let (sessionNew, result) = modFun session SV.storeSession sessionNew sessionRef return result readSessionBase :: V.Key SessionId -> SV.SessionVault (Session conn sess st) -> SpockActionCtx ctx conn sess st (Session conn sess st) readSessionBase vK sessionRef = do mValue <- queryVault vK case mValue of Nothing -> error "(1) Internal Spock Session Error. Please report this bug!" Just sid -> do mSession <- liftIO $ atomically $ SV.loadSession sid sessionRef case mSession of Nothing -> error "(2) Internal Spock Session Error. Please report this bug!" Just session -> return session addSafeActionImpl :: CR.EntropyPool -> V.Key SessionId -> SV.SessionVault (Session conn sess st) -> PackedSafeAction conn sess st -> SpockActionCtx ctx conn sess st SafeActionHash addSafeActionImpl pool vaultKey sessionMapVar safeAction = do base <- readSessionBase vaultKey sessionMapVar case HM.lookup safeAction (sas_reverse (sess_safeActions base)) of Just safeActionHash -> return safeActionHash Nothing -> do safeActionHash <- liftIO (randomHash pool 40) let f sas = sas { sas_forward = HM.insert safeActionHash safeAction (sas_forward sas) , sas_reverse = HM.insert safeAction safeActionHash (sas_reverse sas) } modifySessionBase vaultKey sessionMapVar (\s -> (s { sess_safeActions = f (sess_safeActions s) }, ())) return safeActionHash lookupSafeActionImpl :: V.Key SessionId -> SV.SessionVault (Session conn sess st) -> SafeActionHash -> SpockActionCtx ctx conn sess st (Maybe (PackedSafeAction conn sess st)) lookupSafeActionImpl vaultKey sessionMapVar hash = do base <- readSessionBase vaultKey sessionMapVar return $ HM.lookup hash (sas_forward (sess_safeActions base)) removeSafeActionImpl :: V.Key SessionId -> SV.SessionVault (Session conn sess st) -> PackedSafeAction conn sess st -> SpockActionCtx ctx conn sess st () removeSafeActionImpl vaultKey sessionMapVar action = modifySessionBase vaultKey sessionMapVar (\s -> (s { sess_safeActions = f (sess_safeActions s ) }, ())) where f sas = sas { sas_forward = case HM.lookup action (sas_reverse sas) of Just h -> HM.delete h (sas_forward sas) Nothing -> sas_forward sas , sas_reverse = HM.delete action (sas_reverse sas) } readSessionImpl :: V.Key SessionId -> SV.SessionVault (Session conn sess st) -> SpockActionCtx ctx conn sess st sess readSessionImpl vK sessionRef = do base <- readSessionBase vK sessionRef return (sess_data base) writeSessionImpl :: V.Key SessionId -> SV.SessionVault (Session conn sess st) -> sess -> SpockActionCtx ctx conn sess st () writeSessionImpl vK sessionRef value = modifySessionImpl vK sessionRef (const (value, ())) modifySessionImpl :: V.Key SessionId -> SV.SessionVault (Session conn sess st) -> (sess -> (sess, a)) -> SpockActionCtx ctx conn sess st a modifySessionImpl vK sessionRef f = do let modFun session = let (sessData', out) = f (sess_data session) in (session { sess_data = sessData' }, out) modifySessionBase vK sessionRef modFun makeSessionIdCookie :: SessionCfg sess -> Session conn sess st -> UTCTime -> BS.ByteString makeSessionIdCookie cfg sess now = generateCookieHeaderString name value settings now where name = sc_cookieName cfg value = sess_id sess settings = defaultCookieSettings { cs_EOL = CookieValidUntil (sess_validUntil sess) , cs_HTTPOnly = True } sessionMiddleware :: CR.EntropyPool -> SessionCfg sess -> V.Key SessionId -> SV.SessionVault (Session conn sess st) -> Wai.Middleware sessionMiddleware pool cfg vK sessionRef app req respond = case getCookieFromReq (sc_cookieName cfg) of Just sid -> do mSess <- loadSessionImpl cfg sessionRef sid case mSess of Nothing -> mkNew Just sess -> withSess False sess Nothing -> mkNew where getCookieFromReq name = lookup "cookie" (Wai.requestHeaders req) >>= lookup name . parseCookies defVal = sc_emptySession cfg v = Wai.vault req addCookie sess now responseHeaders = let cookieContent = makeSessionIdCookie cfg sess now cookieC = ("Set-Cookie", cookieContent) in (cookieC : responseHeaders) withSess shouldSetCookie sess = app (req { Wai.vault = V.insert vK (sess_id sess) v }) $ \unwrappedResp -> do now <- getCurrentTime respond $ if shouldSetCookie then mapReqHeaders (addCookie sess now) unwrappedResp else unwrappedResp mkNew = do newSess <- newSessionImpl pool cfg sessionRef defVal withSess True newSess newSessionImpl :: CR.EntropyPool -> SessionCfg sess -> SV.SessionVault (Session conn sess st) -> sess -> IO (Session conn sess st) newSessionImpl pool sessCfg sessionRef content = do sess <- createSession pool sessCfg content atomically $ SV.storeSession sess sessionRef return $! sess loadSessionImpl :: SessionCfg sess -> SV.SessionVault (Session conn sess st) -> SessionId -> IO (Maybe (Session conn sess st)) loadSessionImpl sessCfg sessionRef sid = do mSess <- atomically $ SV.loadSession sid sessionRef now <- getCurrentTime case mSess of Just sess -> do sessWithPossibleExpansion <- if sc_sessionExpandTTL sessCfg then do let expandedSession = sess { sess_validUntil = addUTCTime (sc_sessionTTL sessCfg) now } atomically $ SV.storeSession expandedSession sessionRef return expandedSession else return sess if sess_validUntil sessWithPossibleExpansion > now then return $ Just sessWithPossibleExpansion else do deleteSessionImpl sessionRef sid return Nothing Nothing -> return Nothing deleteSessionImpl :: SV.SessionVault (Session conn sess st) -> SessionId -> IO () deleteSessionImpl sessionRef sid = atomically $ SV.deleteSession sid sessionRef clearAllSessionsImpl :: SV.SessionVault (Session conn sess st) -> SpockActionCtx ctx conn sess st () clearAllSessionsImpl sessionRef = liftIO $ atomically $ SV.filterSessions (const False) sessionRef mapAllSessionsImpl :: SV.SessionVault (Session conn sess st) -> (sess -> STM sess) -> SpockActionCtx ctx conn sess st () mapAllSessionsImpl sessionRef f = liftIO $ atomically $ flip SV.mapSessions sessionRef $ \sess -> do newData <- f (sess_data sess) return $ sess { sess_data = newData } housekeepSessions :: SessionCfg sess -> SV.SessionVault (Session conn sess st) -> (HM.HashMap SessionId (Session conn sess st) -> IO ()) -> IO () housekeepSessions cfg sessionRef storeSessions = do now <- getCurrentTime (newStatus, oldStatus) <- atomically $ do oldSt <- SV.toList sessionRef SV.filterSessions (\sess -> sess_validUntil sess > now) sessionRef (,) <$> SV.toList sessionRef <*> pure oldSt let packSessionHm = HM.fromList . map (\v -> (SV.getSessionKey v, v)) oldHm = packSessionHm oldStatus newHm = packSessionHm newStatus storeSessions newHm sh_removed (sc_hooks cfg) (HM.map sess_data $ oldHm `HM.difference` newHm) threadDelay (1000 * 1000 * (round $ sc_housekeepingInterval cfg)) createSession :: CR.EntropyPool -> SessionCfg sess -> sess -> IO (Session conn sess st) createSession pool sessCfg content = do sid <- randomHash pool (sc_sessionIdEntropy sessCfg) now <- getCurrentTime let validUntil = addUTCTime (sc_sessionTTL sessCfg) now emptySafeActions = SafeActionStore HM.empty HM.empty return (Session sid validUntil content emptySafeActions) randomHash :: CR.EntropyPool -> Int -> IO T.Text randomHash pool len = do let sys :: CR.SystemRNG sys = CR.cprgCreate pool return $ T.replace "=" "" $ T.replace "/" "_" $ T.replace "+" "-" $ T.decodeUtf8 $ B64.encode $ fst $ CR.cprgGenerateWithEntropy len sys