{-# LANGUAGE TypeApplications, MultiWayIf, OverloadedStrings #-}

module HsDev.Database.SQLite.Transaction (
	TransactionType(..),
	Retries(..), def, noRetry, retryForever, retryN,

	-- * Transactions
	withTransaction, beginTransaction, commitTransaction, rollbackTransaction,
	transaction, transaction_,

	-- * Retry functions
	retry, retry_
	) where

import Control.Concurrent
import Control.Monad.Catch
import Control.Monad.IO.Class
import Data.Default
import Database.SQLite.Simple as SQL hiding (withTransaction)

import HsDev.Server.Types (SessionMonad, serverSqlDatabase)

-- | Three types of transactions
data TransactionType = Deferred | Immediate | Exclusive
	deriving (TransactionType -> TransactionType -> Bool
(TransactionType -> TransactionType -> Bool)
-> (TransactionType -> TransactionType -> Bool)
-> Eq TransactionType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransactionType -> TransactionType -> Bool
$c/= :: TransactionType -> TransactionType -> Bool
== :: TransactionType -> TransactionType -> Bool
$c== :: TransactionType -> TransactionType -> Bool
Eq, Eq TransactionType
Eq TransactionType
-> (TransactionType -> TransactionType -> Ordering)
-> (TransactionType -> TransactionType -> Bool)
-> (TransactionType -> TransactionType -> Bool)
-> (TransactionType -> TransactionType -> Bool)
-> (TransactionType -> TransactionType -> Bool)
-> (TransactionType -> TransactionType -> TransactionType)
-> (TransactionType -> TransactionType -> TransactionType)
-> Ord TransactionType
TransactionType -> TransactionType -> Bool
TransactionType -> TransactionType -> Ordering
TransactionType -> TransactionType -> TransactionType
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TransactionType -> TransactionType -> TransactionType
$cmin :: TransactionType -> TransactionType -> TransactionType
max :: TransactionType -> TransactionType -> TransactionType
$cmax :: TransactionType -> TransactionType -> TransactionType
>= :: TransactionType -> TransactionType -> Bool
$c>= :: TransactionType -> TransactionType -> Bool
> :: TransactionType -> TransactionType -> Bool
$c> :: TransactionType -> TransactionType -> Bool
<= :: TransactionType -> TransactionType -> Bool
$c<= :: TransactionType -> TransactionType -> Bool
< :: TransactionType -> TransactionType -> Bool
$c< :: TransactionType -> TransactionType -> Bool
compare :: TransactionType -> TransactionType -> Ordering
$ccompare :: TransactionType -> TransactionType -> Ordering
$cp1Ord :: Eq TransactionType
Ord, ReadPrec [TransactionType]
ReadPrec TransactionType
Int -> ReadS TransactionType
ReadS [TransactionType]
(Int -> ReadS TransactionType)
-> ReadS [TransactionType]
-> ReadPrec TransactionType
-> ReadPrec [TransactionType]
-> Read TransactionType
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [TransactionType]
$creadListPrec :: ReadPrec [TransactionType]
readPrec :: ReadPrec TransactionType
$creadPrec :: ReadPrec TransactionType
readList :: ReadS [TransactionType]
$creadList :: ReadS [TransactionType]
readsPrec :: Int -> ReadS TransactionType
$creadsPrec :: Int -> ReadS TransactionType
Read, Int -> TransactionType -> ShowS
[TransactionType] -> ShowS
TransactionType -> String
(Int -> TransactionType -> ShowS)
-> (TransactionType -> String)
-> ([TransactionType] -> ShowS)
-> Show TransactionType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransactionType] -> ShowS
$cshowList :: [TransactionType] -> ShowS
show :: TransactionType -> String
$cshow :: TransactionType -> String
showsPrec :: Int -> TransactionType -> ShowS
$cshowsPrec :: Int -> TransactionType -> ShowS
Show)

-- | Retry config
data Retries = Retries {
	Retries -> [Int]
retriesIntervals :: [Int],
	Retries -> SQLError -> Bool
retriesError :: SQLError -> Bool }

instance Default Retries where
	def :: Retries
def = [Int] -> (SQLError -> Bool) -> Retries
Retries (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate Int
10 Int
100000) ((SQLError -> Bool) -> Retries) -> (SQLError -> Bool) -> Retries
forall a b. (a -> b) -> a -> b
$ \SQLError
e -> SQLError -> Error
sqlError SQLError
e Error -> [Error] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Error
ErrorBusy, Error
ErrorLocked]

-- | Don't retry
noRetry :: Retries
noRetry :: Retries
noRetry = [Int] -> (SQLError -> Bool) -> Retries
Retries [] (Bool -> SQLError -> Bool
forall a b. a -> b -> a
const Bool
False)

-- | Retry forever
retryForever :: Int -> Retries
retryForever :: Int -> Retries
retryForever Int
interval = Retries
forall a. Default a => a
def { retriesIntervals :: [Int]
retriesIntervals = Int -> [Int]
forall a. a -> [a]
repeat Int
interval }

-- | Retry with interval N times
retryN :: Int -> Int -> Retries
retryN :: Int -> Int -> Retries
retryN Int
interval Int
times = Retries
forall a. Default a => a
def { retriesIntervals :: [Int]
retriesIntervals = Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate Int
times Int
interval }

