{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE StrictData #-}

-- | Database access for your @App@
module Freckle.App.Database
  (
  -- * Abstract over access to a sql database
    HasSqlPool(..)
  , SqlPool
  , makePostgresPool
  , makePostgresPoolWith
  , runDB
  , PostgresConnectionConf(..)
  , PostgresPasswordSource(..)
  , PostgresPassword(..)
  , PostgresStatementTimeout(..)
  , postgresStatementTimeoutMilliseconds
  , envParseDatabaseConf
  , envPostgresPasswordSource
  ) where

import Freckle.App.Prelude

import Control.Concurrent
import qualified Control.Immortal as Immortal
import Control.Monad.Logger (runNoLoggingT)
import Control.Monad.Reader
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as BS8
import Data.Char (isDigit)
import Data.IORef
import Data.Pool
import qualified Data.Text as T
import Database.Persist.Postgresql
  ( SqlBackend
  , SqlPersistT
  , createPostgresqlPoolModified
  , createSqlPool
  , openSimpleConn
  , runSqlPool
  )
import Database.PostgreSQL.Simple
  (Connection, Only(..), connectPostgreSQL, execute)
import Database.PostgreSQL.Simple.SqlQQ (sql)
import qualified Freckle.App.Env as Env
import qualified Prelude as Unsafe (read)
import System.Process (readProcess)

type SqlPool = Pool SqlBackend

class HasSqlPool app where
  getSqlPool :: app -> SqlPool

instance HasSqlPool SqlPool where
  getSqlPool :: SqlPool -> SqlPool
getSqlPool = SqlPool -> SqlPool
forall a. a -> a
id

makePostgresPool :: IO SqlPool
makePostgresPool :: IO SqlPool
makePostgresPool = do
  PostgresPasswordSource
postgresPasswordSource <- Parser PostgresPasswordSource -> IO PostgresPasswordSource
forall a. Parser a -> IO a
Env.parse Parser PostgresPasswordSource
envPostgresPasswordSource
  PostgresConnectionConf
conf <- Parser PostgresConnectionConf -> IO PostgresConnectionConf
forall a. Parser a -> IO a
Env.parse (PostgresPasswordSource -> Parser PostgresConnectionConf
envParseDatabaseConf PostgresPasswordSource
postgresPasswordSource)
  PostgresConnectionConf -> IO SqlPool
makePostgresPoolWith PostgresConnectionConf
conf

runDB
  :: (HasSqlPool app, MonadUnliftIO m, MonadReader app m)
  => SqlPersistT m a
  -> m a
runDB :: SqlPersistT m a -> m a
runDB SqlPersistT m a
action = do
  SqlPool
pool <- (app -> SqlPool) -> m SqlPool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks app -> SqlPool
forall app. HasSqlPool app => app -> SqlPool
getSqlPool
  SqlPersistT m a -> SqlPool -> m a
forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> Pool backend -> m a
runSqlPool SqlPersistT m a
action SqlPool
pool

data PostgresConnectionConf = PostgresConnectionConf
  { PostgresConnectionConf -> String
pccHost :: String
  , PostgresConnectionConf -> Int
pccPort :: Int
  , PostgresConnectionConf -> String
pccUser :: String
  , PostgresConnectionConf -> PostgresPassword
pccPassword :: PostgresPassword
  , PostgresConnectionConf -> String
pccDatabase :: String
  , PostgresConnectionConf -> Int
pccPoolSize :: Int
  , PostgresConnectionConf -> PostgresStatementTimeout
pccStatementTimeout :: PostgresStatementTimeout
  }
  deriving stock (Int -> PostgresConnectionConf -> ShowS
[PostgresConnectionConf] -> ShowS
PostgresConnectionConf -> String
(Int -> PostgresConnectionConf -> ShowS)
-> (PostgresConnectionConf -> String)
-> ([PostgresConnectionConf] -> ShowS)
-> Show PostgresConnectionConf
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PostgresConnectionConf] -> ShowS
$cshowList :: [PostgresConnectionConf] -> ShowS
show :: PostgresConnectionConf -> String
$cshow :: PostgresConnectionConf -> String
showsPrec :: Int -> PostgresConnectionConf -> ShowS
$cshowsPrec :: Int -> PostgresConnectionConf -> ShowS
Show, PostgresConnectionConf -> PostgresConnectionConf -> Bool
(PostgresConnectionConf -> PostgresConnectionConf -> Bool)
-> (PostgresConnectionConf -> PostgresConnectionConf -> Bool)
-> Eq PostgresConnectionConf
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
$c/= :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
== :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
$c== :: PostgresConnectionConf -> PostgresConnectionConf -> Bool
Eq)

