{-# LANGUAGE BlockArguments    #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}

-- |
-- Module: Network.Wai.Session.Redis
-- Copyright: (c) 2021, t4ccer
-- License: BSD3
-- Stability: experimental
-- Portability: portable
--
-- Simple Redis backed wai-session backend. This module allows you to store
-- session data of wai-sessions in a Redis database.
module Network.Wai.Session.Redis
  ( dbStore
  , clearSession
  , SessionSettings(..)
  ) where

import           Control.Monad
import           Control.Monad.IO.Class
import           Data.ByteString        (ByteString)
import           Data.Default
import           Data.Either
import           Data.Serialize         (Serialize, decode, encode)
import           Database.Redis         hiding (decode)
import           Network.Wai.Session

-- | Settings to control session store
data SessionSettings = SessionSettings
  { SessionSettings -> ConnectInfo
redisConnectionInfo :: ConnectInfo
  , SessionSettings -> Integer
expiratinTime       :: Integer
  -- ^ Session expiration time in seconds
  }

instance Default SessionSettings where
  def :: SessionSettings
def = SessionSettings :: ConnectInfo -> Integer -> SessionSettings
SessionSettings
    { redisConnectionInfo :: ConnectInfo
redisConnectionInfo = ConnectInfo
defaultConnectInfo
    , expiratinTime :: Integer
expiratinTime       = Integer
60Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
60Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
24Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
7 -- One week
    }

eitherToMaybe :: Either a b -> Maybe b
eitherToMaybe :: Either a b -> Maybe b
eitherToMaybe (Left a
_)  = Maybe b
forall a. Maybe a
Nothing
eitherToMaybe (Right b
a) = b -> Maybe b
forall a. a -> Maybe a
Just b
a

connectAndRunRedis :: MonadIO m => ConnectInfo -> Redis b -> m b
connectAndRunRedis :: ConnectInfo -> Redis b -> m b
connectAndRunRedis ConnectInfo
ci Redis b
cmd = IO b -> m b
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
  Connection
conn <- ConnectInfo -> IO Connection
connect ConnectInfo
ci
  b
res  <- Connection -> Redis b -> IO b
forall a. Connection -> Redis a -> IO a
runRedis Connection
conn Redis b
cmd
  Connection -> IO ()
disconnect Connection
conn
  b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
res

connectAndRunRedis_ :: MonadIO m => ConnectInfo -> Redis b -> m ()
connectAndRunRedis_ :: ConnectInfo -> Redis b -> m ()
connectAndRunRedis_ ConnectInfo
ci = m b -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m b -> m ()) -> (Redis b -> m b) -> Redis b -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnectInfo -> Redis b -> m b
forall (m :: * -> *) b. MonadIO m => ConnectInfo -> Redis b -> m b
connectAndRunRedis ConnectInfo
ci

createSession :: MonadIO m => SessionSettings -> m ByteString
createSession :: SessionSettings -> m ByteString
createSession SessionSettings{Integer
ConnectInfo
expiratinTime :: Integer
redisConnectionInfo :: ConnectInfo
expiratinTime :: SessionSettings -> Integer
redisConnectionInfo :: SessionSettings -> ConnectInfo
..} = IO ByteString -> m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
  ByteString
sesId <- IO ByteString
genSessionId
  ConnectInfo -> Redis (Either Reply Bool) -> IO (Either Reply Bool)
forall (m :: * -> *) b. MonadIO m => ConnectInfo -> Redis b -> m b
connectAndRunRedis ConnectInfo
redisConnectionInfo (Redis (Either Reply Bool) -> IO (Either Reply Bool))
-> Redis (Either Reply Bool) -> IO (Either Reply Bool)
forall a b. (a -> b) -> a -> b
$ do
    ByteString
