module Happstack.Server.ClientSession
( ClientSession(..)
, SessionConf(..)
, mkSessionConf
, ClientSessionT
, runClientSessionT
, getSession
, putSession
, expireSession
, withSession
, sessionPart
) where
import Control.Applicative (Applicative, Alternative, optional)
import Control.Monad (MonadPlus, when)
import Control.Monad.Reader (ReaderT, runReaderT, ask, asks)
import Control.Monad.State (StateT, State, evalStateT, runState, get, put)
import Control.Monad.Trans (MonadIO, liftIO)
import Data.ByteString.Char8 (pack, unpack)
import Data.SafeCopy (SafeCopy, safeGet, safePut)
import Data.Serialize (runGet, runPut)
import Happstack.Server (HasRqData, FilterMonad, WebMonad, ServerMonad, Happstack, Response, CookieLife(Session), Cookie(secure), lookCookieValue, addCookie, mkCookie, expireCookie)
import Web.ClientSession (Key, decrypt, encryptIO)
class SafeCopy st => ClientSession st where
empty :: st
data SessionState st = Encoded | Decoded st | Modified st | Expired
data SessionConf = SessionConf
{ sessionCookieName :: String
, sessionCookieLife :: CookieLife
, sessionKey :: Key
, sessionSecure :: Bool
}
mkSessionConf :: Key -> SessionConf
mkSessionConf key = SessionConf
{ sessionCookieName = "Happstack.ClientSession"
, sessionCookieLife = Session
, sessionKey = key
, sessionSecure = True
}
newtype ClientSessionT st m a =
ClientSessionT { unClientSessionT :: ReaderT SessionConf (StateT (SessionState st) m) a }
deriving ( Functor, Applicative, Alternative
, Monad, MonadIO, MonadPlus
, HasRqData, FilterMonad r, WebMonad r, ServerMonad
)
instance Happstack m => Happstack (ClientSessionT st m)
runClientSessionT :: Monad m => ClientSessionT st m a -> SessionConf -> m a
runClientSessionT cs sc =
evalStateT (runReaderT (unClientSessionT cs) sc) Encoded
askCS :: Monad m => ClientSessionT st m SessionConf
askCS = ClientSessionT ask
asksCS :: Monad m => (SessionConf -> a) -> ClientSessionT st m a
asksCS = ClientSessionT . asks
getCS :: Monad m => ClientSessionT st m (SessionState st)
getCS = ClientSessionT get
putCS :: Monad m => SessionState st -> ClientSessionT st m ()
putCS = ClientSessionT . put
getSession :: (Functor m, MonadPlus m, HasRqData m, ClientSession st)
=> ClientSessionT st m st
getSession = do
ss <- getCS
case ss of
Decoded a -> return a
Modified a -> return a
Expired -> new
Encoded -> do a <- getValue
putCS $ Decoded a
return a
where
new = return empty
getValue = do name <- asksCS sessionCookieName
value <- optional $ lookCookieValue name
maybe new decode value
decode v = do key <- asksCS sessionKey
maybe new (either (const new) return . runGet safeGet)
. decrypt key $ pack v
putSession :: (Monad m, ClientSession st) => st -> ClientSessionT st m ()
putSession = putCS . Modified
expireSession :: Monad m => ClientSessionT st m ()
expireSession = putCS Expired
withSession :: (Functor m, MonadPlus m, HasRqData m, ClientSession st, Eq st)
=> State st a -> ClientSessionT st m a
withSession m = do s <- getSession
let (a,st) = runState m s
when (st /= s) $ putSession st
return a
sessionPart :: (Functor m, Monad m, MonadIO m, FilterMonad Response m, ClientSession st)
=> ClientSessionT st m a -> ClientSessionT st m a
sessionPart part = do
a <- part
ss <- getCS
case ss of
Modified st -> encode st
Expired -> expire
_ -> return ()
return a
where
encode st = do SessionConf{..} <- askCS
bytes <- liftIO . encryptIO sessionKey . runPut . safePut $ st
addCookie sessionCookieLife $ (mkCookie sessionCookieName $ unpack bytes) { secure = sessionSecure }
expire = do name <- asksCS sessionCookieName
expireCookie name