data PostgresPasswordSource
  = PostgresPasswordSourceIamAuth
  | PostgresPasswordSourceEnv
  deriving stock (Int -> PostgresPasswordSource -> ShowS
[PostgresPasswordSource] -> ShowS
PostgresPasswordSource -> String
(Int -> PostgresPasswordSource -> ShowS)
-> (PostgresPasswordSource -> String)
-> ([PostgresPasswordSource] -> ShowS)
-> Show PostgresPasswordSource
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PostgresPasswordSource] -> ShowS
$cshowList :: [PostgresPasswordSource] -> ShowS
show :: PostgresPasswordSource -> String
$cshow :: PostgresPasswordSource -> String
showsPrec :: Int -> PostgresPasswordSource -> ShowS
$cshowsPrec :: Int -> PostgresPasswordSource -> ShowS
Show, PostgresPasswordSource -> PostgresPasswordSource -> Bool
(PostgresPasswordSource -> PostgresPasswordSource -> Bool)
-> (PostgresPasswordSource -> PostgresPasswordSource -> Bool)
-> Eq PostgresPasswordSource
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
$c/= :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
== :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
$c== :: PostgresPasswordSource -> PostgresPasswordSource -> Bool
Eq)

data PostgresPassword
  = PostgresPasswordIamAuth
  | PostgresPasswordStatic String
  deriving stock (Int -> PostgresPassword -> ShowS
[PostgresPassword] -> ShowS
PostgresPassword -> String
(Int -> PostgresPassword -> ShowS)
-> (PostgresPassword -> String)
-> ([PostgresPassword] -> ShowS)
-> Show PostgresPassword
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PostgresPassword] -> ShowS
$cshowList :: [PostgresPassword] -> ShowS
show :: PostgresPassword -> String
$cshow :: PostgresPassword -> String
showsPrec :: Int -> PostgresPassword -> ShowS
$cshowsPrec :: Int -> PostgresPassword -> ShowS
Show, PostgresPassword -> PostgresPassword -> Bool
(PostgresPassword -> PostgresPassword -> Bool)
-> (PostgresPassword -> PostgresPassword -> Bool)
-> Eq PostgresPassword
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PostgresPassword -> PostgresPassword -> Bool
$c/= :: PostgresPassword -> PostgresPassword -> Bool
== :: PostgresPassword -> PostgresPassword -> Bool
$c== :: PostgresPassword -> PostgresPassword -> Bool
Eq)

