{-# LANGUAGE RecordWildCards, ScopedTypeVariables #-}

------------------------------------------------------------------------------
-- |
-- Module:      Database.PostgreSQL.Simple.Transaction
-- Copyright:   (c) 2011-2013 Leon P Smith
--              (c) 2013 Joey Adams
-- License:     BSD3
-- Maintainer:  Leon P Smith <leon@melding-monads.com>
--
------------------------------------------------------------------------------

module Database.PostgreSQL.Simple.Transaction
    (
    -- * Transaction handling
      withTransaction
    , withTransactionLevel
    , withTransactionMode
    , withTransactionModeRetry
    , withTransactionModeRetry'
    , withTransactionSerializable
    , TransactionMode(..)
    , IsolationLevel(..)
    , ReadWriteMode(..)
    , defaultTransactionMode
    , defaultIsolationLevel
    , defaultReadWriteMode
--    , Base.autocommit
    , begin
    , beginLevel
    , beginMode
    , commit
    , rollback

    -- * Savepoint
    , withSavepoint
    , Savepoint
    , newSavepoint
    , releaseSavepoint
    , rollbackToSavepoint
    , rollbackToAndReleaseSavepoint

    -- * Error predicates
    , isSerializationError
    , isNoActiveTransactionError
    , isFailedTransactionError
    ) where

import qualified Control.Exception as E
import qualified Data.ByteString as B
import Database.PostgreSQL.Simple.Internal
import Database.PostgreSQL.Simple.Types
import Database.PostgreSQL.Simple.Errors
import Database.PostgreSQL.Simple.Compat (mask, (<>))


-- | Of the four isolation levels defined by the SQL standard,
-- these are the three levels distinguished by PostgreSQL as of version 9.0.
-- See <https://www.postgresql.org/docs/9.5/static/transaction-iso.html>
-- for more information.   Note that prior to PostgreSQL 9.0, 'RepeatableRead'
-- was equivalent to 'Serializable'.

data IsolationLevel
   = DefaultIsolationLevel  -- ^ the isolation level will be taken from
                            --   PostgreSQL's per-connection
                            --   @default_transaction_isolation@ variable,
                            --   which is initialized according to the
                            --   server's config.  The default configuration
                            --   is 'ReadCommitted'.
   | ReadCommitted
   | RepeatableRead
   | Serializable
     deriving (Int -> IsolationLevel -> ShowS
[IsolationLevel] -> ShowS
IsolationLevel -> String
(Int -> IsolationLevel -> ShowS)
-> (IsolationLevel -> String)
-> ([IsolationLevel] -> ShowS)
-> Show IsolationLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IsolationLevel] -> ShowS
$cshowList :: [IsolationLevel] -> ShowS
show :: IsolationLevel -> String
$cshow :: IsolationLevel -> String
showsPrec :: Int -> IsolationLevel -> ShowS
$cshowsPrec :: Int -> IsolationLevel -> ShowS
Show, IsolationLevel -> IsolationLevel -> Bool
(IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool) -> Eq IsolationLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IsolationLevel -> IsolationLevel -> Bool
$c/= :: IsolationLevel -> IsolationLevel -> Bool
== :: IsolationLevel -> IsolationLevel -> Bool
$c== :: IsolationLevel -> IsolationLevel -> Bool
Eq, Eq IsolationLevel
Eq IsolationLevel
-> (IsolationLevel -> IsolationLevel -> Ordering)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> Bool)
-> (IsolationLevel -> IsolationLevel -> IsolationLevel)
-> (IsolationLevel -> IsolationLevel -> IsolationLevel)
-> Ord IsolationLevel
IsolationLevel -> IsolationLevel -> Bool
IsolationLevel -> IsolationLevel -> Ordering
IsolationLevel -> IsolationLevel -> IsolationLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: IsolationLevel -> IsolationLevel -> IsolationLevel
$cmin :: IsolationLevel -> IsolationLevel -> IsolationLevel
max :: IsolationLevel -> IsolationLevel -> IsolationLevel
$cmax :: IsolationLevel -> IsolationLevel -> IsolationLevel
>= :: IsolationLevel -> IsolationLevel -> Bool
$c>= :: IsolationLevel -> IsolationLevel -> Bool
> :: IsolationLevel -> IsolationLevel -> Bool
$c> :: IsolationLevel -> IsolationLevel -> Bool
<= :: IsolationLevel -> IsolationLevel -> Bool
$c<= :: IsolationLevel -> IsolationLevel -> Bool
< :: IsolationLevel -> IsolationLevel -> Bool
$c< :: IsolationLevel -> IsolationLevel -> Bool
compare :: IsolationLevel -> IsolationLevel -> Ordering
$ccompare :: IsolationLevel -> IsolationLevel -> Ordering
$cp1Ord :: Eq IsolationLevel
Ord, Int -> IsolationLevel
IsolationLevel -> Int
IsolationLevel -> [IsolationLevel]
IsolationLevel -> IsolationLevel
IsolationLevel -> IsolationLevel -> [IsolationLevel]
IsolationLevel
-> IsolationLevel -> IsolationLevel -> [IsolationLevel]
(IsolationLevel -> IsolationLevel)
-> (IsolationLevel -> IsolationLevel)
-> (Int -> IsolationLevel)
-> (IsolationLevel -> Int)
-> (IsolationLevel -> [IsolationLevel])
-> (IsolationLevel -> IsolationLevel -> [IsolationLevel])
-> (IsolationLevel -> IsolationLevel -> [IsolationLevel])
-> (IsolationLevel
    -> IsolationLevel -> IsolationLevel -> [IsolationLevel])
