{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}

-- | More efficient query execution functions for @beam-postgres@. These
-- functions use the @conduit@ package, to execute @beam-postgres@ statements in
-- an arbitrary 'MonadIO'. These functions may be more efficient for streaming
-- operations than 'MonadBeam'.
module Database.Beam.Postgres.Conduit where

import           Database.Beam
import           Database.Beam.Postgres.Connection
import           Database.Beam.Postgres.Full
import           Database.Beam.Postgres.Syntax
import           Database.Beam.Postgres.Types

import           Control.Exception.Lifted (finally)
import           Control.Monad.Trans.Control (MonadBaseControl)

import qualified Database.PostgreSQL.LibPQ as Pg hiding
  (Connection, escapeStringConn, escapeIdentifier, escapeByteaConn, exec)
import qualified Database.PostgreSQL.Simple as Pg
import qualified Database.PostgreSQL.Simple.Internal as Pg (withConnection)
import qualified Database.PostgreSQL.Simple.Types as Pg (Query(..))

import qualified Data.Conduit as C
import           Data.Int (Int64)
import           Data.Maybe (fromMaybe)
#if !MIN_VERSION_base(4, 11, 0)
import           Data.Semigroup
#endif

import qualified Control.Monad.Fail as Fail

#if MIN_VERSION_conduit(1,3,0)
#define CONDUIT_TRANSFORMER C.ConduitT
#else
#define CONDUIT_TRANSFORMER C.ConduitM
#endif

-- * @SELECT@

-- | Run a PostgreSQL @SELECT@ statement in any 'MonadIO'.
runSelect :: ( MonadIO m, Fail.MonadFail m, MonadBaseControl IO m, FromBackendRow Postgres a )
          => Pg.Connection -> SqlSelect Postgres a
          -> (CONDUIT_TRANSFORMER () a m () -> m b) -> m b
runSelect :: Connection
-> SqlSelect Postgres a -> (ConduitT () a m () -> m b) -> m b
runSelect Connection
conn (SqlSelect (PgSelectSyntax syntax)) ConduitT () a m () -> m b
withSrc =
  Connection -> PgSyntax -> (ConduitT () a m () -> m b) -> m b
forall (m :: * -> *) r b.
(MonadIO m, MonadFail m, MonadBaseControl IO m, Functor m,
 FromBackendRow Postgres r) =>
Connection -> PgSyntax -> (ConduitT () r m () -> m b) -> m b
runQueryReturning Connection
conn PgSyntax
syntax ConduitT () a m () -> m b
withSrc

-- * @INSERT@

-- | Run a PostgreSQL @INSERT@ statement in any 'MonadIO'. Returns the number of
-- rows affected.
runInsert :: MonadIO m
          => Pg.Connection -> SqlInsert Postgres tbl -> m Int64
runInsert :: Connection -> SqlInsert Postgres tbl -> m Int64
runInsert Connection
_ SqlInsert Postgres tbl
SqlInsertNoRows = Int64 -> m Int64
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int64
0
runInsert Connection
conn (SqlInsert TableSettings tbl
_ (PgInsertSyntax i)) =
  Connection -> PgSyntax -> m Int64
forall (m :: * -> *).
MonadIO m =>
Connection -> PgSyntax -> m Int64
executeStatement Connection
conn PgSyntax
i

-- | Run a PostgreSQL @INSERT ... RETURNING ...@ statement in any 'MonadIO' and
-- get a 'C.Source' of the newly inserted rows.
runInsertReturning :: ( MonadIO m, Fail.MonadFail m, MonadBaseControl IO m, FromBackendRow Postgres a )
                   => Pg.Connection
                   -> PgInsertReturning a
                   -> (CONDUIT_TRANSFORMER () a m () -> m b)
                   -> m b
