module Yesod.Session.Embedding.Map
  ( SessionMapEmbedding
  , MapOperations (..)
  , bsKeyEmbedding
  , showReadKeyEmbedding
  ) where

import Internal.Prelude

import Control.Monad.State (StateT (..))
import Control.Monad.State qualified as State
import Data.Map.Strict qualified as Map
import Data.Text qualified as T
import Data.Text.Encoding (decodeUtf8', encodeUtf8)
import Embedding
import Yesod.Core (HandlerFor, deleteSession, lookupSessionBS, setSessionBS)

-- | Specifies how we represent some value within a 'SessionMap'
--
-- We use this to sort of abuse the session; key rotation and freezing are
-- done by embedding special values among the session data. These special
-- values are extracted from the map before persisting to storage and are
-- never actually saved.
type SessionMapEmbedding a = Embedding (MapOperations Text ByteString) () a

-- | A monadic context with operations over some 'Map'-like state
--
-- This allows us to generalize between pure operations over 'Map' and
-- the more limited session manipulation utilities afforded by Yesod.
-- (See the instance list for this class.)
class (Monad m, Ord k) => MapOperations k v m | m -> k v where
  lookup :: k -> m (Maybe v)
  assign :: k -> Maybe v -> m ()

instance MapOperations Text ByteString (HandlerFor site) where
  lookup :: Text -> HandlerFor site (Maybe ByteString)
lookup Text
k = Text -> HandlerFor site (Maybe ByteString)
forall (m :: * -> *).
MonadHandler m =>
Text -> m (Maybe ByteString)
lookupSessionBS Text
k
  assign :: Text -> Maybe ByteString -> HandlerFor site ()
assign Text
k Maybe ByteString
v = HandlerFor site ()
-> (ByteString -> HandlerFor site ())
-> Maybe ByteString
-> HandlerFor site ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Text -> HandlerFor site ()
forall (m :: * -> *). MonadHandler m => Text -> m ()
deleteSession Text
k) (Text -> ByteString -> HandlerFor site ()
forall (m :: * -> *). MonadHandler m => Text -> ByteString -> m ()
setSessionBS Text
k) Maybe ByteString
v

instance (Monad m, Ord k) => MapOperations k v (StateT (Map k v) m) where
  lookup :: k -> StateT (Map k v) m (Maybe v)
lookup k
k = (Map k v -> Maybe v) -> StateT (Map k v) m (Maybe v)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
State.gets ((Map k v -> Maybe v) -> StateT (Map k v) m (Maybe v))
-> (Map k v -> Maybe v) -> StateT (Map k v) m (Maybe v)
forall a b. (a -> b) -> a -> b
$ k -> Map k v -> Maybe v
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup k
k
  assign :: k -> Maybe v -> StateT (Map k v) m ()
assign k
k Maybe v
v = (Map k v -> Map k v) -> StateT (Map k v) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
State.modify' ((Map k v -> Map k v) -> StateT (Map k v) m ())
-> (Map k v -> Map k v) -> StateT (Map k v) m ()
forall a b. (a -> b) -> a -> b
$ (Maybe v -> Maybe v) -> k -> Map k v -> Map k v
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
Map.alter (Maybe v -> Maybe v -> Maybe v
forall a b. a -> b -> a
const Maybe v
v) k
k

-- | An embedding which stores a value at some particular key in a map-like structure
bsKeyEmbedding :: k -> Embedding (MapOperations k a) e a
bsKeyEmbedding :: forall k a e. k -> Embedding (MapOperations k a) e a
bsKeyEmbedding k
key =
  Embedding
    { $sel:embed:Embedding :: forall (m :: * -> *). MapOperations k a m => Maybe a -> m ()
embed = k -> Maybe a -> m ()
forall k v (m :: * -> *).
MapOperations k v m =>
k -> Maybe v -> m ()
assign k
key
    , $sel:extract:Embedding :: forall (m :: * -> *).
(Functor m, MapOperations k a m) =>
m (Either e (Maybe a))
extract = (Maybe a -> Either e (Maybe a))
-> m (Maybe a) -> m (Either e (Maybe a))
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Maybe a -> Either e (Maybe a)
forall a b. b -> Either a b
Right (m (Maybe a) -> m (Either e (Maybe a)))
-> m (Maybe a) -> m (Either e (Maybe a))
forall a b. (a -> b) -> a -> b
$ k -> m (Maybe a)
forall k v (m :: * -> *). MapOperations k v m => k -> m (Maybe v)
lookup k
key m (Maybe a) -> m () -> m (Maybe a)
forall a b. m a -> m b -> m a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* k -> Maybe a -> m ()
forall k v (m :: * -> *).
MapOperations k v m =>
k -> Maybe v -> m ()
assign k
key Maybe a
forall a. Maybe a
Nothing
    }

-- | Represents a value in a 'SessionMap' by storing the
--   UTF-8 encoding of its 'show' representation at the given key
showReadKeyEmbedding
  :: (Read a, Show a) => k -> Embedding (MapOperations k ByteString) () a
showReadKeyEmbedding :: forall a k.
(Read a, Show a) =>
k -> Embedding (MapOperations k ByteString) () a
showReadKeyEmbedding k
k =
  (ByteString -> Either () a)
-> (a -> ByteString)
-> Embedding (MapOperations k ByteString) () ByteString
-> Embedding (MapOperations k ByteString) () a
forall a e b (con :: (* -> *) -> Constraint).
(a -> Either e b)
-> (b -> a) -> Embedding con e a -> Embedding con e b
dimapEmbedding
    ( Either () a -> (a -> Either () a) -> Maybe a -> Either () a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> Either () a
forall a. () -> Either () a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ()) a -> Either () a
forall a. a -> Either () a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        (Maybe a -> Either () a)
-> (String -> Maybe a) -> String -> Either () a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Maybe a
forall a. Read a => String -> Maybe a
readMaybe
        (String -> Either () a)
-> (ByteString -> Either () String) -> ByteString -> Either () a
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ((UnicodeException -> ())
-> (Text -> String)
-> Either UnicodeException Text
-> Either () String
forall a b c d. (a -> b) -> (c -> d) -> Either a c -> Either b d
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (() -> UnicodeException -> ()
forall a b. a -> b -> a
const ()) Text -> String
T.unpack (Either UnicodeException Text -> Either () String)
-> (ByteString -> Either UnicodeException Text)
-> ByteString
-> Either () String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either UnicodeException Text
decodeUtf8')
    )
    (Text -> ByteString
encodeUtf8 (Text -> ByteString) -> (a -> Text) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack (String -> Text) -> (a -> String) -> a -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> String
forall a. Show a => a -> String
show)
    (k -> Embedding (MapOperations k ByteString) () ByteString
forall k a e. k -> Embedding (MapOperations k a) e a
bsKeyEmbedding k
k)