-> Enum IsolationLevel
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: IsolationLevel
-> IsolationLevel -> IsolationLevel -> [IsolationLevel]
$cenumFromThenTo :: IsolationLevel
-> IsolationLevel -> IsolationLevel -> [IsolationLevel]
enumFromTo :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
$cenumFromTo :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
enumFromThen :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
$cenumFromThen :: IsolationLevel -> IsolationLevel -> [IsolationLevel]
enumFrom :: IsolationLevel -> [IsolationLevel]
$cenumFrom :: IsolationLevel -> [IsolationLevel]
fromEnum :: IsolationLevel -> Int
$cfromEnum :: IsolationLevel -> Int
toEnum :: Int -> IsolationLevel
$ctoEnum :: Int -> IsolationLevel
pred :: IsolationLevel -> IsolationLevel
$cpred :: IsolationLevel -> IsolationLevel
succ :: IsolationLevel -> IsolationLevel
$csucc :: IsolationLevel -> IsolationLevel
Enum, IsolationLevel
IsolationLevel -> IsolationLevel -> Bounded IsolationLevel
forall a. a -> a -> Bounded a
maxBound :: IsolationLevel
$cmaxBound :: IsolationLevel
minBound :: IsolationLevel
$cminBound :: IsolationLevel
Bounded)

