{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Database.Persist.Monad.Internal.SqlTransaction (
  SqlTransaction (..),
  SqlTransactionEnv (..),
  runSqlTransaction,
  catchSqlTransaction,
) where

import Control.Monad.Fix (MonadFix)
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.IO.Unlift (MonadUnliftIO (..))
import Control.Monad.Reader (ReaderT, ask, withReaderT)
import Database.Persist.Sql (SqlBackend, runSqlConn)
import qualified GHC.TypeLits as GHC
import UnliftIO.Exception (Exception, SomeException, catchJust, fromException)

import Control.Monad.IO.Rerunnable (MonadRerunnableIO)
import Control.Monad.Trans.Rerunnable (MonadRerunnableTrans)
import Database.Persist.Monad.Class
import Database.Persist.Monad.SqlQueryRep

{-| The monad that tracks transaction state.

 Conceptually equivalent to 'Database.Persist.Sql.SqlPersistT', but restricts
 IO operations, for two reasons:

   1. Forking a thread that uses the same 'SqlBackend' as the current thread
      causes Bad Things to happen.
   2. Transactions may need to be retried, in which case IO operations in
      a transaction are required to be rerunnable.

 You shouldn't need to explicitly use this type; your functions should only
 declare the 'MonadSqlQuery' constraint.
-}
newtype SqlTransaction m a = UnsafeSqlTransaction
  { forall (m :: * -> *) a.
SqlTransaction m a -> ReaderT SqlTransactionEnv m a
unSqlTransaction :: ReaderT SqlTransactionEnv m a
  }
  deriving ((forall a b. (a -> b) -> SqlTransaction m a -> SqlTransaction m b)
-> (forall a b. a -> SqlTransaction m b -> SqlTransaction m a)
-> Functor (SqlTransaction m)
forall a b. a -> SqlTransaction m b -> SqlTransaction m a
forall a b. (a -> b) -> SqlTransaction m a -> SqlTransaction m b
forall (m :: * -> *) a b.
Functor m =>
a -> SqlTransaction m b -> SqlTransaction m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SqlTransaction m a -> SqlTransaction m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> SqlTransaction m b -> SqlTransaction m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> SqlTransaction m b -> SqlTransaction m a
fmap :: forall a b. (a -> b) -> SqlTransaction m a -> SqlTransaction m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SqlTransaction m a -> SqlTransaction m b
Functor, Functor (SqlTransaction m)
Functor (SqlTransaction m)
-> (forall a. a -> SqlTransaction m a)
-> (forall a b.
    SqlTransaction m (a -> b)
    -> SqlTransaction m a -> SqlTransaction m b)
-> (forall a b c.
    (a -> b -> c)
    -> SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m c)
-> (forall a b.
    SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b)
-> (forall a b.
    SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m a)
-> Applicative (SqlTransaction m)
forall a. a -> SqlTransaction m a
forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m a
forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
forall a b.
SqlTransaction m (a -> b)
-> SqlTransaction m a -> SqlTransaction m b
forall a b c.
(a -> b -> c)
-> SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall {m :: * -> *}. Applicative m => Functor (SqlTransaction m)
forall (m :: * -> *) a. Applicative m => a -> SqlTransaction m a
forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m a
forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m (a -> b)
-> SqlTransaction m a -> SqlTransaction m b
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m c
<* :: forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m a
$c<* :: forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m a
*> :: forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
$c*> :: forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
liftA2 :: forall a b c.
(a -> b -> c)
-> SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m c
$cliftA2 :: forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m c
<*> :: forall a b.
SqlTransaction m (a -> b)
-> SqlTransaction m a -> SqlTransaction m b
$c<*> :: forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m (a -> b)
-> SqlTransaction m a -> SqlTransaction m b
pure :: forall a. a -> SqlTransaction m a
$cpure :: forall (m :: * -> *) a. Applicative m => a -> SqlTransaction m a
Applicative, Applicative (SqlTransaction m)
Applicative (SqlTransaction m)
-> (forall a b.
    SqlTransaction m a
    -> (a -> SqlTransaction m b) -> SqlTransaction m b)
-> (forall a b.
    SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b)
-> (forall a. a -> SqlTransaction m a)
-> Monad (SqlTransaction m)
forall a. a -> SqlTransaction m a
forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
forall a b.
SqlTransaction m a
-> (a -> SqlTransaction m b) -> SqlTransaction m b
forall {m :: * -> *}. Monad m => Applicative (SqlTransaction m)
forall (m :: * -> *) a. Monad m => a -> SqlTransaction m a
forall (m :: * -> *) a b.
Monad m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
forall (m :: * -> *) a b.
Monad m =>
SqlTransaction m a
-> (a -> SqlTransaction m b) -> SqlTransaction m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> SqlTransaction m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> SqlTransaction m a
>> :: forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
>>= :: forall a b.
SqlTransaction m a
-> (a -> SqlTransaction m b) -> SqlTransaction m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SqlTransaction m a
-> (a -> SqlTransaction m b) -> SqlTransaction m b
Monad, Monad (SqlTransaction m)
Monad (SqlTransaction m)
-> (forall a. (a -> SqlTransaction m a) -> SqlTransaction m a)
-> MonadFix (SqlTransaction m)
forall a. (a -> SqlTransaction m a) -> SqlTransaction m a
forall (m :: * -> *).
Monad m -> (forall a. (a -> m a) -> m a) -> MonadFix m
forall {m :: * -> *}. MonadFix m => Monad (SqlTransaction m)
forall (m :: * -> *) a.
MonadFix m =>
(a -> SqlTransaction m a) -> SqlTransaction m a
mfix :: forall a. (a -> SqlTransaction m a) -> SqlTransaction m a
$cmfix :: forall (m :: * -> *) a.
MonadFix m =>
(a -> SqlTransaction m a) -> SqlTransaction m a
MonadFix, Monad (SqlTransaction m)
Monad (SqlTransaction m)
-> (forall a. IO a -> SqlTransaction m a)
-> MonadRerunnableIO (SqlTransaction m)
forall a. IO a -> SqlTransaction m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadRerunnableIO m
forall {m :: * -> *}.
MonadRerunnableIO m =>
Monad (SqlTransaction m)
forall (m :: * -> *) a.
MonadRerunnableIO m =>
IO a -> SqlTransaction m a
rerunnableIO :: forall a. IO a -> SqlTransaction m a
$crerunnableIO :: forall (m :: * -> *) a.
MonadRerunnableIO m =>
IO a -> SqlTransaction m a
MonadRerunnableIO, (forall (m :: * -> *) a. Monad m => m a -> SqlTransaction m a)
-> MonadRerunnableTrans SqlTransaction
forall (m :: * -> *) a. Monad m => m a -> SqlTransaction m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a)
-> MonadRerunnableTrans t
rerunnableLift :: forall (m :: * -> *) a. Monad m => m a -> SqlTransaction m a
$crerunnableLift :: forall (m :: * -> *) a. Monad m => m a -> SqlTransaction m a
MonadRerunnableTrans)

