{-# LANGUAGE RecordWildCards, ScopedTypeVariables #-}
module Database.PostgreSQL.Simple.Transaction
    (
    
      withTransaction
    , withTransactionLevel
    , withTransactionMode
    , withTransactionModeRetry
    , withTransactionSerializable
    , TransactionMode(..)
    , IsolationLevel(..)
    , ReadWriteMode(..)
    , defaultTransactionMode
    , defaultIsolationLevel
    , defaultReadWriteMode
    , begin
    , beginLevel
    , beginMode
    , commit
    , rollback
    
    , withSavepoint
    , Savepoint
    , newSavepoint
    , releaseSavepoint
    , rollbackToSavepoint
    , rollbackToAndReleaseSavepoint
    
    , isSerializationError
    , isNoActiveTransactionError
    , isFailedTransactionError
    ) where
import qualified Control.Exception as E
import qualified Data.ByteString as B
import Database.PostgreSQL.Simple.Internal
import Database.PostgreSQL.Simple.Types
import Database.PostgreSQL.Simple.Errors
import Database.PostgreSQL.Simple.Compat (mask, (<>))
data IsolationLevel
   = DefaultIsolationLevel  
                            
                            
                            
                            
                            
   | ReadCommitted
   | RepeatableRead
   | Serializable
     deriving (Show, Eq, Ord, Enum, Bounded)
data ReadWriteMode
   = DefaultReadWriteMode   
                            
                            
                            
                            
                            
   | ReadWrite
   | ReadOnly
     deriving (Show, Eq, Ord, Enum, Bounded)
data TransactionMode = TransactionMode {
       isolationLevel :: !IsolationLevel,
       readWriteMode  :: !ReadWriteMode
     } deriving (Show, Eq)
defaultTransactionMode :: TransactionMode
defaultTransactionMode =  TransactionMode
                            defaultIsolationLevel
                            defaultReadWriteMode
defaultIsolationLevel  :: IsolationLevel
defaultIsolationLevel  =  DefaultIsolationLevel
defaultReadWriteMode   :: ReadWriteMode
defaultReadWriteMode   =  DefaultReadWriteMode
withTransaction :: Connection -> IO a -> IO a
withTransaction = withTransactionMode defaultTransactionMode
withTransactionSerializable :: Connection -> IO a -> IO a
withTransactionSerializable =
    withTransactionModeRetry
        TransactionMode
        { isolationLevel = Serializable
        , readWriteMode  = ReadWrite
        }
        isSerializationError
withTransactionLevel :: IsolationLevel -> Connection -> IO a -> IO a
withTransactionLevel lvl
    = withTransactionMode defaultTransactionMode { isolationLevel = lvl }
withTransactionMode :: TransactionMode -> Connection -> IO a -> IO a
withTransactionMode mode conn act =
  mask $ \restore -> do
    beginMode mode conn
    r <- restore act `E.onException` rollback_ conn
    commit conn
    return r
withTransactionModeRetry :: TransactionMode -> (SqlError -> Bool) -> Connection -> IO a -> IO a
withTransactionModeRetry mode shouldRetry conn act =
    mask $ \restore ->
        retryLoop $ E.try $ do
            a <- restore act
            commit conn
            return a
  where
    retryLoop :: IO (Either E.SomeException a) -> IO a
    retryLoop act' = do
        beginMode mode conn
        r <- act'
        case r of
            Left e -> do
                rollback_ conn
                case fmap shouldRetry (E.fromException e) of
                  Just True -> retryLoop act'
                  _ -> E.throwIO e
            Right a ->
                return a
rollback :: Connection -> IO ()
rollback conn = execute_ conn "ROLLBACK" >> return ()
rollback_ :: Connection -> IO ()
rollback_ conn = rollback conn `E.catch` \(_ :: IOError) -> return ()
commit :: Connection -> IO ()
commit conn = execute_ conn "COMMIT" >> return ()
begin :: Connection -> IO ()
begin = beginMode defaultTransactionMode
beginLevel :: IsolationLevel -> Connection -> IO ()
beginLevel lvl = beginMode defaultTransactionMode { isolationLevel = lvl }
beginMode :: TransactionMode -> Connection -> IO ()
beginMode mode conn = do
    _ <- execute_ conn $! Query (B.concat ["BEGIN", isolevel, readmode])
    return ()
  where
    isolevel = case isolationLevel mode of
                 DefaultIsolationLevel -> ""
                 ReadCommitted  -> " ISOLATION LEVEL READ COMMITTED"
                 RepeatableRead -> " ISOLATION LEVEL REPEATABLE READ"
                 Serializable   -> " ISOLATION LEVEL SERIALIZABLE"
    readmode = case readWriteMode mode of
                 DefaultReadWriteMode -> ""
                 ReadWrite -> " READ WRITE"
                 ReadOnly  -> " READ ONLY"
withSavepoint :: Connection -> IO a -> IO a
withSavepoint conn body =
  mask $ \restore -> do
    sp <- newSavepoint conn
    r <- restore body `E.onException` rollbackToAndReleaseSavepoint conn sp
    releaseSavepoint conn sp `E.catch` \err ->
        if isFailedTransactionError err
            then rollbackToAndReleaseSavepoint conn sp
            else E.throwIO err
    return r
newSavepoint :: Connection -> IO Savepoint
newSavepoint conn = do
    name <- newTempName conn
    _ <- execute_ conn ("SAVEPOINT " <> name)
    return (Savepoint name)
releaseSavepoint :: Connection -> Savepoint -> IO ()
releaseSavepoint conn (Savepoint name) =
    execute_ conn ("RELEASE SAVEPOINT " <> name) >> return ()
rollbackToSavepoint :: Connection -> Savepoint -> IO ()
rollbackToSavepoint conn (Savepoint name) =
    execute_ conn ("ROLLBACK TO SAVEPOINT " <> name) >> return ()
rollbackToAndReleaseSavepoint :: Connection -> Savepoint -> IO ()
rollbackToAndReleaseSavepoint conn (Savepoint name) =
    execute_ conn sql >> return ()
  where
    sql = "ROLLBACK TO SAVEPOINT " <> name <> "; RELEASE SAVEPOINT " <> name