data ReadWriteMode
   = DefaultReadWriteMode   -- ^ the read-write mode will be taken from
                            --   PostgreSQL's per-connection
                            --   @default_transaction_read_only@ variable,
                            --   which is initialized according to the
                            --   server's config.  The default configuration
                            --   is 'ReadWrite'.
   | ReadWrite
   | ReadOnly
     deriving (Int -> ReadWriteMode -> ShowS
[ReadWriteMode] -> ShowS
ReadWriteMode -> String
(Int -> ReadWriteMode -> ShowS)
-> (ReadWriteMode -> String)
-> ([ReadWriteMode] -> ShowS)
-> Show ReadWriteMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ReadWriteMode] -> ShowS
$cshowList :: [ReadWriteMode] -> ShowS
show :: ReadWriteMode -> String
$cshow :: ReadWriteMode -> String
showsPrec :: Int -> ReadWriteMode -> ShowS
$cshowsPrec :: Int -> ReadWriteMode -> ShowS
Show, ReadWriteMode -> ReadWriteMode -> Bool
(ReadWriteMode -> ReadWriteMode -> Bool)
-> (ReadWriteMode -> ReadWriteMode -> Bool) -> Eq ReadWriteMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ReadWriteMode -> ReadWriteMode -> Bool
$c/= :: ReadWriteMode -> ReadWriteMode -> Bool
== :: ReadWriteMode -> ReadWriteMode -> Bool
$c== :: ReadWriteMode -> ReadWriteMode -> Bool
Eq, Eq ReadWriteMode
Eq ReadWriteMode
-> (ReadWriteMode -> ReadWriteMode -> Ordering)
-> (ReadWriteMode -> ReadWriteMode -> Bool)
-> (ReadWriteMode -> ReadWriteMode -> Bool)
-> (ReadWriteMode -> ReadWriteMode -> Bool)
-> (ReadWriteMode -> ReadWriteMode -> Bool)
-> (ReadWriteMode -> ReadWriteMode -> ReadWriteMode)
-> (ReadWriteMode -> ReadWriteMode -> ReadWriteMode)
-> Ord ReadWriteMode
ReadWriteMode -> ReadWriteMode -> Bool
ReadWriteMode -> ReadWriteMode -> Ordering
ReadWriteMode -> ReadWriteMode -> ReadWriteMode
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode
$cmin :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode
max :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode
$cmax :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode
>= :: ReadWriteMode -> ReadWriteMode -> Bool
$c>= :: ReadWriteMode -> ReadWriteMode -> Bool
> :: ReadWriteMode -> ReadWriteMode -> Bool
$c> :: ReadWriteMode -> ReadWriteMode -> Bool
<= :: ReadWriteMode -> ReadWriteMode -> Bool
$c<= :: ReadWriteMode -> ReadWriteMode -> Bool
< :: ReadWriteMode -> ReadWriteMode -> Bool
$c< :: ReadWriteMode -> ReadWriteMode -> Bool
compare :: ReadWriteMode -> ReadWriteMode -> Ordering
$ccompare :: ReadWriteMode -> ReadWriteMode -> Ordering
$cp1Ord :: Eq ReadWriteMode
Ord, Int -> ReadWriteMode
ReadWriteMode -> Int
ReadWriteMode -> [ReadWriteMode]
ReadWriteMode -> ReadWriteMode
ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
ReadWriteMode -> ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
(ReadWriteMode -> ReadWriteMode)
-> (ReadWriteMode -> ReadWriteMode)
-> (Int -> ReadWriteMode)
-> (ReadWriteMode -> Int)
-> (ReadWriteMode -> [ReadWriteMode])
-> (ReadWriteMode -> ReadWriteMode -> [ReadWriteMode])
-> (ReadWriteMode -> ReadWriteMode -> [ReadWriteMode])
-> (ReadWriteMode
    -> ReadWriteMode -> ReadWriteMode -> [ReadWriteMode])
-> Enum ReadWriteMode
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
$cenumFromThenTo :: ReadWriteMode -> ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
enumFromTo :: ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
$cenumFromTo :: ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
enumFromThen :: ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
$cenumFromThen :: ReadWriteMode -> ReadWriteMode -> [ReadWriteMode]
enumFrom :: ReadWriteMode -> [ReadWriteMode]
$cenumFrom :: ReadWriteMode -> [ReadWriteMode]
fromEnum :: ReadWriteMode -> Int
$cfromEnum :: ReadWriteMode -> Int
toEnum :: Int -> ReadWriteMode
$ctoEnum :: Int -> ReadWriteMode
pred :: ReadWriteMode -> ReadWriteMode
$cpred :: ReadWriteMode -> ReadWriteMode
succ :: ReadWriteMode -> ReadWriteMode
$csucc :: ReadWriteMode -> ReadWriteMode
Enum, ReadWriteMode
ReadWriteMode -> ReadWriteMode -> Bounded ReadWriteMode
forall a. a -> a -> Bounded a
maxBound :: ReadWriteMode
$cmaxBound :: ReadWriteMode
minBound :: ReadWriteMode
$cminBound :: ReadWriteMode
Bounded)

