module Yesod.Session.Memcache.Expiration
  ( MemcacheExpiration (..)
  , noExpiration
  , fromUTC
  , maxTimestamp
  , minTimestamp
  ) where

import Internal.Prelude

import Data.Fixed (Pico)
import Database.Memcache.Types qualified as Memcache
import Time (UTCTime, nominalDiffTimeToSeconds, utcTimeToPOSIXSeconds)

newtype Exceptions = InvalidExpiration Pico
  deriving stock (Exceptions -> Exceptions -> Bool
(Exceptions -> Exceptions -> Bool)
-> (Exceptions -> Exceptions -> Bool) -> Eq Exceptions
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Exceptions -> Exceptions -> Bool
== :: Exceptions -> Exceptions -> Bool
$c/= :: Exceptions -> Exceptions -> Bool
/= :: Exceptions -> Exceptions -> Bool
Eq, Int -> Exceptions -> ShowS
[Exceptions] -> ShowS
Exceptions -> String
(Int -> Exceptions -> ShowS)
-> (Exceptions -> String)
-> ([Exceptions] -> ShowS)
-> Show Exceptions
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Exceptions -> ShowS
showsPrec :: Int -> Exceptions -> ShowS
$cshow :: Exceptions -> String
show :: Exceptions -> String
$cshowList :: [Exceptions] -> ShowS
showList :: [Exceptions] -> ShowS
Show)
  deriving anyclass (Show Exceptions
Typeable Exceptions
(Typeable Exceptions, Show Exceptions) =>
(Exceptions -> SomeException)
-> (SomeException -> Maybe Exceptions)
-> (Exceptions -> String)
-> Exception Exceptions
SomeException -> Maybe Exceptions
Exceptions -> String
Exceptions -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
$ctoException :: Exceptions -> SomeException
toException :: Exceptions -> SomeException
$cfromException :: SomeException -> Maybe Exceptions
fromException :: SomeException -> Maybe Exceptions
$cdisplayException :: Exceptions -> String
displayException :: Exceptions -> String
Exception)

data MemcacheExpiration
  = -- | Do not set expiration times; memache will only evict when the cache is full
    NoMemcacheExpiration
  | -- | Sessions will be stored in memcache with the same expiration time that we
    -- send to the HTTP client, the lesser of the idle and absolute timeouts.
    UseMemcacheExpiration

-- | Do not expire the session via Memcache's expiration mechanism.
--
--  Memcache will evict the session when the cache is full.
noExpiration :: Memcache.Expiration
noExpiration :: Expiration
noExpiration = Expiration
0

-- | Convert 'UTCTime' to 'Word32', with possibility of failure.
--
-- This function guards against UTCTime values that, converted to a timestamp,
-- would be too big or too small.
--
-- See 'maxTimestamp' and 'minTimestamp' for definitions of too 'big / small'.
fromUTC :: MonadThrow m => UTCTime -> m Word32
fromUTC :: forall (m :: * -> *). MonadThrow m => UTCTime -> m Expiration
fromUTC UTCTime
utcTime = do
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
tooLarge Bool -> Bool -> Bool
|| Bool
tooSmall) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Exceptions -> m ()
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwWithCallStack (Exceptions -> m ()) -> Exceptions -> m ()
forall a b. (a -> b) -> a -> b
$ Pico -> Exceptions
InvalidExpiration Pico
seconds
  Expiration -> m Expiration
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expiration -> m Expiration) -> Expiration -> m Expiration
forall a b. (a -> b) -> a -> b
$ Pico -> Expiration
forall b. Integral b => Pico -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling Pico
seconds
 where
  seconds :: Pico
seconds = NominalDiffTime -> Pico
nominalDiffTimeToSeconds (NominalDiffTime -> Pico) -> NominalDiffTime -> Pico
forall a b. (a -> b) -> a -> b
$ UTCTime -> NominalDiffTime
utcTimeToPOSIXSeconds UTCTime
utcTime
  tooLarge :: Bool
tooLarge = Pico
seconds Pico -> Pico -> Bool
forall a. Ord a => a -> a -> Bool
> Pico
forall a. Num a => a
maxTimestamp
  tooSmall :: Bool
tooSmall = Pico
seconds Pico -> Pico -> Bool
forall a. Ord a => a -> a -> Bool
< Pico
forall a. Num a => a
minTimestamp

-- | Minimum value that will be interpreted as a timestamp by Memcache
--
-- Values lower than this are considered to be "number of seconds" in the future
-- to expire a key / value pair. This is /not/ the interpretation we want.
--
-- See: https://github.com/dterei/memcache-hs/blob/83957ee9c4983f87447b0e7476a9a9155474dc80/Database/Memcache/Client.hs#L49-L59
--
-- This value is ~1960.
minTimestamp :: Num a => a
minTimestamp :: forall a. Num a => a
minTimestamp = a
2592000 a -> a -> a
forall a. Num a => a -> a -> a
+ a
1 -- Values lower than this

-- | Check to make sure we don't overflow.
--
-- 4_294_967_295 is ~2096
maxTimestamp :: Num a => a
maxTimestamp :: forall a. Num a => a
maxTimestamp = Expiration -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Expiration -> a) -> Expiration -> a
forall a b. (a -> b) -> a -> b
$ forall a. Bounded a => a
maxBound @Word32