-- | Internal module exposing the guts of the package.  Use at
-- your own risk.  No API stability guarantees apply.
module Web.ServerSession.Frontend.Yesod.Internal
  ( simpleBackend
  , backend
  , IsSessionMap(..)
  , createCookie
  , findSessionId
  , forceInvalidate
  ) where

import Control.Monad (guard)
import Control.Monad.IO.Class (MonadIO)
import Data.ByteString (ByteString)
import Data.Default (def)
import Data.Text (Text)
import Web.PathPieces (toPathPiece)
import Web.ServerSession.Core
import Yesod.Core (MonadHandler)
import Yesod.Core.Handler (setSessionBS)
import Yesod.Core.Types (Header(AddCookie), SessionBackend(..))

import qualified Data.ByteString.Char8 as B8
import qualified Data.HashMap.Strict as HM
import qualified Data.Map as M
import qualified Data.Text.Encoding as TE
import qualified Data.Time as TI
import qualified Network.Wai as W
import qualified Web.Cookie as C


-- | Construct the server-side session backend using
-- the given storage backend.
--
-- Example usage for the Yesod scaffold using
-- @serversession-backend-persistent@:
--
-- @
-- import Web.ServerSession.Backend.Persistent (SqlStorage(..))
-- import Web.ServerSession.Frontend.Yesod (simpleBackend)
--
-- instance Yesod App where
--   ...
--   makeSessionBackend = simpleBackend id . SqlStorage . appConnPool
--   -- Do not forget to add migration code to your Application.hs!
--   -- Please check serversession-backend-persistent's documentation.
--   ...
-- @
--
-- For example, if you wanted to disable the idle timeout,
-- decrease the absolute timeout to one day and mark cookies as
-- \"Secure\", you could change that line to:
--
-- @
--   makeSessionBackend = simpleBackend opts . SqlStorage . appConnPool
--     where opts = setIdleTimeout Nothing
--                . setAbsoluteTimeout (Just $ 60*60*24)
--                . setSecureCookies True
-- @
--
-- This is a simple version of 'backend' specialized for using
-- 'SessionMap' as 'SessionData'.  If you want to use a different
-- session data type, please use 'backend' directly (tip: take a
-- peek at this function's source).
simpleBackend
  :: (MonadIO m, Storage sto, SessionData sto ~ SessionMap)
  => (State sto -> State sto) -- ^ Set any options on the @serversession@ state.
  -> sto                      -- ^ Storage backend.
  -> m (Maybe SessionBackend) -- ^ Yesod session backend (always @Just@).
simpleBackend :: (State sto -> State sto) -> sto -> m (Maybe SessionBackend)
simpleBackend State sto -> State sto
opts sto
s =
  Maybe SessionBackend -> m (Maybe SessionBackend)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe SessionBackend -> m (Maybe SessionBackend))
-> (State sto -> Maybe SessionBackend)
-> State sto
-> m (Maybe SessionBackend)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionBackend -> Maybe SessionBackend
forall a. a -> Maybe a
Just (SessionBackend -> Maybe SessionBackend)
-> (State sto -> SessionBackend)
-> State sto
-> Maybe SessionBackend
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State sto -> SessionBackend
forall sto.
(Storage sto, IsSessionMap (SessionData sto)) =>
State sto -> SessionBackend
backend (State sto -> SessionBackend)
-> (State sto -> State sto) -> State sto -> SessionBackend
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State sto -> State sto
opts (State sto -> m (Maybe SessionBackend))
-> m (State sto) -> m (Maybe SessionBackend)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< sto -> m (State sto)
forall (m :: * -> *) sto. MonadIO m => sto -> m (State sto)
createState sto
s


-- | Construct the server-side session backend using the given
-- state.  This is a generalized version of 'simpleBackend'.
--
-- In order to use the Yesod frontend, you 'SessionData' needs to
-- implement 'IsSessionMap'.
backend
  :: (Storage sto, IsSessionMap (SessionData sto))
  => State sto      -- ^ @serversession@ state, incl. storage backend.
  -> SessionBackend -- ^ Yesod session backend.
backend :: State sto -> SessionBackend
backend State sto
state = SessionBackend :: (Request -> IO (SessionMap, SaveSession)) -> SessionBackend
SessionBackend { sbLoadSession :: Request -> IO (SessionMap, SaveSession)
sbLoadSession = Request -> IO (SessionMap, SaveSession)
load }
  where
    load :: Request -> IO (SessionMap, SaveSession)
load Request
req = do
      let rawSessionId :: Maybe ByteString
rawSessionId = ByteString -> Request -> Maybe ByteString
findSessionId ByteString
cookieNameBS Request
req
      (SessionData sto
data_, SaveSessionToken sto
saveSessionToken) <- State sto
-> Maybe ByteString -> IO (SessionData sto, SaveSessionToken sto)
forall sto.
Storage sto =>
State sto
-> Maybe ByteString -> IO (SessionData sto, SaveSessionToken sto)
loadSession State sto
state Maybe ByteString
rawSessionId
      let save :: SaveSession