-- | Run actions inside transaction
withTransaction :: (MonadIO m, MonadMask m) => Connection -> TransactionType -> Retries -> m a -> m a
withTransaction :: Connection -> TransactionType -> Retries -> m a -> m a
withTransaction Connection
conn TransactionType
t Retries
rs m a
act = ((forall a. m a -> m a) -> m a) -> m a
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m a) -> m a)
-> ((forall a. m a -> m a) -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
	(m () -> m ()) -> m () -> m ()
forall (_ :: * -> *) (m :: * -> *) a a.
(MonadCatch _, MonadIO m) =>
(m () -> _ a) -> _ a -> _ a
mretry m () -> m ()
forall a. m a -> m a
restore (Connection -> TransactionType -> m ()
forall (m :: * -> *).
MonadIO m =>
Connection -> TransactionType -> m ()
beginTransaction Connection
conn TransactionType
t)
	(m a -> m a
forall a. m a -> m a
restore m a
act m a -> m () -> m a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* (m () -> m ()) -> m () -> m ()
forall (_ :: * -> *) (m :: * -> *) a a.
(MonadCatch _, MonadIO m) =>
(m () -> _ a) -> _ a -> _ a
mretry m () -> m ()
forall a. m a -> m a
restore (Connection -> m ()
forall (m :: * -> *). MonadIO m => Connection -> m ()
commitTransaction Connection
conn)) m a -> m () -> m a
forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
`onException` Connection -> m ()
forall (m :: * -> *). MonadIO m => Connection -> m ()
rollbackTransaction Connection
conn
	where
		mretry :: (m () -> _ a) -> _ a -> _ a
mretry m () -> _ a
restore' _ a
fn = [Int] -> _ a
mretry' (Retries -> [Int]
retriesIntervals Retries
rs) where
			mretry' :: [Int] -> _ a
mretry' [] = _ a
fn
			mretry' (Int
tm:[Int]
tms) = _ a -> (SQLError -> _ a) -> _ a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
catch @_ @SQLError _ a
fn ((SQLError -> _ a) -> _ a) -> (SQLError -> _ a) -> _ a
forall a b. (a -> b) -> a -> b
$ \SQLError
e -> if
				| Retries -> SQLError -> Bool
retriesError Retries
rs SQLError
e -> do
						a
_ <- m () -> _ a
restore' (m () -> _ a) -> m () -> _ a
forall a b. (a -> b) -> a -> b
$ IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
threadDelay Int
tm
						[Int] -> _ a
mretry' [Int]
tms
				| Bool
otherwise -> SQLError -> _ a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM SQLError
e

-- | Begin transaction
beginTransaction :: MonadIO m => Connection -> TransactionType -> m ()
beginTransaction :: Connection -> TransactionType -> m ()
beginTransaction Connection
conn TransactionType
t = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO ()
SQL.execute_ Connection
conn (Query -> IO ()) -> Query -> IO ()
forall a b. (a -> b) -> a -> b
$ case TransactionType
t of
	TransactionType
Deferred -> Query
"begin transaction;"
	TransactionType
Immediate -> Query
"begin immediate transaction;"
	TransactionType
Exclusive -> Query
"begin exclusive transaction;"

-- | Commit transaction
commitTransaction :: MonadIO m => Connection -> m ()
commitTransaction :: Connection -> m ()
commitTransaction Connection
conn = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO ()
SQL.execute_ Connection
conn Query
"commit transaction;"

-- | Rollback transaction
rollbackTransaction :: MonadIO m => Connection -> m ()
rollbackTransaction :: Connection -> m ()
rollbackTransaction Connection
conn = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO ()
SQL.execute_ Connection
conn Query
"rollback transaction;"

-- | Run transaction in @SessionMonad@
transaction :: SessionMonad m => TransactionType -> Retries -> m a -> m a
transaction :: TransactionType -> Retries -> m a -> m a
transaction TransactionType
t Retries
rs m a
act = do
	Connection
conn <- m Connection
forall (m :: * -> *). SessionMonad m => m Connection
serverSqlDatabase
	Connection -> TransactionType -> Retries -> m a -> m a
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
Connection -> TransactionType -> Retries -> m a -> m a
withTransaction Connection
conn TransactionType
t Retries
rs m a
act

-- | Transaction with default retries config
transaction_ :: SessionMonad m => TransactionType -> m a -> m a
transaction_ :: TransactionType -> m a -> m a
transaction_ TransactionType
t = TransactionType -> Retries -> m a -> m a
forall (m :: * -> *) a.
SessionMonad m =>
TransactionType -> Retries -> m a -> m a
transaction TransactionType
t Retries
forall a. Default a => a
def

-- | Retry operation
retry :: (MonadIO m, MonadCatch m) => Retries -> m a -> m a
retry :: Retries -> m a -> m a
retry Retries
rs = [Int] -> m a -> m a
forall (_ :: * -> *) a.
(MonadCatch _, MonadIO _) =>
[Int] -> _ a -> _ a
retry' (Retries -> [Int]
retriesIntervals Retries
rs) where
	retry' :: [Int] -> _ a -> _ a
retry' [] _ a
fn = _ a
fn
	retry' (Int
tm:[Int]
tms) _ a
fn = _ a -> (SQLError -> _ a) -> _ a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
catch @_ @SQLError _ a
fn ((SQLError -> _ a) -> _ a) -> (SQLError -> _ a) -> _ a
forall a b. (a -> b) -> a -> b
$ \SQLError
e -> if
		| Retries -> SQLError -> Bool
retriesError Retries
rs SQLError
e -> do
			IO () -> _ ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> _ ()) -> IO () -> _ ()
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
threadDelay Int
tm
			[Int] -> _ a -> _ a
retry' [Int]
tms _ a
fn
		| Bool
otherwise -> SQLError -> _ a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM SQLError
e

-- | Retry with default params
retry_ :: (MonadIO m, MonadCatch m) => m a -> m a
retry_ :: m a -> m a
retry_ = Retries -> m a -> m a
forall (m :: * -> *) a.
(MonadIO m, MonadCatch m) =>
Retries -> m a -> m a
retry Retries
forall a. Default a => a
def