{-# LANGUAGE RankNTypes #-}

module Control.Concurrent.Worker (
	Worker(..), WorkerStopped(..),
	startWorker, workerAlive, workerDone,
	sendTask, stopWorker, joinWorker, syncTask,
	inWorkerWith, inWorker,

	module Control.Concurrent.Async
	) where

import Control.Concurrent.MVar
import Control.Monad.IO.Class
import Control.Monad.Catch
import Control.Monad.Except
import Data.Maybe (isNothing)
import Data.Typeable

import Control.Concurrent.FiniteChan
import Control.Concurrent.Async

data Worker m = Worker {
	Worker m -> Chan (Async (), m ())
workerChan :: Chan (Async (), m ()),
	Worker m -> MVar (Async ())
workerTask :: MVar (Async ()) }

data WorkerStopped = WorkerStopped deriving (Int -> WorkerStopped -> ShowS
[WorkerStopped] -> ShowS
WorkerStopped -> String
(Int -> WorkerStopped -> ShowS)
-> (WorkerStopped -> String)
-> ([WorkerStopped] -> ShowS)
-> Show WorkerStopped
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WorkerStopped] -> ShowS
$cshowList :: [WorkerStopped] -> ShowS
show :: WorkerStopped -> String
$cshow :: WorkerStopped -> String
showsPrec :: Int -> WorkerStopped -> ShowS
$cshowsPrec :: Int -> WorkerStopped -> ShowS
Show, Typeable)

instance Exception WorkerStopped

-- | Create new worker
startWorker :: MonadIO m => (m () -> IO ()) -> (m () -> m ()) -> (m () -> m ()) -> IO (Worker m)
startWorker :: (m () -> IO ())
-> (m () -> m ()) -> (m () -> m ()) -> IO (Worker m)
startWorker m () -> IO ()
run m () -> m ()
initialize m () -> m ()
handleErrs = do
	Chan (Async (), m ())
ch <- IO (Chan (Async (), m ()))
forall a. IO (Chan a)
newChan
	MVar (Async ())
taskVar <- IO (MVar (Async ()))
forall a. IO (MVar a)
newEmptyMVar
	let
		job :: IO ()
job = IO () -> IO () -> IO ()
forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
onException (m () -> IO ()
run (m () -> IO ()) -> m () -> IO ()
forall a b. (a -> b) -> a -> b
$ m () -> m ()
initialize m ()
go) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
			Chan (Async (), m ()) -> IO ()
forall a. Chan a -> IO ()
closeChan Chan (Async (), m ())
ch
			IO ()
abort
		go :: m ()
go = do
			Maybe (m ())
t <- ((Async (), m ()) -> m ())
-> Maybe (Async (), m ()) -> Maybe (m ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Async (), m ()) -> m ()
forall a b. (a, b) -> b
snd (Maybe (Async (), m ()) -> Maybe (m ()))
-> m (Maybe (Async (), m ())) -> m (Maybe (m ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Maybe (Async (), m ())) -> m (Maybe (Async (), m ()))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Chan (Async (), m ()) -> IO (Maybe (Async (), m ()))
forall a. Chan a -> IO (Maybe a)
getChan Chan (Async (), m ())
ch)
			m () -> (m () -> m ()) -> Maybe (m ()) -> m ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (\m ()
f' -> m () -> m ()
handleErrs m ()
f' m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m ()
go) Maybe (m ())
t
		abort :: IO ()
abort = do
			Maybe (Async ())
a <- ((Async (), m ()) -> Async ())
-> Maybe (Async (), m ()) -> Maybe (Async ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Async (), m ()) -> Async ()
forall a b. (a, b) -> a
fst (Maybe (Async (), m ()) -> Maybe (Async ()))
-> IO (Maybe (Async (), m ())) -> IO (Maybe (Async ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Maybe (Async (), m ())) -> IO (Maybe (Async (), m ()))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Chan (Async (), m ()) -> IO (Maybe (Async (), m ()))
forall a. Chan a -> IO (Maybe a)
getChan Chan (Async (), m ())
ch)
			IO () -> (Async () -> IO ()) -> Maybe (Async ()) -> IO ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (\Async ()
a' -> IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Async () -> IO ()
forall a. Async a -> IO ()
cancel Async ()
a') IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ()
abort) Maybe (Async ())
a
	IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async IO ()
