{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE UndecidableInstances       #-}
{-# LANGUAGE OverloadedStrings          #-}
{-# LANGUAGE TypeFamilies               #-}

module Snap.Extension.Session.Client (
    HasClientSessionManager(..),
    ClientSessionManager,
    clientSessionInitializer
    ) where

import Control.Applicative
import Control.Monad.Reader
import Data.ByteString.Char8 (ByteString)
import Data.Maybe
import Data.Serialize
import Data.Time.Clock
import Data.Time.Clock.POSIX
import Snap.Extension
import Snap.Extension.Session
import Snap.SessionUtil
import Snap.Types
import Web.ClientSession

class HasClientSessionManager a where
    type ClientSession a
    clientSessionMgr :: a -> ClientSessionManager (ClientSession a)

data ClientSessionManager t = ClientSessionManager {
    clientSessionKey     :: Key,
    clientSessionDefault :: IO t,
    clientSessionTimeout :: Maybe NominalDiffTime
    }

clientSessionInitializer :: Key
                         -> Maybe NominalDiffTime
                         -> IO t
                         -> Initializer (ClientSessionManager t)
clientSessionInitializer key timeout defaultVal =
    mkInitializer (ClientSessionManager key defaultVal timeout)

instance InitializerState (ClientSessionManager t) where
    extensionId = const "Session/Client"
    mkCleanup   = const $ return ()
    mkReload    = const $ return ()

{-
    The actual value encrypted and stored in the session contains not only the
    desired value, but also the client and the current time.  This protects the
    session from cross-site session stealing attacks, and prevents stale session
    attacks.
-}
type SessionRecord t = (ByteString, UTCTime, t)

instance Serialize UTCTime where
    put t = put (round (utcTimeToPOSIXSeconds t) :: Integer)
    get   = posixSecondsToUTCTime . fromInteger <$> get

valToRecord :: (MonadSnap m, Serialize t) => t -> m (SessionRecord t)
valToRecord val = do
     rq  <- getRequest
     t   <- liftIO getCurrentTime
     return (rqRemoteAddr rq, t, val)

recordToVal :: (MonadSnap m, Serialize t)
            => ClientSessionManager t -> SessionRecord t -> m t
recordToVal mgr (client, time, val) = do
    addr <- fmap rqRemoteAddr getRequest
    if addr /= client
        then liftIO $ clientSessionDefault mgr
        else do
            t    <- liftIO getCurrentTime
            case clientSessionTimeout mgr of
                Nothing -> return val
                Just lim-> if t `diffUTCTime` time > lim
                                then liftIO $ clientSessionDefault mgr
                                else return val

getSessionRecord :: (MonadSnap m, Serialize t)
                 => ClientSessionManager t -> m (SessionRecord t)
getSessionRecord mgr = do
     cookie <- lookupCookie "sessionval"
     case cookie of
         Nothing -> valToRecord =<< liftIO (clientSessionDefault mgr)
         Just c  -> case decrypt (clientSessionKey mgr) (cookieValue c) of
             Nothing    -> valToRecord =<< liftIO (clientSessionDefault mgr)
             Just unenc -> case decode unenc of
                 Left _    -> valToRecord =<< liftIO (clientSessionDefault mgr)
                 Right rec -> return rec

putSessionRecord :: (MonadSnap m, Serialize t)
                 => ClientSessionManager t -> SessionRecord t -> m ()
putSessionRecord mgr record = do
    let str = encrypt (clientSessionKey mgr) (encode record)
    setCookie $ Cookie "sessionval" str Nothing Nothing Nothing

instance (HasClientSessionManager s, Serialize (ClientSession s))
        => MonadSession (SnapExtend s) where
    type Session (SnapExtend s) = ClientSession s

    getSession = do
        mgr <- asks clientSessionMgr
        v <- recordToVal mgr =<< getSessionRecord mgr
        return v

    setSession v = do
        mgr <- asks clientSessionMgr
        putSessionRecord mgr =<< valToRecord v

    touchSession = do
        mgr <- asks clientSessionMgr
        v <- recordToVal mgr =<< getSessionRecord mgr
        when (isJust (clientSessionTimeout mgr)) $ do
            putSessionRecord mgr =<< valToRecord v

    clearSession = clearCookie "sessionval"