data PostgresStatementTimeout
  = PostgresStatementTimeoutSeconds Int
  | PostgresStatementTimeoutMilliseconds Int
  deriving stock (Int -> PostgresStatementTimeout -> ShowS
[PostgresStatementTimeout] -> ShowS
PostgresStatementTimeout -> String
(Int -> PostgresStatementTimeout -> ShowS)
-> (PostgresStatementTimeout -> String)
-> ([PostgresStatementTimeout] -> ShowS)
-> Show PostgresStatementTimeout
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PostgresStatementTimeout] -> ShowS
$cshowList :: [PostgresStatementTimeout] -> ShowS
show :: PostgresStatementTimeout -> String
$cshow :: PostgresStatementTimeout -> String
showsPrec :: Int -> PostgresStatementTimeout -> ShowS
$cshowsPrec :: Int -> PostgresStatementTimeout -> ShowS
Show, PostgresStatementTimeout -> PostgresStatementTimeout -> Bool
(PostgresStatementTimeout -> PostgresStatementTimeout -> Bool)
-> (PostgresStatementTimeout -> PostgresStatementTimeout -> Bool)
-> Eq PostgresStatementTimeout
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PostgresStatementTimeout -> PostgresStatementTimeout -> Bool
$c/= :: PostgresStatementTimeout -> PostgresStatementTimeout -> Bool
== :: PostgresStatementTimeout -> PostgresStatementTimeout -> Bool
$c== :: PostgresStatementTimeout -> PostgresStatementTimeout -> Bool
Eq)

postgresStatementTimeoutMilliseconds :: PostgresStatementTimeout -> Int
postgresStatementTimeoutMilliseconds :: PostgresStatementTimeout -> Int
postgresStatementTimeoutMilliseconds = \case
  PostgresStatementTimeoutSeconds Int
s -> Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000
  PostgresStatementTimeoutMilliseconds Int
ms -> Int
ms

-- | Read @PGSTATEMENTTIMEOUT@ as seconds or milliseconds
--
-- >>> readPostgresStatementTimeout "10"
-- Right (PostgresStatementTimeoutSeconds 10)
--
-- >>> readPostgresStatementTimeout "10s"
-- Right (PostgresStatementTimeoutSeconds 10)
--
-- >>> readPostgresStatementTimeout "10ms"
-- Right (PostgresStatementTimeoutMilliseconds 10)
--
-- >>> readPostgresStatementTimeout "20m"
-- Left "..."
--
-- >>> readPostgresStatementTimeout "2m0"
-- Left "..."
--
readPostgresStatementTimeout
  :: String -> Either String PostgresStatementTimeout
readPostgresStatementTimeout :: String -> Either String PostgresStatementTimeout
readPostgresStatementTimeout String
x = case (Char -> Bool) -> String -> (String, String)
forall a. (a -> Bool) -> [a] -> ([a], [a])
span Char -> Bool
isDigit String
x of
  (String
"", String
_) -> String -> Either String PostgresStatementTimeout
forall a b. a -> Either a b
Left String
"must be {digits}(s|ms)"
  (String
digits, String
"") -> PostgresStatementTimeout -> Either String PostgresStatementTimeout
forall a b. b -> Either a b
Right (PostgresStatementTimeout
 -> Either String PostgresStatementTimeout)
-> PostgresStatementTimeout
-> Either String PostgresStatementTimeout
forall a b. (a -> b) -> a -> b
$ Int -> PostgresStatementTimeout
PostgresStatementTimeoutSeconds (Int -> PostgresStatementTimeout)
-> Int -> PostgresStatementTimeout
forall a b. (a -> b) -> a -> b
$ String -> Int
forall a. Read a => String -> a
Unsafe.read String
digits
  (String
digits, String
"s") -> PostgresStatementTimeout -> Either String PostgresStatementTimeout
forall a b. b -> Either a b
Right (PostgresStatementTimeout
 -> Either String PostgresStatementTimeout)
-> PostgresStatementTimeout
-> Either String PostgresStatementTimeout
forall a b. (a -> b) -> a -> b
$ Int -> PostgresStatementTimeout
PostgresStatementTimeoutSeconds (Int -> PostgresStatementTimeout)
-> Int -> PostgresStatementTimeout
forall a b. (a -> b) -> a -> b
$ String -> Int
forall a. Read a => String -> a
Unsafe.read String
digits
  (String
digits, String
"ms") ->
    PostgresStatementTimeout -> Either String PostgresStatementTimeout
forall a b. b -> Either a b
Right (PostgresStatementTimeout
 -> Either String PostgresStatementTimeout)
