-- | Internal module exposing the guts of the package.  Use at
-- your own risk.  No API stability guarantees apply.
module Web.ServerSession.Frontend.Wai.Internal
  ( withServerSession
  , sessionStore
  , mkSession
  , KeyValue(..)
  , createCookieTemplate
  , calculateMaxAge
  , forceInvalidate
  ) where

import Control.Applicative as A
import Control.Monad (guard)
import Control.Monad.IO.Class (MonadIO(..))
import Data.ByteString (ByteString)
import Data.Default (def)
import Data.Kind (Type)
import Data.Text (Text)
import Web.PathPieces (toPathPiece)
import Web.ServerSession.Core
import Web.ServerSession.Core.Internal (absoluteTimeout, idleTimeout, persistentCookies)

import qualified Data.ByteString.Char8 as B8
import qualified Data.HashMap.Strict as HM
import qualified Data.IORef as I
import qualified Data.Text.Encoding as TE
import qualified Data.Time as TI
import qualified Data.Vault.Lazy as V
import qualified Network.Wai as W
import qualified Network.Wai.Session as WS
import qualified Web.Cookie as C


-- | Construct the @wai-session@ middleware using the given
-- storage backend and options.  This is a convenient function
-- that uses 'WS.withSession', 'createState', 'sessionStore',
-- 'getCookieName' and 'createCookieTemplate'.
withServerSession
  :: (Functor m, MonadIO m, MonadIO n, Storage sto, SessionData sto ~ SessionMap)
  => V.Key (WS.Session m Text ByteString) -- ^ 'V.Vault' key to use when passing the session through.
  -> (State sto -> State sto)             -- ^ Set any options on the @serversession@ state.
  -> sto                                  -- ^ Storage backend.
  -> n W.Middleware
withServerSession :: Key (Session m Text ByteString)
-> (State sto -> State sto) -> sto -> n Middleware
withServerSession Key (Session m Text ByteString)
key State sto -> State sto
opts sto
storage = IO Middleware -> n Middleware
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Middleware -> n Middleware) -> IO Middleware -> n Middleware
forall a b. (a -> b) -> a -> b
$ do
  State sto
st <- State sto -> State sto
opts (State sto -> State sto) -> IO (State sto) -> IO (State sto)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> sto -> IO (State sto)
forall (m :: * -> *) sto. MonadIO m => sto -> m (State sto)
createState sto
storage
  Middleware -> IO Middleware
forall (m :: * -> *) a. Monad m => a -> m a
return (Middleware -> IO Middleware) -> Middleware -> IO Middleware
forall a b. (a -> b) -> a -> b
$
    SessionStore m Text ByteString
-> ByteString
-> SetCookie
-> Key (Session m Text ByteString)
-> Middleware
forall (m :: * -> *) k v.
SessionStore m k v
-> ByteString -> SetCookie -> Key (Session m k v) -> Middleware
WS.withSession
      (State sto
-> SessionStore m (Key (SessionData sto)) (Value (SessionData sto))
forall (m :: * -> *) sto.
(Functor m, MonadIO m, Storage sto, KeyValue (SessionData sto)) =>
State sto
-> SessionStore m (Key (SessionData sto)) (Value (SessionData sto))
sessionStore State sto
st)
      (Text -> ByteString
TE.encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ State sto -> Text
forall sto. State sto -> Text
getCookieName State sto
st)
      (State sto -> SetCookie
forall sto. State sto -> SetCookie
createCookieTemplate State sto
st)
      Key (Session m Text ByteString)
key


-- | Construct the @wai-session@ session store using the given
-- state.  Note that keys and values types are fixed.
--
-- As @wai-session@ always requires a value to be provided, we
-- return an empty @ByteString@ when the empty session was not
-- saved.
sessionStore
  :: (Functor m, MonadIO m, Storage sto, KeyValue (SessionData sto))
  => State sto -- ^ @serversession@ state, incl. storage backend.
  -> WS.SessionStore m (Key (SessionData sto)) (Value (SessionData sto))
     -- ^ @wai-session@ session store.
sessionStore :: State sto
-> SessionStore m (Key (SessionData sto)) (Value (SessionData sto))
sessionStore State sto
state =
  \Maybe ByteString
mcookieVal -> do
    (SessionData sto
data1, SaveSessionToken sto
saveSessionToken) <- State sto
-> Maybe ByteString -> IO (SessionData sto, SaveSessionToken sto)
forall sto.
Storage sto =>
State sto
-> Maybe ByteString -> IO (SessionData sto, SaveSessionToken sto)
loadSession State sto
state Maybe ByteString
mcookieVal
    IORef (SessionData sto)
sessionRef <- SessionData sto -> IO (IORef (SessionData sto))
forall a. a -> IO (IORef a)
I.newIORef SessionData sto
data1
    let save :: IO ByteString
save = do
          SessionData sto
