module Web.ServerSession.Backend.Redis.Internal
( RedisStorage(..)
, RedisStorageException(..)
, transaction
, unwrap
, rSessionKey
, rAuthKey
, RedisSession(..)
, parseSession
, printSession
, parseUTCTime
, printUTCTime
, timeFormat
, getSessionImpl
, deleteSessionImpl
, removeSessionFromAuthId
, insertSessionForAuthId
, deleteAllSessionsOfAuthIdImpl
, insertSessionImpl
, replaceSessionImpl
, throwRS
) where
import Control.Applicative as A
import Control.Arrow (first)
import Control.Monad (void, when)
import Control.Monad.IO.Class (liftIO)
import Data.ByteString (ByteString)
import Data.List (partition)
import Data.Maybe (fromMaybe, catMaybes)
import Data.Proxy (Proxy(..))
import Data.Typeable (Typeable)
import Web.PathPieces (toPathPiece)
import Web.ServerSession.Core
import qualified Control.Exception as E
import qualified Database.Redis as R
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.HashMap.Strict as HM
import qualified Data.Text.Encoding as TE
import qualified Data.Time.Clock as TI
import qualified Data.Time.Clock.POSIX as TP
import qualified Data.Time.Format as TI
#if MIN_VERSION_time(1,5,0)
import Data.Time.Format (defaultTimeLocale)
#else
import System.Locale (defaultTimeLocale)
#endif
data RedisStorage sess =
RedisStorage
{ connPool :: R.Connection
, idleTimeout :: Maybe TI.NominalDiffTime
, absoluteTimeout :: Maybe TI.NominalDiffTime
} deriving (Typeable)
instance RedisSession sess => Storage (RedisStorage sess) where
type SessionData (RedisStorage sess) = sess
type TransactionM (RedisStorage sess) = R.Redis
runTransactionM = R.runRedis . connPool
getSession _ = getSessionImpl
deleteSession _ = deleteSessionImpl
deleteAllSessionsOfAuthId _ = deleteAllSessionsOfAuthIdImpl
insertSession = insertSessionImpl
replaceSession = replaceSessionImpl
data RedisStorageException =
ExpectedTxSuccess (R.TxResult ())
| ExpectedRight R.Reply
deriving (Show, Typeable)
instance E.Exception RedisStorageException
transaction :: R.RedisTx (R.Queued ()) -> R.Redis ()
transaction tx = do
ret <- R.multiExec tx
case ret of
R.TxSuccess () -> return ()
_ -> liftIO $ E.throwIO $ ExpectedTxSuccess ret
unwrap :: R.Redis (Either R.Reply a) -> R.Redis a
unwrap act = act >>= either (liftIO . E.throwIO . ExpectedRight) return
rSessionKey :: SessionId sess -> ByteString
rSessionKey = B.append "ssr:session:" . TE.encodeUtf8 . toPathPiece
rAuthKey :: AuthId -> ByteString
rAuthKey = B.append "ssr:authid:"
class IsSessionData sess => RedisSession sess where
toHash :: Proxy sess -> Decomposed sess -> [(ByteString, ByteString)]
fromHash :: Proxy sess -> [(ByteString, ByteString)] -> Decomposed sess
instance RedisSession SessionMap where
toHash _ = map (first TE.encodeUtf8) . HM.toList . unSessionMap
fromHash _ = SessionMap . HM.fromList . map (first TE.decodeUtf8)
parseSession
:: forall sess. RedisSession sess
=> SessionId sess
-> [(ByteString, ByteString)]
-> Maybe (Session sess)
parseSession _ [] = Nothing
parseSession sid bss =
let (externalList, internalList) = partition (B8.isPrefixOf "data:" . fst) bss
authId = lookup "internal:authId" internalList
createdAt = parseUTCTime $ lookup' "internal:createdAt"
accessedAt = parseUTCTime $ lookup' "internal:accessedAt"
lookup' k = fromMaybe (error err) $ lookup k internalList
where err = "serversession-backend-redis/parseSession: missing key " ++ show k
data_ = fromHash p $ map (first removePrefix) externalList
where removePrefix bs = let ("data:", key) = B8.splitAt 5 bs in key
p = Proxy :: Proxy sess
in Just Session
{ sessionKey = sid
, sessionAuthId = authId
, sessionData = data_
, sessionCreatedAt = createdAt
, sessionAccessedAt = accessedAt
}
printSession :: forall sess. RedisSession sess => Session sess -> [(ByteString, ByteString)]
printSession Session {..} =
maybe id ((:) . (,) "internal:authId") sessionAuthId $
(:) ("internal:createdAt", printUTCTime sessionCreatedAt) $
(:) ("internal:accessedAt", printUTCTime sessionAccessedAt) $
map (first $ B8.append "data:") $
toHash (Proxy :: Proxy sess) sessionData
parseUTCTime :: ByteString -> TI.UTCTime
#if MIN_VERSION_time(1,5,0)
parseUTCTime = TI.parseTimeOrError True defaultTimeLocale timeFormat . B8.unpack
#else
parseUTCTime =
fromMaybe (error "Web.ServerSession.Backend.Redis.Internal.parseUTCTime") .
TI.parseTime defaultTimeLocale timeFormat . B8.unpack
#endif
printUTCTime :: TI.UTCTime -> ByteString
printUTCTime = B8.pack . TI.formatTime defaultTimeLocale timeFormat
timeFormat :: String
timeFormat = "%Y-%m-%dT%H:%M:%S%Q"
batched :: Monad m => ([a] -> m b) -> [a] -> m b
batched f xs =
let (this, rest) = splitAt (511*1024) xs
continue | null rest = return
| otherwise = const (batched f rest)
in f this >>= continue
getSessionImpl :: RedisSession sess => SessionId sess -> R.Redis (Maybe (Session sess))
getSessionImpl sid = parseSession sid A.<$> unwrap (R.hgetall $ rSessionKey sid)
deleteSessionImpl :: RedisSession sess => SessionId sess -> R.Redis ()
deleteSessionImpl sid = do
msession <- getSessionImpl sid
case msession of
Nothing -> return ()
Just session ->
transaction $ do
r <- R.del [rSessionKey sid]
removeSessionFromAuthId sid (sessionAuthId session)
return (() <$ r)
removeSessionFromAuthId :: (R.RedisCtx m f, Functor m) => SessionId sess -> Maybe AuthId -> m ()
removeSessionFromAuthId = fooSessionBarAuthId R.srem
insertSessionForAuthId :: (R.RedisCtx m f, Functor m) => SessionId sess -> Maybe AuthId -> m ()
insertSessionForAuthId = fooSessionBarAuthId R.sadd
fooSessionBarAuthId
:: (R.RedisCtx m f, Functor m)
=> (ByteString -> [ByteString] -> m (f Integer))
-> SessionId sess
-> Maybe AuthId
-> m ()
fooSessionBarAuthId _ _ Nothing = return ()
fooSessionBarAuthId fun sid (Just authId) = void $ fun (rAuthKey authId) [rSessionKey sid]
deleteAllSessionsOfAuthIdImpl :: AuthId -> R.Redis ()
deleteAllSessionsOfAuthIdImpl authId = do
sessionRefs <- unwrap $ R.smembers (rAuthKey authId)
void $ unwrap $ R.del $ rAuthKey authId : sessionRefs
insertSessionImpl :: RedisSession sess => RedisStorage sess -> Session sess -> R.Redis ()
insertSessionImpl sto session = do
let sid = sessionKey session
moldSession <- getSessionImpl sid
case moldSession of
Just oldSession -> throwRS $ SessionAlreadyExists oldSession session
Nothing -> do
transaction $ do
let sk = rSessionKey sid
r <- batched (R.hmset sk) (printSession session)
expireSession session sto
insertSessionForAuthId (sessionKey session) (sessionAuthId session)
return (() <$ r)
replaceSessionImpl :: RedisSession sess => RedisStorage sess -> Session sess -> R.Redis ()
replaceSessionImpl sto session = do
let sid = sessionKey session
moldSession <- getSessionImpl sid
case moldSession of
Nothing -> throwRS $ SessionDoesNotExist session
Just oldSession -> do
transaction $ do
let sk = rSessionKey sid
_ <- R.del [sk]
r <- batched (R.hmset sk) (printSession session)
expireSession session sto
let oldAuthId = sessionAuthId oldSession
newAuthId = sessionAuthId session
when (oldAuthId /= newAuthId) $ do
removeSessionFromAuthId sid oldAuthId
insertSessionForAuthId sid newAuthId
return (() <$ r)
throwRS
:: Storage (RedisStorage sess)
=> StorageException (RedisStorage sess)
-> R.Redis a
throwRS = liftIO . E.throwIO
expireSession :: Session sess -> RedisStorage sess -> R.RedisTx ()
expireSession Session {..} RedisStorage {..} =
case minimum' (catMaybes [viaIdle, viaAbsolute]) of
Nothing -> return ()
Just t -> let ts = round (TP.utcTimeToPOSIXSeconds t)
in void (R.expireat sk ts)
where
sk = rSessionKey sessionKey
minimum' [] = Nothing
minimum' xs = Just (minimum xs)
viaIdle = flip TI.addUTCTime sessionAccessedAt <$> idleTimeout
viaAbsolute = flip TI.addUTCTime sessionCreatedAt <$> absoluteTimeout