-> PostgresStatementTimeout
-> Either String PostgresStatementTimeout
forall a b. (a -> b) -> a -> b
$ Int -> PostgresStatementTimeout
PostgresStatementTimeoutMilliseconds (Int -> PostgresStatementTimeout)
-> Int -> PostgresStatementTimeout
forall a b. (a -> b) -> a -> b
$ String -> Int
forall a. Read a => String -> a
Unsafe.read String
digits
  (String, String)
_ -> String -> Either String PostgresStatementTimeout
forall a b. a -> Either a b
Left String
"must be {digits}(s|ms)"

envPostgresPasswordSource :: Env.Parser PostgresPasswordSource
envPostgresPasswordSource :: Parser PostgresPasswordSource
envPostgresPasswordSource = do
  Bool
useIam <- String -> Mod Bool -> Parser Bool
Env.switch String
"USE_RDS_IAM_AUTH" (Mod Bool -> Parser Bool) -> Mod Bool -> Parser Bool
forall a b. (a -> b) -> a -> b
$ Bool -> Mod Bool
forall a. a -> Mod a
Env.def Bool
False
  pure $ if Bool
useIam
    then PostgresPasswordSource
PostgresPasswordSourceIamAuth
    else PostgresPasswordSource
PostgresPasswordSourceEnv

envParseDatabaseConf
  :: PostgresPasswordSource -> Env.Parser PostgresConnectionConf
envParseDatabaseConf :: PostgresPasswordSource -> Parser PostgresConnectionConf
envParseDatabaseConf PostgresPasswordSource
source = do
  String
user <- Reader String -> String -> Mod String -> Parser String
forall a. Reader a -> String -> Mod a -> Parser a
Env.var Reader String
forall a. IsString a => Reader a
Env.str String
"PGUSER" Mod String
forall a. Mod a
Env.nonEmpty
  PostgresPassword
password <- case PostgresPasswordSource
source of
    PostgresPasswordSource
PostgresPasswordSourceIamAuth -> PostgresPassword -> Parser PostgresPassword
forall (f :: * -> *) a. Applicative f => a -> f a
pure PostgresPassword
PostgresPasswordIamAuth
    PostgresPasswordSource
PostgresPasswordSourceEnv ->
      String -> PostgresPassword
PostgresPasswordStatic (String -> PostgresPassword)
-> Parser String -> Parser PostgresPassword
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Reader String -> String -> Mod String -> Parser String
forall a. Reader a -> String -> Mod a -> Parser a
Env.var Reader String
forall a. IsString a => Reader a
Env.str String
"PGPASSWORD" Mod String
forall a. Mod a
Env.nonEmpty
  String
host <- Reader String -> String -> Mod String -> Parser String
forall a. Reader a -> String -> Mod a -> Parser a
Env.var Reader String
forall a. IsString a => Reader a
Env.str String
"PGHOST" Mod String
forall a. Mod a
Env.nonEmpty
  String
database <- Reader String -> String -> Mod String -> Parser String
forall a. Reader a -> String -> Mod a -> Parser a
Env.var Reader String
forall a. IsString a => Reader a
Env.str String
"PGDATABASE" Mod String
forall a. Mod a
Env.nonEmpty
  Int
port <- Reader Int -> String -> Mod Int -> Parser Int
forall a. Reader a -> String -> Mod a -> Parser a
Env.var Reader Int
forall a. Read a => Reader a
Env.auto String
"PGPORT" Mod Int
forall a. Mod a
Env.nonEmpty
  Int
poolSize <- Reader Int -> String -> Mod Int -> Parser Int
forall a. Reader a -> String -> Mod a -> Parser a
Env.var Reader Int
forall a. Read a => Reader a
Env.auto String
"PGPOOLSIZE" (Mod Int -> Parser Int) -> Mod Int -> Parser Int
forall a b. (a -> b) -> a -> b
$ Int -> Mod Int
forall a. a -> Mod a
Env.def Int
1
  PostgresStatementTimeout
