module Web.Spock.Internal.SessionManager
( createSessionManager
, SessionId, Session(..), SessionManager(..)
)
where
import Web.Spock.Internal.Types
import Web.Spock.Internal.CoreAction
import Web.Spock.Internal.Util
import Control.Arrow (first)
import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad
import Control.Monad.Trans
import Data.List (foldl')
import Data.Time
#if MIN_VERSION_time(1,5,0)
#else
import System.Locale (defaultTimeLocale)
#endif
import System.Random
import qualified Data.ByteString.Base64.URL as B64
import qualified Data.ByteString.Char8 as BSC
import qualified Data.ByteString.Lazy as BSL
import qualified Data.HashMap.Strict as HM
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Encoding as TL
import qualified Data.Vault.Lazy as V
import qualified Network.Wai as Wai
createSessionManager :: SessionCfg sess -> IO (SessionManager conn sess st)
createSessionManager cfg =
do oldSess <- loadSessions
cacheHM <- atomically $ newTVar oldSess
vaultKey <- V.newKey
_ <- forkIO (forever (housekeepSessions cacheHM storeSessions))
return
SessionManager
{ sm_getSessionId = getSessionIdImpl vaultKey cacheHM
, sm_readSession = readSessionImpl vaultKey cacheHM
, sm_writeSession = writeSessionImpl vaultKey cacheHM
, sm_modifySession = modifySessionImpl vaultKey cacheHM
, sm_clearAllSessions = clearAllSessionsImpl cacheHM
, sm_middleware = sessionMiddleware cfg vaultKey cacheHM
, sm_addSafeAction = addSafeActionImpl vaultKey cacheHM
, sm_lookupSafeAction = lookupSafeActionImpl vaultKey cacheHM
, sm_removeSafeAction = removeSafeActionImpl vaultKey cacheHM
}
where
(loadSessions, storeSessions) =
case sc_persistCfg cfg of
Nothing ->
( return HM.empty
, const $ return ()
)
Just spc ->
( do sessions <- spc_load spc
return $ foldl' genSession HM.empty sessions
, spc_store spc . map mkSerializable . HM.elems
)
mkSerializable sess =
(sess_id sess, sess_validUntil sess, sess_data sess)
genSession hm (sid, validUntil, theData) =
let s =
Session
{ sess_id = sid
, sess_validUntil = validUntil
, sess_data = theData
, sess_safeActions = SafeActionStore HM.empty HM.empty
}
in HM.insert sid s hm
getSessionIdImpl :: V.Key SessionId
-> UserSessions conn sess st
-> SpockAction conn sess st SessionId
getSessionIdImpl vK sessionRef =
do sess <- readSessionBase vK sessionRef
return $ sess_id sess
modifySessionBase :: V.Key SessionId
-> UserSessions conn sess st
-> (Session conn sess st -> Session conn sess st)
-> SpockAction conn sess st ()
modifySessionBase vK sessionRef modFun =
do req <- request
case V.lookup vK (Wai.vault req) of
Nothing ->
error "(3) Internal Spock Session Error. Please report this bug!"
Just sid ->
liftIO $ atomically $ modifyTVar' sessionRef (HM.adjust modFun sid)
readSessionBase :: V.Key SessionId
-> UserSessions conn sess st
-> SpockAction conn sess st (Session conn sess st)
readSessionBase vK sessionRef =
do req <- request
case V.lookup vK (Wai.vault req) of
Nothing ->
error "(1) Internal Spock Session Error. Please report this bug!"
Just sid ->
do sessions <- liftIO $ atomically $ readTVar sessionRef
case HM.lookup sid sessions of
Nothing ->
error "(2) Internal Spock Session Error. Please report this bug!"
Just session ->
return session
addSafeActionImpl :: V.Key SessionId
-> UserSessions conn sess st
-> PackedSafeAction conn sess st
-> SpockAction conn sess st SafeActionHash
addSafeActionImpl vaultKey sessionMapVar safeAction =
do base <- readSessionBase vaultKey sessionMapVar
case HM.lookup safeAction (sas_reverse (sess_safeActions base)) of
Just safeActionHash ->
return safeActionHash
Nothing ->
do safeActionHash <- liftIO (randomHash 40)
let f sas =
sas
{ sas_forward = HM.insert safeActionHash safeAction (sas_forward sas)
, sas_reverse = HM.insert safeAction safeActionHash (sas_reverse sas)
}
modifySessionBase vaultKey sessionMapVar (\s -> s { sess_safeActions = f (sess_safeActions s) })
return safeActionHash
lookupSafeActionImpl :: V.Key SessionId
-> UserSessions conn sess st
-> SafeActionHash
-> SpockAction conn sess st (Maybe (PackedSafeAction conn sess st))
lookupSafeActionImpl vaultKey sessionMapVar hash =
do base <- readSessionBase vaultKey sessionMapVar
return $ HM.lookup hash (sas_forward (sess_safeActions base))
removeSafeActionImpl :: V.Key SessionId
-> UserSessions conn sess st
-> PackedSafeAction conn sess st
-> SpockAction conn sess st ()
removeSafeActionImpl vaultKey sessionMapVar action =
modifySessionBase vaultKey sessionMapVar (\s -> s { sess_safeActions = f (sess_safeActions s ) })
where
f sas =
sas
{ sas_forward =
case HM.lookup action (sas_reverse sas) of
Just h -> HM.delete h (sas_forward sas)
Nothing -> sas_forward sas
, sas_reverse = HM.delete action (sas_reverse sas)
}
readSessionImpl :: V.Key SessionId
-> UserSessions conn sess st
-> SpockAction conn sess st sess
readSessionImpl vK sessionRef =
do base <- readSessionBase vK sessionRef
return (sess_data base)
writeSessionImpl :: V.Key SessionId
-> UserSessions conn sess st
-> sess
-> SpockAction conn sess st ()
writeSessionImpl vK sessionRef value =
modifySessionImpl vK sessionRef (const value)
modifySessionImpl :: V.Key SessionId
-> UserSessions conn sess st
-> (sess -> sess)
-> SpockAction conn sess st ()
modifySessionImpl vK sessionRef f =
do let modFun session =
session { sess_data = f (sess_data session) }
modifySessionBase vK sessionRef modFun
sessionMiddleware :: SessionCfg sess
-> V.Key SessionId
-> UserSessions conn sess st
-> Wai.Middleware
sessionMiddleware cfg vK sessionRef app req respond =
case getCookieFromReq (sc_cookieName cfg) of
Just sid ->
do mSess <- loadSessionImpl cfg sessionRef sid
case mSess of
Nothing ->
mkNew
Just sess ->
withSess False sess
Nothing ->
mkNew
where
getCookieFromReq name =
lookup "cookie" (Wai.requestHeaders req) >>=
lookup name . parseCookies . T.decodeUtf8
renderCookie name value validUntil =
let formattedTime =
TL.pack $ formatTime defaultTimeLocale "%a, %d-%b-%Y %X %Z" validUntil
in TL.concat [ TL.fromStrict name
, "="
, TL.fromStrict value
, "; path=/; expires="
, formattedTime
, ";"
]
parseCookies :: T.Text -> [(T.Text, T.Text)]
parseCookies = map parseCookie . T.splitOn ";" . T.concat . T.words
parseCookie = first T.init . T.breakOnEnd "="
defVal = sc_emptySession cfg
v = Wai.vault req
addCookie sess responseHeaders =
let cookieContent =
renderCookie (sc_cookieName cfg) (sess_id sess) (sess_validUntil sess)
cookieC = ("Set-Cookie", BSL.toStrict $ TL.encodeUtf8 cookieContent)
in (cookieC : responseHeaders)
withSess shouldSetCookie sess =
app (req { Wai.vault = V.insert vK (sess_id sess) v }) $ \unwrappedResp ->
respond $
if shouldSetCookie
then mapReqHeaders (addCookie sess) unwrappedResp
else unwrappedResp
mkNew =
do newSess <- newSessionImpl cfg sessionRef defVal
withSess True newSess
newSessionImpl :: SessionCfg sess
-> UserSessions conn sess st
-> sess
-> IO (Session conn sess st)
newSessionImpl sessCfg sessionRef content =
do sess <- createSession sessCfg content
atomically $ modifyTVar' sessionRef (HM.insert (sess_id sess) sess)
return $! sess
loadSessionImpl :: SessionCfg sess
-> UserSessions conn sess st
-> SessionId
-> IO (Maybe (Session conn sess st))
loadSessionImpl sessCfg sessionRef sid =
do sessHM <- atomically $ readTVar sessionRef
now <- getCurrentTime
case HM.lookup sid sessHM of
Just sess ->
do sessWithPossibleExpansion <-
if sc_sessionExpandTTL sessCfg
then do let expandedSession =
sess
{ sess_validUntil =
addUTCTime (sc_sessionTTL sessCfg) now
}
atomically $ modifyTVar' sessionRef (HM.insert sid expandedSession)
return expandedSession
else return sess
if sess_validUntil sessWithPossibleExpansion > now
then return $ Just sessWithPossibleExpansion
else do deleteSessionImpl sessionRef sid
return Nothing
Nothing ->
return Nothing
deleteSessionImpl :: UserSessions conn sess st
-> SessionId
-> IO ()
deleteSessionImpl sessionRef sid =
do atomically $ modifyTVar' sessionRef (HM.delete sid)
return ()
clearAllSessionsImpl :: UserSessions conn sess st
-> SpockAction conn sess st ()
clearAllSessionsImpl sessionRef =
liftIO $ atomically $ modifyTVar' sessionRef (const HM.empty)
housekeepSessions :: UserSessions conn sess st
-> (HM.HashMap SessionId (Session conn sess st) -> IO ())
-> IO ()
housekeepSessions sessionRef storeSessions =
do now <- getCurrentTime
newStatus <-
atomically $
do modifyTVar' sessionRef (killOld now)
readTVar sessionRef
storeSessions newStatus
threadDelay (1000 * 1000 * 60)
where
filterOld now (_, sess) = sess_validUntil sess > now
killOld now hm =
HM.fromList $ filter (filterOld now) $ HM.toList hm
createSession :: SessionCfg sess -> sess -> IO (Session conn sess st)
createSession sessCfg content =
do sid <- randomHash (sc_sessionIdEntropy sessCfg)
now <- getCurrentTime
let validUntil = addUTCTime (sc_sessionTTL sessCfg) now
emptySafeActions =
SafeActionStore HM.empty HM.empty
return (Session sid validUntil content emptySafeActions)
randomHash :: Int -> IO T.Text
randomHash len =
do gen <- g
return $ T.replace "=" "" $ T.decodeUtf8 $ B64.encode $ BSC.pack $
take len $ randoms gen
where
g = newStdGen :: IO StdGen