runInsertReturning :: Connection
-> PgInsertReturning a -> (ConduitT () a m () -> m b) -> m b
runInsertReturning Connection
_ PgInsertReturning a
PgInsertReturningEmpty ConduitT () a m () -> m b
withSrc = ConduitT () a m () -> m b
withSrc (() -> ConduitT () a m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
runInsertReturning Connection
conn (PgInsertReturning PgSyntax
i) ConduitT () a m () -> m b
withSrc =
    Connection -> PgSyntax -> (ConduitT () a m () -> m b) -> m b
forall (m :: * -> *) r b.
(MonadIO m, MonadFail m, MonadBaseControl IO m, Functor m,
 FromBackendRow Postgres r) =>
Connection -> PgSyntax -> (ConduitT () r m () -> m b) -> m b
runQueryReturning Connection
conn PgSyntax
i ConduitT () a m () -> m b
withSrc

-- * @UPDATE@

-- | Run a PostgreSQL @UPDATE@ statement in any 'MonadIO'. Returns the number of
-- rows affected.
runUpdate :: MonadIO m
          => Pg.Connection -> SqlUpdate Postgres tbl -> m Int64
runUpdate :: Connection -> SqlUpdate Postgres tbl -> m Int64
runUpdate Connection
_ SqlUpdate Postgres tbl
SqlIdentityUpdate = Int64 -> m Int64
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int64
0
runUpdate Connection
conn (SqlUpdate TableSettings tbl
_ (PgUpdateSyntax i)) =
    Connection -> PgSyntax -> m Int64
forall (m :: * -> *).
MonadIO m =>
Connection -> PgSyntax -> m Int64
executeStatement Connection
conn PgSyntax
i

-- | Run a PostgreSQL @UPDATE ... RETURNING ...@ statement in any 'MonadIO' and
-- get a 'C.Source' of the newly updated rows.
runUpdateReturning :: ( MonadIO m, Fail.MonadFail m, MonadBaseControl IO m, FromBackendRow Postgres a)
                   => Pg.Connection
                   -> PgUpdateReturning a
                   -> (CONDUIT_TRANSFORMER () a m () -> m b)
                   -> m b
runUpdateReturning :: Connection
-> PgUpdateReturning a -> (ConduitT () a m () -> m b) -> m b
runUpdateReturning Connection
_ PgUpdateReturning a
PgUpdateReturningEmpty ConduitT () a m () -> m b
withSrc = ConduitT () a m () -> m b
withSrc (() -> ConduitT () a m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
runUpdateReturning Connection
conn (PgUpdateReturning PgSyntax
u) ConduitT () a m () -> m b
withSrc =
  Connection -> PgSyntax -> (ConduitT () a m () -> m b) -> m b
forall (m :: * -> *) r b.
(MonadIO m, MonadFail m, MonadBaseControl IO m, Functor m,
 FromBackendRow Postgres r) =>
Connection -> PgSyntax -> (ConduitT () r m () -> m b) -> m b
runQueryReturning Connection
conn PgSyntax
u ConduitT () a m () -> m b
withSrc

-- * @DELETE@

-- | Run a PostgreSQL @DELETE@ statement in any 'MonadIO'. Returns the number of
-- rows affected.
runDelete :: MonadIO m
          => Pg.Connection -> SqlDelete Postgres tbl
          -> m Int64
runDelete :: Connection -> SqlDelete Postgres tbl -> m Int64
runDelete Connection
conn (SqlDelete TableSettings tbl
_ (PgDeleteSyntax d)) =
    Connection -> PgSyntax -> m Int64
forall (m :: * -> *).
MonadIO m =>
Connection -> PgSyntax -> m Int64
executeStatement Connection
conn PgSyntax
d

-- | Run a PostgreSQl @DELETE ... RETURNING ...@ statement in any
-- 'MonadIO' and get a 'C.Source' of the deleted rows.
runDeleteReturning :: ( MonadIO m, Fail.MonadFail m, MonadBaseControl IO m, FromBackendRow Postgres a )
                   => Pg.Connection -> PgDeleteReturning a
                   -> (CONDUIT_TRANSFORMER () a m () -> m b) -> m b
runDeleteReturning :: Connection
-> PgDeleteReturning a -> (ConduitT () a m () -> m b) -> m b
runDeleteReturning Connection
conn (PgDeleteReturning PgSyntax
d) ConduitT () a m () -> m b
withSrc =
  Connection -> PgSyntax -> (ConduitT () a m () -> m b) -> m b
forall (m :: * -> *) r b.
(MonadIO m, MonadFail m, MonadBaseControl IO m, Functor m,
 FromBackendRow Postgres r) =>
Connection -> PgSyntax -> (ConduitT () r m () -> m b) -> m b
runQueryReturning Connection
conn PgSyntax
d ConduitT () a m () -> m b
withSrc

-- * Convenience functions

-- | Run any DML statement. Return the number of rows affected
executeStatement ::  MonadIO m => Pg.Connection -> PgSyntax -> m Int64
executeStatement :: Connection -> PgSyntax -> m Int64
executeStatement Connection
conn PgSyntax
x =
  IO Int64 -> m Int64
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int64 -> m Int64) -> IO Int64 -> m Int64
forall a b. (a -> b) -> a -> b
$ do
    ByteString
syntax <- Connection -> PgSyntax -> IO ByteString
pgRenderSyntax Connection
conn PgSyntax
x
    Connection -> Query -> IO Int64
Pg.execute_ Connection
conn (ByteString -> Query
Pg.Query ByteString
syntax)

-- | Runs any query that returns a set of values
runQueryReturning
  :: ( MonadIO m, Fail.MonadFail m, MonadBaseControl IO m, Functor m, FromBackendRow Postgres r )
  => Pg.Connection -> PgSyntax
  -> (CONDUIT_TRANSFORMER () r m () -> m b)
  -> m b
runQueryReturning :: Connection -> PgSyntax -> (ConduitT () r m () -> m b) -> m b
runQueryReturning Connection
conn PgSyntax
x ConduitT () r m () -> m b
withSrc = do
  Bool
success <- IO Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ do
    ByteString
syntax <- Connection -> PgSyntax -> IO ByteString
pgRenderSyntax Connection
conn PgSyntax
x

    Connection -> (Connection -> IO Bool) -> IO Bool
forall a. Connection -> (Connection -> IO a) -> IO a
Pg.withConnection Connection
conn (\Connection
conn' -> Connection -> ByteString -> IO Bool
Pg.sendQuery Connection
conn' ByteString
syntax)

  if Bool
success
    then do
      Bool
singleRowModeSet <- IO Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Connection -> (Connection -> IO Bool) -> IO Bool
forall a. Connection -> (Connection -> IO a) -> IO a
Pg.withConnection Connection
conn Connection -> IO Bool
Pg.setSingleRowMode)
      if Bool
