{- |
Copyright : Flipstone Technology Partners 2023
License   : MIT
Stability : Stable

Low-level functions for executing 'RawSql.SqlExpression' values directly.

@since 1.0.0.0
-}
module Orville.PostgreSQL.Execution.Execute
  ( executeAndDecode
  , executeAndReturnAffectedRows
  , executeVoid
  , executeAndDecodeIO
  , executeAndReturnAffectedRowsIO
  , executeVoidIO
  , AffectedRowsDecodingError
  )
where

import Control.Exception (Exception, throwIO)
import Control.Monad (void)
import Control.Monad.IO.Class (liftIO)
import qualified Database.PostgreSQL.LibPQ as LibPQ

import Orville.PostgreSQL.Execution.QueryType (QueryType)
import qualified Orville.PostgreSQL.Marshall.SqlMarshaller as SqlMarshaller
import Orville.PostgreSQL.Monad (MonadOrville, askOrvilleState, withConnection)
import Orville.PostgreSQL.OrvilleState (OrvilleState, orvilleErrorDetailLevel, orvilleSqlCommenterAttributes, orvilleSqlExecutionCallback)
import Orville.PostgreSQL.Raw.Connection (Connection)
import qualified Orville.PostgreSQL.Raw.RawSql as RawSql
import qualified Orville.PostgreSQL.Raw.SqlCommenter as SqlCommenter
import qualified Orville.PostgreSQL.Raw.SqlValue as SqlValue

{- |
  Executes a SQL query and decodes the result set using the provided
  marshaller. Any SQL Execution callbacks that have been added to the
  'OrvilleState' will be called.

  If the query fails or if any row is unable to be decoded by the marshaller,
  an exception will be raised.

@since 1.0.0.0
-}
executeAndDecode ::
  (MonadOrville m, RawSql.SqlExpression sql) =>
  QueryType ->
  sql ->
  SqlMarshaller.AnnotatedSqlMarshaller writeEntity readEntity ->
  m [readEntity]
executeAndDecode :: forall (m :: * -> *) sql writeEntity readEntity.
(MonadOrville m, SqlExpression sql) =>
QueryType
-> sql
-> AnnotatedSqlMarshaller writeEntity readEntity
-> m [readEntity]
executeAndDecode QueryType
queryType sql
sql AnnotatedSqlMarshaller writeEntity readEntity
marshaller = do
  OrvilleState
orvilleState <- m OrvilleState
forall (m :: * -> *). HasOrvilleState m => m OrvilleState
askOrvilleState
  (Connection -> m [readEntity]) -> m [readEntity]
forall (m :: * -> *) a.
MonadOrville m =>
(Connection -> m a) -> m a
withConnection (IO [readEntity] -> m [readEntity]
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO [readEntity] -> m [readEntity])
-> (Connection -> IO [readEntity]) -> Connection -> m [readEntity]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QueryType
-> sql
-> AnnotatedSqlMarshaller writeEntity readEntity
-> OrvilleState
-> Connection
-> IO [readEntity]
forall sql writeEntity readEntity.
SqlExpression sql =>
QueryType
-> sql
-> AnnotatedSqlMarshaller writeEntity readEntity
-> OrvilleState
-> Connection
-> IO [readEntity]
executeAndDecodeIO QueryType
queryType sql
sql AnnotatedSqlMarshaller writeEntity readEntity
marshaller OrvilleState
orvilleState)

{- |
  Executes a SQL query and returns the number of rows affected by the query.
  Any SQL Execution callbacks that have been added to the 'OrvilleState' will
  be called.

  This function can only be used for the execution of a SELECT, CREATE
  TABLE AS, INSERT, UPDATE, DELETE, MOVE, FETCH, or COPY statement, or an
  EXECUTE of a prepared query that contains an INSERT, UPDATE, or DELETE
  statement. If the query is anything else, an 'AffectedRowsDecodingError'
  wil be raised after the query is executed when the result is read.

  If the query fails, an exception will be raised.

@since 1.0.0.0
-}
executeAndReturnAffectedRows ::
  (MonadOrville m, RawSql.SqlExpression sql) =>
  QueryType ->
  sql ->
  m Int
executeAndReturnAffectedRows :: forall (m :: * -> *) sql.
(MonadOrville m, SqlExpression sql) =>
QueryType -> sql -> m Int
executeAndReturnAffectedRows QueryType
queryType sql
sql = do
  OrvilleState
orvilleState <- m OrvilleState
forall (m :: * -> *). HasOrvilleState m => m OrvilleState
askOrvilleState
  (Connection -> m Int) -> m Int
forall (m :: * -> *) a.
MonadOrville m =>
(Connection -> m a) -> m a
withConnection (IO Int -> m Int
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> m Int) -> (Connection -> IO Int) -> Connection -> m Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QueryType -> sql -> OrvilleState -> Connection -> IO Int
forall sql.
SqlExpression sql =>
QueryType -> sql -> OrvilleState -> Connection -> IO Int
executeAndReturnAffectedRowsIO QueryType
queryType sql
sql OrvilleState
orvilleState)

