{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TypeOperators #-}

-- | Database access for your @App@
module Freckle.App.Database
  ( HasSqlPool(..)
  , SqlPool
  , makePostgresPool
  , makePostgresPoolWith
  , runDB
  , runDBSimple
  , PostgresConnectionConf(..)
  , PostgresPasswordSource(..)
  , PostgresPassword(..)
  , PostgresStatementTimeout(..)
  , postgresStatementTimeoutMilliseconds
  , envParseDatabaseConf
  , envPostgresPasswordSource
  ) where

import Freckle.App.Prelude

import Blammo.Logging
import qualified Control.Immortal as Immortal
import Control.Monad.IO.Unlift (MonadUnliftIO(..))
import Control.Monad.Reader
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as BS8
import qualified Data.ByteString.Lazy as BSL
import Data.Char (isDigit)
import Data.Pool
import qualified Data.Text as T
import Database.Persist.Postgresql
  ( SqlBackend
  , SqlPersistT
  , createPostgresqlPoolModified
  , createSqlPool
  , openSimpleConn
  , runSqlConn
  , runSqlPool
  )
import Database.PostgreSQL.Simple
  (Connection, Only(..), connectPostgreSQL, execute)
import Database.PostgreSQL.Simple.SqlQQ (sql)
import qualified Freckle.App.Env as Env
import Freckle.App.OpenTelemetry (MonadTracer(..))
import Freckle.App.Stats (HasStatsClient)
import qualified Freckle.App.Stats as Stats
import Network.AWS.XRayClient.Persistent
import Network.AWS.XRayClient.WAI
import qualified Prelude as Unsafe (read)
import System.Process.Typed (proc, readProcessStdout_)
import UnliftIO.Concurrent (threadDelay)
import UnliftIO.Exception (displayException)
import UnliftIO.IORef
import Yesod.Core.Types (HandlerData(..), RunHandlerEnv(..))

type SqlPool = Pool SqlBackend

class HasSqlPool app where
  getSqlPool :: app -> SqlPool

instance HasSqlPool SqlPool where
  getSqlPool :: SqlPool -> SqlPool
getSqlPool = forall a. a -> a
id

instance HasSqlPool site => HasSqlPool (HandlerData child site) where
  getSqlPool :: HandlerData child site -> SqlPool
getSqlPool = forall app. HasSqlPool app => app -> SqlPool
getSqlPool forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall child site. RunHandlerEnv child site -> site
rheSite forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall child site.
HandlerData child site -> RunHandlerEnv child site
handlerEnv

makePostgresPool :: (MonadUnliftIO m, MonadLoggerIO m) => m SqlPool
makePostgresPool :: forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
m SqlPool
makePostgresPool = do
  PostgresConnectionConf
conf <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
    PostgresPasswordSource
postgresPasswordSource <- forall e a.
AsUnset e =>
(Info Error -> Info e) -> Parser e a -> IO a
Env.parse forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ forall e a. Parser e a -> Parser e a
Env.kept Parser Error PostgresPasswordSource
envPostgresPasswordSource
    forall e a.
AsUnset e =>
(Info Error -> Info e) -> Parser e a -> IO a
Env.parse forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ forall e a. Parser e a -> Parser e a
Env.kept forall a b. (a -> b) -> a -> b
$ PostgresPasswordSource -> Parser Error PostgresConnectionConf
envParseDatabaseConf PostgresPasswordSource
postgresPasswordSource
  forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
PostgresConnectionConf -> m SqlPool
makePostgresPoolWith PostgresConnectionConf
conf

-- | Run a Database action with connection stats and tracing
runDB
  :: ( MonadUnliftIO m
     , MonadTracer m
     , MonadReader app m
     , HasSqlPool app
     , HasStatsClient app
     )
  => SqlPersistT m a
  -> m a
runDB :: forall (m :: * -> *) app a.
(MonadUnliftIO m, MonadTracer m, MonadReader app m, HasSqlPool app,
 HasStatsClient app) =>
SqlPersistT m a -> m a
runDB SqlPersistT m a
action = do
  SqlPool
pool <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall app. HasSqlPool app => app -> SqlPool
getSqlPool
  Maybe XRayVaultData
mVaultData <- forall (m :: * -> *). MonadTracer m => m (Maybe XRayVaultData)
getVaultData
  forall app (m :: * -> *) a.
(MonadReader app m, HasStatsClient app, MonadUnliftIO m) =>
(Gauges -> Gauge) -> m a -> m a
Stats.withGauge Gauges -> Gauge
Stats.dbConnections
    forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> Pool backend -> m a
runSqlPool (forall backend (m :: * -> *) a.
(backend ~ SqlBackend, MonadUnliftIO m) =>
Text -> XRayVaultData -> ReaderT backend m a -> Pool backend -> m a
runSqlPoolXRay Text
"runDB") Maybe XRayVaultData
mVaultData SqlPersistT m a
action SqlPool
pool

runDBSimple
  :: (HasSqlPool app, MonadUnliftIO m, MonadReader app m)
  => SqlPersistT m a
  -> m a
runDBSimple :: forall app (m :: * -> *) a.
(HasSqlPool app, MonadUnliftIO m, MonadReader app m) =>
SqlPersistT m a -> m a
runDBSimple SqlPersistT m a
action = do
  SqlPool
pool <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall app. HasSqlPool app => app -> SqlPool
getSqlPool
  forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> Pool backend -> m a
runSqlPool SqlPersistT m a
action SqlPool
pool

-- | @'runSqlPool'@ but with XRay tracing
runSqlPoolXRay
  :: (backend ~ SqlBackend, MonadUnliftIO m)
  => Text
  -- ^ Subsegment name
  --
  -- The top-level subsegment will be named @\"<this> runSqlPool\"@ and the,
  -- with a lower-level subsegment named @\"<this> query\"@.
  --
  -> XRayVaultData -- ^ Vault data to trace with
  -> ReaderT backend m a
  -> Pool backend
  -> m a
runSqlPoolXRay :: forall backend (m :: * -> *) a.
(backend ~ SqlBackend, MonadUnliftIO m) =>
Text -> XRayVaultData -> ReaderT backend m a -> Pool backend -> m a
runSqlPoolXRay Text
name XRayVaultData
vaultData ReaderT backend m a
action Pool backend
pool =
  forall (m :: * -> *) a.
MonadUnliftIO m =>
XRayVaultData -> Text -> (XRaySegment -> XRaySegment) -> m a -> m a
traceXRaySubsegment' XRayVaultData
vaultData (Text
name forall a. Semigroup a => a -> a -> a
<> Text
" runSqlPool") forall a. a -> a
id
    forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO
    forall a b. (a -> b) -> a -> b
$ \forall a. m a -> IO a
run -> forall a r. Pool a -> (a -> IO r) -> IO r
withResource Pool backend
pool forall a b. (a -> b) -> a -> b
$ \backend
backend -> do
        let
          sendTrace :: XRaySegment -> IO ()
sendTrace = XRayVaultData -> XRaySegment -> IO ()
atomicallyAddVaultDataSubsegment XRayVaultData
vaultData
          stdGenIORef :: IORef StdGen
stdGenIORef = XRayVaultData -> IORef StdGen
xrayVaultDataStdGen XRayVaultData
vaultData
          subsegmentName :: Text
subsegmentName = Text
name forall a. Semigroup a => a -> a -> a
<> Text
" query"
        forall a. m a -> IO a
run forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> backend -> m a
runSqlConn ReaderT backend m a
action forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
          (forall backend.
(IsPersistBackend backend, BaseBackend backend ~ SqlBackend) =>
(XRaySegment -> IO ())
-> IORef StdGen -> Text -> backend -> IO backend
xraySqlBackend XRaySegment -> IO ()
sendTrace IORef StdGen
stdGenIORef Text
subsegmentName backend
backend)

data PostgresConnectionConf = PostgresConnectionConf
  { PostgresConnectionConf -> String
pccHost :: String
  , PostgresConnectionConf -> Int
pccPort :: Int
  , PostgresConnectionConf -> String
pccUser :: String
  , PostgresConnectionConf -> PostgresPassword
pccPassword :: PostgresPassword
  , PostgresConnectionConf -> String
pccDatabase :: String
  , PostgresConnectionConf -> Int
pccPoolSize :: Int
  , PostgresConnectionConf -> PostgresStatementTimeout
pccStatementTimeout :: PostgresStatementTimeout
  }
  deriving stock (Int -> PostgresConnectionConf -> ShowS
[PostgresConnectionConf] -> ShowS
PostgresConnectionConf -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PostgresConnectionConf] -> ShowS
$cshowList :: [PostgresConnectionConf] -> ShowS
show :: PostgresConnectionConf -> String
$cshow :: PostgresConnectionConf -> String
showsPrec :: Int -> PostgresConnectionConf -> ShowS
$cshowsPrec :: Int -> PostgresConnectionConf -> ShowS
Show, PostgresConnectionConf -> PostgresConnectionConf -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
$c/= :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
== :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
$c== :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
Eq)