data2 <- IORef (SessionData sto)
-> (SessionData sto -> (SessionData sto, SessionData sto))
-> IO (SessionData sto)
forall a b. IORef a -> (a -> (a, b)) -> IO b
I.atomicModifyIORef' IORef (SessionData sto)
sessionRef ((SessionData sto -> (SessionData sto, SessionData sto))
 -> IO (SessionData sto))
-> (SessionData sto -> (SessionData sto, SessionData sto))
-> IO (SessionData sto)
forall a b. (a -> b) -> a -> b
$ \SessionData sto
a -> (SessionData sto
a, SessionData sto
a)
          Maybe (Session (SessionData sto))
msession <- State sto
-> SaveSessionToken sto
-> SessionData sto
-> IO (Maybe (Session (SessionData sto)))
forall sto.
Storage sto =>
State sto
-> SaveSessionToken sto
-> SessionData sto
-> IO (Maybe (Session (SessionData sto)))
saveSession State sto
state SaveSessionToken sto
saveSessionToken SessionData sto
data2
          ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
-> (Session (SessionData sto) -> ByteString)
-> Maybe (Session (SessionData sto))
-> ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ByteString
"" (Text -> ByteString
TE.encodeUtf8 (Text -> ByteString)
-> (Session (SessionData sto) -> Text)
-> Session (SessionData sto)
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionId (SessionData sto) -> Text
forall s. PathPiece s => s -> Text
toPathPiece (SessionId (SessionData sto) -> Text)
-> (Session (SessionData sto) -> SessionId (SessionData sto))
-> Session (SessionData sto)
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Session (SessionData sto) -> SessionId (SessionData sto)
forall sess. Session sess -> SessionId sess
sessionKey) Maybe (Session (SessionData sto))
msession
    (Session m (Key (SessionData sto)) (Value (SessionData sto)),
 IO ByteString)
