{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TypeOperators #-}

-- | Database access for your @App@
module Freckle.App.Database
  ( -- * Running transactions
    MonadSqlTx (..)
  , runDB
  , runDBSimple

    -- * Running queries
  , SqlBackend
  , HasSqlBackend (..)
  , MonadSqlBackend (..)
  , liftSql

    -- * Telemetry
  , MonadTracer
  , HasStatsClient

    -- * Connection pools
  , HasSqlPool (..)
  , SqlPool
  , makePostgresPool
  , makePostgresPoolWith

    -- * Setup
  , PostgresConnectionConf (..)
  , PostgresPasswordSource (..)
  , PostgresPassword (..)
  , PostgresStatementTimeout
  , postgresStatementTimeoutMilliseconds
  , envParseDatabaseConf
  , envPostgresPasswordSource
  ) where

import Freckle.App.Prelude

import Blammo.Logging
import qualified Control.Immortal as Immortal
import Control.Monad.Reader
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as BS8
import qualified Data.ByteString.Lazy as BSL
import Data.Pool
import qualified Data.Text as T
import Database.Persist.Postgresql
  ( SqlBackend
  , SqlPersistT
  , createPostgresqlPoolModified
  , createSqlPool
  , openSimpleConn
  , runSqlPool
  , runSqlPoolWithExtensibleHooks
  )
import Database.Persist.SqlBackend.Internal.SqlPoolHooks (SqlPoolHooks (..))
import Database.Persist.SqlBackend.SqlPoolHooks
import Database.PostgreSQL.Simple
  ( Connection
  , Only (..)
  , connectPostgreSQL
  , execute
  )
import Database.PostgreSQL.Simple.SqlQQ (sql)
import Freckle.App.Env (Timeout (..))
import qualified Freckle.App.Env as Env
import Freckle.App.Exception.MonadUnliftIO
import Freckle.App.OpenTelemetry
import Freckle.App.Stats (HasStatsClient)
import qualified Freckle.App.Stats as Stats
import OpenTelemetry.Instrumentation.Persistent
import System.Process.Typed (proc, readProcessStdout_)
import UnliftIO.Concurrent (threadDelay)
import UnliftIO.IORef
import Yesod.Core.Types (HandlerData (..), RunHandlerEnv (..))

-- | A monadic context in which a SQL backend is available
--   for running database queries
class MonadUnliftIO m => MonadSqlBackend m where
  getSqlBackendM :: m SqlBackend

instance (HasSqlBackend r, MonadUnliftIO m) => MonadSqlBackend (ReaderT r m) where
  getSqlBackendM :: ReaderT r m SqlBackend
getSqlBackendM = (r -> SqlBackend) -> ReaderT r m SqlBackend
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks r -> SqlBackend
forall a. HasSqlBackend a => a -> SqlBackend
getSqlBackend

-- | Generalize from 'SqlPersistT' to 'MonadSqlBackend'
liftSql :: (MonadSqlBackend m, HasCallStack) => ReaderT SqlBackend m a -> m a
liftSql :: forall (m :: * -> *) a.
(MonadSqlBackend m, HasCallStack) =>
ReaderT SqlBackend m a -> m a
liftSql (ReaderT SqlBackend -> m a
f) = m a -> m a
forall (m :: * -> *) a.
(MonadUnliftIO m, HasCallStack) =>
m a -> m a
checkpointCallStack (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ m SqlBackend
forall (m :: * -> *). MonadSqlBackend m => m SqlBackend
getSqlBackendM m SqlBackend -> (SqlBackend -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SqlBackend -> m a
f

-- | The constraint @'MonadSqlTx' db m@ indicates that @m@ is a monadic
--   context that can run @db@ actions, usually as a SQL transaction.
--   Typically, this means that @db@ needs a connection and @m@ can
--   provide one, e.g. from a connection pool.
class (MonadSqlBackend db, MonadUnliftIO m) => MonadSqlTx db m | m -> db where
  -- | Runs the action in a SQL transaction
  runSqlTx :: HasCallStack => db a -> m a

class HasSqlBackend a where
  getSqlBackend :: a -> SqlBackend

instance HasSqlBackend SqlBackend where
  getSqlBackend :: SqlBackend -> SqlBackend
getSqlBackend = SqlBackend -> SqlBackend
forall a. a -> a
id

type SqlPool = Pool SqlBackend

class HasSqlPool app where
  getSqlPool :: app -> SqlPool

instance HasSqlPool SqlPool where
  getSqlPool :: SqlPool -> SqlPool
getSqlPool = SqlPool -> SqlPool
forall a. a -> a
id

instance HasSqlPool site => HasSqlPool (HandlerData child site) where
  getSqlPool :: HandlerData child site -> SqlPool
getSqlPool = site -> SqlPool
forall app. HasSqlPool app => app -> SqlPool
getSqlPool (site -> SqlPool)
-> (HandlerData child site -> site)
-> HandlerData child site
-> SqlPool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RunHandlerEnv child site -> site
forall child site. RunHandlerEnv child site -> site
rheSite (RunHandlerEnv child site -> site)
-> (HandlerData child site -> RunHandlerEnv child site)
-> HandlerData child site
-> site
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HandlerData child site -> RunHandlerEnv child site
forall child site.
HandlerData child site -> RunHandlerEnv child site
handlerEnv

makePostgresPool :: (MonadUnliftIO m, MonadLoggerIO m) => m SqlPool
makePostgresPool :: forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
m SqlPool
makePostgresPool = do
  PostgresConnectionConf
conf <- IO PostgresConnectionConf -> m PostgresConnectionConf
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO PostgresConnectionConf -> m PostgresConnectionConf)
-> IO PostgresConnectionConf -> m PostgresConnectionConf
forall a b. (a -> b) -> a -> b
$ do
    PostgresPasswordSource
postgresPasswordSource <- (Info Error -> Info Error)
-> Parser Error PostgresPasswordSource -> IO PostgresPasswordSource
forall e a.
AsUnset e =>
(Info Error -> Info e) -> Parser e a -> IO a
Env.parse Info Error -> Info Error
forall a. a -> a
id Parser Error PostgresPasswordSource
envPostgresPasswordSource
    (Info Error -> Info Error)
-> Parser Error PostgresConnectionConf -> IO PostgresConnectionConf
forall e a.
AsUnset e =>
(Info Error -> Info e) -> Parser e a -> IO a
Env.parse Info Error -> Info Error
forall a. a -> a
id (Parser Error PostgresConnectionConf -> IO PostgresConnectionConf)
-> Parser Error PostgresConnectionConf -> IO PostgresConnectionConf
forall a b. (a -> b) -> a -> b
$ PostgresPasswordSource -> Parser Error PostgresConnectionConf
envParseDatabaseConf PostgresPasswordSource
postgresPasswordSource
  PostgresConnectionConf -> m SqlPool
forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
PostgresConnectionConf -> m SqlPool
makePostgresPoolWith PostgresConnectionConf
conf

-- | Run a Database action with connection stats and tracing
--
-- This uses OpenTelemetry and 'MonadTracer'. For callstacks in traces to be
-- useful, ensure you have 'HasCallStack' on functions that call this (and
-- functions that call those, for as far as you require to get to a useful
-- source location).
runDB
  :: ( MonadUnliftIO m
     , MonadTracer m
     , MonadReader app m
     , HasSqlPool app
     , HasStatsClient app
     , HasCallStack
     )
  => SqlPersistT m a
  -> m a
runDB :: forall (m :: * -> *) app a.
(MonadUnliftIO m, MonadTracer m, MonadReader app m, HasSqlPool app,
 HasStatsClient app, HasCallStack) =>
SqlPersistT m a -> m a
runDB SqlPersistT m a
action = do
  SqlPool
pool <- (app -> SqlPool) -> m SqlPool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks app -> SqlPool
forall app. HasSqlPool app => app -> SqlPool
getSqlPool
  Gauge
gauge <- (Gauges -> Gauge) -> m Gauge
forall app (m :: * -> *).
(MonadReader app m, HasStatsClient app) =>
(Gauges -> Gauge) -> m Gauge
Stats.lookupGauge Gauges -> Gauge
Stats.dbConnections
  let
    hooks :: SqlPoolHooks m SqlBackend
hooks = SqlPoolHooks m SqlBackend
-> (SqlBackend -> m SqlBackend) -> SqlPoolHooks m SqlBackend
forall (m :: * -> *) backend.
SqlPoolHooks m backend
-> (backend -> m backend) -> SqlPoolHooks m backend
setAlterBackend SqlPoolHooks m SqlBackend
forall (m :: * -> *) backend.
(MonadIO m, BackendCompatible SqlBackend backend) =>
SqlPoolHooks m backend
defaultSqlPoolHooks ((SqlBackend -> m SqlBackend) -> SqlPoolHooks m SqlBackend)
-> (SqlBackend -> m SqlBackend) -> SqlPoolHooks m SqlBackend
forall a b. (a -> b) -> a -> b
$ HashMap Text Attribute -> SqlBackend -> m SqlBackend
forall (m :: * -> *).
MonadIO m =>
HashMap Text Attribute -> SqlBackend -> m SqlBackend
wrapSqlBackend HashMap Text Attribute
forall a. Monoid a => a
mempty
    -- Setting the SqlPoolHooks for metrics collection:
    -- You may be wondering if this code contains a "double-decrement" bug, because
    -- perhaps when an exception occurs, both runAfter and runOnException are
    -- executed. The documentation isn't terribly clear to me on this point, but
    -- notice:
    -- https://hackage.haskell.org/package/persistent-2.14.6.0/docs/Database-Persist-SqlBackend-Internal-SqlPoolHooks.html#t:SqlPoolHooks
    -- especially this part regarding runOnException:
    -- > This action is performed when an exception is received. The exception is
    -- > provided as a convenience - it is rethrown once this cleanup function is
    -- > complete.
    -- Given that the exception is rethrown, then then typical runAfter handler
    -- is very likely to not run. Looking at the code to validate:
    -- https://hackage.haskell.org/package/persistent-2.14.6.0/docs/src/Database.Persist.Sql.Run.html#runSqlPoolWithExtensibleHooks
    -- we can see that this is indeed how the code operates today -- the
    -- exception is rethrown, and there aren't any other spots where other code
    -- might catch the exception that we need to worry about, so in the case of
    -- an exception, runAfter would not be executed.
    -- So, this appears to be the intended interpretation.
    hooks' :: SqlPoolHooks m SqlBackend
hooks' =
      SqlPoolHooks m SqlBackend
forall {m :: * -> *}. MonadIO m => SqlPoolHooks m SqlBackend
hooks
        { runBefore = \SqlBackend
conn Maybe IsolationLevel
mi -> do
            Gauge -> m ()
forall app (m :: * -> *).
(MonadReader app m, HasStatsClient app, MonadUnliftIO m) =>
Gauge -> m ()
Stats.incGauge Gauge
gauge
            SqlPoolHooks m SqlBackend
-> SqlBackend -> Maybe IsolationLevel -> m ()
forall (m :: * -> *) backend.
SqlPoolHooks m backend -> backend -> Maybe IsolationLevel -> m ()
runBefore SqlPoolHooks m SqlBackend
forall {m :: * -> *}. MonadIO m => SqlPoolHooks m SqlBackend
hooks SqlBackend
conn Maybe IsolationLevel
mi
        , runAfter = \SqlBackend
conn Maybe IsolationLevel
mi -> do
            Gauge -> m ()
forall app (m :: * -> *).
(MonadReader app m, HasStatsClient app, MonadUnliftIO m) =>
Gauge -> m ()
Stats.decGauge Gauge
gauge
            SqlPoolHooks m SqlBackend
-> SqlBackend -> Maybe IsolationLevel -> m ()
forall (m :: * -> *) backend.
SqlPoolHooks m backend -> backend -> Maybe IsolationLevel -> m ()
runAfter SqlPoolHooks m SqlBackend
forall {m :: * -> *}. MonadIO m => SqlPoolHooks m SqlBackend
hooks SqlBackend
conn Maybe IsolationLevel
mi
        , runOnException = \SqlBackend
conn Maybe IsolationLevel
mi SomeException
e -> do
            Gauge -> m ()
forall app (m :: * -> *).
(MonadReader app m, HasStatsClient app, MonadUnliftIO m) =>
Gauge -> m ()
Stats.decGauge Gauge
gauge
            SqlPoolHooks m SqlBackend
-> SqlBackend -> Maybe IsolationLevel -> SomeException -> m ()
forall (m :: * -> *) backend.
SqlPoolHooks m backend
-> backend -> Maybe IsolationLevel -> SomeException -> m ()
runOnException SqlPoolHooks m SqlBackend
forall {m :: * -> *}. MonadIO m => SqlPoolHooks m SqlBackend
hooks SqlBackend
conn Maybe IsolationLevel
mi SomeException
e
        }
  (Gauges -> Gauge) -> m a -> m a
forall app (m :: * -> *) a.
(MonadReader app m, HasStatsClient app, MonadUnliftIO m) =>
(Gauges -> Gauge) -> m a -> m a
Stats.withGauge Gauges -> Gauge
Stats.dbEnqueuedAndProcessing (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$
    Text -> SpanArguments -> m a -> m a
forall (m :: * -> *) a.
(MonadUnliftIO m, MonadTracer m, HasCallStack) =>
Text -> SpanArguments -> m a -> m a
inSpan Text
"runDB" SpanArguments
clientSpanArguments (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$
      SqlPersistT m a
-> SqlPool
-> Maybe IsolationLevel
-> SqlPoolHooks m SqlBackend
-> m a
forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a
-> Pool backend
-> Maybe IsolationLevel
-> SqlPoolHooks m backend
-> m a
runSqlPoolWithExtensibleHooks SqlPersistT m a
action SqlPool
pool Maybe IsolationLevel
forall a. Maybe a
Nothing SqlPoolHooks m SqlBackend
hooks'

runDBSimple
  :: (HasSqlPool app, MonadUnliftIO m, MonadReader app m)
  => SqlPersistT m a
  -> m a
runDBSimple :: forall app (m :: * -> *) a.
(HasSqlPool app, MonadUnliftIO m, MonadReader app m) =>
SqlPersistT m a -> m a
runDBSimple SqlPersistT m a
action = do
  SqlPool
pool <- (app -> SqlPool) -> m SqlPool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks app -> SqlPool
forall app. HasSqlPool app => app -> SqlPool
getSqlPool
  SqlPersistT m a -> SqlPool -> m a
forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> Pool backend -> m a
runSqlPool SqlPersistT m a
action SqlPool
pool

data PostgresConnectionConf = PostgresConnectionConf
  { PostgresConnectionConf -> String
pccHost :: String
  , PostgresConnectionConf -> Int
pccPort :: Int
  , PostgresConnectionConf -> String
pccUser :: String
  , PostgresConnectionConf -> PostgresPassword
pccPassword :: PostgresPassword
  , PostgresConnectionConf -> String
pccDatabase :: String
  , PostgresConnectionConf -> Int
pccPoolSize :: Int
  , PostgresConnectionConf -> PostgresStatementTimeout
pccStatementTimeout :: PostgresStatementTimeout
  , PostgresConnectionConf -> Maybe String
pccSchema :: Maybe String
  }
  deriving stock (Int -> PostgresConnectionConf -> ShowS
[PostgresConnectionConf] -> ShowS
PostgresConnectionConf -> String
(Int -> PostgresConnectionConf -> ShowS)
-> (PostgresConnectionConf -> String)
-> ([PostgresConnectionConf] -> ShowS)
-> Show PostgresConnectionConf
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PostgresConnectionConf -> ShowS
showsPrec :: Int -> PostgresConnectionConf -> ShowS
$cshow :: PostgresConnectionConf -> String
show :: PostgresConnectionConf -> String
$cshowList :: [PostgresConnectionConf] -> ShowS
showList :: [PostgresConnectionConf] -> ShowS
Show, PostgresConnectionConf -> PostgresConnectionConf -> Bool
(PostgresConnectionConf -> PostgresConnectionConf -> Bool)
-> (PostgresConnectionConf -> PostgresConnectionConf -> Bool)
-> Eq PostgresConnectionConf
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
== :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
$c/= :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
/= :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
Eq)

data PostgresPasswordSource
  = PostgresPasswordSourceIamAuth
  | PostgresPasswordSourceEnv
  deriving stock (Int -> PostgresPasswordSource -> ShowS
[PostgresPasswordSource] -> ShowS
PostgresPasswordSource -> String
(Int -> PostgresPasswordSource -> ShowS)
-> (PostgresPasswordSource -> String)
-> ([PostgresPasswordSource] -> ShowS)
-> Show PostgresPasswordSource
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PostgresPasswordSource -> ShowS
showsPrec :: Int -> PostgresPasswordSource -> ShowS
$cshow :: PostgresPasswordSource -> String
show :: PostgresPasswordSource -> String
$cshowList :: [PostgresPasswordSource] -> ShowS
showList :: [PostgresPasswordSource] -> ShowS
Show, PostgresPasswordSource -> PostgresPasswordSource -> Bool
(PostgresPasswordSource -> PostgresPasswordSource -> Bool)
-> (PostgresPasswordSource -> PostgresPasswordSource -> Bool)
-> Eq PostgresPasswordSource
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
== :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
$c/= :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
/= :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
Eq)

data PostgresPassword
  = PostgresPasswordIamAuth
  | PostgresPasswordStatic String
  deriving stock (Int -> PostgresPassword -> ShowS
[PostgresPassword] -> ShowS
PostgresPassword -> String
(Int -> PostgresPassword -> ShowS)
-> (PostgresPassword -> String)
-> ([PostgresPassword] -> ShowS)
-> Show PostgresPassword
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PostgresPassword -> ShowS
showsPrec :: Int -> PostgresPassword -> ShowS
$cshow :: PostgresPassword -> String
show :: PostgresPassword -> String
$cshowList :: [PostgresPassword] -> ShowS
showList :: [PostgresPassword] -> ShowS
Show, PostgresPassword -> PostgresPassword -> Bool
(PostgresPassword -> PostgresPassword -> Bool)
-> (PostgresPassword -> PostgresPassword -> Bool)
-> Eq PostgresPassword
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PostgresPassword -> PostgresPassword -> Bool
== :: PostgresPassword -> PostgresPassword -> Bool
$c/= :: PostgresPassword -> PostgresPassword -> Bool
/= :: PostgresPassword -> PostgresPassword -> Bool
Eq)

type PostgresStatementTimeout = Timeout

postgresStatementTimeoutMilliseconds :: PostgresStatementTimeout -> Int
postgresStatementTimeoutMilliseconds :: PostgresStatementTimeout -> Int
postgresStatementTimeoutMilliseconds = \case
  TimeoutSeconds Int
s -> Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000
  TimeoutMilliseconds Int
ms -> Int
ms

envPostgresPasswordSource :: Env.Parser Env.Error PostgresPasswordSource
envPostgresPasswordSource :: Parser Error PostgresPasswordSource
envPostgresPasswordSource =
  Off PostgresPasswordSource
-> On PostgresPasswordSource
-> String
-> Mod Flag PostgresPasswordSource
-> Parser Error PostgresPasswordSource
forall a. Off a -> On a -> String -> Mod Flag a -> Parser Error a
Env.flag
    (PostgresPasswordSource -> Off PostgresPasswordSource
forall a. a -> Off a
Env.Off PostgresPasswordSource
PostgresPasswordSourceEnv)
    (PostgresPasswordSource -> On PostgresPasswordSource
forall a. a -> On a
Env.On PostgresPasswordSource
PostgresPasswordSourceIamAuth)
    String
"USE_RDS_IAM_AUTH"
    Mod Flag PostgresPasswordSource
forall a. Monoid a => a
mempty

envParseDatabaseConf
  :: PostgresPasswordSource -> Env.Parser Env.Error PostgresConnectionConf
envParseDatabaseConf :: PostgresPasswordSource -> Parser Error PostgresConnectionConf
envParseDatabaseConf PostgresPasswordSource
source = do
  String
user <- Reader Error String
-> String -> Mod Var String -> Parser Error String
forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var Reader Error String
forall e s. (AsEmpty e, IsString s) => Reader e s
Env.nonempty String
"PGUSER" Mod Var String
forall a. Monoid a => a
mempty
  PostgresPassword
password <- case PostgresPasswordSource
source of
    PostgresPasswordSource
PostgresPasswordSourceIamAuth -> PostgresPassword -> Parser Error PostgresPassword
forall a. a -> Parser Error a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PostgresPassword
PostgresPasswordIamAuth
    PostgresPasswordSource
PostgresPasswordSourceEnv ->
      String -> PostgresPassword
PostgresPasswordStatic (String -> PostgresPassword)
-> Parser Error String -> Parser Error PostgresPassword
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Reader Error String
-> String -> Mod Var String -> Parser Error String
forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var Reader Error String
forall e s. (AsEmpty e, IsString s) => Reader e s
Env.nonempty String
"PGPASSWORD" Mod Var String
forall a. Monoid a => a
mempty
  String
host <- Reader Error String
-> String -> Mod Var String -> Parser Error String
forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var Reader Error String
forall e s. (AsEmpty e, IsString s) => Reader e s
Env.nonempty String
"PGHOST" Mod Var String
forall a. Monoid a => a
mempty
  String
database <- Reader Error String
-> String -> Mod Var String -> Parser Error String
forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var Reader Error String
forall e s. (AsEmpty e, IsString s) => Reader e s
Env.nonempty String
"PGDATABASE" Mod Var String
forall a. Monoid a => a
mempty
  Int
port <- Reader Error Int -> String -> Mod Var Int -> Parser Error Int
forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var Reader Error Int
forall e a. (AsUnread e, Read a) => Reader e a
Env.auto String
"PGPORT" Mod Var Int
forall a. Monoid a => a
mempty
  Int
poolSize <- Reader Error Int -> String -> Mod Var Int -> Parser Error Int
forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var Reader Error Int
forall e a. (AsUnread e, Read a) => Reader e a
Env.auto String
"PGPOOLSIZE" (Mod Var Int -> Parser Error Int)
-> Mod Var Int -> Parser Error Int
forall a b. (a -> b) -> a -> b
$ Int -> Mod Var Int
forall a. a -> Mod Var a
Env.def Int
10
  Maybe String
schema <- Parser Error String -> Parser Error (Maybe String)
forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional (Parser Error String -> Parser Error (Maybe String))
-> Parser Error String -> Parser Error (Maybe String)
forall a b. (a -> b) -> a -> b
$ Reader Error String
-> String -> Mod Var String -> Parser Error String
forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var Reader Error String
forall e s. (AsEmpty e, IsString s) => Reader e s
Env.nonempty String
"PGSCHEMA" Mod Var String
forall a. Monoid a => a
mempty
  PostgresStatementTimeout
statementTimeout <-
    Reader Error PostgresStatementTimeout
-> String
-> Mod Var PostgresStatementTimeout
-> Parser Error PostgresStatementTimeout
forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var Reader Error PostgresStatementTimeout
Env.timeout String
"PGSTATEMENTTIMEOUT" (Mod Var PostgresStatementTimeout
 -> Parser Error PostgresStatementTimeout)
-> Mod Var PostgresStatementTimeout
-> Parser Error PostgresStatementTimeout
forall a b. (a -> b) -> a -> b
$ PostgresStatementTimeout -> Mod Var PostgresStatementTimeout
forall a. a -> Mod Var a
Env.def (Int -> PostgresStatementTimeout
TimeoutSeconds Int
120)
  pure
    PostgresConnectionConf
      { pccHost :: String
pccHost = String
host
      , pccPort :: Int
pccPort = Int
port
      , pccUser :: String
pccUser = String
user
      , pccPassword :: PostgresPassword
pccPassword = PostgresPassword
password
      , pccDatabase :: String
pccDatabase = String
database
      , pccPoolSize :: Int
pccPoolSize = Int
poolSize
      , pccStatementTimeout :: PostgresStatementTimeout
pccStatementTimeout = PostgresStatementTimeout
statementTimeout
      , pccSchema :: Maybe String
pccSchema = Maybe String
schema
      }

data AuroraIamToken = AuroraIamToken
  { AuroraIamToken -> Text
aitToken :: Text
  , AuroraIamToken -> UTCTime
aitCreatedAt :: UTCTime
  , AuroraIamToken -> PostgresConnectionConf
aitPostgresConnectionConf :: PostgresConnectionConf
  }
  deriving stock (Int -> AuroraIamToken -> ShowS
[AuroraIamToken] -> ShowS
AuroraIamToken -> String
(Int -> AuroraIamToken -> ShowS)
-> (AuroraIamToken -> String)
-> ([AuroraIamToken] -> ShowS)
-> Show AuroraIamToken
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AuroraIamToken -> ShowS
showsPrec :: Int -> AuroraIamToken -> ShowS
$cshow :: AuroraIamToken -> String
show :: AuroraIamToken -> String
$cshowList :: [AuroraIamToken] -> ShowS
showList :: [AuroraIamToken] -> ShowS
Show, AuroraIamToken -> AuroraIamToken -> Bool
(AuroraIamToken -> AuroraIamToken -> Bool)
-> (AuroraIamToken -> AuroraIamToken -> Bool) -> Eq AuroraIamToken
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AuroraIamToken -> AuroraIamToken -> Bool
== :: AuroraIamToken -> AuroraIamToken -> Bool
$c/= :: AuroraIamToken -> AuroraIamToken -> Bool
/= :: AuroraIamToken -> AuroraIamToken -> Bool
Eq)

createAuroraIamToken :: MonadIO m => PostgresConnectionConf -> m AuroraIamToken
createAuroraIamToken :: forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> m AuroraIamToken
createAuroraIamToken aitPostgresConnectionConf :: PostgresConnectionConf
aitPostgresConnectionConf@PostgresConnectionConf {Int
String
Maybe String
PostgresStatementTimeout
PostgresPassword
pccHost :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccUser :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccDatabase :: PostgresConnectionConf -> String
pccPoolSize :: PostgresConnectionConf -> Int
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccSchema :: PostgresConnectionConf -> Maybe String
pccHost :: String
pccPort :: Int
pccUser :: String
pccPassword :: PostgresPassword
pccDatabase :: String
pccPoolSize :: Int
pccStatementTimeout :: PostgresStatementTimeout
pccSchema :: Maybe String
..} = do
  Text
aitToken <-
    Text -> Text
T.strip (Text -> Text) -> (ByteString -> Text) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
decodeUtf8 (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BSL.toStrict
      (ByteString -> Text) -> m ByteString -> m Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ProcessConfig () () () -> m ByteString
forall (m :: * -> *) stdin stdoutIgnored stderr.
MonadIO m =>
ProcessConfig stdin stdoutIgnored stderr -> m ByteString
readProcessStdout_
        ( String -> [String] -> ProcessConfig () () ()
proc
            String
"aws"
            [ String
"rds"
            , String
"generate-db-auth-token"
            , String
"--hostname"
            , String
pccHost
            , String
"--port"
            , Int -> String
forall a. Show a => a -> String
show Int
pccPort
            , String
"--username"
            , String
pccUser
            ]
        )
  UTCTime
aitCreatedAt <- IO UTCTime -> m UTCTime
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
  pure AuroraIamToken {Text
UTCTime
PostgresConnectionConf
aitToken :: Text
aitCreatedAt :: UTCTime
aitPostgresConnectionConf :: PostgresConnectionConf
aitPostgresConnectionConf :: PostgresConnectionConf
aitToken :: Text
aitCreatedAt :: UTCTime
..}

-- | Spawns a thread that refreshes the IAM auth token every minute
--
-- The IAM auth token lasts 15 minutes, but we refresh it every minute just to
-- be super safe.
spawnIamTokenRefreshThread
  :: (MonadUnliftIO m, MonadLogger m)
  => PostgresConnectionConf
  -> m (IORef AuroraIamToken)
spawnIamTokenRefreshThread :: forall (m :: * -> *).
(MonadUnliftIO m, MonadLogger m) =>
PostgresConnectionConf -> m (IORef AuroraIamToken)
spawnIamTokenRefreshThread PostgresConnectionConf
conf = do
  Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
Message -> m ()
logInfo Message
"Spawning thread to refresh IAM auth token"
  IORef AuroraIamToken
tokenIORef <- AuroraIamToken -> m (IORef AuroraIamToken)
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef (AuroraIamToken -> m (IORef AuroraIamToken))
-> m AuroraIamToken -> m (IORef AuroraIamToken)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PostgresConnectionConf -> m AuroraIamToken
forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> m AuroraIamToken
createAuroraIamToken PostgresConnectionConf
conf
  m Thread -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Thread -> m ()) -> m Thread -> m ()
forall a b. (a -> b) -> a -> b
$ (Thread -> m ()) -> m Thread
forall (m :: * -> *).
MonadUnliftIO m =>
(Thread -> m ()) -> m Thread
Immortal.create ((Thread -> m ()) -> m Thread) -> (Thread -> m ()) -> m Thread
forall a b. (a -> b) -> a -> b
$ \Thread
_ -> (Either SomeException () -> m ()) -> m () -> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
(Either SomeException () -> m ()) -> m () -> m ()
Immortal.onFinish Either SomeException () -> m ()
forall {m :: * -> *} {e}.
(MonadLogger m, Exception e) =>
Either e () -> m ()
onFinishCallback (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
Message -> m ()
logDebug Message
"Refreshing IAM auth token"
    PostgresConnectionConf -> IORef AuroraIamToken -> m ()
forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> IORef AuroraIamToken -> m ()
refreshIamToken PostgresConnectionConf
conf IORef AuroraIamToken
tokenIORef
    Int -> m ()
forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay Int
forall {a}. Num a => a
oneMinuteInMicroseconds
  pure IORef AuroraIamToken
tokenIORef
 where
  oneMinuteInMicroseconds :: a
oneMinuteInMicroseconds = a
60 a -> a -> a
forall a. Num a => a -> a -> a
* a
1000000

  onFinishCallback :: Either e () -> m ()
onFinishCallback = \case
    Left e
ex ->
      Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
Message -> m ()
logError (Message -> m ()) -> Message -> m ()
forall a b. (a -> b) -> a -> b
$
        Text
"Error refreshing IAM auth token"
          Text -> [SeriesElem] -> Message
:# [Key
"exception" Key -> String -> SeriesElem
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
forall v. ToJSON v => Key -> v -> SeriesElem
.= e -> String
forall e. Exception e => e -> String
displayException e
ex]
    Right () -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

refreshIamToken
  :: MonadIO m => PostgresConnectionConf -> IORef AuroraIamToken -> m ()
refreshIamToken :: forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> IORef AuroraIamToken -> m ()
refreshIamToken PostgresConnectionConf
conf IORef AuroraIamToken
tokenIORef = do
  AuroraIamToken
token' <- PostgresConnectionConf -> m AuroraIamToken
forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> m AuroraIamToken
createAuroraIamToken PostgresConnectionConf
conf
  IORef AuroraIamToken -> AuroraIamToken -> m ()
forall (m :: * -> *) a. MonadIO m => IORef a -> a -> m ()
writeIORef IORef AuroraIamToken
tokenIORef AuroraIamToken
token'

setStartupOptions :: MonadIO m => PostgresConnectionConf -> Connection -> m ()
setStartupOptions :: forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> Connection -> m ()
setStartupOptions PostgresConnectionConf {Int
String
Maybe String
PostgresStatementTimeout
PostgresPassword
pccHost :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccUser :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccDatabase :: PostgresConnectionConf -> String
pccPoolSize :: PostgresConnectionConf -> Int
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccSchema :: PostgresConnectionConf -> Maybe String
pccHost :: String
pccPort :: Int
pccUser :: String
pccPassword :: PostgresPassword
pccDatabase :: String
pccPoolSize :: Int
pccStatementTimeout :: PostgresStatementTimeout
pccSchema :: Maybe String
..} Connection
conn = do
  let timeoutMillis :: Int
timeoutMillis = PostgresStatementTimeout -> Int
postgresStatementTimeoutMilliseconds PostgresStatementTimeout
pccStatementTimeout
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> IO Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$
      Connection -> Query -> Only Int -> IO Int64
forall q. ToRow q => Connection -> Query -> q -> IO Int64
execute
        Connection
conn
        [sql| SET statement_timeout = ? |]
        (Int -> Only Int
forall a. a -> Only a
Only Int
timeoutMillis)
    Maybe String -> (String -> IO Int64) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ Maybe String
pccSchema ((String -> IO Int64) -> IO ()) -> (String -> IO Int64) -> IO ()
forall a b. (a -> b) -> a -> b
$ \String
schema -> Connection -> Query -> Only String -> IO Int64
forall q. ToRow q => Connection -> Query -> q -> IO Int64
execute Connection
conn [sql| SET search_path TO ? |] (String -> Only String
forall a. a -> Only a
Only String
schema)

makePostgresPoolWith
  :: (MonadUnliftIO m, MonadLoggerIO m) => PostgresConnectionConf -> m SqlPool
makePostgresPoolWith :: forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
PostgresConnectionConf -> m SqlPool
makePostgresPoolWith conf :: PostgresConnectionConf
conf@PostgresConnectionConf {Int
String
Maybe String
PostgresStatementTimeout
PostgresPassword
pccHost :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccUser :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccDatabase :: PostgresConnectionConf -> String
pccPoolSize :: PostgresConnectionConf -> Int
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccSchema :: PostgresConnectionConf -> Maybe String
pccHost :: String
pccPort :: Int
pccUser :: String
pccPassword :: PostgresPassword
pccDatabase :: String
pccPoolSize :: Int
pccStatementTimeout :: PostgresStatementTimeout
pccSchema :: Maybe String
..} = case PostgresPassword
pccPassword of
  PostgresPassword
PostgresPasswordIamAuth -> PostgresConnectionConf -> m SqlPool
forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
PostgresConnectionConf -> m SqlPool
makePostgresPoolWithIamAuth PostgresConnectionConf
conf
  PostgresPasswordStatic String
password ->
    (Connection -> IO ()) -> ByteString -> Int -> m SqlPool
forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
(Connection -> IO ()) -> ByteString -> Int -> m SqlPool
createPostgresqlPoolModified
      (PostgresConnectionConf -> Connection -> IO ()
forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> Connection -> m ()
setStartupOptions PostgresConnectionConf
conf)
      (PostgresConnectionConf -> String -> ByteString
postgresConnectionString PostgresConnectionConf
conf String
password)
      Int
pccPoolSize

-- | Creates a PostgreSQL pool using IAM auth for the password
makePostgresPoolWithIamAuth
  :: (MonadUnliftIO m, MonadLoggerIO m) => PostgresConnectionConf -> m SqlPool
makePostgresPoolWithIamAuth :: forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
PostgresConnectionConf -> m SqlPool
makePostgresPoolWithIamAuth conf :: PostgresConnectionConf
conf@PostgresConnectionConf {Int
String
Maybe String
PostgresStatementTimeout
PostgresPassword
pccHost :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccUser :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccDatabase :: PostgresConnectionConf -> String
pccPoolSize :: PostgresConnectionConf -> Int
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccSchema :: PostgresConnectionConf -> Maybe String
pccHost :: String
pccPort :: Int
pccUser :: String
pccPassword :: PostgresPassword
pccDatabase :: String
pccPoolSize :: Int
pccStatementTimeout :: PostgresStatementTimeout
pccSchema :: Maybe String
..} = do
  IORef AuroraIamToken
tokenIORef <- PostgresConnectionConf -> m (IORef AuroraIamToken)
forall (m :: * -> *).
(MonadUnliftIO m, MonadLogger m) =>
PostgresConnectionConf -> m (IORef AuroraIamToken)
spawnIamTokenRefreshThread PostgresConnectionConf
conf
  (LogFunc -> IO SqlBackend) -> Int -> m SqlPool
forall backend (m :: * -> *).
(MonadLoggerIO m, MonadUnliftIO m,
 BackendCompatible SqlBackend backend) =>
(LogFunc -> IO backend) -> Int -> m (Pool backend)
createSqlPool (IORef AuroraIamToken -> LogFunc -> IO SqlBackend
mkConn IORef AuroraIamToken
tokenIORef) Int
pccPoolSize
 where
  mkConn :: IORef AuroraIamToken -> LogFunc -> IO SqlBackend
mkConn IORef AuroraIamToken
tokenIORef LogFunc
logFunc = do
    AuroraIamToken
token <- IORef AuroraIamToken -> IO AuroraIamToken
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef AuroraIamToken
tokenIORef
    let connStr :: ByteString
connStr = PostgresConnectionConf -> String -> ByteString
postgresConnectionString PostgresConnectionConf
conf (Text -> String
unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ AuroraIamToken -> Text
aitToken AuroraIamToken
token)
    Connection
conn <- ByteString -> IO Connection
connectPostgreSQL ByteString
connStr
    PostgresConnectionConf -> Connection -> IO ()
forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> Connection -> m ()
setStartupOptions PostgresConnectionConf
conf Connection
conn
    LogFunc -> Connection -> IO SqlBackend
openSimpleConn LogFunc
logFunc Connection
conn

postgresConnectionString :: PostgresConnectionConf -> String -> ByteString
postgresConnectionString :: PostgresConnectionConf -> String -> ByteString
postgresConnectionString PostgresConnectionConf {Int
String
Maybe String
PostgresStatementTimeout
PostgresPassword
pccHost :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccUser :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccDatabase :: PostgresConnectionConf -> String
pccPoolSize :: PostgresConnectionConf -> Int
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccSchema :: PostgresConnectionConf -> Maybe String
pccHost :: String
pccPort :: Int
pccUser :: String
pccPassword :: PostgresPassword
pccDatabase :: String
pccPoolSize :: Int
pccStatementTimeout :: PostgresStatementTimeout
pccSchema :: Maybe String
..} String
password =
  String -> ByteString
BS8.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$
    [String] -> String
unwords
      [ String
"host=" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
pccHost
      , String
"port=" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
pccPort
      , String
"user=" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
pccUser
      , String
"password=" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
password
      , String
"dbname=" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
pccDatabase
      ]