{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeFamilies #-} module Snap.Extension.Session.Client ( HasClientSessionManager(..), ClientSessionManager, clientSessionInitializer ) where import Control.Applicative import Control.Monad.Reader import Data.ByteString.Char8 (ByteString) import Data.Maybe import Data.Serialize import Data.Time.Clock import Data.Time.Clock.POSIX import Snap.Extension import Snap.Extension.Session import Snap.SessionUtil import Snap.Types import Web.ClientSession class HasClientSessionManager a where type ClientSession a clientSessionMgr :: a -> ClientSessionManager (ClientSession a) data ClientSessionManager t = ClientSessionManager { clientSessionKey :: Key, clientSessionDefault :: IO t, clientSessionTimeout :: Maybe NominalDiffTime } clientSessionInitializer :: Key -> Maybe NominalDiffTime -> IO t -> Initializer (ClientSessionManager t) clientSessionInitializer key timeout defaultVal = mkInitializer (ClientSessionManager key defaultVal timeout) instance InitializerState (ClientSessionManager t) where extensionId = const "Session/Client" mkCleanup = const $ return () mkReload = const $ return () {- The actual value encrypted and stored in the session contains not only the desired value, but also the client and the current time. This protects the session from cross-site session stealing attacks, and prevents stale session attacks. -} type SessionRecord t = (ByteString, UTCTime, t) instance Serialize UTCTime where put t = put (round (utcTimeToPOSIXSeconds t) :: Integer) get = posixSecondsToUTCTime . fromInteger <$> get valToRecord :: (MonadSnap m, Serialize t) => t -> m (SessionRecord t) valToRecord val = do rq <- getRequest t <- liftIO getCurrentTime return (rqRemoteAddr rq, t, val) recordToVal :: (MonadSnap m, Serialize t) => ClientSessionManager t -> SessionRecord t -> m t recordToVal mgr (client, time, val) = do addr <- fmap rqRemoteAddr getRequest if addr /= client then liftIO $ clientSessionDefault mgr else do t <- liftIO getCurrentTime case clientSessionTimeout mgr of Nothing -> return val Just lim-> if t `diffUTCTime` time > lim then liftIO $ clientSessionDefault mgr else return val getSessionRecord :: (MonadSnap m, Serialize t) => ClientSessionManager t -> m (SessionRecord t) getSessionRecord mgr = do cookie <- getCookie "sessionval" case cookie of Nothing -> valToRecord =<< liftIO (clientSessionDefault mgr) Just c -> case decrypt (clientSessionKey mgr) (cookieValue c) of Nothing -> valToRecord =<< liftIO (clientSessionDefault mgr) Just unenc -> case decode unenc of Left _ -> valToRecord =<< liftIO (clientSessionDefault mgr) Right rec -> return rec putSessionRecord :: (MonadSnap m, Serialize t) => ClientSessionManager t -> SessionRecord t -> m () putSessionRecord mgr record = do let str = encrypt (clientSessionKey mgr) (encode record) setCookie $ Cookie "sessionval" str Nothing Nothing Nothing instance (HasClientSessionManager s, Serialize (ClientSession s)) => MonadSession (SnapExtend s) where type Session (SnapExtend s) = ClientSession s getSession = do mgr <- asks clientSessionMgr v <- recordToVal mgr =<< getSessionRecord mgr return v setSession v = do mgr <- asks clientSessionMgr putSessionRecord mgr =<< valToRecord v touchSession = do mgr <- asks clientSessionMgr v <- recordToVal mgr =<< getSessionRecord mgr when (isJust (clientSessionTimeout mgr)) $ do putSessionRecord mgr =<< valToRecord v clearSession = clearCookie "sessionval"