-> IO
     (Session m (Key (SessionData sto)) (Value (SessionData sto)),
      IO ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (IORef (SessionData sto)
-> Session m (Key (SessionData sto)) (Value (SessionData sto))
forall (m :: * -> *) sess.
(Functor m, MonadIO m, KeyValue sess) =>
IORef sess -> Session m (Key sess) (Value sess)
mkSession IORef (SessionData sto)
sessionRef, IO ByteString
save)


-- | Build a 'WS.Session' from an 'I.IORef' containing the
-- session data.
mkSession :: (Functor m, MonadIO m, KeyValue sess) => I.IORef sess -> WS.Session m (Key sess) (Value sess)
mkSession :: IORef sess -> Session m (Key sess) (Value sess)
mkSession IORef sess
sessionRef =
  -- We need to use atomicModifyIORef instead of readIORef
  -- because latter may be reordered (cf. "Memory Model" on
  -- Data.IORef's documentation).
  ( \Key sess
k   -> Key sess -> sess -> Maybe (Value sess)
forall sess.
KeyValue sess =>
Key sess -> sess -> Maybe (Value sess)
kvLookup Key sess
k (sess -> Maybe (Value sess)) -> m sess -> m (Maybe (Value sess))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
A.<$> IO sess -> m sess
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IORef sess -> (sess -> (sess, sess)) -> IO sess
forall a b. IORef a -> (a -> (a, b)) -> IO b
I.atomicModifyIORef' IORef sess
sessionRef ((sess -> (sess, sess)) -> IO sess)
-> (sess -> (sess, sess)) -> IO sess
forall a b. (a -> b) -> a -> b
$ \sess
a -> (sess
a, sess
a))
  , \Key sess
k Value sess
v -> IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IORef sess -> (sess -> (sess, ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
I.atomicModifyIORef' IORef sess
sessionRef ((sess -> (sess, ())) -> IO ()) -> (sess -> (sess, ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ (sess -> () -> (sess, ())) -> () -> sess -> (sess, ())
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,) () (sess -> (sess, ())) -> (sess -> sess) -> sess -> (sess, ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Key sess -> Value sess -> sess -> sess
forall sess.
KeyValue sess =>
Key sess -> Value sess -> sess -> sess
kvInsert Key sess
k Value sess
v)
  )


----------------------------------------------------------------------


-- | Class for session data types that can be used as key-value
-- stores.
class IsSessionData sess => KeyValue sess where
  type Key   sess :: Type
  type Value sess :: Type
  kvLookup :: Key sess -> sess -> Maybe (Value sess)
  kvInsert :: Key sess -> Value sess -> sess -> sess


instance KeyValue SessionMap where
  type Key   SessionMap = Text
  type Value SessionMap = ByteString
  kvLookup :: Key SessionMap -> SessionMap -> Maybe (Value SessionMap)
kvLookup Key SessionMap
k = Text -> HashMap Text ByteString -> Maybe ByteString
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HM.lookup Text
Key SessionMap
k (HashMap Text ByteString -> Maybe ByteString)
-> (SessionMap -> HashMap Text ByteString)
-> SessionMap
-> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionMap -> HashMap Text ByteString
unSessionMap
  kvInsert :: Key SessionMap -> Value SessionMap -> SessionMap -> SessionMap
kvInsert Key SessionMap
k Value SessionMap
v (SessionMap HashMap Text ByteString
m) = HashMap Text ByteString -> SessionMap
SessionMap (Text
-> ByteString -> HashMap Text ByteString -> HashMap Text ByteString
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HM.insert Text
Key SessionMap
k ByteString
Value SessionMap
v HashMap Text ByteString
m)


----------------------------------------------------------------------


-- | Create a cookie template given a state.
--
-- Since we don't have access to the 'Session', we can't fill the
-- @Expires@ field.  Besides, as the template is constant,
-- eventually the @Expires@ field would become outdated.  This is
-- a limitation of @wai-session@'s interface, not a
-- @serversession@ limitation.  Other frontends support the
-- @Expires@ field.
--
-- Instead, we fill only the @Max-age@ field.  It works fine for
-- modern browsers, but many don't support it and will treat the
-- cookie as non-persistent (notably IE 6, 7 and 8).
createCookieTemplate :: State sto -> C.SetCookie
createCookieTemplate :: State sto -> SetCookie
createCookieTemplate State sto
state =
  -- Generate a cookie with the final session ID.
  SetCookie
forall a. Default a => a
def
    { setCookiePath :: Maybe ByteString
C.setCookiePath     = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"/"
    , setCookieMaxAge :: Maybe DiffTime
C.setCookieMaxAge   = State sto -> Maybe DiffTime
forall sto. State sto -> Maybe DiffTime
calculateMaxAge State sto
state
    , setCookieDomain :: Maybe ByteString
C.setCookieDomain   = Maybe ByteString
forall a. Maybe a
Nothing
    , setCookieHttpOnly :: Bool
C.setCookieHttpOnly = State sto -> Bool
forall sto. State sto -> Bool
getHttpOnlyCookies State sto
state
    , setCookieSecure :: Bool
C.setCookieSecure   = State sto -> Bool
forall sto. State sto -> Bool
getSecureCookies State sto
state
    }


-- | Calculate the @Max-age@ of a cookie template for the given
-- state.
--
--   * If the state asks for non-persistent sessions, the result
--   is @Nothing@.
--
--   * If no timeout is defined, the result is 10 years.
--
--   * Otherwise, the max age is set as the maximum timeout.
calculateMaxAge :: State sto -> Maybe TI.DiffTime
calculateMaxAge :: State sto -> Maybe DiffTime
calculateMaxAge State sto
st = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (State sto -> Bool
forall sto. State sto -> Bool
persistentCookies State sto
st)
  DiffTime -> Maybe DiffTime
forall (m :: * -> *) a. Monad m => a -> m a
return (DiffTime -> Maybe DiffTime) -> DiffTime -> Maybe DiffTime
forall a b. (a -> b) -> a -> b
$ DiffTime
-> (NominalDiffTime -> DiffTime)
-> Maybe NominalDiffTime
-> DiffTime
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (DiffTime
60DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
*DiffTime
60DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
*DiffTime
24DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
*DiffTime
3652) NominalDiffTime -> DiffTime
forall a b. (Real a, Fractional b) => a -> b
realToFrac
         (Maybe NominalDiffTime -> DiffTime)
-> Maybe NominalDiffTime -> DiffTime
forall a b. (a -> b) -> a -> b
$ State sto -> Maybe NominalDiffTime
forall sto. State sto -> Maybe NominalDiffTime
idleTimeout State sto
st Maybe NominalDiffTime
-> Maybe NominalDiffTime -> Maybe NominalDiffTime
forall a. Ord a => a -> a -> a
`max` State sto -> Maybe NominalDiffTime
forall sto. State sto -> Maybe NominalDiffTime
absoluteTimeout State sto
st


-- | Invalidate the current session ID (and possibly more, check
-- 'ForceInvalidate').  This is useful to avoid session fixation
-- attacks (cf. <http://www.acrossecurity.com/papers/session_fixation.pdf>).
forceInvalidate :: WS.Session m Text ByteString -> ForceInvalidate -> m ()
forceInvalidate :: Session m Text ByteString -> ForceInvalidate -> m ()
forceInvalidate (Text -> m (Maybe ByteString)
_, Text -> ByteString -> m ()
insert) = Text -> ByteString -> m ()
insert Text
forceInvalidateKey (ByteString -> m ())
-> (ForceInvalidate -> ByteString) -> ForceInvalidate -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
B8.pack (String -> ByteString)
-> (ForceInvalidate -> String) -> ForceInvalidate -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForceInvalidate -> String
forall a. Show a => a -> String
show