data TransactionMode = TransactionMode {
       TransactionMode -> IsolationLevel
isolationLevel :: !IsolationLevel,
       TransactionMode -> ReadWriteMode
readWriteMode  :: !ReadWriteMode
     } deriving (Int -> TransactionMode -> ShowS
[TransactionMode] -> ShowS
TransactionMode -> String
(Int -> TransactionMode -> ShowS)
-> (TransactionMode -> String)
-> ([TransactionMode] -> ShowS)
-> Show TransactionMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransactionMode] -> ShowS
$cshowList :: [TransactionMode] -> ShowS
show :: TransactionMode -> String
$cshow :: TransactionMode -> String
showsPrec :: Int -> TransactionMode -> ShowS
$cshowsPrec :: Int -> TransactionMode -> ShowS
Show, TransactionMode -> TransactionMode -> Bool
(TransactionMode -> TransactionMode -> Bool)
-> (TransactionMode -> TransactionMode -> Bool)
-> Eq TransactionMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransactionMode -> TransactionMode -> Bool
$c/= :: TransactionMode -> TransactionMode -> Bool
== :: TransactionMode -> TransactionMode -> Bool
$c== :: TransactionMode -> TransactionMode -> Bool
Eq)

defaultTransactionMode :: TransactionMode
defaultTransactionMode :: TransactionMode
defaultTransactionMode =  IsolationLevel -> ReadWriteMode -> TransactionMode
TransactionMode
                            IsolationLevel
defaultIsolationLevel
                            ReadWriteMode
defaultReadWriteMode

defaultIsolationLevel  :: IsolationLevel
defaultIsolationLevel :: IsolationLevel
defaultIsolationLevel  =  IsolationLevel
DefaultIsolationLevel

defaultReadWriteMode   :: ReadWriteMode
defaultReadWriteMode :: ReadWriteMode
defaultReadWriteMode   =  ReadWriteMode
DefaultReadWriteMode

-- | Execute an action inside a SQL transaction.
--
-- This function initiates a transaction with a \"@begin
-- transaction@\" statement, then executes the supplied action.  If
-- the action succeeds, the transaction will be completed with
-- 'Base.commit' before this function returns.
--
-- If the action throws /any/ kind of exception (not just a
-- PostgreSQL-related exception), the transaction will be rolled back using
-- 'rollback', then the exception will be rethrown.
--
-- For nesting transactions, see 'withSavepoint'.
withTransaction :: Connection -> IO a -> IO a
withTransaction :: Connection -> IO a -> IO a
withTransaction = TransactionMode -> Connection -> IO a -> IO a
forall a. TransactionMode -> Connection -> IO a -> IO a
withTransactionMode TransactionMode
defaultTransactionMode

-- | Execute an action inside of a 'Serializable' transaction.  If a
-- serialization failure occurs, roll back the transaction and try again.
-- Be warned that this may execute the IO action multiple times.
--
-- A 'Serializable' transaction creates the illusion that your program has
-- exclusive access to the database.  This means that, even in a concurrent
-- setting, you can perform queries in sequence without having to worry about
-- what might happen between one statement and the next.
--
-- Think of it as STM, but without @retry@.
withTransactionSerializable :: Connection -> IO a -> IO a
withTransactionSerializable :: Connection -> IO a -> IO a
withTransactionSerializable =
    TransactionMode -> (SqlError -> Bool) -> Connection -> IO a -> IO a
