{-# LANGUAGE OverloadedStrings #-}
module Database.Bolt.Transaction
  ( transact
  ) where

import           Control.Monad                  ( void )
import           Control.Monad.Trans            ( MonadIO(..) )
import           Control.Monad.Except           ( MonadError(..) )

import           Database.Bolt.Connection       ( BoltActionT
                                                , query'
                                                )

-- |Runs a sequence of actions as transaction. All queries would be rolled back
-- in case of any exception inside the block.
transact :: MonadIO m => BoltActionT m a -> BoltActionT m a
transact :: BoltActionT m a -> BoltActionT m a
transact BoltActionT m a
actions = do
    BoltActionT m ()
forall (m :: * -> *). MonadIO m => BoltActionT m ()
txBegin
    let processErrors :: BoltActionT m a -> BoltActionT m a
processErrors = (BoltActionT m a
 -> (BoltError -> BoltActionT m a) -> BoltActionT m a)
-> (BoltError -> BoltActionT m a)
-> BoltActionT m a
-> BoltActionT m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip BoltActionT m a
-> (BoltError -> BoltActionT m a) -> BoltActionT m a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError ((BoltError -> BoltActionT m a)
 -> BoltActionT m a -> BoltActionT m a)
-> (BoltError -> BoltActionT m a)
-> BoltActionT m a
-> BoltActionT m a
forall a b. (a -> b) -> a -> b
$ \BoltError
e -> BoltActionT m ()
forall (m :: * -> *). MonadIO m => BoltActionT m ()
txRollback BoltActionT m () -> BoltActionT m a -> BoltActionT m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> BoltError -> BoltActionT m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError BoltError
e
    a
result <- BoltActionT m a -> BoltActionT m a
forall a. BoltActionT m a -> BoltActionT m a
processErrors BoltActionT m a
actions
    BoltActionT m ()
forall (m :: * -> *). MonadIO m => BoltActionT m ()
txCommit
    a -> BoltActionT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
result

txBegin :: MonadIO m => BoltActionT m ()
txBegin :: BoltActionT m ()
txBegin = BoltActionT m [Record] -> BoltActionT m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (BoltActionT m [Record] -> BoltActionT m ())
-> BoltActionT m [Record] -> BoltActionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> BoltActionT m [Record]
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Text -> BoltActionT m [Record]
query' Text
"BEGIN"

txCommit :: MonadIO m => BoltActionT m ()
txCommit :: BoltActionT m ()
txCommit = BoltActionT m [Record] -> BoltActionT m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (BoltActionT m [Record] -> BoltActionT m ())
-> BoltActionT m [Record] -> BoltActionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> BoltActionT m [Record]
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Text -> BoltActionT m [Record]
query' Text
"COMMIT"

txRollback :: MonadIO m => BoltActionT m ()
txRollback :: BoltActionT m ()
txRollback = BoltActionT m [Record] -> BoltActionT m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (BoltActionT m [Record] -> BoltActionT m ())
-> BoltActionT m [Record] -> BoltActionT m ()
forall a b. (a -> b) -> a -> b
$ Text -> BoltActionT m [Record]
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Text -> BoltActionT m [Record]
query' Text
"ROLLBACK"