module Yesod.Core.Internal.Session
    ( encodeClientSession
    , decodeClientSession
    , clientSessionDateCacher
    , ClientSessionDateCache(..)
    , SaveSession
    , SessionBackend(..)
    ) where

import qualified Web.ClientSession as CS
import Data.Serialize
import Data.Time
import Data.ByteString (ByteString)
import Control.Monad (guard)
import Yesod.Core.Types
import Yesod.Core.Internal.Util
import Control.AutoUpdate

encodeClientSession :: CS.Key
                    -> CS.IV
                    -> ClientSessionDateCache  -- ^ expire time
                    -> ByteString -- ^ remote host
                    -> SessionMap -- ^ session
                    -> ByteString -- ^ cookie value
encodeClientSession :: Key
-> IV
-> ClientSessionDateCache
-> ByteString
-> SessionMap
-> ByteString
encodeClientSession Key
key IV
iv ClientSessionDateCache
date ByteString
rhost SessionMap
session' =
    Key -> IV -> ByteString -> ByteString
CS.encrypt Key
key IV
iv forall a b. (a -> b) -> a -> b
$ forall a. Serialize a => a -> ByteString
encode forall a b. (a -> b) -> a -> b
$ Either UTCTime ByteString
-> ByteString -> SessionMap -> SessionCookie
SessionCookie forall {a}. Either a ByteString
expires ByteString
rhost SessionMap
session'
      where expires :: Either a ByteString
expires = forall a b. b -> Either a b
Right (ClientSessionDateCache -> ByteString
csdcExpiresSerialized ClientSessionDateCache
date)

decodeClientSession :: CS.Key
                    -> ClientSessionDateCache  -- ^ current time
                    -> ByteString -- ^ remote host field
                    -> ByteString -- ^ cookie value
                    -> Maybe SessionMap
decodeClientSession :: Key
-> ClientSessionDateCache
-> ByteString
-> ByteString
-> Maybe SessionMap
decodeClientSession Key
key ClientSessionDateCache
date ByteString
rhost ByteString
encrypted = do
    ByteString
decrypted <- Key -> ByteString -> Maybe ByteString
CS.decrypt Key
key ByteString
encrypted
    SessionCookie (Left UTCTime
expire) ByteString
rhost' SessionMap
session' <-
        forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const forall a. Maybe a
Nothing) forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. Serialize a => ByteString -> Either String a
decode ByteString
decrypted
    forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ UTCTime
expire forall a. Ord a => a -> a -> Bool
> ClientSessionDateCache -> UTCTime
csdcNow ClientSessionDateCache
date
    forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ ByteString
rhost' forall a. Eq a => a -> a -> Bool
== ByteString
rhost
    forall (m :: * -> *) a. Monad m => a -> m a
return SessionMap
session'


----------------------------------------------------------------------


-- Originally copied from Kazu's date-cache, but now using mkAutoUpdate.
--
-- The cached date is updated every 10s, we don't need second
-- resolution for session expiration times.
--
-- The second component of the returned tuple used to be an action that
-- killed the updater thread, but is now a no-op that's just there
-- to preserve the type.

clientSessionDateCacher ::
     NominalDiffTime -- ^ Inactive session validity.
  -> IO (IO ClientSessionDateCache, IO ())
clientSessionDateCacher :: NominalDiffTime -> IO (IO ClientSessionDateCache, IO ())
clientSessionDateCacher NominalDiffTime
validity = do
    IO ClientSessionDateCache
getClientSessionDateCache <- forall a. UpdateSettings a -> IO (IO a)
mkAutoUpdate UpdateSettings ()
defaultUpdateSettings
      { updateAction :: IO ClientSessionDateCache
updateAction = IO ClientSessionDateCache
getUpdated
      , updateFreq :: Int
updateFreq   = Int
10000000 -- 10s
      }

    forall (m :: * -> *) a. Monad m => a -> m a
return (IO ClientSessionDateCache
getClientSessionDateCache, forall (m :: * -> *) a. Monad m => a -> m a
return ())
  where
    getUpdated :: IO ClientSessionDateCache
getUpdated = do
      UTCTime
now <- IO UTCTime
getCurrentTime
      let expires :: UTCTime
expires  = NominalDiffTime
validity NominalDiffTime -> UTCTime -> UTCTime
`addUTCTime` UTCTime
now
          expiresS :: ByteString
expiresS = Put -> ByteString
runPut (UTCTime -> Put
putTime UTCTime
expires)
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! UTCTime -> UTCTime -> ByteString -> ClientSessionDateCache
ClientSessionDateCache UTCTime
now UTCTime
expires ByteString
expiresS