{- |
  Executes a SQL query and ignores the result. Any SQL Execution callbacks
  that have been added to the 'OrvilleState' will be called.

  If the query fails an exception will be raised.

@since 1.0.0.0
-}
executeVoid ::
  (MonadOrville m, RawSql.SqlExpression sql) =>
  QueryType ->
  sql ->
  m ()
executeVoid :: forall (m :: * -> *) sql.
(MonadOrville m, SqlExpression sql) =>
QueryType -> sql -> m ()
executeVoid QueryType
queryType sql
sql = do
  OrvilleState
orvilleState <- m OrvilleState
forall (m :: * -> *). HasOrvilleState m => m OrvilleState
askOrvilleState
  (Connection -> m ()) -> m ()
forall (m :: * -> *) a.
MonadOrville m =>
(Connection -> m a) -> m a
withConnection (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (Connection -> IO ()) -> Connection -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QueryType -> sql -> OrvilleState -> Connection -> IO ()
forall sql.
SqlExpression sql =>
QueryType -> sql -> OrvilleState -> Connection -> IO ()
executeVoidIO QueryType
queryType sql
sql OrvilleState
orvilleState)

{- |
  Executes a SQL query and decodes the result set using the provided
  marshaller. Any SQL Execution callbacks that have been added to the
  'OrvilleState' will be called.

  If the query fails or if any row is unable to be decoded by the marshaller,
  an exception will be raised.

@since 1.0.0.0
-}
executeAndDecodeIO ::
  RawSql.SqlExpression sql =>
  QueryType ->
  sql ->
  SqlMarshaller.AnnotatedSqlMarshaller writeEntity readEntity ->
  OrvilleState ->
  Connection ->
  IO [readEntity]
executeAndDecodeIO :: forall sql writeEntity readEntity.
SqlExpression sql =>
QueryType
-> sql
-> AnnotatedSqlMarshaller writeEntity readEntity
-> OrvilleState
-> Connection
-> IO [readEntity]
executeAndDecodeIO QueryType
queryType sql
sql AnnotatedSqlMarshaller writeEntity readEntity
marshaller OrvilleState
orvilleState Connection
conn = do
  Result
libPqResult <- QueryType -> sql -> OrvilleState -> Connection -> IO Result
forall sql.
SqlExpression sql =>
QueryType -> sql -> OrvilleState -> Connection -> IO Result
executeWithCallbacksIO QueryType
queryType sql
sql OrvilleState
orvilleState Connection
conn

  let
    errorDetailLevel :: ErrorDetailLevel
errorDetailLevel = OrvilleState -> ErrorDetailLevel
orvilleErrorDetailLevel OrvilleState
orvilleState

  Either MarshallError [readEntity]
decodingResult <-
    ErrorDetailLevel
-> AnnotatedSqlMarshaller writeEntity readEntity
-> Result
-> IO (Either MarshallError [readEntity])
forall result writeEntity readEntity.
ExecutionResult result =>
ErrorDetailLevel
-> AnnotatedSqlMarshaller writeEntity readEntity
-> result
-> IO (Either MarshallError [readEntity])
SqlMarshaller.marshallResultFromSql
      ErrorDetailLevel
errorDetailLevel
      AnnotatedSqlMarshaller writeEntity readEntity
marshaller
      Result
libPqResult

  case Either MarshallError [readEntity]
decodingResult of
    Left MarshallError
err ->
      MarshallError -> IO [readEntity]
forall e a. Exception e => e -> IO a
throwIO MarshallError
err
    Right [readEntity]
entities ->
      [readEntity] -> IO [readEntity]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [readEntity]
entities

{- |
  Executes a SQL query and returns the number of rows affected by the query.
  Any SQL Execution callbacks that have been added to the 'OrvilleState' will
  be called.

  This function can only be used for the execution of a SELECT, CREATE
  TABLE AS, INSERT, UPDATE, DELETE, MOVE, FETCH, or COPY statement, or an
  EXECUTE of a prepared query that contains an INSERT, UPDATE, or DELETE
  statement. If the query is anything else, an 'AffectedRowsDecodingError'
  wil be raised after the query is executed when the result is read.

  If the query fails, an exception will be raised.

@since 1.0.0.0
-}
executeAndReturnAffectedRowsIO ::
  RawSql.SqlExpression sql =>
  QueryType ->
  sql ->
  OrvilleState ->
  Connection ->
  IO Int
executeAndReturnAffectedRowsIO :: forall sql.
SqlExpression sql =>
QueryType -> sql -> OrvilleState -> Connection -> IO Int
executeAndReturnAffectedRowsIO QueryType
queryType sql
sql OrvilleState
orvilleState Connection
conn = do
  Result
libPqResult <- QueryType -> sql -> OrvilleState -> Connection -> IO Result
forall sql.
SqlExpression sql =>
QueryType -> sql -> OrvilleState -> Connection -> IO Result
executeWithCallbacksIO QueryType
queryType sql
sql OrvilleState
orvilleState Connection
conn
  Maybe ByteString
mbTupleCount <- Result -> IO (Maybe ByteString)
LibPQ.cmdTuples Result
libPqResult
  case Maybe ByteString
