{-|
Description: Spock Session implementation for Beam

This provides one essential function, 'patchConfig' to replace Spock's default
in-RAM session implementation with a Beam/postgres one.
-}
module TsWeb.Session
  ( patchConfig
  , UserData(..)
  ) where

import qualified TsWeb.Tables.Session as T
import qualified TsWeb.Types.Db as Db

import TsWeb.Tables.Session (SessionT(..))
import TsWeb.Types.Db (ReadOnlyConn, ReadWriteConn, SomeConn)

import qualified Database.Beam as Beam

import Control.Monad (unless)
import Data.Maybe (fromJust)
import Data.Pool (Pool)
import Data.Time.Format (buildTime, defaultTimeLocale)
import Database.Beam ((==.), all_, guard_, val_)
import Database.Beam.Backend.SQL.SQL92 (HasSqlValueSyntax)
import Database.Beam.Postgres (Postgres)
import Database.Beam.Postgres.Syntax (PgValueSyntax)
import Database.Beam.Schema.Tables (FieldsFulfillConstraint)
import Web.Spock.Config
  ( SessionStore(..)
  , SessionStoreInstance(..)
  , SpockCfg(..)
  , sc_store
  )
import Web.Spock.Internal.SessionManager as SM (Session(..), SessionId)

data TxAction s next
  = LoadSession SM.SessionId
                (Maybe s -> next)
  | DeleteSession SM.SessionId
                  next
  | StoreSession s
                 next
  | ToList ([s] -> next)
  | FilterSessions (s -> Bool)
                   next

type TxProgram s = Free (TxAction s)

-- | Update a spock configuration stanza to replace its default session with
-- one backed by postgres.  This will typically be used like so:
--
-- @
--  spockCfg <-
--    'patchConfig' (_dbSession db) ropool rwpool <$>
--    'Web.Spock.Config.defaultSpockCfg' sess PCNoDatabase ()
--  'Web.Spock.runSpock' port ('Web.Spock.spock' spockCfg routes)
--  where
--    sess = ...
--    routes = ...
-- @
--
patchConfig ::
     ( Beam.FromBackendRow Postgres (sessdata Beam.Identity)
     , Beam.Beamable sessdata
     , Beam.Typeable sessdata
     , Beam.Database Postgres db
     , UserData (sessdata Beam.Identity)
     , FieldsFulfillConstraint (HasSqlValueSyntax PgValueSyntax) sessdata
     )
  => Beam.DatabaseEntity Postgres db (Beam.TableEntity (SessionT sessdata))
  -> Pool ReadOnlyConn
  -> Pool ReadWriteConn
  -> SpockCfg conn (sessdata Beam.Identity) st
  -> SpockCfg conn (sessdata Beam.Identity) st
patchConfig session ropool rwpool conf = conf {spc_sessionCfg = updSession}
  where
    updSession = sessCfg {sc_store = customStore}
    sessCfg = spc_sessionCfg conf
    customStore = SessionStoreInstance $ newPgSessionStore session ropool rwpool

-- | Typeclass for user-supplied data. We really just need to know whether the
-- user has set a remember-me indicator upon login so that the session's
-- lifespan can be intelligently controlled. If your session never needs to be
-- remembers, then @rememberMe = const False@ should suffice.
class UserData c where
  -- | Should the associated session be stored permanently?
  rememberMe :: c -> Bool

newPgSessionStore ::
     ( Beam.FromBackendRow Postgres (sessdata Beam.Identity)
     , Beam.Beamable sessdata
     , Beam.Typeable sessdata
     , Beam.Database Postgres db
     , UserData (sessdata Beam.Identity)
     , FieldsFulfillConstraint (HasSqlValueSyntax PgValueSyntax) sessdata
     )
  => Beam.DatabaseEntity Postgres db (Beam.TableEntity (SessionT sessdata))
  -> Pool ReadOnlyConn
  -> Pool ReadWriteConn
  -> SessionStore (SM.Session conn (sessdata Beam.Identity) st) (TxProgram (SM.Session conn (sessdata Beam.Identity) st))