save =
            (Maybe (Session (SessionData sto)) -> [Header])
-> IO (Maybe (Session (SessionData sto))) -> IO [Header]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Header -> [Header] -> [Header]
forall a. a -> [a] -> [a]
:[]) (Header -> [Header])
-> (Maybe (Session (SessionData sto)) -> Header)
-> Maybe (Session (SessionData sto))
-> [Header]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Header
-> (Session (SessionData sto) -> Header)
-> Maybe (Session (SessionData sto))
-> Header
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (State sto -> ByteString -> Header
forall sto. State sto -> ByteString -> Header
deleteCookie State sto
state ByteString
cookieNameBS)
                                (State sto -> ByteString -> Session (SessionData sto) -> Header
forall sto sess. State sto -> ByteString -> Session sess -> Header
createCookie State sto
state ByteString
cookieNameBS)) (IO (Maybe (Session (SessionData sto))) -> IO [Header])
-> (SessionMap -> IO (Maybe (Session (SessionData sto))))
-> SaveSession
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
            State sto
-> SaveSessionToken sto
-> SessionData sto
-> IO (Maybe (Session (SessionData sto)))
forall sto.
Storage sto =>
State sto
-> SaveSessionToken sto
-> SessionData sto
-> IO (Maybe (Session (SessionData sto)))
saveSession State sto
state SaveSessionToken sto
saveSessionToken (SessionData sto -> IO (Maybe (Session (SessionData sto))))
-> (SessionMap -> SessionData sto)
-> SessionMap
-> IO (Maybe (Session (SessionData sto)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
            SessionMap -> SessionData sto
forall sess. IsSessionMap sess => SessionMap -> sess
fromSessionMap
      (SessionMap, SaveSession) -> IO (SessionMap, SaveSession)
forall (m :: * -> *) a. Monad m => a -> m a
return (SessionData sto -> SessionMap
forall sess. IsSessionMap sess => sess -> SessionMap
toSessionMap SessionData sto
data_, SaveSession
save)

    cookieNameBS :: ByteString
cookieNameBS = Text -> ByteString
TE.encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ State sto -> Text
forall sto. State sto -> Text
getCookieName State sto
state


----------------------------------------------------------------------


-- | Class for session data types meant to be used with the Yesod
-- frontend.  The only session interface Yesod provides is via
-- session variables, so your data type needs to be convertible
-- from/to a 'M.Map' of 'Text' to 'ByteString'.
class IsSessionMap sess where
  toSessionMap   :: sess -> M.Map Text ByteString
  fromSessionMap :: M.Map Text ByteString -> sess


instance IsSessionMap SessionMap where
  toSessionMap :: SessionMap -> SessionMap
toSessionMap   = [(Text, ByteString)] -> SessionMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Text, ByteString)] -> SessionMap)
-> (SessionMap -> [(Text, ByteString)]) -> SessionMap -> SessionMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMap Text ByteString -> [(Text, ByteString)]
forall k v. HashMap k v -> [(k, v)]
HM.toList (HashMap Text ByteString -> [(Text, ByteString)])
-> (SessionMap -> HashMap Text ByteString)
-> SessionMap
-> [(Text, ByteString)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionMap -> HashMap Text ByteString
unSessionMap
  fromSessionMap :: SessionMap -> SessionMap
fromSessionMap = HashMap Text ByteString -> SessionMap
SessionMap (HashMap Text ByteString -> SessionMap)
-> (SessionMap -> HashMap Text ByteString)
-> SessionMap
-> SessionMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Text, ByteString)] -> HashMap Text ByteString
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
HM.fromList ([(Text, ByteString)] -> HashMap Text ByteString)
-> (SessionMap -> [(Text, ByteString)])
-> SessionMap
-> HashMap Text ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionMap -> [(Text, ByteString)]
forall k a. Map k a -> [(k, a)]
M.toList


----------------------------------------------------------------------


-- | Create a cookie for the given session.
--
-- The cookie expiration is set via 'nextExpires'.  Note that
-- this is just an optimization, as the expiration is checked on
-- the server-side as well.
createCookie :: State sto -> ByteString -> Session sess -> Header
createCookie :: State sto -> ByteString -> Session sess -> Header
createCookie State sto
state ByteString
cookieNameBS Session sess
session =
  -- Generate a cookie with the final session ID.
  SetCookie -> Header