statementTimeout <-
    Reader PostgresStatementTimeout
-> String
-> Mod PostgresStatementTimeout
-> Parser PostgresStatementTimeout
forall a. Reader a -> String -> Mod a -> Parser a
Env.var ((String -> Either String PostgresStatementTimeout)
-> Reader PostgresStatementTimeout
forall a. (String -> Either String a) -> Reader a
Env.eitherReader String -> Either String PostgresStatementTimeout
readPostgresStatementTimeout) String
"PGSTATEMENTTIMEOUT"
      (Mod PostgresStatementTimeout -> Parser PostgresStatementTimeout)
-> Mod PostgresStatementTimeout -> Parser PostgresStatementTimeout
forall a b. (a -> b) -> a -> b
$ PostgresStatementTimeout -> Mod PostgresStatementTimeout
forall a. a -> Mod a
Env.def (Int -> PostgresStatementTimeout
PostgresStatementTimeoutSeconds Int
120)
  pure PostgresConnectionConf :: String
-> Int
-> String
-> PostgresPassword
-> String
-> Int
-> PostgresStatementTimeout
-> PostgresConnectionConf
PostgresConnectionConf
    { pccHost :: String
pccHost = String
host
    , pccPort :: Int
pccPort = Int
port
    , pccUser :: String
pccUser = String
user
    , pccPassword :: PostgresPassword
pccPassword = PostgresPassword
password
    , pccDatabase :: String
pccDatabase = String
database
    , pccPoolSize :: Int
pccPoolSize = Int
poolSize
    , pccStatementTimeout :: PostgresStatementTimeout
pccStatementTimeout = PostgresStatementTimeout
statementTimeout
    }

data AuroraIamToken = AuroraIamToken
  { AuroraIamToken -> String
aitToken :: String
  , AuroraIamToken -> UTCTime
aitCreatedAt :: UTCTime
  , AuroraIamToken -> PostgresConnectionConf
aitPostgresConnectionConf :: PostgresConnectionConf
  }
  deriving stock (Int -> AuroraIamToken -> ShowS
[AuroraIamToken] -> ShowS
AuroraIamToken -> String
(Int -> AuroraIamToken -> ShowS)
-> (AuroraIamToken -> String)
-> ([AuroraIamToken] -> ShowS)
-> Show AuroraIamToken
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AuroraIamToken] -> ShowS
$cshowList :: [AuroraIamToken] -> ShowS
show :: AuroraIamToken -> String
$cshow :: AuroraIamToken -> String
showsPrec :: Int -> AuroraIamToken -> ShowS
$cshowsPrec :: Int -> AuroraIamToken -> ShowS
Show, AuroraIamToken -> AuroraIamToken -> Bool
(AuroraIamToken -> AuroraIamToken -> Bool)
-> (AuroraIamToken -> AuroraIamToken -> Bool) -> Eq AuroraIamToken
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AuroraIamToken -> AuroraIamToken -> Bool
$c/= :: AuroraIamToken -> AuroraIamToken -> Bool
== :: AuroraIamToken -> AuroraIamToken -> Bool
$c== :: AuroraIamToken -> AuroraIamToken -> Bool
Eq)

createAuroraIamToken :: PostgresConnectionConf -> IO AuroraIamToken
createAuroraIamToken :: PostgresConnectionConf -> IO AuroraIamToken
createAuroraIamToken aitPostgresConnectionConf :: PostgresConnectionConf
aitPostgresConnectionConf@PostgresConnectionConf {Int
String
PostgresStatementTimeout
PostgresPassword
pccStatementTimeout :: PostgresStatementTimeout
pccPoolSize :: Int
pccDatabase :: String
pccPassword :: PostgresPassword
pccUser :: String
pccPort :: Int
pccHost :: String
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccPoolSize :: PostgresConnectionConf -> Int
pccDatabase :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccUser :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccHost :: PostgresConnectionConf -> String
..} = do
  -- TODO: Consider recording how long creating an auth token takes
  -- somewhere, even if it is just in the logs, so we get an idea of how long
  -- it takes in prod.
  String
