{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Database.Persist.Monad.Internal.SqlTransaction (
  SqlTransaction (..),
  SqlTransactionEnv (..),
  runSqlTransaction,
  catchSqlTransaction,
) where

import Control.Monad.Fix (MonadFix)
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.IO.Unlift (MonadUnliftIO (..))
import Control.Monad.Reader (ReaderT, ask, withReaderT)
import Database.Persist.Sql (SqlBackend, runSqlConn)
import qualified GHC.TypeLits as GHC
import UnliftIO.Exception (Exception, SomeException, catchJust, fromException)

import Control.Monad.IO.Rerunnable (MonadRerunnableIO)
import Control.Monad.Trans.Rerunnable (MonadRerunnableTrans)
import Database.Persist.Monad.Class
import Database.Persist.Monad.SqlQueryRep

-- | The monad that tracks transaction state.
--
--  Conceptually equivalent to 'Database.Persist.Sql.SqlPersistT', but restricts
--  IO operations, for two reasons:
--
--    1. Forking a thread that uses the same 'SqlBackend' as the current thread
--       causes Bad Things to happen.
--    2. Transactions may need to be retried, in which case IO operations in
--       a transaction are required to be rerunnable.
--
--  You shouldn't need to explicitly use this type; your functions should only
--  declare the 'MonadSqlQuery' constraint.
newtype SqlTransaction m a = UnsafeSqlTransaction
  { forall (m :: * -> *) a.
SqlTransaction m a -> ReaderT SqlTransactionEnv m a
unSqlTransaction :: ReaderT SqlTransactionEnv m a
  }
  deriving (forall a b. a -> SqlTransaction m b -> SqlTransaction m a
forall a b. (a -> b) -> SqlTransaction m a -> SqlTransaction m b
forall (m :: * -> *) a b.
Functor m =>
a -> SqlTransaction m b -> SqlTransaction m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SqlTransaction m a -> SqlTransaction m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> SqlTransaction m b -> SqlTransaction m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> SqlTransaction m b -> SqlTransaction m a
fmap :: forall a b. (a -> b) -> SqlTransaction m a -> SqlTransaction m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> SqlTransaction m a -> SqlTransaction m b
Functor, forall a. a -> SqlTransaction m a
forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m a
forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
forall a b.
SqlTransaction m (a -> b)
-> SqlTransaction m a -> SqlTransaction m b
forall a b c.
(a -> b -> c)
-> SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall {m :: * -> *}. Applicative m => Functor (SqlTransaction m)
forall (m :: * -> *) a. Applicative m => a -> SqlTransaction m a
forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m a
forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m (a -> b)
-> SqlTransaction m a -> SqlTransaction m b
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m c
<* :: forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m a
$c<* :: forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m a
*> :: forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
$c*> :: forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
liftA2 :: forall a b c.
(a -> b -> c)
-> SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m c
$cliftA2 :: forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m c
<*> :: forall a b.
SqlTransaction m (a -> b)
-> SqlTransaction m a -> SqlTransaction m b
$c<*> :: forall (m :: * -> *) a b.
Applicative m =>
SqlTransaction m (a -> b)
-> SqlTransaction m a -> SqlTransaction m b
pure :: forall a. a -> SqlTransaction m a
$cpure :: forall (m :: * -> *) a. Applicative m => a -> SqlTransaction m a
Applicative, forall a. a -> SqlTransaction m a
forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
forall a b.
SqlTransaction m a
-> (a -> SqlTransaction m b) -> SqlTransaction m b
forall {m :: * -> *}. Monad m => Applicative (SqlTransaction m)
forall (m :: * -> *) a. Monad m => a -> SqlTransaction m a
forall (m :: * -> *) a b.
Monad m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
forall (m :: * -> *) a b.
Monad m =>
SqlTransaction m a
-> (a -> SqlTransaction m b) -> SqlTransaction m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> SqlTransaction m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> SqlTransaction m a
>> :: forall a b.
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
SqlTransaction m a -> SqlTransaction m b -> SqlTransaction m b
>>= :: forall a b.
SqlTransaction m a
-> (a -> SqlTransaction m b) -> SqlTransaction m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
SqlTransaction m a
-> (a -> SqlTransaction m b) -> SqlTransaction m b
Monad, forall a. (a -> SqlTransaction m a) -> SqlTransaction m a
forall (m :: * -> *).
Monad m -> (forall a. (a -> m a) -> m a) -> MonadFix m
forall {m :: * -> *}. MonadFix m => Monad (SqlTransaction m)
forall (m :: * -> *) a.
MonadFix m =>
(a -> SqlTransaction m a) -> SqlTransaction m a
mfix :: forall a. (a -> SqlTransaction m a) -> SqlTransaction m a
$cmfix :: forall (m :: * -> *) a.
MonadFix m =>
(a -> SqlTransaction m a) -> SqlTransaction m a
MonadFix, forall a. IO a -> SqlTransaction m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadRerunnableIO m
forall {m :: * -> *}.
MonadRerunnableIO m =>
Monad (SqlTransaction m)
forall (m :: * -> *) a.
MonadRerunnableIO m =>
IO a -> SqlTransaction m a
rerunnableIO :: forall a. IO a -> SqlTransaction m a
$crerunnableIO :: forall (m :: * -> *) a.
MonadRerunnableIO m =>
IO a -> SqlTransaction m a
MonadRerunnableIO, forall (m :: * -> *) a. Monad m => m a -> SqlTransaction m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a)
-> MonadRerunnableTrans t
rerunnableLift :: forall (m :: * -> *) a. Monad m => m a -> SqlTransaction m a
$crerunnableLift :: forall (m :: * -> *) a. Monad m => m a -> SqlTransaction m a
MonadRerunnableTrans)