-> ByteString -> ByteString -> Redis (Either Reply Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> ByteString -> ByteString -> m (f Integer)
hset ByteString
sesId ByteString
"" ByteString
""
    ByteString -> Integer -> Redis (Either Reply Bool)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> Integer -> m (f Bool)
expire ByteString
sesId Integer
expiratinTime
  ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
sesId

isSesIdValid :: MonadIO m => SessionSettings -> ByteString -> m Bool
isSesIdValid :: SessionSettings -> ByteString -> m Bool
isSesIdValid SessionSettings{Integer
ConnectInfo
expiratinTime :: Integer
redisConnectionInfo :: ConnectInfo
expiratinTime :: SessionSettings -> Integer
redisConnectionInfo :: SessionSettings -> ConnectInfo
..} ByteString
sesId = IO Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
  Either Reply Bool
res <- ConnectInfo -> Redis (Either Reply Bool) -> IO (Either Reply Bool)
forall (m :: * -> *) b. MonadIO m => ConnectInfo -> Redis b -> m b
connectAndRunRedis ConnectInfo
redisConnectionInfo (Redis (Either Reply Bool) -> IO (Either Reply Bool))
-> Redis (Either Reply Bool) -> IO (Either Reply Bool)
forall a b. (a -> b) -> a -> b
$
    ByteString -> Redis (Either Reply Bool)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> m (f Bool)
exists ByteString
sesId
  Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ Bool -> Either Reply Bool -> Bool
forall b a. b -> Either a b -> b
fromRight Bool
False Either Reply Bool
res

insertIntoSession :: MonadIO m => SessionSettings
  -> ByteString -- ^ Sessionn id
  -> ByteString -- ^ Key
  -> ByteString -- ^ Value
  -> m ()
insertIntoSession :: SessionSettings -> ByteString -> ByteString -> ByteString -> m ()
insertIntoSession SessionSettings{Integer
ConnectInfo
expiratinTime :: Integer
redisConnectionInfo :: ConnectInfo
expiratinTime :: SessionSettings -> Integer
redisConnectionInfo :: SessionSettings -> ConnectInfo
..} ByteString
sesId ByteString
key ByteString
value =
  ConnectInfo -> Redis (Either Reply Bool) -> m ()
forall (m :: * -> *) b. MonadIO m => ConnectInfo -> Redis b -> m ()
connectAndRunRedis_ ConnectInfo
redisConnectionInfo (Redis (Either Reply Bool) -> m ())
-> Redis (Either Reply Bool) -> m ()
forall a b. (a -> b) -> a -> b
$ do
  ByteString
-> ByteString -> ByteString -> Redis (Either Reply Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> ByteString -> ByteString -> m (f Integer)
hset ByteString
sesId ByteString
key ByteString
value
  ByteString -> Integer -> Redis (Either Reply Bool)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> Integer -> m (f Bool)
expire ByteString
sesId Integer
expiratinTime

lookupFromSession :: MonadIO m => SessionSettings
  -> ByteString -- ^ Session id
  -> ByteString -- ^ Key
  -> m (Maybe ByteString)
lookupFromSession :: SessionSettings -> ByteString -> ByteString -> m (Maybe ByteString)
lookupFromSession SessionSettings{Integer
ConnectInfo
expiratinTime :: Integer
redisConnectionInfo :: ConnectInfo
expiratinTime :: SessionSettings -> Integer
redisConnectionInfo :: SessionSettings -> ConnectInfo
..} ByteString
sesId ByteString
key = do
  Either Reply (Maybe ByteString)
v <- ConnectInfo
-> Redis (Either Reply (Maybe ByteString))
-> m (Either Reply (Maybe ByteString))
forall (m :: * -> *) b. MonadIO m => ConnectInfo -> Redis b -> m b
connectAndRunRedis ConnectInfo
redisConnectionInfo (Redis (Either Reply (Maybe ByteString))
 -> m (Either Reply (Maybe ByteString)))
-> Redis (Either Reply (Maybe ByteString))
-> m (Either Reply (Maybe ByteString))
forall a b. (a -> b) -> a -> b
$ do
    Either Reply (Maybe ByteString)
v <- ByteString -> ByteString -> Redis (Either Reply (Maybe ByteString))
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> ByteString -> m (f (Maybe ByteString))
hget ByteString
sesId ByteString
key
    ByteString -> Integer -> Redis (Either Reply Bool)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> Integer -> m (f Bool)
expire ByteString
sesId Integer
expiratinTime
    Either Reply (Maybe ByteString)
-> Redis (Either Reply (Maybe ByteString))
forall (m :: * -> *) a. Monad m => a -> m a
return Either Reply (Maybe ByteString)
v
  Maybe ByteString -> m (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> m (Maybe ByteString))
-> Maybe ByteString -> m (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ Maybe (Maybe ByteString) -> Maybe ByteString
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Maybe (Maybe ByteString) -> Maybe ByteString)
-> Maybe (Maybe ByteString) -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Either Reply (Maybe ByteString) -> Maybe (Maybe ByteString)
forall a b. Either a b -> Maybe b
eitherToMaybe Either Reply (Maybe ByteString)
v

-- | Invalidate session id
clearSession :: MonadIO m => SessionSettings
  -> ByteString -- ^ Session id
  -> m ()
clearSession :: SessionSettings -> ByteString -> m ()
clearSession SessionSettings{Integer
ConnectInfo
expiratinTime :: Integer
redisConnectionInfo :: ConnectInfo
expiratinTime :: SessionSettings -> Integer
redisConnectionInfo :: SessionSettings -> ConnectInfo
..} ByteString
sessionId =
  ConnectInfo -> Redis (Either Reply Integer) -> m ()
forall (m :: * -> *) b. MonadIO m => ConnectInfo -> Redis b -> m ()
connectAndRunRedis_ ConnectInfo
redisConnectionInfo (Redis (Either Reply Integer) -> m ())
-> Redis (Either Reply Integer) -> m ()
forall a b. (a -> b) -> a -> b
$
  [ByteString] -> Redis (Either Reply Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
[ByteString] -> m (f Integer)
del [ByteString
sessionId]

-- | Create new redis backend wai session store
dbStore :: (MonadIO m1, MonadIO m2, Eq k, Serialize k, Serialize v) => SessionSettings -> m2 (SessionStore m1 k v)
dbStore :: SessionSettings -> m2 (SessionStore m1 k v)
dbStore SessionSettings
s =
  SessionStore m1 k v -> m2 (SessionStore m1 k v)
forall (m :: * -> *) a. Monad m => a -> m a
return (SessionStore m1 k v -> m2 (SessionStore m1 k v))
-> SessionStore m1 k v -> m2 (SessionStore m1 k v)
forall a b. (a -> b) -> a -> b
$ SessionSettings -> SessionStore m1 k v
forall (m1 :: * -> *) (m2 :: * -> *) k v.
(MonadIO m1, MonadIO m2, Eq k, Serialize k, Serialize v,
 Monad m2) =>
SessionSettings
-> Maybe ByteString -> m2 (Session m1 k v, m2 ByteString)
dbStore' SessionSettings
s

dbStore' :: (MonadIO m1, MonadIO m2, Eq k, Serialize k, Serialize v, Monad m2) => SessionSettings -> Maybe ByteString -> m2 (Session m1 k v, m2 ByteString)
dbStore' :: SessionSettings
-> Maybe ByteString -> m2 (Session m1 k v, m2 ByteString)
dbStore' SessionSettings
s (Just ByteString
sesId) = do
  Bool
isValid <- SessionSettings -> ByteString -> m2 Bool
forall (m :: * -> *).
MonadIO m =>
SessionSettings -> ByteString -> m Bool
isSesIdValid SessionSettings
s ByteString
sesId
  if Bool
isValid
    then (Session m1 k v, m2 ByteString)
-> m2 (Session m1 k v, m2 ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (SessionSettings -> ByteString -> Session m1 k v
forall (m1 :: * -> *) k v.
(MonadIO m1, Eq k, Serialize k, Serialize v) =>
SessionSettings -> ByteString -> Session m1 k v
mkSessionFromSesId SessionSettings
s ByteString
sesId, ByteString -> m2 ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
sesId)
    else SessionSettings
-> Maybe ByteString -> m2 (Session m1 k v, m2 ByteString)
forall (m1 :: * -> *) (m2 :: * -> *) k v.
(MonadIO m1, MonadIO m2, Eq k, Serialize k, Serialize v,
 Monad m2) =>
SessionSettings
-> Maybe ByteString -> m2 (Session m1 k v, m2 ByteString)
dbStore' SessionSettings
s Maybe ByteString
forall a. Maybe a
Nothing
dbStore' SessionSettings
s Maybe ByteString
Nothing = do
  ByteString
sesId <- SessionSettings -> m2 ByteString
forall (m :: * -> *). MonadIO m => SessionSettings -> m ByteString
createSession SessionSettings
s
  (Session m1 k v, m2 ByteString)
-> m2 (Session m1 k v, m2 ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (SessionSettings -> ByteString -> Session m1 k v
forall (m1 :: * -> *) k v.
(MonadIO m1, Eq k, Serialize k, Serialize v) =>
SessionSettings -> ByteString -> Session m1 k v
mkSessionFromSesId SessionSettings
s ByteString
sesId, ByteString -> m2 ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
sesId)

mkSessionFromSesId :: (MonadIO m1, Eq k, Serialize k, Serialize v) => SessionSettings -> ByteString -> Session m1 k v
mkSessionFromSesId :: SessionSettings -> ByteString -> Session m1 k v
mkSessionFromSesId SessionSettings
s ByteString
sesId = (k -> m1 (Maybe v)
forall (m :: * -> *) b a.
(MonadIO m, Serialize b, Serialize a) =>
a -> m (Maybe b)
mkLookup, k -> v -> m1 ()
forall (m :: * -> *) a a.
(MonadIO m, Serialize a, Serialize a) =>
a -> a -> m ()
mkInsert)
  where
    mkLookup :: a -> m (Maybe b)
mkLookup a
k = IO (Maybe b) -> m (Maybe b)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (((Either String b -> Maybe b
forall a b. Either a b -> Maybe b
eitherToMaybe (Either String b -> Maybe b)
-> (ByteString -> Either String b) -> ByteString -> Maybe b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either String b
forall a. Serialize a => ByteString -> Either String a
decode) (ByteString -> Maybe b) -> Maybe ByteString -> Maybe b
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) (Maybe ByteString -> Maybe b)
-> IO (Maybe ByteString) -> IO (Maybe b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SessionSettings
-> ByteString -> ByteString -> IO (Maybe ByteString)
forall (m :: * -> *).
MonadIO m =>
SessionSettings -> ByteString -> ByteString -> m (Maybe ByteString)
lookupFromSession SessionSettings
s ByteString
sesId (a -> ByteString
forall a. Serialize a => a -> ByteString
encode a
k))
    mkInsert :: a -> a -> m ()
mkInsert a
k a
v = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ SessionSettings -> ByteString -> ByteString -> ByteString -> IO ()
forall (m :: * -> *).
MonadIO m =>
SessionSettings -> ByteString -> ByteString -> ByteString -> m ()
insertIntoSession SessionSettings
s ByteString
sesId (a -> ByteString
forall a. Serialize a => a -> ByteString
encode a
k) (a -> ByteString
forall a. Serialize a => a -> ByteString
encode a
v)