{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE UndecidableInstances       #-}
{-# LANGUAGE OverloadedStrings          #-}
{-# LANGUAGE TypeFamilies               #-}

module Snap.Extension.Session.Client (
    ) where

import Control.Applicative
import Control.Monad.Reader
import Data.ByteString.Char8 (ByteString)
import Data.Maybe
import Data.Serialize
import Data.Time.Clock
import Data.Time.Clock.POSIX
import Snap.Extension
import Snap.Extension.Session
import Snap.SessionUtil
import Snap.Types
import Web.ClientSession

class HasClientSessionManager a where
    type ClientSession a
    clientSessionMgr :: a -> ClientSessionManager (ClientSession a)

data ClientSessionManager t = ClientSessionManager {
    clientSessionKey     :: Key,
    clientSessionDefault :: IO t,
    clientSessionTimeout :: Maybe NominalDiffTime

clientSessionInitializer :: Key
                         -> Maybe NominalDiffTime
                         -> IO t
                         -> Initializer (ClientSessionManager t)
clientSessionInitializer key timeout defaultVal =
    mkInitializer (ClientSessionManager key defaultVal timeout)

instance InitializerState (ClientSessionManager t) where
    extensionId = const "Session/Client"
    mkCleanup   = const $ return ()
    mkReload    = const $ return ()

    The actual value encrypted and stored in the session contains not only the
    desired value, but also the client and the current time.  This protects the
    session from cross-site session stealing attacks, and prevents stale session
type SessionRecord t = (ByteString, UTCTime, t)

instance Serialize UTCTime where
    put t = put (round (utcTimeToPOSIXSeconds t) :: Integer)
    get   = posixSecondsToUTCTime . fromInteger <$> get

valToRecord :: (MonadSnap m, Serialize t) => t -> m (SessionRecord t)
valToRecord val = do
     rq  <- getRequest
     t   <- liftIO getCurrentTime
     return (rqRemoteAddr rq, t, val)

recordToVal :: (MonadSnap m, Serialize t)
            => ClientSessionManager t -> SessionRecord t -> m t
recordToVal mgr (client, time, val) = do
    addr <- fmap rqRemoteAddr getRequest
    if addr /= client
        then liftIO $ clientSessionDefault mgr
        else do
            t    <- liftIO getCurrentTime
            case clientSessionTimeout mgr of
                Nothing -> return val
                Just lim-> if t `diffUTCTime` time > lim
                                then liftIO $ clientSessionDefault mgr
                                else return val

getSessionRecord :: (MonadSnap m, Serialize t)
                 => ClientSessionManager t -> m (SessionRecord t)
getSessionRecord mgr = do
     cookie <- lookupCookie "sessionval"
     case cookie of
         Nothing -> valToRecord =<< liftIO (clientSessionDefault mgr)
         Just c  -> case decrypt (clientSessionKey mgr) (cookieValue c) of
             Nothing    -> valToRecord =<< liftIO (clientSessionDefault mgr)
             Just unenc -> case decode unenc of
                 Left _    -> valToRecord =<< liftIO (clientSessionDefault mgr)
                 Right rec -> return rec

putSessionRecord :: (MonadSnap m, Serialize t)
                 => ClientSessionManager t -> SessionRecord t -> m ()
putSessionRecord mgr record = do
    let str = encrypt (clientSessionKey mgr) (encode record)
    setCookie $ Cookie "sessionval" str Nothing Nothing Nothing

instance (HasClientSessionManager s, Serialize (ClientSession s))
        => MonadSession (SnapExtend s) where
    type Session (SnapExtend s) = ClientSession s

    getSession = do
        mgr <- asks clientSessionMgr
        v <- recordToVal mgr =<< getSessionRecord mgr
        return v

    setSession v = do
        mgr <- asks clientSessionMgr
        putSessionRecord mgr =<< valToRecord v

    touchSession = do
        mgr <- asks clientSessionMgr
        v <- recordToVal mgr =<< getSessionRecord mgr
        when (isJust (clientSessionTimeout mgr)) $ do
            putSessionRecord mgr =<< valToRecord v

    clearSession = clearCookie "sessionval"