instance
  ( GHC.TypeError ('GHC.Text "Cannot run arbitrary IO actions within a transaction. If the IO action is rerunnable, use rerunnableIO")
  , Monad m
  ) =>
  MonadIO (SqlTransaction m)
  where
  liftIO :: forall a. IO a -> SqlTransaction m a
liftIO = forall a. HasCallStack => a
undefined

instance (MonadSqlQuery m, MonadUnliftIO m) => MonadSqlQuery (SqlTransaction m) where
  type TransactionM (SqlTransaction m) = TransactionM m

  runQueryRep :: forall record a.
Typeable record =>
SqlQueryRep record a -> SqlTransaction m a
runQueryRep = forall (m :: * -> *) a.
ReaderT SqlTransactionEnv m a -> SqlTransaction m a
UnsafeSqlTransaction forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r' r (m :: * -> *) a.
(r' -> r) -> ReaderT r m a -> ReaderT r' m a
withReaderT SqlTransactionEnv -> SqlBackend
sqlBackend forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) record a.
MonadUnliftIO m =>
SqlQueryRep record a -> SqlPersistT m a
runSqlQueryRep

  -- Delegate to 'm', since 'm' is in charge of starting/stopping transactions.
  -- 'SqlTransaction' is ONLY in charge of executing queries.
  withTransaction :: forall a. TransactionM (SqlTransaction m) a -> SqlTransaction m a
withTransaction = forall (m :: * -> *) a.
ReaderT SqlTransactionEnv m a -> SqlTransaction m a
UnsafeSqlTransaction forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadSqlQuery m => TransactionM m a -> m a
withTransaction

data SqlTransactionEnv = SqlTransactionEnv
  { SqlTransactionEnv -> SqlBackend
sqlBackend :: SqlBackend
  , SqlTransactionEnv -> SomeException -> Bool
ignoreCatch :: SomeException -> Bool
  }

runSqlTransaction ::
  (MonadUnliftIO m) =>
  SqlTransactionEnv
  -> SqlTransaction m a
  -> m a
runSqlTransaction :: forall (m :: * -> *) a.
MonadUnliftIO m =>
SqlTransactionEnv -> SqlTransaction m a -> m a
runSqlTransaction SqlTransactionEnv
opts =
  (forall backend (m :: * -> *) a.
(MonadUnliftIO m, BackendCompatible SqlBackend backend) =>
ReaderT backend m a -> backend -> m a
`runSqlConn` SqlTransactionEnv -> SqlBackend
sqlBackend SqlTransactionEnv
opts)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r' r (m :: * -> *) a.
(r' -> r) -> ReaderT r m a -> ReaderT r' m a
withReaderT (\SqlBackend
conn -> SqlTransactionEnv
opts{sqlBackend :: SqlBackend
sqlBackend = SqlBackend
conn})
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
SqlTransaction m a -> ReaderT SqlTransactionEnv m a
unSqlTransaction

-- | Like normal 'catch', except ignores errors specified by 'ignoreCatch'.
catchSqlTransaction ::
  (MonadUnliftIO m, Exception e) =>
  SqlTransaction m a
  -> (e -> SqlTransaction m a)
  -> SqlTransaction m a
catchSqlTransaction :: forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
SqlTransaction m a
-> (e -> SqlTransaction m a) -> SqlTransaction m a
catchSqlTransaction (UnsafeSqlTransaction ReaderT SqlTransactionEnv m a
m) e -> SqlTransaction m a
handler =
  forall (m :: * -> *) a.
ReaderT SqlTransactionEnv m a -> SqlTransaction m a
UnsafeSqlTransaction forall a b. (a -> b) -> a -> b
$ ReaderT SqlTransactionEnv m a
m forall {m :: * -> *} {b} {b}.
(MonadReader SqlTransactionEnv m, MonadUnliftIO m, Exception b) =>
m b -> (b -> m b) -> m b
`catch` (forall (m :: * -> *) a.
SqlTransaction m a -> ReaderT SqlTransactionEnv m a
unSqlTransaction forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> SqlTransaction m a
handler)
  where
    catch :: m b -> (b -> m b) -> m b
catch m b
a b -> m b
b = do
      SqlTransactionEnv{SomeException -> Bool
ignoreCatch :: SomeException -> Bool
ignoreCatch :: SqlTransactionEnv -> SomeException -> Bool
ignoreCatch} <- forall r (m :: * -> *). MonadReader r m => m r
ask
      forall (m :: * -> *) e b a.
(MonadUnliftIO m, Exception e) =>
(e -> Maybe b) -> m a -> (b -> m a) -> m a
catchJust
        (\SomeException
e -> if SomeException -> Bool
ignoreCatch SomeException
e then forall a. Maybe a
Nothing else forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e)
        m b
a
        b -> m b
b