{-# LANGUAGE QuasiQuotes #-}

module MessageDb.Temp
  ( ConnectionStrings (..),
    withConnectionStrings,
  )
where

import Control.Exception (Exception, IOException)
import Control.Exception.Safe (bracket, throwIO)
import Control.Monad (void)
import Control.Monad.Catch (Handler (Handler))
import qualified Control.Retry as Retry
import Data.ByteString (ByteString)
import qualified Data.Map.Strict as Map
import Data.Maybe (fromMaybe)
import Data.Monoid (getLast)
import qualified Database.PostgreSQL.Simple as Postgres
import Database.PostgreSQL.Simple.Options (Options)
import qualified Database.PostgreSQL.Simple.Options as PostgresOptions
import Database.PostgreSQL.Simple.SqlQQ (sql)
import Database.Postgres.Temp (Accum (Merge))
import qualified Database.Postgres.Temp as PostgresTemp
import qualified Paths_message_db_temp
import System.Environment (getEnvironment)
import qualified System.Process.Typed as Process


-- | Connection strings used to connect to your temporary message-db.
data ConnectionStrings = ConnectionStrings
  { ConnectionStrings -> ByteString
privilegedConnectionString :: ByteString
  -- ^ Connection string used to connect to the database as a privileged user.
  , ConnectionStrings -> ByteString
normalConnectionString :: ByteString
  -- ^ Connection string used to connect to the database as a user that only has the 'message_store' role.
  }


migrate :: Options -> IO ()
migrate :: Options -> IO ()
migrate Options
options = do
  [(String, String)]
hostEnv <- IO [(String, String)]
getEnvironment

  String
installScriptPath <-
    String -> IO String
Paths_message_db_temp.getDataFileName String
"official-message-db-upstream/database/install.sh"

  let fromLast :: Last String -> String
fromLast =
        String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe String
"" (Maybe String -> String)
-> (Last String -> Maybe String) -> Last String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Last String -> Maybe String
forall a. Last a -> Maybe a
getLast

      processEnv :: [(String, String)]
processEnv =
        [(String, String)]
hostEnv
          [(String, String)] -> [(String, String)] -> [(String, String)]
forall a. Semigroup a => a -> a -> a
<> [
               ( String
"PGHOST"
               , Last String -> String
fromLast (Last String -> String) -> Last String -> String
forall a b. (a -> b) -> a -> b
$ Options -> Last String
PostgresOptions.host Options
options
               )
             ,
               ( String
"PGDATABASE"
               , Last String -> String
fromLast (Last String -> String) -> Last String -> String
forall a b. (a -> b) -> a -> b
$ Options -> Last String
PostgresOptions.dbname Options
options
               )
             ,
               ( String
"PGPORT"
               , Last String -> String
fromLast (Last String -> String)
-> (Last Int -> Last String) -> Last Int -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> String) -> Last Int -> Last String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> String
forall a. Show a => a -> String
show (Last Int -> String) -> Last Int -> String
forall a b. (a -> b) -> a -> b
$ Options -> Last Int
PostgresOptions.port Options
options
               )
             ,
               ( String
"PGUSER"
               , Last String -> String
fromLast (Last String -> String) -> Last String -> String
forall a b. (a -> b) -> a -> b
$ Options -> Last String
PostgresOptions.user Options
options
               )
             ]
      command :: ProcessConfig () () ()
command = String -> [String] -> ProcessConfig () () ()
Process.proc String
"bash" [String
installScriptPath]
   in IO (ByteString, ByteString) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (ByteString, ByteString) -> IO ())
-> (ProcessConfig () () () -> IO (ByteString, ByteString))
-> ProcessConfig () () ()
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ProcessConfig () () () -> IO (ByteString, ByteString)
forall (m :: * -> *) stdin stdoutIgnored stderrIgnored.
MonadIO m =>
ProcessConfig stdin stdoutIgnored stderrIgnored
-> m (ByteString, ByteString)
Process.readProcess_ (ProcessConfig () () () -> IO ())
-> ProcessConfig () () () -> IO ()
forall a b. (a -> b) -> a -> b
$ [(String, String)]
-> ProcessConfig () () () -> ProcessConfig () () ()
forall stdin stdout stderr.
[(String, String)]
-> ProcessConfig stdin stdout stderr
-> ProcessConfig stdin stdout stderr
Process.setEnv [(String, String)]
processEnv ProcessConfig () () ()
command

  let url :: ByteString
url = Options -> ByteString
PostgresOptions.toConnectionString Options
options
   in IO Connection
-> (Connection -> IO ()) -> (Connection -> IO ()) -> IO ()
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (ByteString -> IO Connection
Postgres.connectPostgreSQL ByteString
url) Connection -> IO ()
Postgres.close ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Connection
connection -> do
        let query :: Query
query =
              [sql|
                CREATE ROLE normal_user 
                WITH LOGIN PASSWORD 'password' 
                IN ROLE message_store;
              |]
         in IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> IO Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
Postgres.execute_ Connection
connection Query
query

        let query :: Query
query =
              [sql|
                ALTER ROLE normal_user 
                SET search_path TO message_store,public;
              |]
         in IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> IO Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
Postgres.execute_ Connection
connection Query
query

        let query :: Query
