module Yesod.Session.Memcache.Storage
  ( memcacheStorage
  , SessionPersistence (..)
  , getMemcacheExpiration
  ) where

import Internal.Prelude

import Database.Memcache.Client qualified as Memcache
import Database.Memcache.Types qualified as Memcache
import Session.Key
import Session.Timing.Math (nextExpires)
import Session.Timing.Options (TimingOptions (timeout))
import Session.Timing.Time (Time (..))
import Session.Timing.Timeout (Timeout (..))
import Time (NominalDiffTime, UTCTime)
import Yesod.Core (SessionMap)
import Yesod.Session.Memcache.Expiration
  ( MemcacheExpiration (NoMemcacheExpiration, UseMemcacheExpiration)
  , fromUTC
  , noExpiration
  )
import Yesod.Session.Options (Options (timing))
import Yesod.Session.SessionType
import Yesod.Session.Storage.Exceptions
import Yesod.Session.Storage.Operation

-- | Mapping between 'Session' and Memcache representation.
data SessionPersistence = SessionPersistence
  { SessionPersistence -> SessionKey -> Key
databaseKey :: SessionKey -> Memcache.Key
  , SessionPersistence -> (SessionMap, Time UTCTime) -> Key
toDatabase :: (SessionMap, Time UTCTime) -> Memcache.Value
  , SessionPersistence
-> Key -> Either SomeException (SessionMap, Time UTCTime)
fromDatabase
      :: Memcache.Value
      -> Either SomeException (SessionMap, Time UTCTime)
  , SessionPersistence -> Client
client :: Memcache.Client
  , SessionPersistence -> MemcacheExpiration
expiration :: MemcacheExpiration
  }

memcacheStorage
  :: forall m result
   . (MonadThrow m, MonadIO m)
  => SessionPersistence
  -> Options IO IO
  -> StorageOperation result
  -> m result
memcacheStorage :: forall (m :: * -> *) result.
(MonadThrow m, MonadIO m) =>
SessionPersistence
-> Options IO IO -> StorageOperation result -> m result
memcacheStorage SessionPersistence
sp Options IO IO
opt = \case
  GetSession SessionKey
sessionKey -> do
    Maybe Key
mValue <-
      IO (Maybe Key) -> m (Maybe Key)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe Key) -> m (Maybe Key))
-> IO (Maybe Key) -> m (Maybe Key)
forall a b. (a -> b) -> a -> b
$ ((Key, Flags, Version) -> Key)
-> Maybe (Key, Flags, Version) -> Maybe Key
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Key, Flags, Version) -> Key
forall {a} {b} {c}. (a, b, c) -> a
fstOf3 (Maybe (Key, Flags, Version) -> Maybe Key)
-> IO (Maybe (Key, Flags, Version)) -> IO (Maybe Key)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Client -> Key -> IO (Maybe (Key, Flags, Version))
Memcache.get SessionPersistence
sp.client (SessionPersistence
sp.databaseKey SessionKey
sessionKey)

    case Maybe Key
mValue of
      Maybe Key
Nothing -> result -> m result
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure result
Maybe Session
forall a. Maybe a
Nothing
      Just Key
value -> do
        (SessionMap
map, Time UTCTime
time) <- (SomeException -> m (SessionMap, Time UTCTime))
-> ((SessionMap, Time UTCTime) -> m (SessionMap, Time UTCTime))
-> Either SomeException (SessionMap, Time UTCTime)
-> m (SessionMap, Time UTCTime)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> m (SessionMap, Time UTCTime)
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (SessionMap, Time UTCTime) -> m (SessionMap, Time UTCTime)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SomeException (SessionMap, Time UTCTime)
 -> m (SessionMap, Time UTCTime))
-> Either SomeException (SessionMap, Time UTCTime)
-> m (SessionMap, Time UTCTime)
forall a b. (a -> b) -> a -> b
$ SessionPersistence
sp.fromDatabase Key
value
        result -> m result
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (result -> m result) -> result -> m result
forall a b. (a -> b) -> a -> b
$ Session -> Maybe Session
forall a. a -> Maybe a
Just Session {$sel:key:Session :: SessionKey
key = SessionKey
sessionKey, SessionMap
map :: SessionMap
$sel:map:Session :: SessionMap
map, Time UTCTime
time :: Time UTCTime
$sel:time:Session :: Time UTCTime
time}
  DeleteSession SessionKey
sessionKey -> do
    m Bool -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Bool -> m ()) -> m Bool -> m ()
forall a b. (a -> b) -> a -> b
$ IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ Client -> Key -> Version -> IO Bool
Memcache.delete SessionPersistence
sp.client (SessionPersistence
sp.databaseKey SessionKey
sessionKey) Version
bypassCAS
  InsertSession Session
session -> do
    let
      key :: Key
key = SessionPersistence
sp.databaseKey Session
session.key
      value :: Key
value = SessionPersistence
sp.toDatabase (Session
session.map, Session
session.time)

    Flags
expiration <-
      MemcacheExpiration
