module Database.PostgreSQL.Consumers.Utils (
    finalize
  , ThrownFrom(..)
  , stopExecution
  , forkP
  , gforkP
  ) where

import Control.Concurrent.Lifted
import Control.Monad.Base
import Control.Monad.Catch
import Control.Monad.Trans.Control
import Data.Typeable
import Prelude
import qualified Control.Concurrent.Thread.Group.Lifted as TG
import qualified Control.Concurrent.Thread.Lifted as T
import qualified Control.Exception.Lifted as E

-- | Run an action 'm' that returns a finalizer and perform the
-- returned finalizer after the action 'action' completes.
finalize :: (MonadMask m, MonadBase IO m) => m (m ()) -> m a -> m a
finalize :: m (m ()) -> m a -> m a
finalize m (m ())
m m a
action = do
  MVar (m ())
finalizer <- m (MVar (m ()))
forall (m :: * -> *) a. MonadBase IO m => m (MVar a)
newEmptyMVar
  (m a -> m () -> m a) -> m () -> m a -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip m a -> m () -> m a
forall (m :: * -> *) a b. MonadMask m => m a -> m b -> m a
finally (MVar (m ()) -> m (Maybe (m ()))
forall (m :: * -> *) a. MonadBase IO m => MVar a -> m (Maybe a)
tryTakeMVar MVar (m ())
finalizer m (Maybe (m ())) -> (Maybe (m ()) -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= m () -> (m () -> m ()) -> Maybe (m ()) -> m ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) m () -> m ()
forall a. a -> a
id) (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$ do
    MVar (m ()) -> m () -> m ()
forall (m :: * -> *) a. MonadBase IO m => MVar a -> a -> m ()
putMVar MVar (m ())
finalizer (m () -> m ()) -> m (m ()) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (m ())
m
    m a
action

----------------------------------------

-- | Exception thrown to a thread to stop its execution.
-- All exceptions other than 'StopExecution' thrown to
-- threads spawned by 'forkP' and 'gforkP' are propagated
-- back to the parent thread.
data StopExecution = StopExecution
  deriving (Int -> StopExecution -> ShowS
[StopExecution] -> ShowS
StopExecution -> String
(Int -> StopExecution -> ShowS)
-> (StopExecution -> String)
-> ([StopExecution] -> ShowS)
-> Show StopExecution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StopExecution] -> ShowS
$cshowList :: [StopExecution] -> ShowS
show :: StopExecution -> String
$cshow :: StopExecution -> String
showsPrec :: Int -> StopExecution -> ShowS
$cshowsPrec :: Int -> StopExecution -> ShowS
Show, Typeable)
instance Exception StopExecution

-- | Exception thrown from a child thread.
data ThrownFrom = ThrownFrom String SomeException
  deriving (Int -> ThrownFrom -> ShowS
[ThrownFrom] -> ShowS
ThrownFrom -> String
(Int -> ThrownFrom -> ShowS)
-> (ThrownFrom -> String)
-> ([ThrownFrom] -> ShowS)
-> Show ThrownFrom
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ThrownFrom] -> ShowS
$cshowList :: [ThrownFrom] -> ShowS
show :: ThrownFrom -> String
$cshow :: ThrownFrom -> String
showsPrec :: Int -> ThrownFrom -> ShowS
$cshowsPrec :: Int -> ThrownFrom -> ShowS
Show, Typeable)
instance Exception ThrownFrom

-- | Stop execution of a thread.
stopExecution :: MonadBase IO m => ThreadId -> m ()
stopExecution :: ThreadId -> m ()
stopExecution = (ThreadId -> StopExecution -> m ())
-> StopExecution -> ThreadId -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ThreadId -> StopExecution -> m ()
forall (m :: * -> *) e.
(MonadBase IO m, Exception e) =>
ThreadId -> e -> m ()
throwTo StopExecution
StopExecution

----------------------------------------

-- | Modified version of 'fork' that propagates
-- thrown exceptions to the parent thread.
forkP :: MonadBaseControl IO m => String -> m () -> m ThreadId
forkP :: String -> m () -> m ThreadId
forkP = (m () -> m ThreadId) -> String -> m () -> m ThreadId
forall (m :: * -> *) a.
MonadBaseControl IO m =>
(m () -> m a) -> String -> m () -> m a
forkImpl m () -> m ThreadId
forall (m :: * -> *). MonadBaseControl IO m => m () -> m ThreadId
fork

-- | Modified version of 'TG.fork' that propagates
-- thrown exceptions to the parent thread.
gforkP :: MonadBaseControl IO m
       => TG.ThreadGroup
       -> String
       -> m ()
       -> m (ThreadId, m (T.Result ()))
gforkP :: ThreadGroup -> String -> m () -> m (ThreadId, m (Result ()))
gforkP = (m () -> m (ThreadId, m (Result ())))
-> String -> m () -> m (ThreadId, m (Result ()))
forall (m :: * -> *) a.
MonadBaseControl IO m =>
(m () -> m a) -> String -> m () -> m a
forkImpl ((m () -> m (ThreadId, m (Result ())))
 -> String -> m () -> m (ThreadId, m (Result ())))
-> (ThreadGroup -> m () -> m (ThreadId, m (Result ())))
-> ThreadGroup
-> String
-> m ()
-> m (ThreadId, m (Result ()))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ThreadGroup -> m () -> m (ThreadId, m (Result ()))
forall (m :: * -> *) a.
MonadBaseControl IO m =>
ThreadGroup -> m a -> m (ThreadId, m (Result a))
TG.fork

----------------------------------------

forkImpl :: MonadBaseControl IO m
         => (m () -> m a)
         -> String
         -> m ()
         -> m a
forkImpl :: (m () -> m a) -> String -> m () -> m a
forkImpl m () -> m a
ffork String
tname m ()
m = ((forall a. m a -> m a) -> m a) -> m a
forall (m :: * -> *) b.
MonadBaseControl IO m =>
((forall a. m a -> m a) -> m b) -> m b
E.mask (((forall a. m a -> m a) -> m a) -> m a)
-> ((forall a. m a -> m a) -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
release -> do
  ThreadId
parent <- m ThreadId
forall (m :: * -> *). MonadBase IO m => m ThreadId
myThreadId
  m () -> m a
ffork (m () -> m a) -> m () -> m a
forall a b. (a -> b) -> a -> b
$ m () -> m ()
forall a. m a -> m a
release m ()
m m () -> [Handler m ()] -> m ()
forall (m :: * -> *) a.
MonadBaseControl IO m =>
m a -> [Handler m a] -> m a
`E.catches` [
      (StopExecution -> m ()) -> Handler m ()
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
E.Handler ((StopExecution -> m ()) -> Handler m ())
-> (StopExecution -> m ()) -> Handler m ()
forall a b. (a -> b) -> a -> b
$ \StopExecution
StopExecution -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    , (SomeException -> m ()) -> Handler m ()
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
E.Handler ((SomeException -> m ()) -> Handler m ())
-> (SomeException -> m ()) -> Handler m ()
forall a b. (a -> b) -> a -> b
$ (ThreadId -> ThrownFrom -> m ()
forall (m :: * -> *) e.
(MonadBase IO m, Exception e) =>
ThreadId -> e -> m ()
throwTo ThreadId
parent (ThrownFrom -> m ())
-> (SomeException -> ThrownFrom) -> SomeException -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> SomeException -> ThrownFrom
ThrownFrom String
tname)
    ]