newPgSessionStore session ropool rwpool =
  SessionStore
    { ss_runTx = runTxProgram session ropool rwpool
    , ss_loadSession = liftF . flip LoadSession id
    , ss_deleteSession = liftF . flip DeleteSession ()
    , ss_storeSession = liftF . flip StoreSession ()
    , ss_toList = liftF $ ToList id
    , ss_filterSessions =
        \fn -> do
          ss <- liftF $ ToList id
          mapM_ (\s -> unless (fn s) (liftF $ DeleteSession (sess_id s) ())) ss
    , ss_mapSessions =
        \fn -> do
          ss <- liftF $ ToList id
          mapM_
            (\s -> do
               s' <- fn s
               liftF $ StoreSession s' ())
            ss
    }

loadSession' ::
     ( Beam.FromBackendRow Postgres (sessdata Beam.Identity)
     , Beam.Beamable sessdata
     , Beam.Typeable sessdata
     , Beam.Database Postgres db
     )
  => Beam.DatabaseEntity Postgres db (Beam.TableEntity (SessionT sessdata))
  -> SomeConn c
  -> SM.SessionId
  -> IO (Maybe (SM.Session conn (sessdata Beam.Identity) st))
loadSession' session conn sessionid = do
  match <-
    Db.readOnly conn $
    Beam.runSelectReturningOne $
    Beam.select $ do
      sess <- all_ session
      guard_ (_sessionId sess ==. val_ sessionid)
      pure sess
  case match of
    Nothing -> return Nothing
    Just result ->
      return $
      Just
        (SM.Session
           (_sessionId result)
           (_sessionCsrf result)
           (_sessionExpires result)
           (_sessionData result))

deleteSession' ::
     Beam.DatabaseEntity Postgres db (Beam.TableEntity (SessionT sessdata))
  -> ReadWriteConn
  -> SM.SessionId
  -> IO ()
deleteSession' session conn sessionid =
  Db.readWrite conn $
  Beam.runDelete $
  Beam.delete session (\sess -> _sessionId sess ==. val_ sessionid)

storeSession' ::
     ( Beam.FromBackendRow Postgres (sessdata Beam.Identity)
     , Beam.Beamable sessdata
     , Beam.Typeable sessdata
     , Beam.Database Postgres db
     , UserData (sessdata Beam.Identity)
     , FieldsFulfillConstraint (HasSqlValueSyntax PgValueSyntax) sessdata
     )
  => Beam.DatabaseEntity Postgres db (Beam.TableEntity (SessionT sessdata))
  -> ReadWriteConn
  -> SM.Session conn (sessdata Beam.Identity) st
  -> IO ()
storeSession' session conn sess = Db.readWrite conn go
  where
    go =
      getSess >>= \case
        Nothing ->
          Beam.runInsert $
          Beam.insert session $
          Beam.insertExpressions
            [ T.Session
                (val_ $ sess_id sess)
                (val_ $ sess_csrfToken sess)
                (val_ $
                 if rememberMe (sess_data sess)
                   then farOut
                   else sess_validUntil sess)
                (val_ $ sess_data sess)
            ]
        Just exist ->
          Beam.runUpdate $
          Beam.save
            session
            (exist
               { _sessionCsrf = sess_csrfToken sess
               , _sessionExpires =
                   if rememberMe (sess_data sess)
                     then farOut
                     else sess_validUntil sess
               , _sessionData = sess_data sess
               })
    getSess =
      Beam.runSelectReturningOne $
      Beam.lookup_ session (T.SessionId $ sess_id sess)
    farOut = fromJust $ buildTime defaultTimeLocale [('Y', "2999")]

toList' ::
     ( Beam.FromBackendRow Postgres (sessdata Beam.Identity)
     , Beam.Beamable sessdata
     , Beam.Typeable sessdata
     , Beam.Database Postgres db
     )
  => Beam.DatabaseEntity Postgres db (Beam.TableEntity (SessionT sessdata))
  -> SomeConn a
  -> IO [SM.Session conn (sessdata Beam.Identity) st]
toList' session conn = do
  sessions <- go
  return $
    [ SM.Session
      (_sessionId s)
      (_sessionCsrf s)
      (_sessionExpires s)
      (_sessionData s)
    | s <- sessions
    ]
  where
    go =
      Db.readOnly conn $
      Beam.runSelectReturningList $ Beam.select $ all_ session

instance Functor (TxAction s) where
  fmap f (LoadSession sessid g) = LoadSession sessid (f . g)
  fmap f (DeleteSession sessid x) = DeleteSession sessid (f x)
  fmap f (StoreSession sess x) = StoreSession sess (f x)
  fmap f (ToList g) = ToList (f . g)
  fmap f (FilterSessions g x) = FilterSessions g (f x)

data Free f r
  = Free (f (Free f r))
  | Pure r

instance Functor f => Functor (Free f) where
  fmap f = go
    where
      go (Pure a) = Pure (f a)
      go (Free fa) = Free (go <$> fa)

instance (Functor f) => Applicative (Free f) where
  pure = Pure
  Pure a <*> Pure b = Pure $ a b
  Pure a <*> Free mb = Free $ fmap a <$> mb
  Free ma <*> b = Free $ (<*> b) <$> ma

instance (Functor f) => Monad (Free f) where
  return = Pure
  (Free x) >>= f = Free (fmap (>>= f) x)
  (Pure r) >>= f = f r

instance Show sessdata => Show (TxProgram (SM.Session conn sessdata st) a) where
  show = show'
    where
      show' (Free (LoadSession sid g)) =
        "Load " ++ (show sid) ++ " / " ++ show' (g Nothing)
      show' (Free (DeleteSession sid g)) =
        "Delete " ++ (show sid) ++ " / " ++ show' g
      show' (Free (StoreSession sid g)) =
        "Store " ++ (show $ sess_id sid) ++ " / " ++ show' g
      show' (Free (ToList g)) = "ToList / " ++ show' (g [])
      show' (Free (FilterSessions _ g)) = "Filter / " ++ show' g
      show' (Pure _) = "Pure"

runTxProgram ::
     ( Beam.FromBackendRow Postgres (sessdata Beam.Identity)
     , Beam.Beamable sessdata
     , Beam.Typeable sessdata
     , Beam.Database Postgres db
     , UserData (sessdata Beam.Identity)
     , FieldsFulfillConstraint (HasSqlValueSyntax PgValueSyntax) sessdata
     )
  => Beam.DatabaseEntity Postgres db (Beam.TableEntity (SessionT sessdata))
  -> Db.ReadOnlyPool
  -> Db.ReadWritePool
  -> TxProgram (SM.Session conn (sessdata Beam.Identity) st) a
  -> IO a
runTxProgram session ropool rwpool tx =
  (Db.withConnection ropool $ \conn -> ro conn tx) >>= \case
    Just res -> return res {- putStrLn "Ran read-only query" >> -}
    Nothing {- do
      putStrLn "Running read-write query" -}
     -> Db.withConnection rwpool $ flip rw tx
  where
    ro conn (Free (LoadSession sessid g)) =
      loadSession' session conn sessid >>= ro conn . g
    ro conn (Free (ToList g)) = toList' session conn >>= ro conn . g
    ro _conn (Pure p) = pure $ Just p
    ro _conn _free = pure Nothing
    rw conn (Free (LoadSession sessid g)) =
      loadSession' session conn sessid >>= rw conn . g
    rw conn (Free (DeleteSession sessid g)) =
      deleteSession' session conn sessid >> rw conn g
    rw conn (Free (StoreSession sess g)) =
      storeSession' session conn sess >> rw conn g
    rw conn (Free (ToList g)) = toList' session conn >>= rw conn . g
    rw conn (Free (FilterSessions _b n)) = rw conn n
    rw _conn (Pure p) = pure p

liftF :: Functor f => f r -> Free f r
liftF x = Free (fmap Pure x)