singleRowModeSet
         then ConduitT () r m () -> m b
withSrc (Maybe [Field] -> ConduitT () r m ()
streamResults Maybe [Field]
forall a. Maybe a
Nothing) m b -> m () -> m b
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
m a -> m b -> m a
`finally` m ()
gracefulShutdown
         else String -> m b
forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail String
"Could not enable single row mode"
    else do
      ByteString
errMsg <- ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"No libpq error provided" (Maybe ByteString -> ByteString)
-> m (Maybe ByteString) -> m ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Maybe ByteString) -> m (Maybe ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Connection
-> (Connection -> IO (Maybe ByteString)) -> IO (Maybe ByteString)
forall a. Connection -> (Connection -> IO a) -> IO a
Pg.withConnection Connection
conn Connection -> IO (Maybe ByteString)
Pg.errorMessage)
      String -> m b
forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail (ByteString -> String
forall a. Show a => a -> String
show ByteString
errMsg)

  where
    streamResults :: Maybe [Field] -> ConduitT () r m ()
streamResults Maybe [Field]
fields = do
      Maybe Result
nextRow <- IO (Maybe Result) -> ConduitT () r m (Maybe Result)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Connection
-> (Connection -> IO (Maybe Result)) -> IO (Maybe Result)
forall a. Connection -> (Connection -> IO a) -> IO a
Pg.withConnection Connection
conn Connection -> IO (Maybe Result)
Pg.getResult)
      case Maybe Result
nextRow of
        Maybe Result
Nothing -> () -> ConduitT () r m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Just Result
row ->
          IO ExecStatus -> ConduitT () r m ExecStatus
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Result -> IO ExecStatus
Pg.resultStatus Result
row) ConduitT () r m ExecStatus
-> (ExecStatus -> ConduitT () r m ()) -> ConduitT () r m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
          \case
            ExecStatus
Pg.SingleTuple ->
              do [Field]
fields' <- IO [Field] -> ConduitT () r m [Field]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [Field]
-> ([Field] -> IO [Field]) -> Maybe [Field] -> IO [Field]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Result -> IO [Field]
getFields Result
row) [Field] -> IO [Field]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe [Field]
fields)
                 Either BeamRowReadError r