aitToken <- Text -> String
T.unpack (Text -> String) -> (String -> Text) -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
T.strip (Text -> Text) -> (String -> Text) -> String -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack ShowS -> IO String -> IO String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> [String] -> String -> IO String
readProcess
    String
"aws"
    [ String
"rds"
    , String
"generate-db-auth-token"
    , String
"--hostname"
    , String
pccHost
    , String
"--port"
    , Int -> String
forall a. Show a => a -> String
show Int
pccPort
    , String
"--username"
    , String
pccUser
    ]
    String
""
  UTCTime
aitCreatedAt <- IO UTCTime
getCurrentTime
  pure AuroraIamToken :: String -> UTCTime -> PostgresConnectionConf -> AuroraIamToken
AuroraIamToken { String
UTCTime
PostgresConnectionConf
aitCreatedAt :: UTCTime
aitToken :: String
aitPostgresConnectionConf :: PostgresConnectionConf
aitPostgresConnectionConf :: PostgresConnectionConf
aitCreatedAt :: UTCTime
aitToken :: String
.. }

-- | Spawns a thread that refreshes the IAM auth token every minute
--
-- The IAM auth token lasts 15 minutes, but we refresh it every minute just to
-- be super safe.
--
spawnIamTokenRefreshThread
  :: PostgresConnectionConf -> IO (IORef AuroraIamToken)
spawnIamTokenRefreshThread :: PostgresConnectionConf -> IO (IORef AuroraIamToken)
spawnIamTokenRefreshThread PostgresConnectionConf
conf = do
  IORef AuroraIamToken
tokenIORef <- AuroraIamToken -> IO (IORef AuroraIamToken)
forall a. a -> IO (IORef a)
newIORef (AuroraIamToken -> IO (IORef AuroraIamToken))
-> IO AuroraIamToken -> IO (IORef AuroraIamToken)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PostgresConnectionConf -> IO AuroraIamToken
createAuroraIamToken PostgresConnectionConf
conf
  IO Thread -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Thread -> IO ()) -> IO Thread -> IO ()
