module Database.PostgreSQL.PQTypes.Transaction (
    Savepoint(..)
  , withSavepoint
  , withTransaction
  , begin
  , commit
  , rollback
  , withTransaction'
  , begin'
  , commit'
  , rollback'
  ) where

import Control.Monad
import Control.Monad.Catch
import Data.Function
import Data.String
import Data.Typeable

import Data.Monoid.Utils
import Database.PostgreSQL.PQTypes.Class
import Database.PostgreSQL.PQTypes.Internal.Exception
import Database.PostgreSQL.PQTypes.SQL.Raw
import Database.PostgreSQL.PQTypes.Transaction.Settings
import Database.PostgreSQL.PQTypes.Utils

-- | Wrapper that represents savepoint name.
newtype Savepoint = Savepoint (RawSQL ())

instance IsString Savepoint where
  fromString :: String -> Savepoint
fromString = RawSQL () -> Savepoint
Savepoint forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsString a => String -> a
fromString

-- | Create a savepoint and roll back to it if given monadic action throws.
-- This may only be used if a transaction is already active. Note that it
-- provides something like \"nested transaction\".
--
-- See <http://www.postgresql.org/docs/current/static/sql-savepoint.html>
withSavepoint :: (MonadDB m, MonadMask m) => Savepoint -> m a -> m a
withSavepoint :: forall (m :: * -> *) a.
(MonadDB m, MonadMask m) =>
Savepoint -> m a -> m a
withSavepoint (Savepoint RawSQL ()
savepoint) m a
m = forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
generalBracket
  (forall sql (m :: * -> *). (IsSQL sql, MonadDB m) => sql -> m ()
runQuery_ forall a b. (a -> b) -> a -> b
$ RawSQL ()
"SAVEPOINT" forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint)
  (\() -> \case
      ExitCaseSuccess a
_ -> forall sql (m :: * -> *). (IsSQL sql, MonadDB m) => sql -> m ()
runQuery_ RawSQL ()
sqlReleaseSavepoint
      ExitCase a
_                 -> m ()
rollbackAndReleaseSavepoint
  )
  (\() -> m a
m)
  where
    sqlReleaseSavepoint :: RawSQL ()
sqlReleaseSavepoint = RawSQL ()
"RELEASE SAVEPOINT" forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint
    rollbackAndReleaseSavepoint :: m ()
rollbackAndReleaseSavepoint = do
      forall sql (m :: * -> *). (IsSQL sql, MonadDB m) => sql -> m ()
runQuery_ forall a b. (a -> b) -> a -> b
$ RawSQL ()
"ROLLBACK TO SAVEPOINT" forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint
      forall sql (m :: * -> *). (IsSQL sql, MonadDB m) => sql -> m ()
runQuery_ RawSQL ()
sqlReleaseSavepoint

----------------------------------------

-- | Same as 'withTransaction'' except that it uses current
-- transaction settings instead of custom ones.  It is worth
-- noting that changing transaction settings inside supplied
-- monadic action won't have any effect  on the final 'commit'
-- / 'rollback' as settings that were in effect during the call
-- to 'withTransaction' will be used.
withTransaction :: (MonadDB m, MonadMask m) => m a -> m a
withTransaction :: forall (m :: * -> *) a. (MonadDB m, MonadMask m) => m a -> m a
withTransaction m a
m = forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) a.
(MonadDB m, MonadMask m) =>
TransactionSettings -> m a -> m a
withTransaction' m a
m

-- | Begin transaction using current transaction settings.
begin :: MonadDB m => m ()
begin :: forall (m :: * -> *). MonadDB m => m ()
begin = forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
begin'

-- | Commit active transaction using current transaction settings.
commit :: MonadDB m => m ()
commit :: forall (m :: * -> *). MonadDB m => m ()
commit = forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
commit'

-- | Rollback active transaction using current transaction settings.
rollback :: MonadDB m => m ()
rollback :: forall (m :: * -> *). MonadDB m => m ()
rollback = forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
rollback'

----------------------------------------

-- | Execute monadic action within a transaction using given transaction
-- settings. Note that it won't work as expected if a transaction is already
-- active (in such case 'withSavepoint' should be used instead).
withTransaction' :: (MonadDB m, MonadMask m)
                 => TransactionSettings -> m a -> m a