instance
  ( GHC.TypeError ( 'GHC.Text "Cannot run arbitrary IO actions within a transaction. If the IO action is rerunnable, use rerunnableIO")
  , Monad m
  ) =>
  MonadIO (SqlTransaction m)
  where
  liftIO :: forall a. IO a -> SqlTransaction m a
liftIO = IO a -> SqlTransaction m a
forall a. HasCallStack => a
undefined

instance (MonadSqlQuery m, MonadUnliftIO m) => MonadSqlQuery (SqlTransaction m) where
  type TransactionM (SqlTransaction m) = TransactionM m

  runQueryRep :: forall record a.
Typeable record =>
SqlQueryRep record a -> SqlTransaction m a
runQueryRep = ReaderT SqlTransactionEnv m a -> SqlTransaction m a
forall (m :: * -> *) a.
ReaderT SqlTransactionEnv m a -> SqlTransaction m a
UnsafeSqlTransaction (ReaderT SqlTransactionEnv m a -> SqlTransaction m a)
-> (SqlQueryRep record a -> ReaderT SqlTransactionEnv m a)
-> SqlQueryRep record a
-> SqlTransaction m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SqlTransactionEnv -> SqlBackend)
-> ReaderT SqlBackend m a -> ReaderT SqlTransactionEnv m a
forall r' r (m :: * -> *) a.
(r' -> r) -> ReaderT r m a -> ReaderT r' m a
withReaderT SqlTransactionEnv -> SqlBackend
sqlBackend (ReaderT SqlBackend m a -> ReaderT SqlTransactionEnv m a)
-> (SqlQueryRep record a -> ReaderT SqlBackend m a)
-> SqlQueryRep record a
-> ReaderT SqlTransactionEnv m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlQueryRep record a -> ReaderT SqlBackend m a
forall (m :: * -> *) record a.
MonadUnliftIO m =>
SqlQueryRep record a -> SqlPersistT m a
runSqlQueryRep

  -- Delegate to 'm', since 'm' is in charge of starting/stopping transactions.
  -- 'SqlTransaction' is ONLY in charge of executing queries.
  withTransaction :: forall a. TransactionM (SqlTransaction m) a -> SqlTransaction m a
withTransaction = ReaderT SqlTransactionEnv m a -> SqlTransaction m a
forall (m :: * -> *) a.
ReaderT SqlTransactionEnv m a -> SqlTransaction m a
UnsafeSqlTransaction (ReaderT SqlTransactionEnv m a -> SqlTransaction m a)
-> (TransactionM m a -> ReaderT SqlTransactionEnv m a)
-> TransactionM m a
-> SqlTransaction m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TransactionM m a -> ReaderT SqlTransactionEnv m a
forall (m :: * -> *) a. MonadSqlQuery m => TransactionM m a -> m a
withTransaction

data SqlTransactionEnv = SqlTransactionEnv
  { SqlTransactionEnv -> SqlBackend
sqlBackend :: SqlBackend
  , SqlTransactionEnv -> SomeException -> Bool
ignoreCatch :: SomeException -> Bool
  }

runSqlTransaction ::
  MonadUnliftIO m =>
  SqlTransactionEnv ->
  SqlTransaction m a ->
  m a
runSqlTransaction :: forall (m :: * -> *) a.
MonadUnliftIO m =>
SqlTransactionEnv -> SqlTransaction m a -> m a
runSqlTransaction SqlTransactionEnv
opts =
  (ReaderT SqlBackend m a -> SqlBackend -> m a
forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> backend -> m a
`runSqlConn` SqlTransactionEnv -> SqlBackend
sqlBackend SqlTransactionEnv
opts)
    (ReaderT SqlBackend m a -> m a)
-> (SqlTransaction m a -> ReaderT SqlBackend m a)
-> SqlTransaction m a
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SqlBackend -> SqlTransactionEnv)
-> ReaderT SqlTransactionEnv m a -> ReaderT SqlBackend m a
forall r' r (m :: * -> *) a.
(r' -> r) -> ReaderT r m a -> ReaderT r' m a
withReaderT (\SqlBackend
conn -> SqlTransactionEnv
opts{sqlBackend :: SqlBackend
sqlBackend = SqlBackend
conn})
    (ReaderT SqlTransactionEnv m a -> ReaderT SqlBackend m a)
-> (SqlTransaction m a -> ReaderT SqlTransactionEnv m a)
-> SqlTransaction m a
-> ReaderT SqlBackend m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlTransaction m a -> ReaderT SqlTransactionEnv m a
forall (m :: * -> *) a.
SqlTransaction m a -> ReaderT SqlTransactionEnv m a
unSqlTransaction

-- | Like normal 'catch', except ignores errors specified by 'ignoreCatch'.
catchSqlTransaction ::
  (MonadUnliftIO m, Exception e) =>
  SqlTransaction m a ->
  (e -> SqlTransaction m a) ->
  SqlTransaction m a
catchSqlTransaction :: forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
SqlTransaction m a
-> (e -> SqlTransaction m a) -> SqlTransaction m a
catchSqlTransaction (UnsafeSqlTransaction ReaderT SqlTransactionEnv m a
m) e -> SqlTransaction m a
handler =
  ReaderT SqlTransactionEnv m a -> SqlTransaction m a
forall (m :: * -> *) a.
ReaderT SqlTransactionEnv m a -> SqlTransaction m a
UnsafeSqlTransaction (ReaderT SqlTransactionEnv m a -> SqlTransaction m a)
-> ReaderT SqlTransactionEnv m a -> SqlTransaction m a
forall a b. (a -> b) -> a -> b
$ ReaderT SqlTransactionEnv m a
m ReaderT SqlTransactionEnv m a
-> (e -> ReaderT SqlTransactionEnv m a)
-> ReaderT SqlTransactionEnv m a
forall {m :: * -> *} {b} {b}.
(MonadReader SqlTransactionEnv m, MonadUnliftIO m, Exception b) =>
m b -> (b -> m b) -> m b
`catch` (SqlTransaction m a -> ReaderT SqlTransactionEnv m a
forall (m :: * -> *) a.
SqlTransaction m a -> ReaderT SqlTransactionEnv m a
unSqlTransaction (SqlTransaction m a -> ReaderT SqlTransactionEnv m a)
-> (e -> SqlTransaction m a) -> e -> ReaderT SqlTransactionEnv m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> SqlTransaction m a
handler)
  where
    catch :: m b -> (b -> m b) -> m b
catch m b
a b -> m b
b = do
      SqlTransactionEnv{SomeException -> Bool
ignoreCatch :: SomeException -> Bool
ignoreCatch :: SqlTransactionEnv -> SomeException -> Bool
ignoreCatch} <- m SqlTransactionEnv
forall r (m :: * -> *). MonadReader r m => m r
ask
      (SomeException -> Maybe b) -> m b -> (b -> m b) -> m b
forall (m :: * -> *) e b a.
(MonadUnliftIO m, Exception e) =>
(e -> Maybe b) -> m a -> (b -> m a) -> m a
catchJust
        (\SomeException
e -> if SomeException -> Bool
ignoreCatch SomeException
e then Maybe b
forall a. Maybe a
Nothing else SomeException -> Maybe b
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e)
        m b
a
        b -> m b
b