-> Timeout NominalDiffTime -> Time UTCTime -> m Flags
forall (m :: * -> *).
MonadThrow m =>
MemcacheExpiration
-> Timeout NominalDiffTime -> Time UTCTime -> m Flags
getMemcacheExpiration SessionPersistence
sp.expiration Options IO IO
opt.timing.timeout Session
session.time

    Maybe Version
mVersion <- IO (Maybe Version) -> m (Maybe Version)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe Version) -> m (Maybe Version))
-> IO (Maybe Version) -> m (Maybe Version)
forall a b. (a -> b) -> a -> b
$ Client -> Key -> Key -> Flags -> Flags -> IO (Maybe Version)
Memcache.add SessionPersistence
sp.client Key
key Key
value Flags
defaultFlags Flags
expiration
    StorageException -> Maybe Version -> m ()
forall {m :: * -> *} {p} {b}.
(MonadThrow m, Exception p) =>
p -> Maybe b -> m ()
throwOnNothing StorageException
SessionAlreadyExists Maybe Version
mVersion
  ReplaceSession Session
session -> do
    let key :: Key
key = SessionPersistence
sp.databaseKey Session
session.key

    Flags
expiration <-
      MemcacheExpiration
-> Timeout NominalDiffTime -> Time UTCTime -> m Flags
forall (m :: * -> *).
MonadThrow m =>
MemcacheExpiration
-> Timeout NominalDiffTime -> Time UTCTime -> m Flags
getMemcacheExpiration SessionPersistence
sp.expiration Options IO IO
opt.timing.timeout Session
session.time

    Maybe Version
mVersion <-
      IO (Maybe Version) -> m (Maybe Version)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
        (IO (Maybe Version) -> m (Maybe Version))
-> IO (Maybe Version) -> m (Maybe Version)
forall a b. (a -> b) -> a -> b
$ Client
-> Key -> Key -> Flags -> Flags -> Version -> IO (Maybe Version)
Memcache.replace
          SessionPersistence
sp.client
          Key
key
          (SessionPersistence
sp.toDatabase (Session
session.map, Session
session.time))
          Flags
defaultFlags
          Flags
expiration
          Version
bypassCAS
    StorageException -> Maybe Version -> m ()
forall {m :: * -> *} {p} {b}.
(MonadThrow m, Exception p) =>
p -> Maybe b -> m ()
throwOnNothing StorageException
SessionDoesNotExist Maybe Version
mVersion
 where
  throwOnNothing :: p -> Maybe b -> m ()
throwOnNothing p
exception Maybe b
maybeValue = m () -> (b -> m ()) -> Maybe b -> m ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (p -> m ()
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwWithCallStack p
exception) (m () -> b -> m ()
forall a b. a -> b -> a
const (m () -> b -> m ()) -> m () -> b -> m ()
forall a b. (a -> b) -> a -> b
$ () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) Maybe b
maybeValue
  fstOf3 :: (a, b, c) -> a
fstOf3 (a
a, b
_, c
_) = a
a

-- | Determine what 'Memcache.Expiration' value to use.
getMemcacheExpiration
  :: MonadThrow m
  => MemcacheExpiration
  -> Timeout NominalDiffTime
  -> Time UTCTime
  -> m Memcache.Expiration
getMemcacheExpiration :: forall (m :: * -> *).
MonadThrow m =>
MemcacheExpiration
-> Timeout NominalDiffTime -> Time UTCTime -> m Flags
getMemcacheExpiration MemcacheExpiration
UseMemcacheExpiration Timeout NominalDiffTime
timeout Time UTCTime
time = m Flags -> (UTCTime -> m Flags) -> Maybe UTCTime -> m Flags
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Flags -> m Flags
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Flags
noExpiration) UTCTime -> m Flags
forall (m :: * -> *). MonadThrow m => UTCTime -> m Flags
fromUTC (Maybe UTCTime -> m Flags) -> Maybe UTCTime -> m Flags
forall a b. (a -> b) -> a -> b
$ Timeout NominalDiffTime -> Time UTCTime -> Maybe UTCTime
nextExpires Timeout NominalDiffTime
timeout Time UTCTime
time
getMemcacheExpiration MemcacheExpiration
NoMemcacheExpiration Timeout NominalDiffTime
_timeout Time UTCTime
_time = Flags -> m Flags
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Flags
noExpiration

defaultFlags :: Memcache.Flags
defaultFlags :: Flags
defaultFlags = Flags
0

-- | Do not do any CAS checking.
--
-- Logically, a 'Version' (a.k.a CAS) value is optional. However, this optionality is represented
-- by a 'Version' of /0/. This is documented in the Memcache docs for the /set/, /add/,
-- and /replace/ commands:
--
-- https://github.com/memcached/memcached/wiki/BinaryProtocolRevamped#set-add-replace
--
-- But it applies at the level of the binary protocol itself. The /0/ 'Version'
-- sentinel value means "do not do any CAS checking".
bypassCAS :: Memcache.Version
bypassCAS :: Version
bypassCAS = Version
0