mbTupleCount of
    Maybe ByteString
Nothing ->
      AffectedRowsDecodingError -> IO Int
forall e a. Exception e => e -> IO a
throwIO
        (AffectedRowsDecodingError -> IO Int)
-> (String -> AffectedRowsDecodingError) -> String -> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> AffectedRowsDecodingError
AffectedRowsDecodingError
        (String -> IO Int) -> String -> IO Int
forall a b. (a -> b) -> a -> b
$ String
"No affected row count was produced by the query"
    Just ByteString
bs ->
      case SqlValue -> Either String Int
SqlValue.toInt (ByteString -> SqlValue
SqlValue.fromRawBytes ByteString
bs) of
        Left String
err ->
          AffectedRowsDecodingError -> IO Int
forall e a. Exception e => e -> IO a
throwIO (AffectedRowsDecodingError -> IO Int)
-> (String -> AffectedRowsDecodingError) -> String -> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> AffectedRowsDecodingError
AffectedRowsDecodingError (String -> IO Int) -> String -> IO Int
forall a b. (a -> b) -> a -> b
$ String
err
        Right Int
n ->
          Int -> IO Int
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
n

{- |
  Executes a SQL query and ignores the result. Any SQL Execution callbacks
  that have been added to the 'OrvilleState' will be called.

  If the query fails, an exception will be raised.

@since 1.0.0.0
-}
executeVoidIO ::
  RawSql.SqlExpression sql =>
  QueryType ->
  sql ->
  OrvilleState ->
  Connection ->
  IO ()
executeVoidIO :: forall sql.
SqlExpression sql =>
QueryType -> sql -> OrvilleState -> Connection -> IO ()
executeVoidIO QueryType
queryType sql
sql OrvilleState
orvilleState =
  IO Result -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Result -> IO ())
-> (Connection -> IO Result) -> Connection -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QueryType -> sql -> OrvilleState -> Connection -> IO Result
forall sql.
SqlExpression sql =>
QueryType -> sql -> OrvilleState -> Connection -> IO Result
executeWithCallbacksIO QueryType
queryType sql
sql OrvilleState
orvilleState

executeWithCallbacksIO ::
  RawSql.SqlExpression sql =>
  QueryType ->
  sql ->
  OrvilleState ->
  Connection ->
  IO LibPQ.Result
executeWithCallbacksIO :: forall sql.
SqlExpression sql =>
QueryType -> sql -> OrvilleState -> Connection -> IO Result
executeWithCallbacksIO QueryType
queryType sql
sql OrvilleState
orvilleState Connection
conn =
  let
    rawSql :: RawSql
rawSql =
      case OrvilleState -> Maybe SqlCommenterAttributes
orvilleSqlCommenterAttributes OrvilleState
orvilleState of
        Maybe SqlCommenterAttributes
Nothing ->
          sql -> RawSql
forall a. SqlExpression a => a -> RawSql
RawSql.toRawSql sql
sql
        Just SqlCommenterAttributes
sqlCommenterAttributes ->
          SqlCommenterAttributes -> RawSql -> RawSql
forall a. SqlExpression a => SqlCommenterAttributes -> a -> a
SqlCommenter.addSqlCommenterAttributes SqlCommenterAttributes
sqlCommenterAttributes (RawSql -> RawSql) -> RawSql -> RawSql
forall a b. (a -> b) -> a -> b
$ sql -> RawSql
forall a. SqlExpression a => a -> RawSql
RawSql.toRawSql sql
sql
  in
    OrvilleState -> forall a. QueryType -> RawSql -> IO a -> IO a
orvilleSqlExecutionCallback
      OrvilleState
orvilleState
      QueryType
queryType
      RawSql
rawSql
      (Connection -> RawSql -> IO Result
forall sql. SqlExpression sql => Connection -> sql -> IO Result
RawSql.execute Connection
conn RawSql
rawSql)

{- |
  Thrown by 'executeAndReturnAffectedRows' and 'executeAndReturnAffectedRowsIO'
  if the number of affected rows cannot be successfully read from the LibPQ
  command result.

@since 1.0.0.0
-}
newtype AffectedRowsDecodingError
  = AffectedRowsDecodingError String
  deriving
    ( -- | @since 1.0.0.0
      Int -> AffectedRowsDecodingError -> ShowS
[AffectedRowsDecodingError] -> ShowS
AffectedRowsDecodingError -> String
(Int -> AffectedRowsDecodingError -> ShowS)
-> (AffectedRowsDecodingError -> String)
-> ([AffectedRowsDecodingError] -> ShowS)
-> Show AffectedRowsDecodingError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AffectedRowsDecodingError -> ShowS
showsPrec :: Int -> AffectedRowsDecodingError -> ShowS
$cshow :: AffectedRowsDecodingError -> String
show :: AffectedRowsDecodingError -> String
$cshowList :: [AffectedRowsDecodingError] -> ShowS
showList :: [AffectedRowsDecodingError] -> ShowS
Show
    )

-- | @since 1.0.0.0
instance Exception AffectedRowsDecodingError