{-# LANGUAGE RecordWildCards #-}

module Database.PostgreSQL.Simple.Transaction
    (
    -- * Transaction handling
      withTransaction
    , withTransactionLevel
    , withTransactionMode
    , withTransactionModeRetry
    , withTransactionSerializable
    , isSerializationError
    , TransactionMode(..)
    , IsolationLevel(..)
    , ReadWriteMode(..)
    , defaultTransactionMode
    , defaultIsolationLevel
    , defaultReadWriteMode
--    , Base.autocommit
    , begin
    , beginLevel
    , beginMode
    , commit
    , rollback
    ) where

import Control.Exception hiding (mask)
import qualified Data.ByteString as B
import Database.PostgreSQL.Simple.Internal
import Database.PostgreSQL.Simple.Types
import Database.PostgreSQL.Simple.Compat(mask)


-- | Of the four isolation levels defined by the SQL standard,
-- these are the three levels distinguished by PostgreSQL as of version 9.0.
-- See <http://www.postgresql.org/docs/9.1/static/transaction-iso.html>
-- for more information.   Note that prior to PostgreSQL 9.0, 'RepeatableRead'
-- was equivalent to 'Serializable'.

data IsolationLevel
   = DefaultIsolationLevel  -- ^ the isolation level will be taken from
                            --   PostgreSQL's per-connection
                            --   @default_transaction_isolation@ variable,
                            --   which is initialized according to the
                            --   server's config.  The default configuration
                            --   is 'ReadCommitted'.
   | ReadCommitted
   | RepeatableRead
   | Serializable
     deriving (Show, Eq, Ord, Enum, Bounded)

data ReadWriteMode
   = DefaultReadWriteMode   -- ^ the read-write mode will be taken from
                            --   PostgreSQL's per-connection
                            --   @default_transaction_read_only@ variable,
                            --   which is initialized according to the
                            --   server's config.  The default configuration
                            --   is 'ReadWrite'.
   | ReadWrite
   | ReadOnly
     deriving (Show, Eq, Ord, Enum, Bounded)

data TransactionMode = TransactionMode {
       isolationLevel :: !IsolationLevel,
       readWriteMode  :: !ReadWriteMode
     } deriving (Show, Eq)

defaultTransactionMode :: TransactionMode
defaultTransactionMode =  TransactionMode
                            defaultIsolationLevel
                            defaultReadWriteMode

defaultIsolationLevel  :: IsolationLevel
defaultIsolationLevel  =  DefaultIsolationLevel

defaultReadWriteMode   :: ReadWriteMode
defaultReadWriteMode   =  DefaultReadWriteMode

-- | Execute an action inside a SQL transaction.
--
-- This function initiates a transaction with a \"@begin
-- transaction@\" statement, then executes the supplied action.  If
-- the action succeeds, the transaction will be completed with
-- 'Base.commit' before this function returns.
--
-- If the action throws /any/ kind of exception (not just a
-- PostgreSQL-related exception), the transaction will be rolled back using
-- 'rollback', then the exception will be rethrown.
withTransaction :: Connection -> IO a -> IO a
withTransaction = withTransactionMode defaultTransactionMode

-- | Execute an action inside of a 'Serializable' transaction.  If a
-- serialization failure occurs, roll back the transaction and try again.
-- Be warned that this may execute the IO action multiple times.
--
-- A 'Serializable' transaction creates the illusion that your program has
-- exclusive access to the database.  This means that, even in a concurrent
-- setting, you can perform queries in sequence without having to worry about
-- what might happen between one statement and the next.
--
-- Think of it as STM, but without @retry@.
withTransactionSerializable :: Connection -> IO a -> IO a
withTransactionSerializable =
    withTransactionModeRetry
        TransactionMode
        { isolationLevel = Serializable
        , readWriteMode  = ReadWrite
        }
        isSerializationError


isSerializationError :: SqlError -> Bool
isSerializationError exception =
      case exception of
        SqlError{..} | sqlState == serialization_failure
          -> True
        _ -> False
  where
    -- http://www.postgresql.org/docs/current/static/errcodes-appendix.html
    serialization_failure = "40001"

-- | Execute an action inside a SQL transaction with a given isolation level.
withTransactionLevel :: IsolationLevel -> Connection -> IO a -> IO a
withTransactionLevel lvl
    = withTransactionMode defaultTransactionMode { isolationLevel = lvl }

-- | Execute an action inside a SQL transaction with a given transaction mode.
withTransactionMode :: TransactionMode -> Connection -> IO a -> IO a
withTransactionMode mode conn act =
  mask $ \restore -> do
    beginMode mode conn
    r <- restore act `onException` rollback conn
    commit conn
    return r

-- | Like 'withTransactionMode', but also takes a custom callback to
-- determine if a transaction should be retried if an 'SqlError' occurs.
-- If the callback returns True, then the transaction will be retried.
-- If the callback returns False, or an exception other than an 'SqlError'
-- occurs then the transaction will be rolled back and the exception rethrown.
--
-- This is used to implement 'withTransactionSerializable'.
withTransactionModeRetry :: TransactionMode -> (SqlError -> Bool) -> Connection -> IO a -> IO a
withTransactionModeRetry mode shouldRetry conn act =
    mask $ \restore ->
        retryLoop $ try $ do
            a <- restore act
            commit conn
            return a
  where
    retryLoop :: IO (Either SomeException a) -> IO a
    retryLoop act' = do
        beginMode mode conn
        r <- act'
        case r of
            Left e -> do
                rollback conn
                case fmap shouldRetry (fromException e) of
                  Just True -> retryLoop act'
                  _ -> throwIO e
            Right a ->
                return a

-- | Rollback a transaction.
rollback :: Connection -> IO ()
rollback conn = execute_ conn "ABORT" >> return ()

-- | Commit a transaction.
commit :: Connection -> IO ()
commit conn = execute_ conn "COMMIT" >> return ()

-- | Begin a transaction.
begin :: Connection -> IO ()
begin = beginMode defaultTransactionMode

-- | Begin a transaction with a given isolation level
beginLevel :: IsolationLevel -> Connection -> IO ()
beginLevel lvl = beginMode defaultTransactionMode { isolationLevel = lvl }

-- | Begin a transaction with a given transaction mode
beginMode :: TransactionMode -> Connection -> IO ()
beginMode mode conn = do
    _ <- execute_ conn $! Query (B.concat ["BEGIN", isolevel, readmode])
    return ()
  where
    isolevel = case isolationLevel mode of
                 DefaultIsolationLevel -> ""
                 ReadCommitted  -> " ISOLATION LEVEL READ COMMITTED"
                 RepeatableRead -> " ISOLATION LEVEL REPEATABLE READ"
                 Serializable   -> " ISOLATION LEVEL SERIALIZABLE"
    readmode = case readWriteMode mode of
                 DefaultReadWriteMode -> ""
                 ReadWrite -> " READ WRITE"
                 ReadOnly  -> " READ ONLY"