module Web.Spock.Internal.SessionManager
( createSessionManager, withSessionManager
, SessionId, Session(..), SessionManager(..)
)
where
import Web.Spock.Internal.Types
import Web.Spock.Internal.CoreAction
import Web.Spock.Internal.Util
import qualified Web.Spock.Internal.SessionVault as SV
import Control.Arrow (first)
import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import Control.Monad.Trans
import Data.Time
#if MIN_VERSION_time(1,5,0)
#else
import System.Locale (defaultTimeLocale)
#endif
import System.Random
import qualified Data.ByteString.Base64.URL as B64
import qualified Data.ByteString.Char8 as BSC
import qualified Data.ByteString.Lazy as BSL
import qualified Data.HashMap.Strict as HM
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Encoding as TL
import qualified Data.Vault.Lazy as V
import qualified Network.Wai as Wai
withSessionManager :: SessionCfg sess -> (SessionManager conn sess st -> IO a) -> IO a
withSessionManager sessCfg =
bracket (createSessionManager sessCfg) sm_closeSessionManager
createSessionManager :: SessionCfg sess -> IO (SessionManager conn sess st)
createSessionManager cfg =
do oldSess <- loadSessions
cacheHM <-
atomically $
do mapV <- SV.newSessionVault
forM_ oldSess $ \v -> SV.storeSession v mapV
return mapV
vaultKey <- V.newKey
housekeepThread <- forkIO (forever (housekeepSessions cfg cacheHM storeSessions))
return
SessionManager
{ sm_getSessionId = getSessionIdImpl vaultKey cacheHM
, sm_readSession = readSessionImpl vaultKey cacheHM
, sm_writeSession = writeSessionImpl vaultKey cacheHM
, sm_modifySession = modifySessionImpl vaultKey cacheHM
, sm_clearAllSessions = clearAllSessionsImpl cacheHM
, sm_middleware = sessionMiddleware cfg vaultKey cacheHM
, sm_addSafeAction = addSafeActionImpl vaultKey cacheHM
, sm_lookupSafeAction = lookupSafeActionImpl vaultKey cacheHM
, sm_removeSafeAction = removeSafeActionImpl vaultKey cacheHM
, sm_closeSessionManager = killThread housekeepThread
}
where
(loadSessions, storeSessions) =
case sc_persistCfg cfg of
Nothing ->
( return []
, const $ return ()
)
Just spc ->
( do sessions <- spc_load spc
return (map genSession sessions)
, spc_store spc . map mkSerializable . HM.elems
)
mkSerializable sess =
(sess_id sess, sess_validUntil sess, sess_data sess)
genSession (sid, validUntil, theData) =
Session
{ sess_id = sid
, sess_validUntil = validUntil
, sess_data = theData
, sess_safeActions = SafeActionStore HM.empty HM.empty
}
getSessionIdImpl :: V.Key SessionId
-> SV.SessionVault (Session conn sess st)
-> SpockAction conn sess st SessionId
getSessionIdImpl vK sessionRef =
do sess <- readSessionBase vK sessionRef
return $ sess_id sess
modifySessionBase :: V.Key SessionId
-> SV.SessionVault (Session conn sess st)
-> (Session conn sess st -> (Session conn sess st, a))
-> SpockAction conn sess st a
modifySessionBase vK sessionRef modFun =
do req <- request
case V.lookup vK (Wai.vault req) of
Nothing ->
error "(3) Internal Spock Session Error. Please report this bug!"
Just sid ->
liftIO $ atomically $
do mSession <- SV.loadSession sid sessionRef
case mSession of
Nothing ->
fail "Internal Spock Session Error: Unknown SessionId"
Just session ->
do let (sessionNew, result) = modFun session
SV.storeSession sessionNew sessionRef
return result
readSessionBase :: V.Key SessionId
-> SV.SessionVault (Session conn sess st)
-> SpockAction conn sess st (Session conn sess st)
readSessionBase vK sessionRef =
do req <- request
case V.lookup vK (Wai.vault req) of
Nothing ->
error "(1) Internal Spock Session Error. Please report this bug!"
Just sid ->
do mSession <- liftIO $ atomically $ SV.loadSession sid sessionRef
case mSession of
Nothing ->
error "(2) Internal Spock Session Error. Please report this bug!"
Just session ->
return session
addSafeActionImpl :: V.Key SessionId
-> SV.SessionVault (Session conn sess st)
-> PackedSafeAction conn sess st
-> SpockAction conn sess st SafeActionHash
addSafeActionImpl vaultKey sessionMapVar safeAction =
do base <- readSessionBase vaultKey sessionMapVar
case HM.lookup safeAction (sas_reverse (sess_safeActions base)) of
Just safeActionHash ->
return safeActionHash
Nothing ->
do safeActionHash <- liftIO (randomHash 40)
let f sas =
sas
{ sas_forward = HM.insert safeActionHash safeAction (sas_forward sas)
, sas_reverse = HM.insert safeAction safeActionHash (sas_reverse sas)
}
modifySessionBase vaultKey sessionMapVar (\s -> (s { sess_safeActions = f (sess_safeActions s) }, ()))
return safeActionHash
lookupSafeActionImpl :: V.Key SessionId
-> SV.SessionVault (Session conn sess st)
-> SafeActionHash
-> SpockAction conn sess st (Maybe (PackedSafeAction conn sess st))
lookupSafeActionImpl vaultKey sessionMapVar hash =
do base <- readSessionBase vaultKey sessionMapVar
return $ HM.lookup hash (sas_forward (sess_safeActions base))
removeSafeActionImpl :: V.Key SessionId
-> SV.SessionVault (Session conn sess st)
-> PackedSafeAction conn sess st
-> SpockAction conn sess st ()
removeSafeActionImpl vaultKey sessionMapVar action =
modifySessionBase vaultKey sessionMapVar (\s -> (s { sess_safeActions = f (sess_safeActions s ) }, ()))
where
f sas =
sas
{ sas_forward =
case HM.lookup action (sas_reverse sas) of
Just h -> HM.delete h (sas_forward sas)
Nothing -> sas_forward sas
, sas_reverse = HM.delete action (sas_reverse sas)
}
readSessionImpl :: V.Key SessionId
-> SV.SessionVault (Session conn sess st)
-> SpockAction conn sess st sess
readSessionImpl vK sessionRef =
do base <- readSessionBase vK sessionRef
return (sess_data base)
writeSessionImpl :: V.Key SessionId
-> SV.SessionVault (Session conn sess st)
-> sess
-> SpockAction conn sess st ()
writeSessionImpl vK sessionRef value =
modifySessionImpl vK sessionRef (const (value, ()))
modifySessionImpl :: V.Key SessionId
-> SV.SessionVault (Session conn sess st)
-> (sess -> (sess, a))
-> SpockAction conn sess st a
modifySessionImpl vK sessionRef f =
do let modFun session =
let (sessData', out) = f (sess_data session)
in (session { sess_data = sessData' }, out)
modifySessionBase vK sessionRef modFun
sessionMiddleware :: SessionCfg sess
-> V.Key SessionId
-> SV.SessionVault (Session conn sess st)
-> Wai.Middleware
sessionMiddleware cfg vK sessionRef app req respond =
case getCookieFromReq (sc_cookieName cfg) of
Just sid ->
do mSess <- loadSessionImpl cfg sessionRef sid
case mSess of
Nothing ->
mkNew
Just sess ->
withSess False sess
Nothing ->
mkNew
where
getCookieFromReq name =
lookup "cookie" (Wai.requestHeaders req) >>=
lookup name . parseCookies . T.decodeUtf8
renderCookie name value validUntil =
let formattedTime =
TL.pack $ formatTime defaultTimeLocale "%a, %d-%b-%Y %X %Z" validUntil
in TL.concat [ TL.fromStrict name
, "="
, TL.fromStrict value
, "; path=/; expires="
, formattedTime
, ";"
]
parseCookies :: T.Text -> [(T.Text, T.Text)]
parseCookies = map parseCookie . T.splitOn ";" . T.concat . T.words
parseCookie = first T.init . T.breakOnEnd "="
defVal = sc_emptySession cfg
v = Wai.vault req
addCookie sess responseHeaders =
let cookieContent =
renderCookie (sc_cookieName cfg) (sess_id sess) (sess_validUntil sess)
cookieC = ("Set-Cookie", BSL.toStrict $ TL.encodeUtf8 cookieContent)
in (cookieC : responseHeaders)
withSess shouldSetCookie sess =
app (req { Wai.vault = V.insert vK (sess_id sess) v }) $ \unwrappedResp ->
respond $
if shouldSetCookie
then mapReqHeaders (addCookie sess) unwrappedResp
else unwrappedResp
mkNew =
do newSess <- newSessionImpl cfg sessionRef defVal
withSess True newSess
newSessionImpl :: SessionCfg sess
-> SV.SessionVault (Session conn sess st)
-> sess
-> IO (Session conn sess st)
newSessionImpl sessCfg sessionRef content =
do sess <- createSession sessCfg content
atomically $ SV.storeSession sess sessionRef
return $! sess
loadSessionImpl :: SessionCfg sess
-> SV.SessionVault (Session conn sess st)
-> SessionId
-> IO (Maybe (Session conn sess st))
loadSessionImpl sessCfg sessionRef sid =
do mSess <- atomically $ SV.loadSession sid sessionRef
now <- getCurrentTime
case mSess of
Just sess ->
do sessWithPossibleExpansion <-
if sc_sessionExpandTTL sessCfg
then do let expandedSession =
sess
{ sess_validUntil =
addUTCTime (sc_sessionTTL sessCfg) now
}
atomically $ SV.storeSession expandedSession sessionRef
return expandedSession
else return sess
if sess_validUntil sessWithPossibleExpansion > now
then return $ Just sessWithPossibleExpansion
else do deleteSessionImpl sessionRef sid
return Nothing
Nothing ->
return Nothing
deleteSessionImpl :: SV.SessionVault (Session conn sess st)
-> SessionId
-> IO ()
deleteSessionImpl sessionRef sid =
atomically $ SV.deleteSession sid sessionRef
clearAllSessionsImpl :: SV.SessionVault (Session conn sess st)
-> SpockAction conn sess st ()
clearAllSessionsImpl sessionRef =
liftIO $ atomically $ SV.filterSessions (const False) sessionRef
housekeepSessions :: SessionCfg sess
-> SV.SessionVault (Session conn sess st)
-> (HM.HashMap SessionId (Session conn sess st) -> IO ())
-> IO ()
housekeepSessions cfg sessionRef storeSessions =
do now <- getCurrentTime
newStatus <-
atomically $
do SV.filterSessions (\sess -> sess_validUntil sess > now) sessionRef
SV.toList sessionRef
storeSessions (HM.fromList $ map (\v -> (SV.getSessionKey v, v)) newStatus)
threadDelay (1000 * 1000 * (round $ sc_housekeepingInterval cfg))
createSession :: SessionCfg sess -> sess -> IO (Session conn sess st)
createSession sessCfg content =
do sid <- randomHash (sc_sessionIdEntropy sessCfg)
now <- getCurrentTime
let validUntil = addUTCTime (sc_sessionTTL sessCfg) now
emptySafeActions =
SafeActionStore HM.empty HM.empty
return (Session sid validUntil content emptySafeActions)
randomHash :: Int -> IO T.Text
randomHash len =
do gen <- g
return $ T.replace "=" "" $ T.decodeUtf8 $ B64.encode $ BSC.pack $
take len $ randoms gen
where
g = newStdGen :: IO StdGen