withTransaction' :: forall (m :: * -> *) a.
(MonadDB m, MonadMask m) =>
TransactionSettings -> m a -> m a
withTransaction' TransactionSettings
ts m a
m = (forall a. (a -> a) -> a
`fix` Integer
1) forall a b. (a -> b) -> a -> b
$ \Integer -> m a
loop Integer
n -> do
  -- Optimization for squashing possible space leaks.
  -- It looks like GHC doesn't like 'catch' and passes
  -- on introducing strictness in some cases.
  let maybeRestart :: m a -> m a
maybeRestart = case TransactionSettings -> Maybe RestartPredicate
tsRestartPredicate TransactionSettings
ts of
        Just RestartPredicate
_  -> forall (m :: * -> *) e b a.
(MonadCatch m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust (Integer -> SomeException -> Maybe ()
expred Integer
n) (\()
_ -> Integer -> m a
loop forall a b. (a -> b) -> a -> b
$ Integer
nforall a. Num a => a -> a -> a
+Integer
1)
        Maybe RestartPredicate
Nothing -> forall a. a -> a
id
  m a -> m a
maybeRestart forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
generalBracket
    (forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
begin' TransactionSettings
ts)
    (\() -> \case
        ExitCaseSuccess a
_ -> forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
commit' TransactionSettings
ts
        ExitCase a
_                 -> forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
rollback' TransactionSettings
ts
    )
    (\() -> m a
m)
  where
    expred :: Integer -> SomeException -> Maybe ()
    expred :: Integer -> SomeException -> Maybe ()
expred !Integer
n SomeException
e = do
      -- check if the predicate exists
      RestartPredicate e -> Integer -> Bool
f <- TransactionSettings -> Maybe RestartPredicate
tsRestartPredicate TransactionSettings
ts
      -- cast exception to the type expected by the predicate
      e
err <- forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [
          -- either cast the exception itself...
          forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e
          -- ...or extract it from DBException
        , forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \DBException{e
sql
dbeError :: ()
dbeQueryContext :: ()
dbeError :: e
dbeQueryContext :: sql
..} -> forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast e
dbeError
        ]
      -- check if the predicate allows for the restart
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ e -> Integer -> Bool
f e
err Integer
n

-- | Begin transaction using given transaction settings.
begin' :: MonadDB m => TransactionSettings -> m ()
begin' :: forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
begin' TransactionSettings
ts = forall (m :: * -> *). MonadDB m => SQL -> m ()
runSQL_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall m. Monoid m => m -> [m] -> m
mintercalate SQL
" " forall a b. (a -> b) -> a -> b
$ [SQL
"BEGIN", SQL
isolationLevel, SQL
permissions]
  where
    isolationLevel :: SQL
isolationLevel = case TransactionSettings -> IsolationLevel
tsIsolationLevel TransactionSettings
ts of
      IsolationLevel
DefaultLevel   -> SQL
""
      IsolationLevel
ReadCommitted  -> SQL
"ISOLATION LEVEL READ COMMITTED"
      IsolationLevel
RepeatableRead -> SQL
"ISOLATION LEVEL REPEATABLE READ"
      IsolationLevel
Serializable   -> SQL
"ISOLATION LEVEL SERIALIZABLE"
    permissions :: SQL
permissions = case TransactionSettings -> Permissions
tsPermissions TransactionSettings
ts of
      Permissions
DefaultPermissions -> SQL
""
      Permissions
ReadOnly           -> SQL
"READ ONLY"
      Permissions
ReadWrite          -> SQL
"READ WRITE"

-- | Commit active transaction using given transaction settings.
commit' :: MonadDB m => TransactionSettings -> m ()
commit' :: forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
commit' TransactionSettings
ts = do
  forall (m :: * -> *). MonadDB m => SQL -> m ()
runSQL_ SQL
"COMMIT"
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TransactionSettings -> Bool
tsAutoTransaction TransactionSettings
ts) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
begin' TransactionSettings
ts

-- | Rollback active transaction using given transaction settings.
rollback' :: MonadDB m => TransactionSettings -> m ()
rollback' :: forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
rollback' TransactionSettings
ts = do
  forall (m :: * -> *). MonadDB m => SQL -> m ()
runSQL_ SQL
"ROLLBACK"
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TransactionSettings -> Bool
tsAutoTransaction TransactionSettings
ts) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *). MonadDB m => TransactionSettings -> m ()
begin' TransactionSettings
ts