job IO (Async ()) -> (Async () -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVar (Async ()) -> Async () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Async ())
taskVar
	Worker m -> IO (Worker m)
forall (m :: * -> *) a. Monad m => a -> m a
return (Worker m -> IO (Worker m)) -> Worker m -> IO (Worker m)
forall a b. (a -> b) -> a -> b
$ Chan (Async (), m ()) -> MVar (Async ()) -> Worker m
forall (m :: * -> *).
Chan (Async (), m ()) -> MVar (Async ()) -> Worker m
Worker Chan (Async (), m ())
ch MVar (Async ())
taskVar

-- | Check whether worker alive
workerAlive :: Worker m -> IO Bool
workerAlive :: Worker m -> IO Bool
workerAlive Worker m
w = do
	Async ()
task <- MVar (Async ()) -> IO (Async ())
forall a. MVar a -> IO a
readMVar (MVar (Async ()) -> IO (Async ()))
-> MVar (Async ()) -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ Worker m -> MVar (Async ())
forall (m :: * -> *). Worker m -> MVar (Async ())
workerTask Worker m
w
	Maybe (Either SomeException ()) -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe (Either SomeException ()) -> Bool)
-> IO (Maybe (Either SomeException ())) -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Async () -> IO (Maybe (Either SomeException ()))
forall a. Async a -> IO (Maybe (Either SomeException a))
poll Async ()
task

workerDone :: Worker m -> IO Bool
workerDone :: Worker m -> IO Bool
workerDone = Chan (Async (), m ()) -> IO Bool
forall a. Chan a -> IO Bool
doneChan (Chan (Async (), m ()) -> IO Bool)
-> (Worker m -> Chan (Async (), m ())) -> Worker m -> IO Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Worker m -> Chan (Async (), m ())
forall (m :: * -> *). Worker m -> Chan (Async (), m ())
workerChan

sendTask :: (MonadCatch m, MonadIO m) => Worker m -> m a -> IO (Async a)
sendTask :: Worker m -> m a -> IO (Async a)
sendTask Worker m
w m a
act = (Async a -> IO (Async a)) -> IO (Async a)
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
mfix ((Async a -> IO (Async a)) -> IO (Async a))
-> (Async a -> IO (Async a)) -> IO (Async a)
forall a b. (a -> b) -> a -> b
$ \Async a
async' -> do
	MVar (Either SomeException a)
var <- IO (MVar (Either SomeException a))
forall a. IO (MVar a)
newEmptyMVar
	let
		act' :: m ()