forall a b. (a -> b) -> a -> b
$ (Thread -> IO ()) -> IO Thread
forall (m :: * -> *).
MonadUnliftIO m =>
(Thread -> m ()) -> m Thread
Immortal.create ((Thread -> IO ()) -> IO Thread) -> (Thread -> IO ()) -> IO Thread
forall a b. (a -> b) -> a -> b
$ \Thread
_ -> (Either SomeException () -> IO ()) -> IO () -> IO ()
forall (m :: * -> *).
MonadUnliftIO m =>
(Either SomeException () -> m ()) -> m () -> m ()
Immortal.onFinish Either SomeException () -> IO ()
forall a. Show a => Either a () -> IO ()
onFinishCallback (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    PostgresConnectionConf -> IORef AuroraIamToken -> IO ()
refreshIamToken PostgresConnectionConf
conf IORef AuroraIamToken
tokenIORef
    Int -> IO ()
threadDelay Int
forall a. Num a => a
oneMinuteInMicroseconds
  pure IORef AuroraIamToken
tokenIORef
 where
  oneMinuteInMicroseconds :: a
oneMinuteInMicroseconds = a
60 a -> a -> a
forall a. Num a => a -> a -> a
* a
1000000

  onFinishCallback :: Either a () -> IO ()
onFinishCallback (Left a
ex) =
    -- TODO: Somehow get MonadLogger-style error log message in here
    String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Error refreshing IAM auth token: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
ex
  onFinishCallback (Right ()) = () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

refreshIamToken :: PostgresConnectionConf -> IORef AuroraIamToken -> IO ()
refreshIamToken :: PostgresConnectionConf -> IORef AuroraIamToken -> IO ()
refreshIamToken PostgresConnectionConf
conf IORef AuroraIamToken
tokenIORef = do
  AuroraIamToken
token' <- PostgresConnectionConf -> IO AuroraIamToken
createAuroraIamToken PostgresConnectionConf
conf
  IORef AuroraIamToken -> AuroraIamToken -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef AuroraIamToken
tokenIORef AuroraIamToken
token'

-- isAuroraIamTokenExpired :: AuroraIamToken -> IO Bool
-- isAuroraIamTokenExpired AuroraIamToken {..} = do
--   now <- getCurrentTime
--   let tenMinutesInSeconds = 60 * 15
--   pure $ now `diffUTCTime` aitCreatedAt > tenMinutesInSeconds

setTimeout :: PostgresConnectionConf -> Connection -> IO ()
setTimeout :: PostgresConnectionConf -> Connection -> IO ()
setTimeout PostgresConnectionConf {Int
String
PostgresStatementTimeout
PostgresPassword
pccStatementTimeout :: PostgresStatementTimeout
pccPoolSize :: Int
pccDatabase :: String
pccPassword :: PostgresPassword
pccUser :: String
pccPort :: Int
pccHost :: String
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccPoolSize :: PostgresConnectionConf -> Int
pccDatabase :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccUser :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccHost :: PostgresConnectionConf -> String
..} Connection
conn =
  let timeoutMillis :: Int
timeoutMillis = PostgresStatementTimeout -> Int
postgresStatementTimeoutMilliseconds PostgresStatementTimeout
pccStatementTimeout
  in IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> IO Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> Only Int -> IO Int64
forall q. ToRow q => Connection -> Query -> q -> IO Int64
execute Connection
conn [sql| SET statement_timeout = ? |] (Int -> Only Int
forall a. a -> Only a
Only Int
timeoutMillis)

makePostgresPoolWith :: PostgresConnectionConf -> IO SqlPool
makePostgresPoolWith :: PostgresConnectionConf -> IO SqlPool
makePostgresPoolWith conf :: PostgresConnectionConf
conf@PostgresConnectionConf {Int
String
PostgresStatementTimeout
PostgresPassword
pccStatementTimeout :: PostgresStatementTimeout
pccPoolSize :: Int
pccDatabase :: String
pccPassword :: PostgresPassword
pccUser :: String
pccPort :: Int
pccHost :: String
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccPoolSize :: PostgresConnectionConf -> Int
pccDatabase :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccUser :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccHost :: PostgresConnectionConf -> String
..} = case PostgresPassword
pccPassword of
  PostgresPassword
PostgresPasswordIamAuth -> PostgresConnectionConf -> IO SqlPool
makePostgresPoolWithIamAuth PostgresConnectionConf
conf
  PostgresPasswordStatic String
password ->
    NoLoggingT IO SqlPool -> IO SqlPool
forall (m :: * -> *) a. NoLoggingT m a -> m a
runNoLoggingT (NoLoggingT IO SqlPool -> IO SqlPool)
-> NoLoggingT IO SqlPool -> IO SqlPool
forall a b. (a -> b) -> a -> b
$ (Connection -> IO ())
-> ConnectionString -> Int -> NoLoggingT IO SqlPool
forall (m :: * -> *).
(MonadUnliftIO m, MonadLoggerIO m) =>
(Connection -> IO ()) -> ConnectionString -> Int -> m SqlPool
createPostgresqlPoolModified
      (PostgresConnectionConf -> Connection -> IO ()
setTimeout PostgresConnectionConf
conf)
      (PostgresConnectionConf -> String -> ConnectionString
postgresConnectionString PostgresConnectionConf
conf String
password)
      Int
pccPoolSize

-- | Creates a PostgreSQL pool using IAM auth for the password.
makePostgresPoolWithIamAuth :: PostgresConnectionConf -> IO SqlPool
makePostgresPoolWithIamAuth :: PostgresConnectionConf -> IO SqlPool
makePostgresPoolWithIamAuth conf :: PostgresConnectionConf
conf@PostgresConnectionConf {Int
String
PostgresStatementTimeout
PostgresPassword
pccStatementTimeout :: PostgresStatementTimeout
pccPoolSize :: Int
pccDatabase :: String
pccPassword :: PostgresPassword
pccUser :: String
pccPort :: Int
pccHost :: String
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccPoolSize :: PostgresConnectionConf -> Int
pccDatabase :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccUser :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccHost :: PostgresConnectionConf -> String
..} = do
  IORef AuroraIamToken
tokenIORef <- PostgresConnectionConf -> IO (IORef AuroraIamToken)
spawnIamTokenRefreshThread PostgresConnectionConf
conf
  NoLoggingT IO SqlPool -> IO SqlPool
forall (m :: * -> *) a. NoLoggingT m a -> m a
runNoLoggingT (NoLoggingT IO SqlPool -> IO SqlPool)
-> NoLoggingT IO SqlPool -> IO SqlPool
forall a b. (a -> b) -> a -> b
$ (LogFunc -> IO SqlBackend) -> Int -> NoLoggingT IO SqlPool
forall backend (m :: * -> *).
(MonadLoggerIO m, MonadUnliftIO m,
 BackendCompatible SqlBackend backend) =>
(LogFunc -> IO backend) -> Int -> m (Pool backend)
createSqlPool (IORef AuroraIamToken -> LogFunc -> IO SqlBackend
mkConn IORef AuroraIamToken
tokenIORef) Int
pccPoolSize
 where
  -- TODO: Instead of refreshing the token before creating a connection, we
  -- could spawn a separate thread to refresh it on a timer. That way we don't
  -- waste time refreshing it when we want to make a new connection.
  mkConn :: IORef AuroraIamToken -> LogFunc -> IO SqlBackend
mkConn IORef AuroraIamToken
tokenIORef LogFunc
logFunc = do
    AuroraIamToken
token <- IORef AuroraIamToken -> IO AuroraIamToken
forall a. IORef a -> IO a
readIORef IORef AuroraIamToken
tokenIORef
    let connStr :: ConnectionString
connStr = PostgresConnectionConf -> String -> ConnectionString
postgresConnectionString PostgresConnectionConf
conf (AuroraIamToken -> String
aitToken AuroraIamToken
token)
    Connection
conn <- ConnectionString -> IO Connection
connectPostgreSQL ConnectionString
connStr
    PostgresConnectionConf -> Connection -> IO ()
setTimeout PostgresConnectionConf
conf Connection
conn
    LogFunc -> Connection -> IO SqlBackend
openSimpleConn LogFunc
logFunc Connection
conn

postgresConnectionString :: PostgresConnectionConf -> String -> ByteString
postgresConnectionString :: PostgresConnectionConf -> String -> ConnectionString
postgresConnectionString PostgresConnectionConf {Int
String
PostgresStatementTimeout
PostgresPassword
pccStatementTimeout :: PostgresStatementTimeout
pccPoolSize :: Int
pccDatabase :: String
pccPassword :: PostgresPassword
pccUser :: String
pccPort :: Int
pccHost :: String
pccStatementTimeout :: PostgresConnectionConf -> PostgresStatementTimeout
pccPoolSize :: PostgresConnectionConf -> Int
pccDatabase :: PostgresConnectionConf -> String
pccPassword :: PostgresConnectionConf -> PostgresPassword
pccUser :: PostgresConnectionConf -> String
pccPort :: PostgresConnectionConf -> Int
pccHost :: PostgresConnectionConf -> String
..} String
password =
  String -> ConnectionString
BS8.pack (String -> ConnectionString) -> String -> ConnectionString
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords
    [ String
"host=" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
pccHost
    , String
"port=" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
pccPort
    , String
"user=" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
pccUser
    , String
"password=" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
password
    , String
"dbname=" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
pccDatabase
    ]