data PostgresPasswordSource
  = PostgresPasswordSourceIamAuth
  | PostgresPasswordSourceEnv
  deriving stock (Int -> PostgresPasswordSource -> ShowS
[PostgresPasswordSource] -> ShowS
PostgresPasswordSource -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PostgresPasswordSource] -> ShowS
$cshowList :: [PostgresPasswordSource] -> ShowS
show :: PostgresPasswordSource -> String
$cshow :: PostgresPasswordSource -> String
showsPrec :: Int -> PostgresPasswordSource -> ShowS
$cshowsPrec :: Int -> PostgresPasswordSource -> ShowS
Show, PostgresPasswordSource -> PostgresPasswordSource -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
$c/= :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
== :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
$c== :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
Eq)

data PostgresPassword
  = PostgresPasswordIamAuth
  | PostgresPasswordStatic String
  deriving stock (Int -> PostgresPassword -> ShowS
[PostgresPassword] -> ShowS
PostgresPassword -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PostgresPassword] -> ShowS
$cshowList :: [PostgresPassword] -> ShowS
show :: PostgresPassword -> String
$cshow :: PostgresPassword -> String
showsPrec :: Int -> PostgresPassword -> ShowS
$cshowsPrec :: Int -> PostgresPassword -> ShowS
Show, PostgresPassword -> PostgresPassword -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PostgresPassword -> PostgresPassword -> Bool
$c/= :: PostgresPassword -> PostgresPassword -> Bool
== :: PostgresPassword -> PostgresPassword -> Bool
$c== :: PostgresPassword -> PostgresPassword -> Bool
Eq)

