module Network.Wai.Session.PostgreSQL
( dbStore
, defaultSettings
, clearSession
, purgeOldSessions
, purger
, ratherSecureGen
, WithPostgreSQLConn (..)
, StoreSettings (..)
) where
import Control.Concurrent
import Control.Exception.Base
import Control.Exception
import Control.Monad
import Control.Monad.IO.Class
import Data.Default
import Data.Int (Int64)
import Data.Serialize (encode, decode, Serialize)
import Data.String (fromString)
import Data.Time.Clock.POSIX (getPOSIXTime)
import Database.PostgreSQL.Simple
import Network.Wai (Request, requestHeaders)
import Network.Wai.Session
import Numeric (showHex)
import System.Entropy (getEntropy)
import Web.Cookie (parseCookies)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
data StoreSettings = StoreSettings {
storeSettingsSessionTimeout :: Int64
, storeSettingsKeyGen :: IO B.ByteString
, storeSettingsCreateTable :: Bool
, storeSettingsLog :: String -> IO ()
, storeSettingsPurgeInterval :: Int
}
instance Default StoreSettings where
def = defaultSettings
class WithPostgreSQLConn a where
withPostgreSQLConn :: a -> (Connection -> IO b) -> IO b
instance WithPostgreSQLConn Connection where
withPostgreSQLConn conn = bracket (return conn) (\_ -> return ())
qryCreateTable = "CREATE TABLE session (id bigserial NOT NULL, session_key character varying NOT NULL, session_created_at bigint NOT NULL, session_last_access bigint NOT NULL, session_value bytea NOT NULL, session_invalidate_key bool NOT NULL DEFAULT FALSE, CONSTRAINT session_pkey PRIMARY KEY (id), CONSTRAINT session_session_key_key UNIQUE (session_key)) WITH (OIDS=FALSE);"
qryCreateSession = "INSERT INTO session (session_key, session_created_at, session_last_access, session_value) VALUES (?,?,?,?)"
qryUpdateSession = "UPDATE session SET session_value=?,session_last_access=? WHERE session_key=?"
qryLookupSession = "SELECT session_value FROM session WHERE session_key=? AND session_last_access>=?"
qryLookupSession' = "UPDATE session SET session_last_access=? WHERE session_key=?"
qryLookupSession'' = "SELECT session_value FROM session WHERE session_key=?"
qryPurgeOldSessions = "DELETE FROM session WHERE session_last_access<?"
qryCheckNewKey = "SELECT session_invalidate_key FROM session WHERE session_key=?"
qryInvalidateSess = "UPDATE session SET session_value=?,session_invalidate_key=TRUE WHERE session_key=?"
qryUpdateKey = "UPDATE session SET session_key=?,session_invalidate_key=FALSE WHERE session_key=?"
dbStore :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> StoreSettings -> IO (SessionStore m k v)
dbStore pool stos = do
when (storeSettingsCreateTable stos) $
withPostgreSQLConn pool $ \ conn ->
unerror $ execute_ conn qryCreateTable
return $ dbStore' pool stos
purgeOldSessions :: WithPostgreSQLConn a => a -> StoreSettings -> IO Int64
purgeOldSessions pool stos = do
curtime <- round <$> liftIO getPOSIXTime
count <- withPostgreSQLConn pool $ \ conn ->
execute conn qryPurgeOldSessions (Only (curtime storeSettingsSessionTimeout stos))
storeSettingsLog stos $ "Purged " ++ show count ++ " session(s)."
return count
purger :: WithPostgreSQLConn a => a -> StoreSettings -> IO ThreadId
purger pool stos = forkIO . forever . unerror $ do
purgeOldSessions pool stos
threadDelay $ storeSettingsPurgeInterval stos
defaultSettings :: StoreSettings
defaultSettings = StoreSettings
{ storeSettingsSessionTimeout=3600
, storeSettingsKeyGen=ratherSecureGen 24
, storeSettingsCreateTable=True
, storeSettingsLog=putStrLn
, storeSettingsPurgeInterval=600000000
}
ratherSecureGen :: Int -> IO B.ByteString
ratherSecureGen n = TE.encodeUtf8 . prettyPrint <$> getEntropy n
prettyPrint :: B.ByteString -> T.Text
prettyPrint = T.pack . concatMap (`showHex` "") . B.unpack
dbStore' :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> StoreSettings -> SessionStore m k v
dbStore' pool stos Nothing = do
newKey <- storeSettingsKeyGen stos
let map = [] :: [(k, v)]
map' = ""
curtime <- round <$> liftIO getPOSIXTime
withPostgreSQLConn pool $ \ conn ->
void $ execute conn qryCreateSession (newKey, curtime :: Int64, curtime, Binary (map' :: B.ByteString))
backend pool stos newKey map
dbStore' pool stos (Just key) = do
let map = [] :: [(k, v)]
map' = "\"\""
curtime <- round <$> liftIO getPOSIXTime
res <- withPostgreSQLConn pool $ \ conn ->
query conn qryLookupSession (key, curtime storeSettingsSessionTimeout stos) :: IO [Only B.ByteString]
case res of
[Only _] -> backend pool stos key map
_ -> dbStore' pool stos Nothing
clearSession :: (WithPostgreSQLConn a) => a -> B.ByteString -> Request -> IO ()
clearSession pool cookieName req = do
let map = [] :: [(k, v)]
map' = ""
cookies = fmap parseCookies $ lookup (fromString "Cookie") (requestHeaders req)
Just key = lookup cookieName =<< cookies
withPostgreSQLConn pool $ \ conn ->
void $ execute conn qryInvalidateSess (Binary (map' :: B.ByteString), key)
backend :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> StoreSettings -> B.ByteString -> [(k, v)] -> IO (Session m k v, IO B.ByteString)
backend pool stos key mappe = do
return ( (
(reader pool key mappe)
, (writer pool key mappe) ), withPostgreSQLConn pool $ \conn -> do
[Only shouldNewKey] <- query conn qryCheckNewKey (Only key)
if shouldNewKey then do
newKey' <- storeSettingsKeyGen stos
execute conn qryUpdateKey (newKey', key)
return newKey'
else
return key
)
reader :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> B.ByteString -> [(k, v)] -> k -> m (Maybe v)
reader pool key mappe k = do
curtime <- round <$> liftIO getPOSIXTime
res <- liftIO $ withPostgreSQLConn pool $ \conn -> do
void $ execute conn qryLookupSession' (curtime :: Int64, key)
query conn qryLookupSession'' (Only key)
case res of
[Only store'] -> case decode (fromBinary store') of
Right store -> return $ k `lookup` store
Left error -> return Nothing
[] -> return Nothing
writer :: (WithPostgreSQLConn a, Serialize k, Eq k, Serialize v, MonadIO m) => a -> B.ByteString -> [(k, v)] -> k -> v -> m ()
writer pool key mappe k v = do
curtime <- round <$> liftIO getPOSIXTime
[Only store] <- liftIO $ withPostgreSQLConn pool $ \conn ->
query conn qryLookupSession'' (Only key)
let store' = case decode (fromBinary store) of
Right s -> s
_ -> []
store'' = ((k,v):) . filter ((/=k) . fst) $ store'
store''' = encode store''
liftIO $ withPostgreSQLConn pool $ \conn ->
void $ execute conn qryUpdateSession (Binary store''', curtime :: Int64, key)
ignoreSqlError :: SqlError -> IO ()
ignoreSqlError _ = pure ()
unerror :: IO a -> IO ()
unerror action = void action `catch` ignoreSqlError