forall a.
TransactionMode -> (SqlError -> Bool) -> Connection -> IO a -> IO a
withTransactionModeRetry
        TransactionMode :: IsolationLevel -> ReadWriteMode -> TransactionMode
TransactionMode
        { isolationLevel :: IsolationLevel
isolationLevel = IsolationLevel
Serializable
        , readWriteMode :: ReadWriteMode
readWriteMode  = ReadWriteMode
ReadWrite
        }
        SqlError -> Bool
isSerializationError

-- | Execute an action inside a SQL transaction with a given isolation level.
withTransactionLevel :: IsolationLevel -> Connection -> IO a -> IO a
withTransactionLevel :: IsolationLevel -> Connection -> IO a -> IO a
withTransactionLevel IsolationLevel
lvl
    = TransactionMode -> Connection -> IO a -> IO a
forall a. TransactionMode -> Connection -> IO a -> IO a
withTransactionMode TransactionMode
defaultTransactionMode { isolationLevel :: IsolationLevel
isolationLevel = IsolationLevel
lvl }

-- | Execute an action inside a SQL transaction with a given transaction mode.
withTransactionMode :: TransactionMode -> Connection -> IO a -> IO a
withTransactionMode :: TransactionMode -> Connection -> IO a -> IO a
withTransactionMode TransactionMode
mode Connection
conn IO a
act =
  ((IO a -> IO a) -> IO a) -> IO a