data PostgresStatementTimeout
  = PostgresStatementTimeoutSeconds Int
  | PostgresStatementTimeoutMilliseconds Int
  deriving stock (Int -> PostgresStatementTimeout -> ShowS
[PostgresStatementTimeout] -> ShowS
PostgresStatementTimeout -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PostgresStatementTimeout] -> ShowS
$cshowList :: [PostgresStatementTimeout] -> ShowS
show :: PostgresStatementTimeout -> String
$cshow :: PostgresStatementTimeout -> String
showsPrec :: Int -> PostgresStatementTimeout -> ShowS
$cshowsPrec :: Int -> PostgresStatementTimeout -> ShowS
Show, PostgresStatementTimeout -> PostgresStatementTimeout -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PostgresStatementTimeout -> PostgresStatementTimeout -> Bool
$c/= :: PostgresStatementTimeout -> PostgresStatementTimeout -> Bool
== :: PostgresStatementTimeout -> PostgresStatementTimeout -> Bool
$c== :: PostgresStatementTimeout -> PostgresStatementTimeout -> Bool
Eq)

postgresStatementTimeoutMilliseconds :: PostgresStatementTimeout -> Int
postgresStatementTimeoutMilliseconds :: PostgresStatementTimeout -> Int
postgresStatementTimeoutMilliseconds = \case
  PostgresStatementTimeoutSeconds Int
s -> Int
s forall a. Num a => a -> a -> a
* Int
1000
  PostgresStatementTimeoutMilliseconds Int
ms -> Int
ms