act' = (m a
act m a -> (a -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (a -> IO ()) -> a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar (Either SomeException a) -> Either SomeException a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Either SomeException a)
var (Either SomeException a -> IO ())
-> (a -> Either SomeException a) -> a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Either SomeException a
forall a b. b -> Either a b
Right) m () -> (SomeException -> m ()) -> m ()
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` SomeException -> m ()
forall (m :: * -> *). MonadIO m => SomeException -> m ()
onErr
		onErr :: MonadIO m => SomeException -> m ()
		onErr :: SomeException -> m ()
onErr = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> (SomeException -> IO ()) -> SomeException -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar (Either SomeException a) -> Either SomeException a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Either SomeException a)
var (Either SomeException a -> IO ())
-> (SomeException -> Either SomeException a)
-> SomeException
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Either SomeException a
forall a b. a -> Either a b
Left
		f :: IO a
f = do
			Bool
p <- Chan (Async (), m ()) -> (Async (), m ()) -> IO Bool
forall a. Chan a -> a -> IO Bool
sendChan (Worker m -> Chan (Async (), m ())
forall (m :: * -> *). Worker m -> Chan (Async (), m ())
workerChan Worker m
w) (Async a -> Async ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void Async a
async', m () -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void m ()
act')
			Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
p (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar (Either SomeException a) -> Either SomeException a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Either SomeException a)
var (SomeException -> Either SomeException a
forall a b. a -> Either a b
Left (SomeException -> Either SomeException a)
-> SomeException -> Either SomeException a
forall a b. (a -> b) -> a -> b
$ WorkerStopped -> SomeException
forall e. Exception e => e -> SomeException
SomeException WorkerStopped
WorkerStopped)
			Either SomeException a
r <- MVar (Either SomeException a) -> IO (Either SomeException a)
forall a. MVar a -> IO a
takeMVar MVar (Either SomeException a)
var
			(SomeException -> IO a)
-> (a -> IO a) -> Either SomeException a -> IO a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> IO a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Either SomeException a
r
	IO a -> IO (Async a)
forall a. IO a -> IO (Async a)
async IO a
f

-- | Close worker channel
stopWorker :: Worker m -> IO ()
stopWorker :: Worker m -> IO ()
stopWorker = Chan (Async (), m ()) -> IO ()
forall a. Chan a -> IO ()
closeChan (Chan (Async (), m ()) -> IO ())
-> (Worker m -> Chan (Async (), m ())) -> Worker m -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Worker m -> Chan (Async (), m ())
forall (m :: * -> *). Worker m -> Chan (Async (), m ())
workerChan

-- | Stop worker and wait for it
joinWorker :: Worker m -> IO ()
joinWorker :: Worker m -> IO ()
joinWorker Worker m
w = do
	Worker m -> IO ()
forall (m :: * -> *). Worker m -> IO ()
stopWorker Worker m
w
	Async ()
async' <- MVar (Async ()) -> IO (Async ())
forall a. MVar a -> IO a
readMVar (MVar (Async ()) -> IO (Async ()))
-> MVar (Async ()) -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ Worker m -> MVar (Async ())
forall (m :: * -> *). Worker m -> MVar (Async ())
workerTask Worker m
w
	IO (Either SomeException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either SomeException ()) -> IO ())
-> IO (Either SomeException ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Async () -> IO (Either SomeException ())
forall a. Async a -> IO (Either SomeException a)
waitCatch Async ()
async'

-- | Send empty task and wait until worker run it
syncTask :: (MonadCatch m, MonadIO m) => Worker m -> IO ()
syncTask :: Worker m -> IO ()
syncTask Worker m
w = Worker m -> m () -> IO (Async ())
forall (m :: * -> *) a.
(MonadCatch m, MonadIO m) =>
Worker m -> m a -> IO (Async a)
sendTask Worker m
w (() -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) IO (Async ()) -> (Async () -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> (Async () -> IO ()) -> Async () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Async () -> IO ()
forall a. Async a -> IO a
wait

-- | Run action in worker and wait for result
inWorkerWith :: (MonadIO m, MonadCatch m, MonadIO n) => (SomeException -> n a) -> Worker m -> m a -> n a
inWorkerWith :: (SomeException -> n a) -> Worker m -> m a -> n a
inWorkerWith SomeException -> n a
err Worker m
w m a
act = IO (Async a) -> n (Async a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Worker m -> m a -> IO (Async a)
forall (m :: * -> *) a.
(MonadCatch m, MonadIO m) =>
Worker m -> m a -> IO (Async a)
sendTask Worker m
w m a
act) n (Async a) -> (Async a -> n a) -> n a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (IO (Either SomeException a) -> n (Either SomeException a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SomeException a) -> n (Either SomeException a))
-> (Async a -> IO (Either SomeException a))
-> Async a
-> n (Either SomeException a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Async a -> IO (Either SomeException a)
forall a. Async a -> IO (Either SomeException a)
waitCatch (Async a -> n (Either SomeException a))
-> (Either SomeException a -> n a) -> Async a -> n a
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (SomeException -> n a)
-> (a -> n a) -> Either SomeException a -> n a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> n a
err a -> n a
forall (m :: * -> *) a. Monad m => a -> m a
return)

-- | Run action in worker and wait for result
inWorker :: (MonadIO m, MonadCatch m) => Worker m -> m a -> IO a
inWorker :: Worker m -> m a -> IO a
inWorker Worker m
w m a
act = Worker m -> m a -> IO (Async a)
forall (m :: * -> *) a.
(MonadCatch m, MonadIO m) =>
Worker m -> m a -> IO (Async a)
sendTask Worker m
w m a
act IO (Async a) -> (Async a -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> IO a) -> (Async a -> IO a) -> Async a -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Async a -> IO a
forall a. Async a -> IO a
wait