module Database.PostgreSQL.PQTypes.Transaction (
    Savepoint(..)
  , withSavepoint
  , withTransaction
  , begin
  , commit
  , rollback
  , withTransaction'
  , begin'
  , commit'
  , rollback'
  ) where

import Control.Monad
import Control.Monad.Catch
import Data.Function
import Data.String
import Data.Typeable

import Data.Monoid.Utils
import Database.PostgreSQL.PQTypes.Class
import Database.PostgreSQL.PQTypes.Internal.Exception
import Database.PostgreSQL.PQTypes.SQL.Raw
import Database.PostgreSQL.PQTypes.Transaction.Settings
import Database.PostgreSQL.PQTypes.Utils

-- | Wrapper that represents savepoint name.
newtype Savepoint = Savepoint (RawSQL ())

instance IsString Savepoint where
  fromString :: String -> Savepoint
fromString = RawSQL () -> Savepoint
Savepoint (RawSQL () -> Savepoint)
-> (String -> RawSQL ()) -> String -> Savepoint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> RawSQL ()
forall a. IsString a => String -> a
fromString

-- | Create a savepoint and roll back to it if given monadic action throws.
-- This may only be used if a transaction is already active. Note that it
-- provides something like \"nested transaction\".
--
-- See <http://www.postgresql.org/docs/current/static/sql-savepoint.html>
{-# INLINABLE withSavepoint #-}
withSavepoint :: (MonadDB m, MonadMask m) => Savepoint -> m a -> m a
withSavepoint :: Savepoint -> m a -> m a
withSavepoint (Savepoint RawSQL ()
savepoint) m a
m = ((forall a. m a -> m a) -> m a) -> m a
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m a) -> m a)
-> ((forall a. m a -> m a) -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
  RawSQL () -> m ()
forall sql (m :: * -> *). (IsSQL sql, MonadDB m) => sql -> m ()
runQuery_ (RawSQL () -> m ()) -> RawSQL () -> m ()
forall a b. (a -> b) -> a -> b
$ RawSQL ()
"SAVEPOINT" RawSQL () -> RawSQL () -> RawSQL ()
forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint
  a
res <- m a -> m a
forall a. m a -> m a
restore m a
m m a -> m () -> m a
forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
`onException` m ()
rollbackAndReleaseSavepoint
  RawSQL () -> m ()
forall sql (m :: * -> *). (IsSQL sql, MonadDB m) => sql -> m ()
runQuery_ RawSQL ()
sqlReleaseSavepoint
  a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
res
  where
    sqlReleaseSavepoint :: RawSQL ()
sqlReleaseSavepoint = RawSQL ()
"RELEASE SAVEPOINT" RawSQL () -> RawSQL () -> RawSQL ()
forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint
    rollbackAndReleaseSavepoint :: m ()
rollbackAndReleaseSavepoint = do
      RawSQL () -> m ()
forall sql (m :: * -> *). (IsSQL sql, MonadDB m) => sql -> m ()
runQuery_ (RawSQL () -> m ()) -> RawSQL () -> m ()
forall a b. (a -> b) -> a -> b
$ RawSQL ()
"ROLLBACK TO SAVEPOINT" RawSQL () -> RawSQL () -> RawSQL ()
forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint
      RawSQL () -> m ()
forall sql (m :: * -> *). (IsSQL sql, MonadDB m) => sql -> m ()
runQuery_ RawSQL ()
sqlReleaseSavepoint

----------------------------------------

-- | Same as 'withTransaction'' except that it uses current
-- transaction settings instead of custom ones.  It is worth
-- noting that changing transaction settings inside supplied
-- monadic action won't have any effect  on the final 'commit'
-- / 'rollback' as settings that were in effect during the call
-- to 'withTransaction' will be used.
{-# INLINABLE withTransaction #-}
withTransaction :: (MonadDB m, MonadMask m) => m a -> m a
withTransaction :: m a -> m a
withTransaction m a
m = m TransactionSettings
forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings m TransactionSettings -> (TransactionSettings -> m a) -> m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (TransactionSettings -> m a -> m a)
-> m a -> TransactionSettings -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip TransactionSettings -> m a -> m a
forall (m :: * -> *) a.
(MonadDB m, MonadMask m) =>
TransactionSettings -> m a -> m a
withTransaction' m a
m

-- | Begin transaction using current transaction settings.
{-# INLINABLE begin #-}
begin :: MonadDB m => m ()
begin :: m ()
begin = m TransactionSettings
forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings m TransactionSettings -> (TransactionSettings -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TransactionSettings -> m ()
forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
begin'

-- | Commit active transaction using current transaction settings.
{-# INLINABLE commit #-}
commit :: MonadDB m => m ()
commit :: m ()
commit = m TransactionSettings
forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings m TransactionSettings -> (TransactionSettings -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TransactionSettings -> m ()
forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
commit'

-- | Rollback active transaction using current transaction settings.
{-# INLINABLE rollback #-}
rollback :: MonadDB m => m ()
rollback :: m ()
rollback = m TransactionSettings
forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings m TransactionSettings -> (TransactionSettings -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TransactionSettings -> m ()
forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
rollback'

----------------------------------------

-- | Execute monadic action within a transaction using given transaction
-- settings. Note that it won't work as expected if a transaction is already
-- active (in such case 'withSavepoint' should be used instead).
{-# INLINABLE withTransaction' #-}
withTransaction' :: (MonadDB m, MonadMask m)
                 => TransactionSettings -> m a -> m a
withTransaction' :: TransactionSettings -> m a -> m a
withTransaction' TransactionSettings
ts m a
m = ((forall a. m a -> m a) -> m a) -> m a
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m a) -> m a)
-> ((forall a. m a -> m a) -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> (((Integer -> m a) -> Integer -> m a) -> Integer -> m a
forall a. (a -> a) -> a
`fix` Integer
1) (((Integer -> m a) -> Integer -> m a) -> m a)
-> ((Integer -> m a) -> Integer -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \Integer -> m a
loop Integer
n -> do
  -- Optimization for squashing possible space leaks.
  -- It looks like GHC doesn't like 'catch' and passes
  -- on introducing strictness in some cases.
  let maybeRestart :: m a -> m a
maybeRestart = case TransactionSettings -> Maybe RestartPredicate
tsRestartPredicate TransactionSettings
ts of
        Just RestartPredicate
_  -> (SomeException -> Maybe ()) -> (() -> m a) -> m a -> m a
forall (m :: * -> *) e b a.
(MonadCatch m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust (Integer -> SomeException -> Maybe ()
expred Integer
n) (\()
_ -> Integer -> m a
loop (Integer -> m a) -> Integer -> m a
forall a b. (a -> b) -> a -> b
$ Integer
nInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1)
        Maybe RestartPredicate
Nothing -> m a -> m a
forall a. a -> a
id
  m a -> m a
maybeRestart (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ do
    TransactionSettings -> m ()
forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
begin' TransactionSettings
ts
    a
res <- m a -> m a
forall a. m a -> m a
restore m a
m m a -> m () -> m a
forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
`onException` TransactionSettings -> m ()
forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
rollback' TransactionSettings
ts
    TransactionSettings -> m ()
forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
commit' TransactionSettings
ts
    a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
res
  where
    expred :: Integer -> SomeException -> Maybe ()
    expred :: Integer -> SomeException -> Maybe ()
expred !Integer
n SomeException
e = do
      -- check if the predicate exists
      RestartPredicate e -> Integer -> Bool
f <- TransactionSettings -> Maybe RestartPredicate
tsRestartPredicate TransactionSettings
ts
      -- cast exception to the type expected by the predicate
      e
err <- [Maybe e] -> Maybe e
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [
          -- either cast the exception itself...
          SomeException -> Maybe e
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e
          -- ...or extract it from DBException
        , SomeException -> Maybe DBException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e Maybe DBException -> (DBException -> Maybe e) -> Maybe e
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \DBException{e
sql
dbeError :: ()
dbeQueryContext :: ()
dbeError :: e
dbeQueryContext :: sql
..} -> e -> Maybe e
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast e
dbeError
        ]
      -- check if the predicate allows for the restart
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ e -> Integer -> Bool
f e
err Integer
n

-- | Begin transaction using given transaction settings.
{-# INLINABLE begin' #-}
begin' :: MonadDB m => TransactionSettings -> m ()
begin' :: TransactionSettings -> m ()
begin' TransactionSettings
ts = SQL -> m ()
forall (m :: * -> *). MonadDB m => SQL -> m ()
runSQL_ (SQL -> m ()) -> ([SQL] -> SQL) -> [SQL] -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SQL -> [SQL] -> SQL
forall m. Monoid m => m -> [m] -> m
mintercalate SQL
" " ([SQL] -> m ()) -> [SQL] -> m ()
forall a b. (a -> b) -> a -> b
$ [SQL
"BEGIN", SQL
isolationLevel, SQL
permissions]
  where
    isolationLevel :: SQL
isolationLevel = case TransactionSettings -> IsolationLevel
tsIsolationLevel TransactionSettings
ts of
      IsolationLevel
DefaultLevel   -> SQL
""
      IsolationLevel
ReadCommitted  -> SQL
"ISOLATION LEVEL READ COMMITTED"
      IsolationLevel
RepeatableRead -> SQL
"ISOLATION LEVEL REPEATABLE READ"
      IsolationLevel
Serializable   -> SQL
"ISOLATION LEVEL SERIALIZABLE"
    permissions :: SQL
permissions = case TransactionSettings -> Permissions
tsPermissions TransactionSettings
ts of
      Permissions
DefaultPermissions -> SQL
""
      Permissions
ReadOnly           -> SQL
"READ ONLY"
      Permissions
ReadWrite          -> SQL
"READ WRITE"

-- | Commit active transaction using given transaction settings.
{-# INLINABLE commit' #-}
commit' :: MonadDB m => TransactionSettings -> m ()
commit' :: TransactionSettings -> m ()
commit' TransactionSettings
ts = do
  SQL -> m ()
forall (m :: * -> *). MonadDB m => SQL -> m ()
runSQL_ SQL
"COMMIT"
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TransactionSettings -> Bool
tsAutoTransaction TransactionSettings
ts) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    TransactionSettings -> m ()
forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
begin' TransactionSettings
ts

-- | Rollback active transaction using given transaction settings.
{-# INLINABLE rollback' #-}
rollback' :: MonadDB m => TransactionSettings -> m ()
rollback' :: TransactionSettings -> m ()
rollback' TransactionSettings
ts = do
  SQL -> m ()
forall (m :: * -> *). MonadDB m => SQL -> m ()
runSQL_ SQL
"ROLLBACK"
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TransactionSettings -> Bool
tsAutoTransaction TransactionSettings
ts) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    TransactionSettings -> m ()
forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
begin' TransactionSettings
ts