-- | Read @PGSTATEMENTTIMEOUT@ as seconds or milliseconds
--
-- >>> readPostgresStatementTimeout "10"
-- Right (PostgresStatementTimeoutSeconds 10)
--
-- >>> readPostgresStatementTimeout "10s"
-- Right (PostgresStatementTimeoutSeconds 10)
--
-- >>> readPostgresStatementTimeout "10ms"
-- Right (PostgresStatementTimeoutMilliseconds 10)
--
-- >>> readPostgresStatementTimeout "20m"
-- Left "..."
--
-- >>> readPostgresStatementTimeout "2m0"
-- Left "..."
--
readPostgresStatementTimeout
  :: String -> Either String PostgresStatementTimeout
readPostgresStatementTimeout :: String -> Either String PostgresStatementTimeout
readPostgresStatementTimeout String
x = case forall a. (a -> Bool) -> [a] -> ([a], [a])
span Char -> Bool
isDigit String
x of
  (String
"", String
_) -> forall a b. a -> Either a b
Left String
"must be {digits}(s|ms)"
  (String
digits, String
"") -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ Int -> PostgresStatementTimeout
PostgresStatementTimeoutSeconds forall a b. (a -> b) -> a -> b
$ forall a. Read a => String -> a
Unsafe.read String
digits
  (String
digits, String
"s") -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ Int -> PostgresStatementTimeout
PostgresStatementTimeoutSeconds forall a b. (a -> b) -> a -> b
$ forall a. Read a => String -> a
Unsafe.read String
digits
  (String
digits, String
"ms") ->
    forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ Int -> PostgresStatementTimeout
PostgresStatementTimeoutMilliseconds forall a b. (a -> b) -> a -> b
$ forall a. Read a => String -> a
Unsafe.read String
digits
  (String, String)
_ -> forall a b. a -> Either a b
Left String
"must be {digits}(s|ms)"

envPostgresPasswordSource :: Env.Parser Env.Error PostgresPasswordSource
envPostgresPasswordSource :: Parser Error PostgresPasswordSource
envPostgresPasswordSource = forall a. Off a -> On a -> String -> Mod Flag a -> Parser Error a
Env.flag
  (forall a. a -> Off a
Env.Off PostgresPasswordSource
PostgresPasswordSourceEnv)
  (forall a. a -> On a
Env.On PostgresPasswordSource
PostgresPasswordSourceIamAuth)
  String
"USE_RDS_IAM_AUTH"
  forall a. Monoid a => a
mempty

envParseDatabaseConf
  :: PostgresPasswordSource -> Env.Parser Env.Error PostgresConnectionConf
envParseDatabaseConf :: PostgresPasswordSource -> Parser Error PostgresConnectionConf
envParseDatabaseConf PostgresPasswordSource
source = do
  String
user <- forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var forall e s. (AsEmpty e, IsString s) => Reader e s
Env.nonempty String
"PGUSER" forall a. Monoid a => a
mempty
  PostgresPassword
password <- case PostgresPasswordSource
source of
    PostgresPasswordSource
PostgresPasswordSourceIamAuth -> forall (f :: * -> *) a. Applicative f => a -> f a
pure PostgresPassword
PostgresPasswordIamAuth
    PostgresPasswordSource
PostgresPasswordSourceEnv ->
      String -> PostgresPassword
PostgresPasswordStatic forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var forall e s. (AsEmpty e, IsString s) => Reader e s
Env.nonempty String
"PGPASSWORD" forall a. Monoid a => a
mempty
  String
host <- forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var forall e s. (AsEmpty e, IsString s) => Reader e s
Env.nonempty String
"PGHOST" forall a. Monoid a => a
mempty
  String
database <- forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var forall e s. (AsEmpty e, IsString s) => Reader e s
Env.nonempty String
"PGDATABASE" forall a. Monoid a => a
mempty
  Int
port <- forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var forall e a. (AsUnread e, Read a) => Reader e a
Env.auto String
"PGPORT" forall a. Monoid a => a
mempty
  Int
poolSize <- forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var forall e a. (AsUnread e, Read a) => Reader e a
Env.auto String
"PGPOOLSIZE" forall a b. (a -> b) -> a -> b
$ forall a. a -> Mod Var a
Env.def Int
10
  PostgresStatementTimeout
statementTimeout <-
    forall e a.
AsUnset e =>
Reader e a -> String -> Mod Var a -> Parser e a
Env.var (forall a. (String -> Either String a) -> Reader Error a
Env.eitherReader String -> Either String PostgresStatementTimeout
readPostgresStatementTimeout) String
"PGSTATEMENTTIMEOUT"
      forall a b. (a -> b) -> a -> b