AddCookie SetCookie
forall a. Default a => a
def
    { setCookieName :: ByteString
C.setCookieName     = ByteString
cookieNameBS
    , setCookieValue :: ByteString
C.setCookieValue    = Text -> ByteString
TE.encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ SessionId sess -> Text
forall s. PathPiece s => s -> Text
toPathPiece (SessionId sess -> Text) -> SessionId sess -> Text
forall a b. (a -> b) -> a -> b
$ Session sess -> SessionId sess
forall sess. Session sess -> SessionId sess
sessionKey Session sess
session
    , setCookiePath :: Maybe ByteString
C.setCookiePath     = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"/"
    , setCookieExpires :: Maybe UTCTime
C.setCookieExpires  = State sto -> Session sess -> Maybe UTCTime
forall sto sess. State sto -> Session sess -> Maybe UTCTime
cookieExpires State sto
state Session sess
session
    , setCookieDomain :: Maybe ByteString
C.setCookieDomain   = Maybe ByteString
forall a. Maybe a
Nothing
    , setCookieHttpOnly :: Bool
C.setCookieHttpOnly = State sto -> Bool
forall sto. State sto -> Bool
getHttpOnlyCookies State sto
state
    , setCookieSecure :: Bool
C.setCookieSecure   = State sto -> Bool
forall sto. State sto -> Bool
getSecureCookies State sto
state
    }


-- | Remove the session cookie from the client.  This is used
-- when 'saveSession' returns @Nothing@:
--
--   * If the user didn't have a session cookie, this cookie
--   deletion will be harmless.
--
--   * If the user had a session cookie that was invalidated,
--   this will remove the invalid cookie from the client.
deleteCookie :: State sto -> ByteString -> Header
deleteCookie :: State sto -> ByteString -> Header
deleteCookie State sto
state ByteString
cookieNameBS =
  SetCookie -> Header
AddCookie SetCookie
forall a. Default a => a
def
    { setCookieName :: ByteString
C.setCookieName     = ByteString
cookieNameBS
    , setCookieValue :: ByteString
C.setCookieValue    = ByteString
""
    , setCookiePath :: Maybe ByteString
C.setCookiePath     = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"/"
    , setCookieExpires :: Maybe UTCTime
C.setCookieExpires  = UTCTime -> Maybe UTCTime
forall a. a -> Maybe a
Just UTCTime
aLongTimeAgo
    , setCookieMaxAge :: Maybe DiffTime
C.setCookieMaxAge   = DiffTime -> Maybe DiffTime
forall a. a -> Maybe a
Just DiffTime
0
    , setCookieDomain :: Maybe ByteString
C.setCookieDomain   = Maybe ByteString
forall a. Maybe a
Nothing
    , setCookieHttpOnly :: Bool
C.setCookieHttpOnly = State sto -> Bool
forall sto. State sto -> Bool
getHttpOnlyCookies State sto
state
    , setCookieSecure :: Bool
C.setCookieSecure   = State sto -> Bool
forall sto. State sto -> Bool
getSecureCookies State sto
state
    }
  where aLongTimeAgo :: UTCTime
aLongTimeAgo = String -> UTCTime
forall a. Read a => String -> a
read String
"1970-01-01 00:00:01 UTC" :: TI.UTCTime


-- | Fetch the 'SessionId' from the cookie with the given name.
-- Returns @Nothing@ if:
--
--   * There are zero cookies with the given name.
--
--   * There is more than one cookie with the given name.
findSessionId :: ByteString -> W.Request -> Maybe ByteString
findSessionId :: ByteString -> Request -> Maybe ByteString
findSessionId ByteString
cookieNameBS Request
req = do
  [ByteString
raw] <- [ByteString] -> Maybe [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ByteString] -> Maybe [ByteString])
-> [ByteString] -> Maybe [ByteString]
forall a b. (a -> b) -> a -> b
$ do
    (HeaderName
"Cookie", ByteString
header) <- Request -> RequestHeaders
W.requestHeaders Request
req
    (ByteString
k, ByteString
v) <- ByteString -> Cookies
C.parseCookies ByteString
header
    Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString
k ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
cookieNameBS)
    ByteString -> [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
v
  ByteString -> Maybe ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
raw


-- | Invalidate the current session ID (and possibly more, check
-- 'ForceInvalidate').  This is useful to avoid session fixation
-- attacks (cf. <http://www.acrossecurity.com/papers/session_fixation.pdf>).
--
-- Note that the invalidate /does not/ occur when the call to
-- this action is made!  The sessions will be invalidated on the
-- end of the handler processing.  This means that later calls to
-- 'forceInvalidate' on the same handler will override earlier
-- calls.
--
-- This function works by setting a session variable that is
-- checked when saving the session.  The session variable set by
-- this function is then discarded and is not persisted across
-- requests.
forceInvalidate :: MonadHandler m => ForceInvalidate -> m ()
forceInvalidate :: ForceInvalidate -> m ()
forceInvalidate = Text -> ByteString -> m ()
forall (m :: * -> *). MonadHandler m => Text -> ByteString -> m ()
setSessionBS Text
forceInvalidateKey (ByteString -> m ())
-> (ForceInvalidate -> ByteString) -> ForceInvalidate -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
B8.pack (String -> ByteString)
-> (ForceInvalidate -> String) -> ForceInvalidate -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForceInvalidate -> String
forall a. Show a => a -> String
show