------------------------------------------------------------------------------ {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE OverloadedStrings #-} module Snap.Snaplet.Session.Backends.RedisSession ( initRedisSessionManager ) where ------------------------------------------------------------------------------ import Control.Monad.Reader import Data.ByteString (ByteString) import Data.HashMap.Strict (HashMap) import qualified Data.HashMap.Strict as HM import Data.Serialize (Serialize) import qualified Data.Serialize as S import Data.Text (Text) import Data.Text.Encoding import Data.Typeable -- import GHC.Generics import Snap.Core (Snap) import Web.ClientSession import Database.Redis ------------------------------------------------------------------------------ import Snap.Snaplet import Snap.Snaplet.RedisDB import Snap.Snaplet.Session import Snap.Snaplet.Session.SessionManager ------------------------------------------------------------------------------- ------------------------------------------------------------------------------ -- | Session data are kept in a 'HashMap' for this backend -- type Session = HashMap Text Text ------------------------------------------------------------------------------ -- | This is what the 'Payload' will be for the RedisSession backend -- | Only the rsCSRFToken is sent to the client. -- | The Session hash is stored in Redis. data RedisSession = RedisSession { rsCSRFToken :: Text , rsSession :: Session } deriving (Eq, Show) ------------------------------------------------------------------------------ --Only serialize the rsCSRFToken to send to the client instance Serialize RedisSession where put (RedisSession a _) = S.put $ encodeUtf8 a get = let unpack a = RedisSession (decodeUtf8 a) HM.empty in unpack <$> S.get encodeTuple :: (Text, Text) -> (ByteString, ByteString) encodeTuple (a,b) = (encodeUtf8 a, encodeUtf8 b) decodeTuple :: (ByteString, ByteString) -> (Text, Text) decodeTuple (a,b) = (decodeUtf8 a, decodeUtf8 b) ------------------------------------------------------------------------------ mkCookieSession :: RNG -> IO RedisSession mkCookieSession rng = do t <- liftIO $ mkCSRFToken rng return $ RedisSession t HM.empty ------------------------------------------------------------------------------ -- | The manager data type to be stuffed into 'SessionManager' -- data RedisSessionManager = RedisSessionManager { session :: Maybe RedisSession -- ^ Per request cache for 'CookieSession' , siteKey :: Key -- ^ A long encryption key used for secure cookie transport , cookieName :: ByteString -- ^ Cookie name for the session system , cookieDomain :: Maybe ByteString -- ^ Cookie domain for session system. You may want to set it to -- dot prefixed domain name like ".example.com", so the cookie is -- available to sub domains. , timeOut :: Maybe Int -- ^ Session cookies will be considered "stale" after this many -- seconds. , randomNumberGenerator :: RNG -- ^ handle to a random number generator , _redisConnection :: Connection -- ^ Redis connection to store session info } deriving (Typeable) ------------------------------------------------------------------------------ loadDefSession :: RedisSessionManager -> IO RedisSessionManager loadDefSession mgr@(RedisSessionManager ses _ _ _ _ rng _) = case ses of Nothing -> do ses' <- mkCookieSession rng return $! mgr { session = Just ses' } Just _ -> return mgr ------------------------------------------------------------------------------ modSession :: (Session -> Session) -> RedisSession -> RedisSession modSession f (RedisSession t ses) = RedisSession t (f ses) ------------------------------------------------------------------------------ sessionKey :: Text -> ByteString sessionKey t = encodeUtf8 $ mappend "session:" t ------------------------------------------------------------------------------ -- | Initialize a cookie-backed session, returning a 'SessionManager' to be -- stuffed inside your application's state. This 'SessionManager' will enable -- the use of all session storage functionality defined in -- 'Snap.Snaplet.Session' -- initRedisSessionManager :: FilePath -- ^ Path to site-wide encryption key -> ByteString -- ^ Session cookie name -> Maybe ByteString -- ^ Cookie Domain (has no effect with snap < 1.0) -> Maybe Int -- ^ Session time-out (replay attack protection) -> RedisDB -- ^ Redis connection -> SnapletInit b SessionManager initRedisSessionManager fp cn cd to c = makeSnaplet "RedisSession" "A snaplet providing sessions via HTTP cookies with a Redis backend." Nothing $ liftIO $ do key <- getKey fp rng <- liftIO mkRNG return $! SessionManager $ RedisSessionManager Nothing key cn cd to rng (_connection c) ------------------------------------------------------------------------------ instance ISessionManager RedisSessionManager where -------------------------------------------------------------------------- --load grabs the session from redis. load mgr@(RedisSessionManager r _ _ _ _ rng con) = case r of Just _ -> return mgr Nothing -> do pl <- getPayload mgr case pl of Nothing -> liftIO $ loadDefSession mgr Just (Payload x) -> do let c = S.decode x case c of Left _ -> liftIO $ loadDefSession mgr Right cs -> liftIO $ do sess <- runRedis con $ do l <- hgetall (sessionKey $ rsCSRFToken cs) case l of Left _ -> liftIO $ mkCookieSession rng Right l' -> do let rs = cs { rsSession = HM.fromList $ map decodeTuple l'} return rs return mgr { session = Just sess } -------------------------------------------------------------------------- --commit writes to redis and sends the csrf to client and also sets the --timeout. commit mgr@(RedisSessionManager r _ _ _ to rng con) = do pl <- case r of Just r' -> liftIO $ runRedis con $ do res <- multiExec $ do _ <- del [sessionKey (rsCSRFToken r')] --Clear old values let sess = map encodeTuple $ HM.toList (rsSession r') res1 <- case sess of [] -> hmset (sessionKey (rsCSRFToken r')) [("","")] _ -> hmset (sessionKey (rsCSRFToken r')) sess res2 <- case to of Just i -> expire (sessionKey (rsCSRFToken r')) $ toInteger i Nothing -> persist (sessionKey (rsCSRFToken r')) return $ (,) <$> res1 <*> res2 case res of TxSuccess _ -> return . Payload $ S.encode r' TxError e -> error e TxAborted -> error "transaction aborted" Nothing -> liftIO $ Payload . S.encode <$> mkCookieSession rng setPayload mgr pl -------------------------------------------------------------------------- --clear the session from redis and return a new empty one {-reset mgr@(RedisSessionManager _ _ _ _ _ _) = trace "RedisSessionManager reset" $ do-} reset mgr@(RedisSessionManager r _ _ _ _ _ con) = do case r of Just r' -> liftIO $ runRedis con $ do res1 <- del [sessionKey $ rsCSRFToken r'] case res1 of Left e -> error $ show e _ -> return () _ -> return () cs <- liftIO $ mkCookieSession (randomNumberGenerator mgr) return $ mgr { session = Just cs } -------------------------------------------------------------------------- touch = id -------------------------------------------------------------------------- insert k v mgr@(RedisSessionManager r _ _ _ _ _ _) = case r of Just r' -> mgr { session = Just $ modSession (HM.insert k v) r' } Nothing -> mgr -------------------------------------------------------------------------- lookup k (RedisSessionManager r _ _ _ _ _ _) = r >>= HM.lookup k . rsSession -------------------------------------------------------------------------- delete k mgr@(RedisSessionManager r _ _ _ _ _ _) = case r of Just r' -> mgr { session = Just $ modSession (HM.delete k) r' } Nothing -> mgr -------------------------------------------------------------------------- csrf (RedisSessionManager r _ _ _ _ _ _) = case r of Just r' -> rsCSRFToken r' Nothing -> "" -------------------------------------------------------------------------- toList (RedisSessionManager r _ _ _ _ _ _) = case r of Just r' -> HM.toList . rsSession $ r' Nothing -> [] ------------------------------------------------------------------------------ -- | A session payload to be stored in a SecureCookie. newtype Payload = Payload ByteString deriving (Eq, Show, Ord, Serialize) ------------------------------------------------------------------------------ -- | Get the current client-side value getPayload :: RedisSessionManager -> Snap (Maybe Payload) getPayload mgr = getSecureCookie (cookieName mgr) (siteKey mgr) (timeOut mgr) ------------------------------------------------------------------------------ -- | Set the client-side value setPayload :: RedisSessionManager -> Payload -> Snap () setPayload mgr = setSecureCookie (cookieName mgr) #if MIN_VERSION_snap(1,0,0) (cookieDomain mgr) #endif (siteKey mgr) (timeOut mgr)