$ forall a. a -> Mod Var a
Env.def (Int -> PostgresStatementTimeout
PostgresStatementTimeoutSeconds Int
120)
  pure PostgresConnectionConf
    { pccHost :: String
pccHost = String
host
    , pccPort :: Int
pccPort = Int
port
    , pccUser :: String
pccUser = String
user
    , pccPassword :: PostgresPassword
pccPassword = PostgresPassword
password
    , pccDatabase :: String
pccDatabase = String
database
    , pccPoolSize :: Int
pccPoolSize = Int
poolSize
    , pccStatementTimeout :: PostgresStatementTimeout
pccStatementTimeout = PostgresStatementTimeout
statementTimeout
    }

data AuroraIamToken = AuroraIamToken
  { AuroraIamToken -> Text
aitToken :: Text
  , AuroraIamToken -> UTCTime
aitCreatedAt :: UTCTime
  , AuroraIamToken -> PostgresConnectionConf
aitPostgresConnectionConf :: PostgresConnectionConf
  }
  deriving stock (Int -> AuroraIamToken -> ShowS
[AuroraIamToken] -> ShowS
AuroraIamToken -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AuroraIamToken] -> ShowS
$cshowList :: [AuroraIamToken] -> ShowS
show :: AuroraIamToken -> String
$cshow :: AuroraIamToken -> String
showsPrec :: Int -> AuroraIamToken -> ShowS
$cshowsPrec :: Int -> AuroraIamToken -> ShowS
Show, AuroraIamToken -> AuroraIamToken -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AuroraIamToken -> AuroraIamToken -> Bool
$c/= :: AuroraIamToken -> AuroraIamToken -> Bool
== :: AuroraIamToken -> AuroraIamToken -> Bool
$c== :: AuroraIamToken -> AuroraIamToken -> Bool
Eq)

createAuroraIamToken :: MonadIO m => PostgresConnectionConf -> m AuroraIamToken
createAuroraIamToken :: forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> m AuroraIamToken
createAuroraIamToken aitPostgresConnectionConf :: PostgresConnectionConf
aitPostgresConnectionConf@PostgresConnectionConf {Int
String
PostgresStatementTimeout
PostgresPassword
pccStatementTimeout :: PostgresStatementTimeout
pccPoolSize :: Int
pccDatabase :: String
pccPassword :: PostgresPassword
pccUser :: String
pccPort :: Int
pccHost :: String
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccPoolSize :: PostgresConnectionConf -> Int
pccDatabase :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccUser :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccHost :: PostgresConnectionConf -> String
..} = do
  Text
aitToken <- Text -> Text
T.strip forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
decodeUtf8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BSL.toStrict forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) stdin stdoutIgnored stderr.
MonadIO m =>
ProcessConfig stdin stdoutIgnored stderr -> m ByteString
readProcessStdout_
    (String -> [String] -> ProcessConfig () () ()
proc
      String
"aws"
      [ String
"rds"
      , String
"generate-db-auth-token"
      , String
"--hostname"
      , String
pccHost
      , String
"--port"
      , forall a. Show a => a -> String
show Int
pccPort
      , String
"--username"
      , String
pccUser
      ]
    )
  UTCTime
aitCreatedAt <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
  pure AuroraIamToken { UTCTime
Text
PostgresConnectionConf
aitCreatedAt :: UTCTime
aitToken :: Text
aitPostgresConnectionConf :: PostgresConnectionConf
aitPostgresConnectionConf :: PostgresConnectionConf
aitCreatedAt :: UTCTime
aitToken :: Text
.. }

-- | Spawns a thread that refreshes the IAM auth token every minute
--
-- The IAM auth token lasts 15 minutes, but we refresh it every minute just to
-- be super safe.
--
spawnIamTokenRefreshThread
  :: (MonadUnliftIO m, MonadLogger m)
  => PostgresConnectionConf
  -> m (IORef AuroraIamToken)
spawnIamTokenRefreshThread :: forall (m :: * -> *).
(MonadUnliftIO m, MonadLogger m) =>
PostgresConnectionConf -> m (IORef AuroraIamToken)
spawnIamTokenRefreshThread PostgresConnectionConf
conf = do
  forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