parsedRow <- IO (Either BeamRowReadError r)
-> ConduitT () r m (Either BeamRowReadError r)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Connection
-> Row
-> Result
-> [Field]
-> FromBackendRowM Postgres r
-> IO (Either BeamRowReadError r)
forall a.
Connection
-> Row
-> Result
-> [Field]
-> FromBackendRowM Postgres a
-> IO (Either BeamRowReadError a)
runPgRowReader Connection
conn Row
0 Result
row [Field]
fields' FromBackendRowM Postgres r
forall be a. FromBackendRow be a => FromBackendRowM be a
fromBackendRow)
                 case Either BeamRowReadError r
parsedRow of
                   Left BeamRowReadError
err -> IO () -> ConduitT () r m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Result -> String -> IO ()
bailEarly Result
row (String
"Could not read row: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> BeamRowReadError -> String
forall a. Show a => a -> String
show BeamRowReadError
err))
                   Right r
parsedRow' ->
                     do r -> ConduitT () r m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
C.yield r
parsedRow'
                        Maybe [Field] -> ConduitT () r m ()
streamResults ([Field] -> Maybe [Field]
forall a. a -> Maybe a
Just [Field]
fields')
            ExecStatus
Pg.TuplesOk -> IO () -> ConduitT () r m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Connection -> (Connection -> IO ()) -> IO ()
forall a. Connection -> (Connection -> IO a) -> IO a
Pg.withConnection Connection
conn Connection -> IO ()
finishQuery)
            ExecStatus
Pg.EmptyQuery -> String -> ConduitT () r m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail String
"No query"
            ExecStatus
Pg.CommandOk -> () -> ConduitT () r m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            ExecStatus
_ -> do Maybe ByteString
errMsg <- IO (Maybe ByteString) -> ConduitT () r m (Maybe ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Result -> IO (Maybe ByteString)
Pg.resultErrorMessage Result
row)
                    String -> ConduitT () r m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail (String
"Postgres error: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Maybe ByteString -> String
forall a. Show a => a -> String
show Maybe ByteString
errMsg)

    bailEarly :: Result -> String -> IO ()
bailEarly Result
row String
errorString = do
      Result -> IO ()
Pg.unsafeFreeResult Result
row
      Connection -> (Connection -> IO ()) -> IO ()
forall a. Connection -> (Connection -> IO a) -> IO a
Pg.withConnection Connection
conn ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
cancelQuery
      String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail String
errorString

    cancelQuery :: Connection -> IO ()
cancelQuery Connection
conn' = do
      Maybe Cancel
cancel <- Connection -> IO (Maybe Cancel)
Pg.getCancel Connection
conn'
      case Maybe Cancel
cancel of
        Maybe Cancel
Nothing -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Just Cancel
cancel' -> do
          Either ByteString ()
res <- Cancel -> IO (Either ByteString ())
Pg.cancel Cancel
cancel'
          case Either ByteString ()
res of
            Right () -> IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Connection -> IO ()
finishQuery Connection
conn')
            Left ByteString
err -> String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail (String
"Could not cancel: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ByteString -> String
forall a. Show a => a -> String
show ByteString
err)

    finishQuery :: Connection -> IO ()
finishQuery Connection
conn' = do
      Maybe Result
nextRow <- Connection -> IO (Maybe Result)
Pg.getResult Connection
conn'
      case Maybe Result
nextRow of
        Maybe Result
Nothing -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Just Result
_ -> Connection -> IO ()
finishQuery Connection
conn'

    gracefulShutdown :: m ()
gracefulShutdown =
      IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> ((Connection -> IO ()) -> IO ())
-> (Connection -> IO ())
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> (Connection -> IO ()) -> IO ()
forall a. Connection -> (Connection -> IO a) -> IO a
Pg.withConnection Connection
conn ((Connection -> IO ()) -> m ()) -> (Connection -> IO ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Connection
conn' ->
      do TransactionStatus
sts <- Connection -> IO TransactionStatus
Pg.transactionStatus Connection
conn'
         case TransactionStatus
sts of
           TransactionStatus
Pg.TransIdle -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
           TransactionStatus
Pg.TransInTrans -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
           TransactionStatus
Pg.TransInError -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
           TransactionStatus
Pg.TransUnknown -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
           TransactionStatus
Pg.TransActive -> Connection -> IO ()
cancelQuery Connection
conn'