{-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts, OverloadedStrings, DoAndIfThenElse, RankNTypes #-} {-# LANGUAGE GADTs #-} module Web.Spock.Internal.SessionManager ( createSessionManager, withSessionManager , SessionId, Session(..), SessionManager(..) , SessionIf(..) ) where import Web.Spock.Core import Web.Spock.Internal.Types import Web.Spock.Internal.Util import Web.Spock.Internal.Cookies #if MIN_VERSION_base(4,8,0) #else import Control.Applicative #endif import Control.Concurrent import Control.Exception import Control.Monad import Control.Monad.Trans import Data.Time import qualified Crypto.Random as CR import qualified Data.ByteString as BS import qualified Data.ByteString.Base64 as B64 import qualified Data.HashMap.Strict as HM import qualified Data.Traversable as T import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Vault.Lazy as V import qualified Network.Wai as Wai data SessionIf m = SessionIf { si_queryVault :: forall a. V.Key a -> m (Maybe a) , si_modifyVault :: (V.Vault -> V.Vault) -> m () , si_setRawMultiHeader :: MultiHeader -> BS.ByteString -> m () , si_vaultKey :: IO (V.Key SessionId) } withSessionManager :: MonadIO m => SessionCfg conn sess st -> SessionIf m -> (SessionManager m conn sess st -> IO a) -> IO a withSessionManager sessCfg sif = bracket (createSessionManager sessCfg sif) sm_closeSessionManager createSessionManager :: MonadIO m => SessionCfg conn sess st -> SessionIf m -> IO (SessionManager m conn sess st) createSessionManager cfg sif = do vaultKey <- si_vaultKey sif housekeepThread <- forkIO (forever (housekeepSessions cfg)) return SessionManager { sm_getSessionId = getSessionIdImpl vaultKey cfg sif , sm_getCsrfToken = getCsrfTokenImpl vaultKey cfg sif , sm_regenerateSessionId = regenerateSessionIdImpl vaultKey store cfg sif , sm_readSession = readSessionImpl vaultKey cfg sif , sm_writeSession = writeSessionImpl vaultKey store cfg sif , sm_modifySession = modifySessionImpl vaultKey store cfg sif , sm_clearAllSessions = clearAllSessionsImpl store , sm_mapSessions = mapAllSessionsImpl store , sm_middleware = sessionMiddleware cfg vaultKey , sm_closeSessionManager = killThread housekeepThread } where store = sc_store cfg regenerateSessionIdImpl :: MonadIO m => V.Key SessionId -> SessionStoreInstance (Session conn sess st) -> SessionCfg conn sess st -> SessionIf m -> m () regenerateSessionIdImpl vK sessionRef cfg sif = do sess <- readSessionBase vK cfg sif liftIO $ deleteSessionImpl sessionRef (sess_id sess) newSession <- liftIO $ newSessionImpl cfg sessionRef (sess_data sess) now <- liftIO getCurrentTime si_setRawMultiHeader sif MultiHeaderSetCookie (makeSessionIdCookie cfg newSession now) si_modifyVault sif $ V.insert vK (sess_id newSession) getSessionIdImpl :: MonadIO m => V.Key SessionId -> SessionCfg conn sess st -> SessionIf m -> m SessionId getSessionIdImpl vK cfg sif = do sess <- readSessionBase vK cfg sif return $ sess_id sess getCsrfTokenImpl :: ( MonadIO m ) => V.Key SessionId -> SessionCfg conn sess st -> SessionIf m -> m T.Text getCsrfTokenImpl vK cfg sif = do sess <- readSessionBase vK cfg sif return $ sess_csrfToken sess modifySessionBase :: MonadIO m => V.Key SessionId -> SessionStoreInstance (Session conn sess st) -> SessionCfg conn sess st -> SessionIf m -> (Session conn sess st -> (Session conn sess st, a)) -> m a modifySessionBase vK (SessionStoreInstance sessionRef) cfg sif modFun = do mValue <- si_queryVault sif vK case mValue of Nothing -> error "(3) Internal Spock Session Error. Please report this bug!" Just sid -> do session <- readOrNewSession cfg vK sif (Just sid) let (sessionNew, result) = modFun session liftIO $ ss_runTx sessionRef $ ss_storeSession sessionRef sessionNew return result readSessionBase :: MonadIO m => V.Key SessionId -> SessionCfg conn sess st -> SessionIf m -> m (Session conn sess st) readSessionBase vK cfg sif = do mValue <- si_queryVault sif vK case mValue of Nothing -> error "(1) Internal Spock Session Error. Please report this bug!" Just sid -> readOrNewSession cfg vK sif (Just sid) readSessionImpl :: MonadIO m => V.Key SessionId -> SessionCfg conn sess st -> SessionIf m -> m sess readSessionImpl vK cfg sif = do base <- readSessionBase vK cfg sif return (sess_data base) writeSessionImpl :: MonadIO m => V.Key SessionId -> SessionStoreInstance (Session conn sess st) -> SessionCfg conn sess st -> SessionIf m -> sess -> m () writeSessionImpl vK sessionRef cfg sif value = modifySessionImpl vK sessionRef cfg sif (const (value, ())) modifySessionImpl :: MonadIO m => V.Key SessionId -> SessionStoreInstance (Session conn sess st) -> SessionCfg conn sess st -> SessionIf m -> (sess -> (sess, a)) -> m a modifySessionImpl vK sessionRef cfg sif f = do let modFun session = let (sessData', out) = f (sess_data session) in (session { sess_data = sessData' }, out) modifySessionBase vK sessionRef cfg sif modFun makeSessionIdCookie :: SessionCfg conn sess st -> Session conn sess st -> UTCTime -> BS.ByteString makeSessionIdCookie cfg sess now = generateCookieHeaderString name value settings now where name = sc_cookieName cfg value = sess_id sess settings = defaultCookieSettings { cs_EOL = CookieValidForever , cs_HTTPOnly = True } readOrNewSession :: MonadIO m => SessionCfg conn sess st -> V.Key SessionId -> SessionIf m -> Maybe SessionId -> m (Session conn sess st) readOrNewSession cfg vK sif mSid = do (sess, write) <- loadOrSpanSession cfg mSid when write $ do now <- liftIO getCurrentTime si_setRawMultiHeader sif MultiHeaderSetCookie (makeSessionIdCookie cfg sess now) si_modifyVault sif $ V.insert vK (sess_id sess) return sess loadOrSpanSession :: MonadIO m => SessionCfg conn sess st -> Maybe SessionId -> m (Session conn sess st, Bool) loadOrSpanSession cfg mSid = do mSess <- liftIO $ join <$> T.mapM (loadSessionImpl cfg sessionRef) mSid case mSess of Nothing -> do newSess <- liftIO $ newSessionImpl cfg sessionRef (sc_emptySession cfg) return (newSess, True) Just s -> return (s, False) where sessionRef = sc_store cfg sessionMiddleware :: SessionCfg conn sess st -> V.Key SessionId -> Wai.Middleware sessionMiddleware cfg vK app req respond = go $ getCookieFromReq (sc_cookieName cfg) where go mSid = do (sess, shouldWriteCookie) <- loadOrSpanSession cfg mSid withSess shouldWriteCookie sess getCookieFromReq name = lookup "cookie" (Wai.requestHeaders req) >>= lookup name . parseCookies v = Wai.vault req addCookie sess now responseHeaders = let cookieContent = makeSessionIdCookie cfg sess now cookieC = ("Set-Cookie", cookieContent) in (cookieC : responseHeaders) withSess shouldSetCookie sess = app (req { Wai.vault = V.insert vK (sess_id sess) v }) $ \unwrappedResp -> do now <- getCurrentTime respond $ if shouldSetCookie then mapReqHeaders (addCookie sess now) unwrappedResp else unwrappedResp newSessionImpl :: SessionCfg conn sess st -> SessionStoreInstance (Session conn sess st) -> sess -> IO (Session conn sess st) newSessionImpl sessCfg (SessionStoreInstance sessionRef) content = do sess <- createSession sessCfg content ss_runTx sessionRef $ ss_storeSession sessionRef sess return $! sess loadSessionImpl :: SessionCfg conn sess st -> SessionStoreInstance (Session conn sess st) -> SessionId -> IO (Maybe (Session conn sess st)) loadSessionImpl sessCfg sessionRef@(SessionStoreInstance store) sid = do mSess <- ss_runTx store $ ss_loadSession store sid now <- getCurrentTime case mSess of Just sess -> do sessWithPossibleExpansion <- if sc_sessionExpandTTL sessCfg then do let expandedSession = sess { sess_validUntil = addUTCTime (sc_sessionTTL sessCfg) now } ss_runTx store $ ss_storeSession store expandedSession return expandedSession else return sess if sess_validUntil sessWithPossibleExpansion > now then return $ Just sessWithPossibleExpansion else do deleteSessionImpl sessionRef sid return Nothing Nothing -> return Nothing deleteSessionImpl :: SessionStoreInstance (Session conn sess st) -> SessionId -> IO () deleteSessionImpl (SessionStoreInstance sessionRef) sid = ss_runTx sessionRef $ ss_deleteSession sessionRef sid clearAllSessionsImpl :: MonadIO m => SessionStoreInstance (Session conn sess st) -> m () clearAllSessionsImpl (SessionStoreInstance sessionRef) = liftIO $ ss_runTx sessionRef $ ss_filterSessions sessionRef (const False) mapAllSessionsImpl :: MonadIO m => SessionStoreInstance (Session conn sess st) -> (forall n. Monad n => sess -> n sess) -> m () mapAllSessionsImpl (SessionStoreInstance sessionRef) f = liftIO $ ss_runTx sessionRef $ ss_mapSessions sessionRef $ \sess -> do newData <- f (sess_data sess) return $ sess { sess_data = newData } housekeepSessions :: SessionCfg conn sess st -> IO () housekeepSessions cfg = case sc_store cfg of SessionStoreInstance store -> do now <- getCurrentTime (newStatus, oldStatus) <- ss_runTx store $ do oldSt <- ss_toList store ss_filterSessions store (\sess -> sess_validUntil sess > now) (,) <$> ss_toList store <*> pure oldSt let packSessionHm = HM.fromList . map (\v -> (sess_id v, v)) oldHm = packSessionHm oldStatus newHm = packSessionHm newStatus sh_removed (sc_hooks cfg) (HM.map sess_data $ oldHm `HM.difference` newHm) threadDelay (1000 * 1000 * (round $ sc_housekeepingInterval cfg)) createSession :: SessionCfg conn sess st -> sess -> IO (Session conn sess st) createSession sessCfg content = do sid <- randomHash (sc_sessionIdEntropy sessCfg) csrfToken <- randomHash 12 now <- getCurrentTime let validUntil = addUTCTime (sc_sessionTTL sessCfg) now return (Session sid csrfToken validUntil content) randomHash :: Int -> IO T.Text randomHash len = do by <- CR.getRandomBytes len return $ T.replace "=" "" $ T.replace "/" "_" $ T.replace "+" "-" $ T.decodeUtf8 $ B64.encode by