Message -> m ()
logInfo Message
"Spawning thread to refresh IAM auth token"
  IORef AuroraIamToken
tokenIORef <- forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> m AuroraIamToken
createAuroraIamToken PostgresConnectionConf
conf
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadUnliftIO m =>
(Thread -> m ()) -> m Thread
Immortal.create forall a b. (a -> b) -> a -> b
$ \Thread
_ -> forall (m :: * -> *).
MonadUnliftIO m =>
(Either SomeException () -> m ()) -> m () -> m ()
Immortal.onFinish forall {m :: * -> *} {e}.
(MonadLogger m, Exception e) =>
Either e () -> m ()
onFinishCallback forall a b. (a -> b) -> a -> b
$ do
    forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
Message -> m ()
logDebug Message
"Refreshing IAM auth token"
    forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> IORef AuroraIamToken -> m ()
refreshIamToken PostgresConnectionConf
conf IORef AuroraIamToken
tokenIORef
    forall (m :: * -> *). MonadIO m => Int -> m ()
threadDelay forall {a}. Num a => a
oneMinuteInMicroseconds
  pure IORef AuroraIamToken
tokenIORef
 where
  oneMinuteInMicroseconds :: a
oneMinuteInMicroseconds = a
60 forall a. Num a => a -> a -> a
* a
1000000

  onFinishCallback :: Either e () -> m ()
onFinishCallback = \case
    Left e
ex ->
      forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
Message -> m ()
logError
        forall a b. (a -> b) -> a -> b
$ Text
"Error refreshing IAM auth token"
        Text -> [SeriesElem] -> Message
:# [Key
"exception" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= forall e. Exception e => e -> String
displayException e
ex]
    Right () -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

refreshIamToken
  :: MonadIO m => PostgresConnectionConf -> IORef AuroraIamToken -> m ()
refreshIamToken :: forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> IORef AuroraIamToken -> m ()
refreshIamToken PostgresConnectionConf
conf IORef AuroraIamToken
tokenIORef = do
  AuroraIamToken
token' <- forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> m AuroraIamToken
createAuroraIamToken PostgresConnectionConf
conf
  forall (m :: * -> *) a. MonadIO m => IORef a -> a -> m ()
writeIORef IORef AuroraIamToken
tokenIORef AuroraIamToken
token'

setTimeout :: MonadIO m => PostgresConnectionConf -> Connection -> m ()
setTimeout :: forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> Connection -> m ()
setTimeout PostgresConnectionConf {Int
String
PostgresStatementTimeout
PostgresPassword
pccStatementTimeout :: PostgresStatementTimeout
pccPoolSize :: Int
pccDatabase :: String
pccPassword :: PostgresPassword
pccUser :: String
pccPort :: Int
pccHost :: String
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccPoolSize :: PostgresConnectionConf -> Int
pccDatabase :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccUser :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccHost :: PostgresConnectionConf -> String
..} Connection
conn = do
  let timeoutMillis :: Int
timeoutMillis = PostgresStatementTimeout -> Int
postgresStatementTimeoutMilliseconds PostgresStatementTimeout
pccStatementTimeout
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall q. ToRow q => Connection -> Query -> q -> IO Int64
execute
    Connection
conn
    [sql| SET statement_timeout = ? |]
    (forall a. a -> Only a
Only Int
timeoutMillis)

makePostgresPoolWith
  :: (MonadUnliftIO m, MonadLoggerIO m) => PostgresConnectionConf -> m SqlPool
makePostgresPoolWith :: forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
PostgresConnectionConf -> m SqlPool
makePostgresPoolWith conf :: PostgresConnectionConf
conf@PostgresConnectionConf {Int
String
PostgresStatementTimeout
PostgresPassword
pccStatementTimeout :: PostgresStatementTimeout
pccPoolSize :: Int
pccDatabase :: String
pccPassword :: PostgresPassword
pccUser :: String
pccPort :: Int
pccHost :: String
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccPoolSize :: PostgresConnectionConf -> Int
pccDatabase :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccUser :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccHost :: PostgresConnectionConf -> String
..} = case PostgresPassword
pccPassword of
  PostgresPassword