forall a b. ((IO a -> IO a) -> IO b) -> IO b
mask (((IO a -> IO a) -> IO a) -> IO a)
-> ((IO a -> IO a) -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \IO a -> IO a
restore -> do
    TransactionMode -> Connection -> IO ()
beginMode TransactionMode
mode Connection
conn
    a
r <- IO a -> IO a
restore IO a
act IO a -> IO () -> IO a
forall a b. IO a -> IO b -> IO a
`E.onException` Connection -> IO ()
rollback_ Connection
conn
    Connection -> IO ()
commit Connection
conn
    a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r

-- | 'withTransactionModeRetry'' but with the exception type pinned to 'SqlError'.
withTransactionModeRetry :: TransactionMode -> (SqlError -> Bool) -> Connection -> IO a -> IO a
withTransactionModeRetry :: TransactionMode -> (SqlError -> Bool) -> Connection -> IO a -> IO a
withTransactionModeRetry = TransactionMode -> (SqlError -> Bool) -> Connection -> IO a -> IO a
forall a e.
Exception e =>
TransactionMode -> (e -> Bool) -> Connection -> IO a -> IO a
withTransactionModeRetry'

-- | Like 'withTransactionMode', but also takes a custom callback to
-- determine if a transaction should be retried if an exception occurs.
-- If the callback returns 'True', then the transaction will be retried.
-- If the callback returns 'False', or an exception other than an @e@
-- occurs then the transaction will be rolled back and the exception rethrown.
--
-- This is used to implement 'withTransactionSerializable'.
withTransactionModeRetry' :: forall a e. E.Exception e => TransactionMode -> (e -> Bool) -> Connection -> IO a -> IO a
withTransactionModeRetry' :: TransactionMode -> (e -> Bool) -> Connection -> IO a -> IO a
withTransactionModeRetry' TransactionMode
mode e -> Bool
shouldRetry Connection
conn IO a
act =
    ((IO a -> IO a) -> IO a) -> IO a
forall a b. ((IO a -> IO a) -> IO b) -> IO b
mask (((IO a -> IO a) -> IO a) -> IO a)
-> ((IO a -> IO a) -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \IO a -> IO a
restore ->
        IO (Either e a) -> IO a
retryLoop (IO (Either e a) -> IO a) -> IO (Either e a) -> IO a
forall a b. (a -> b) -> a -> b
$ IO a -> IO (Either e a)
forall e a. Exception e => IO a -> IO (Either e a)
E.try (IO a -> IO (Either e a)) -> IO a -> IO (Either e a)
forall a b. (a -> b) -> a -> b
$ do
            a
a <- IO a -> IO a
restore IO a
act IO a -> IO () -> IO a
forall a b. IO a -> IO b -> IO a
`E.onException` Connection -> IO ()
rollback_ Connection
conn
            Connection -> IO ()
commit Connection
conn
            a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
  where
    retryLoop :: IO (Either e a) -> IO a
    retryLoop :: IO (Either e a) -> IO a
retryLoop IO (Either e a)
act' = do
        TransactionMode -> Connection -> IO ()
beginMode TransactionMode
mode Connection
conn
        Either e a
r <- IO (Either e a)
act'
        case Either e a
r of
            Left e
e ->
                case e -> Bool
shouldRetry e
e of
                  Bool
True -> IO (Either e a) -> IO a
retryLoop IO (Either e a)
act'
                  Bool
False -> e -> IO a
forall e a. Exception e => e -> IO a
E.throwIO e
e
            Right a
a ->
                a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

-- | Rollback a transaction.
rollback :: Connection -> IO ()
rollback :: Connection -> IO ()
rollback Connection
conn = Connection -> Query -> IO Int64
execute_ Connection
conn Query
"ROLLBACK" IO Int64 -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Rollback a transaction, ignoring any @IOErrors@
rollback_ :: Connection -> IO ()
rollback_ :: Connection -> IO ()
rollback_ Connection
conn = Connection -> IO ()
rollback Connection
conn IO () -> (IOError -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \(IOError
_ :: IOError) -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Commit a transaction.
commit :: Connection -> IO ()
commit :: Connection -> IO ()
commit Connection
conn = Connection -> Query -> IO Int64
execute_ Connection
conn Query
"COMMIT" IO Int64 -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Begin a transaction.
begin :: Connection -> IO ()
begin :: Connection -> IO ()
begin = TransactionMode -> Connection -> IO ()
beginMode TransactionMode
defaultTransactionMode

-- | Begin a transaction with a given isolation level
beginLevel :: IsolationLevel -> Connection -> IO ()
beginLevel :: IsolationLevel -> Connection -> IO ()
beginLevel IsolationLevel
lvl = TransactionMode -> Connection -> IO ()
beginMode TransactionMode
defaultTransactionMode { isolationLevel :: IsolationLevel
isolationLevel = IsolationLevel
lvl }

-- | Begin a transaction with a given transaction mode
beginMode :: TransactionMode -> Connection -> IO ()
beginMode :: TransactionMode -> Connection -> IO ()
beginMode TransactionMode
mode Connection
conn = do
    Int64
_ <- Connection -> Query -> IO Int64
execute_ Connection
conn (Query -> IO Int64) -> Query -> IO Int64
forall a b. (a -> b) -> a -> b
$! ByteString -> Query
Query ([ByteString] -> ByteString
B.concat [ByteString
"BEGIN", ByteString
isolevel, ByteString
readmode])
    () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where
    isolevel :: ByteString
isolevel = case TransactionMode -> IsolationLevel
isolationLevel TransactionMode
mode of
                 IsolationLevel
DefaultIsolationLevel -> ByteString
""
                 IsolationLevel
ReadCommitted  -> ByteString
" ISOLATION LEVEL READ COMMITTED"
                 IsolationLevel
RepeatableRead -> ByteString
" ISOLATION LEVEL REPEATABLE READ"
                 IsolationLevel
Serializable   -> ByteString
" ISOLATION LEVEL SERIALIZABLE"
    readmode :: ByteString
readmode = case TransactionMode -> ReadWriteMode
readWriteMode TransactionMode
mode of
                 ReadWriteMode
DefaultReadWriteMode -> ByteString
""
                 ReadWriteMode
ReadWrite -> ByteString
" READ WRITE"
                 ReadWriteMode
ReadOnly  -> ByteString
" READ ONLY"

------------------------------------------------------------------------
-- Savepoint

-- | Create a savepoint, and roll back to it if an error occurs.  This may only
-- be used inside of a transaction, and provides a sort of
-- \"nested transaction\".
--
-- See <https://www.postgresql.org/docs/9.5/static/sql-savepoint.html>
withSavepoint :: Connection -> IO a -> IO a
withSavepoint :: Connection -> IO a -> IO a
withSavepoint Connection
conn IO a
body =
  ((IO a -> IO a) -> IO a) -> IO a
forall a b. ((IO a -> IO a) -> IO b) -> IO b
mask (((IO a -> IO a) -> IO a) -> IO a)
-> ((IO a -> IO a) -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \IO a -> IO a
restore -> do
    Savepoint
sp <- Connection -> IO Savepoint
newSavepoint Connection
conn
    a
r <- IO a -> IO a
restore IO a
body IO a -> IO () -> IO a
forall a b. IO a -> IO b -> IO a
`E.onException` Connection -> Savepoint -> IO ()
rollbackToAndReleaseSavepoint Connection
conn Savepoint
sp
    Connection -> Savepoint -> IO ()
releaseSavepoint Connection
conn Savepoint
sp IO () -> (SqlError -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` \SqlError
err ->
        if SqlError -> Bool
isFailedTransactionError SqlError
err
            then Connection -> Savepoint -> IO ()
rollbackToAndReleaseSavepoint Connection
conn Savepoint
sp
            else SqlError -> IO ()
forall e a. Exception e => e -> IO a
E.throwIO SqlError
err
    a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r

-- | Create a new savepoint.  This may only be used inside of a transaction.
newSavepoint :: Connection -> IO Savepoint
newSavepoint :: Connection -> IO Savepoint
newSavepoint Connection
conn = do
    Query
name <- Connection -> IO Query
newTempName Connection
conn
    Int64
_ <- Connection -> Query -> IO Int64
execute_ Connection
conn (Query
"SAVEPOINT " Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
name)
    Savepoint -> IO Savepoint
forall (m :: * -> *) a. Monad m => a -> m a
return (Query -> Savepoint
Savepoint Query
name)

-- | Destroy a savepoint, but retain its effects.
--
-- Warning: this will throw a 'SqlError' matching 'isFailedTransactionError' if
-- the transaction is aborted due to an error.  'commit' would merely warn and
-- roll back.
releaseSavepoint :: Connection -> Savepoint -> IO ()
releaseSavepoint :: Connection -> Savepoint -> IO ()
releaseSavepoint Connection
conn (Savepoint Query
name) =
    Connection -> Query -> IO Int64
execute_ Connection
conn (Query
"RELEASE SAVEPOINT " Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
name) IO Int64 -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Roll back to a savepoint.  This will not release the savepoint.
rollbackToSavepoint :: Connection -> Savepoint -> IO ()
rollbackToSavepoint :: Connection -> Savepoint -> IO ()
rollbackToSavepoint Connection
conn (Savepoint Query
name) =
    Connection -> Query -> IO Int64
execute_ Connection
conn (Query
"ROLLBACK TO SAVEPOINT " Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
name) IO Int64 -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Roll back to a savepoint and release it.  This is like calling
-- 'rollbackToSavepoint' followed by 'releaseSavepoint', but avoids a
-- round trip to the database server.
rollbackToAndReleaseSavepoint :: Connection -> Savepoint -> IO ()
rollbackToAndReleaseSavepoint :: Connection -> Savepoint -> IO ()
rollbackToAndReleaseSavepoint Connection
conn (Savepoint Query
name) =
    Connection -> Query -> IO Int64
execute_ Connection
conn Query
sql IO Int64 -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where
    sql :: Query
sql = Query
"ROLLBACK TO SAVEPOINT " Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
name Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
"; RELEASE SAVEPOINT " Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
name