query =
              [sql|
                ALTER ROLE privileged_user
                SET search_path TO message_store,public;
              |]
         in IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> IO Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
Postgres.execute_ Connection
connection Query
query


-- | Create and use a temporary message-db for testing.
withConnectionStrings :: (ConnectionStrings -> IO a) -> IO a
withConnectionStrings :: (ConnectionStrings -> IO a) -> IO a
withConnectionStrings ConnectionStrings -> IO a
use = do
  let tempConfig :: Config
tempConfig =
        Config
forall a. Monoid a => a
mempty
          { connectionOptions :: Options
PostgresTemp.connectionOptions =
              Options
forall a. Monoid a => a
mempty
                { user :: Last String
PostgresOptions.user = String -> Last String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
"privileged_user"
                }
          , initDbConfig :: Accum ProcessConfig
PostgresTemp.initDbConfig =
              ProcessConfig -> Accum ProcessConfig
forall a. a -> Accum a
Merge (ProcessConfig -> Accum ProcessConfig)
-> ProcessConfig -> Accum ProcessConfig
forall a b. (a -> b) -> a -> b
$
                ProcessConfig
forall a. Monoid a => a
mempty
                  { commandLine :: CommandLineArgs
PostgresTemp.commandLine =
                      CommandLineArgs
forall a. Monoid a => a
mempty{keyBased :: Map String (Maybe String)
PostgresTemp.keyBased = [(String, Maybe String)] -> Map String (Maybe String)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(String
"--username=", String -> Maybe String
forall a. a -> Maybe a
Just String
"privileged_user")]}
                  }
          , postgresConfigFile :: [(String, String)]
PostgresTemp.postgresConfigFile =
              [ (String
"message_store.sql_condition", String
"on")
              ]
          }

  let retryPolicy :: RetryPolicyM IO
retryPolicy =
        Int -> RetryPolicy
Retry.limitRetries Int
10

      exceptionHandlers :: [a -> Handler IO Bool]
exceptionHandlers =
        let restartFor :: forall e a. Exception e => a -> Handler IO Bool
            restartFor :: a -> Handler IO Bool
restartFor a
_ = Exception e => (e -> IO Bool) -> Handler IO Bool
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler @_ @_ @e ((e -> IO Bool) -> Handler IO Bool)
-> (e -> IO Bool) -> Handler IO Bool
forall a b. (a -> b) -> a -> b
$ \e
_ -> Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
         in [ forall a. Exception StartError => a -> Handler IO Bool
forall e a. Exception e => a -> Handler IO Bool
restartFor @PostgresTemp.StartError
            , forall a. Exception IOException => a -> Handler IO Bool
forall e a. Exception e => a -> Handler IO Bool
restartFor @IOException
            ]

  RetryPolicyM IO
-> [RetryStatus -> Handler IO Bool]
-> (RetryStatus -> IO a)
-> IO a
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
RetryPolicyM m
-> [RetryStatus -> Handler m Bool] -> (RetryStatus -> m a) -> m a
Retry.recovering RetryPolicyM IO
retryPolicy [RetryStatus -> Handler IO Bool]
forall a. [a -> Handler IO Bool]
exceptionHandlers ((RetryStatus -> IO a) -> IO a) -> (RetryStatus -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \RetryStatus
_ -> do
    Either StartError a
result <- Config -> (DB -> IO a) -> IO (Either StartError a)
forall a. Config -> (DB -> IO a) -> IO (Either StartError a)
PostgresTemp.withConfig Config
tempConfig ((DB -> IO a) -> IO (Either StartError a))
-> (DB -> IO a) -> IO (Either StartError a)
forall a b. (a -> b) -> a -> b
$ \DB
db -> do
      let options :: Options
options =
            DB -> Options
PostgresTemp.toConnectionOptions DB
db

          tempMessageDb :: ConnectionStrings
tempMessageDb =
            let privilegedConnectionString :: ByteString
privilegedConnectionString =
                  Options -> ByteString
PostgresOptions.toConnectionString (Options -> ByteString) -> Options -> ByteString
forall a b. (a -> b) -> a -> b
$
                    Options
options
                      { dbname :: Last String
PostgresOptions.dbname = String -> Last String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
"message_store"
                      }

                normalConnectionString :: ByteString
normalConnectionString =
                  Options -> ByteString
PostgresOptions.toConnectionString (Options -> ByteString) -> Options -> ByteString
forall a b. (a -> b) -> a -> b
$
                    Options
options
                      { user :: Last String
PostgresOptions.user = String -> Last String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
"normal_user"
                      , password :: Last String
PostgresOptions.password = String -> Last String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
"password"
                      , dbname :: Last String
PostgresOptions.dbname = String -> Last String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
"message_store"
                      }
             in ConnectionStrings :: ByteString -> ByteString -> ConnectionStrings
ConnectionStrings{ByteString
normalConnectionString :: ByteString
privilegedConnectionString :: ByteString
normalConnectionString :: ByteString
privilegedConnectionString :: ByteString
..}
      Options -> IO ()
migrate Options
options IO () -> IO a -> IO a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> ConnectionStrings -> IO a
use ConnectionStrings
tempMessageDb

    case Either StartError a
result of
      Left StartError
err -> StartError -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO StartError
err
      Right a
value -> a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
value