PostgresPasswordIamAuth -> forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
PostgresConnectionConf -> m SqlPool
makePostgresPoolWithIamAuth PostgresConnectionConf
conf
  PostgresPasswordStatic String
password -> forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
(Connection -> IO ()) -> ByteString -> Int -> m SqlPool
createPostgresqlPoolModified
    (forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> Connection -> m ()
setTimeout PostgresConnectionConf
conf)
    (PostgresConnectionConf -> String -> ByteString
postgresConnectionString PostgresConnectionConf
conf String
password)
    Int
pccPoolSize

-- | Creates a PostgreSQL pool using IAM auth for the password
makePostgresPoolWithIamAuth
  :: (MonadUnliftIO m, MonadLoggerIO m) => PostgresConnectionConf -> m SqlPool
makePostgresPoolWithIamAuth :: forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
PostgresConnectionConf -> m SqlPool
makePostgresPoolWithIamAuth conf :: PostgresConnectionConf
conf@PostgresConnectionConf {Int
String
PostgresStatementTimeout
PostgresPassword
pccStatementTimeout :: PostgresStatementTimeout
pccPoolSize :: Int
pccDatabase :: String
pccPassword :: PostgresPassword
pccUser :: String
pccPort :: Int
pccHost :: String
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccPoolSize :: PostgresConnectionConf -> Int
pccDatabase :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccUser :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccHost :: PostgresConnectionConf -> String
..} = do
  IORef AuroraIamToken
tokenIORef <- forall (m :: * -> *).
(MonadUnliftIO m, MonadLogger m) =>
PostgresConnectionConf -> m (IORef AuroraIamToken)
spawnIamTokenRefreshThread PostgresConnectionConf
conf
  forall backend (m :: * -> *).
(MonadLoggerIO m, MonadUnliftIO m,
 BackendCompatible SqlBackend backend) =>
(LogFunc -> IO backend) -> Int -> m (Pool backend)
createSqlPool (IORef AuroraIamToken -> LogFunc -> IO SqlBackend
mkConn IORef AuroraIamToken
tokenIORef) Int
pccPoolSize
 where
  mkConn :: IORef AuroraIamToken -> LogFunc -> IO SqlBackend
mkConn IORef AuroraIamToken
tokenIORef LogFunc
logFunc = do
    AuroraIamToken
token <- forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef AuroraIamToken
tokenIORef
    let connStr :: ByteString
connStr = PostgresConnectionConf -> String -> ByteString
postgresConnectionString PostgresConnectionConf
conf (Text -> String
unpack forall a b. (a -> b) -> a -> b
$ AuroraIamToken -> Text
aitToken AuroraIamToken
token)
    Connection
conn <- ByteString -> IO Connection
connectPostgreSQL ByteString
connStr
    forall (m :: * -> *).
MonadIO m =>
PostgresConnectionConf -> Connection -> m ()
setTimeout PostgresConnectionConf
conf Connection
conn
    LogFunc -> Connection -> IO SqlBackend
openSimpleConn LogFunc
logFunc Connection
conn

postgresConnectionString :: PostgresConnectionConf -> String -> ByteString
postgresConnectionString :: PostgresConnectionConf -> String -> ByteString
postgresConnectionString PostgresConnectionConf {Int
String
PostgresStatementTimeout
PostgresPassword
pccStatementTimeout :: PostgresStatementTimeout
pccPoolSize :: Int
pccDatabase :: String
pccPassword :: PostgresPassword
pccUser :: String
pccPort :: Int
pccHost :: String
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccPoolSize :: PostgresConnectionConf -> Int
pccDatabase :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccUser :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccHost :: PostgresConnectionConf -> String
..} String
password =
  String -> ByteString
BS8.pack forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords
    [ String
"host=" forall a. Semigroup a => a -> a -> a
<> String
pccHost
    , String
"port=" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
pccPort
    , String
"user=" forall a. Semigroup a => a -> a -> a
<> String
pccUser
    , String
"password=" forall a. Semigroup a => a -> a -> a
<> String
password
    , String
"dbname=" forall a. Semigroup a => a -> a -> a
<> String
pccDatabase
    ]