module Control.Concurrent.Async
(
Async
, async
, asyncOn
, asyncWithUnmask
, asyncOnWithUnmask
, withAsync
, withAsyncOn
, withAsyncWithUnmask
, withAsyncOnWithUnmask
, wait, waitSTM
, poll, pollSTM
, waitCatch, waitCatchSTM
, cancel
, cancelWith
, asyncThreadId
, waitAny, waitAnySTM
, waitAnyCatch, waitAnyCatchSTM
, waitAnyCancel
, waitAnyCatchCancel
, waitEither, waitEitherSTM
, waitEitherCatch, waitEitherCatchSTM
, waitEitherCancel
, waitEitherCatchCancel
, waitEither_, waitEitherSTM_
, waitBoth, waitBothSTM
, link
, link2
, race
, race_
, concurrently
, mapConcurrently
, forConcurrently
, Concurrently(..)
) where
import Control.Applicative
import Control.Exception (AsyncException(ThreadKilled), BlockedIndefinitelyOnSTM(..), Exception, SomeException)
import Control.Monad
import Control.Monad.Catch (finally, try, onException)
import Control.Monad.Conc.Class
import Control.Monad.STM.Class
#if MIN_VERSION_dejafu(0,3,0)
import Control.Concurrent.Classy.STM.TMVar (newEmptyTMVar, putTMVar, readTMVar)
#else
import Control.Concurrent.STM.CTMVar (CTMVar, newEmptyCTMVar, putCTMVar, readCTMVar)
#endif
#if !MIN_VERSION_base(4,8,0)
import Data.Traversable
#endif
#if !MIN_VERSION_dejafu(0,3,0)
type MVar m = CVar m
newEmptyMVar :: MonadConc m => m (MVar m a)
newEmptyMVar = newEmptyCVar
putMVar :: MonadConc m => MVar m a -> a -> m ()
putMVar = putCVar
takeMVar :: MonadConc m => MVar m a -> m a
takeMVar = takeCVar
type STM m = STMLike m
type TMVar m = CTMVar m
newEmptyTMVar :: MonadSTM stm => stm (TMVar stm a)
newEmptyTMVar = newEmptyCTMVar
putTMVar :: MonadSTM stm => TMVar stm a -> a -> stm ()
putTMVar = putCTMVar
readTMVar :: MonadSTM stm => TMVar stm a -> stm a
readTMVar = readCTMVar
#endif
data Async m a = Async
{ asyncThreadId :: !(ThreadId m)
, _asyncWait :: STM m (Either SomeException a)
}
instance MonadConc m => Eq (Async m a) where
Async t1 _ == Async t2 _ = t1 == t2
instance MonadConc m => Functor (Async m) where
fmap f (Async t w) = Async t $ fmap f <$> w
newtype Concurrently m a = Concurrently { runConcurrently :: m a }
instance MonadConc m => Functor (Concurrently m) where
fmap f (Concurrently a) = Concurrently $ f <$> a
instance MonadConc m => Applicative (Concurrently m) where
pure = Concurrently . return
Concurrently fs <*> Concurrently as =
Concurrently $ (\(f, a) -> f a) <$> concurrently fs as
instance MonadConc m => Alternative (Concurrently m) where
empty = Concurrently $ forever yield
Concurrently as <|> Concurrently bs =
Concurrently $ either id id <$> race as bs
async :: MonadConc m => m a -> m (Async m a)
async = asyncUsing fork
asyncOn :: MonadConc m => Int -> m a -> m (Async m a)
asyncOn = asyncUsing . forkOn
asyncWithUnmask :: MonadConc m => ((forall b. m b -> m b) -> m a) -> m (Async m a)
asyncWithUnmask = asyncUnmaskUsing forkWithUnmask
asyncOnWithUnmask :: MonadConc m => Int -> ((forall b. m b -> m b) -> m a) -> m (Async m a)
asyncOnWithUnmask i = asyncUnmaskUsing (forkOnWithUnmask i)
asyncUsing :: MonadConc m => (m () -> m (ThreadId m)) -> m a -> m (Async m a)
asyncUsing doFork action = do
var <- atomically newEmptyTMVar
tid <- mask $ \restore -> doFork $ try (restore action) >>= atomically . putTMVar var
return $ Async tid (readTMVar var)
asyncUnmaskUsing :: MonadConc m => (((forall b. m b -> m b) -> m ()) -> m (ThreadId m)) -> ((forall b. m b -> m b) -> m a) -> m (Async m a)
asyncUnmaskUsing doFork action = do
var <- atomically newEmptyTMVar
tid <- doFork $ \restore -> try (action restore) >>= atomically . putTMVar var
return $ Async tid (readTMVar var)
withAsync :: MonadConc m => m a -> (Async m a -> m b) -> m b
withAsync = withAsyncUsing fork
withAsyncOn :: MonadConc m => Int -> m a -> (Async m a -> m b) -> m b
withAsyncOn = withAsyncUsing . forkOn
withAsyncWithUnmask :: MonadConc m => ((forall x. m x -> m x) -> m a) -> (Async m a -> m b) -> m b
withAsyncWithUnmask = withAsyncUnmaskUsing forkWithUnmask
withAsyncOnWithUnmask :: MonadConc m => Int -> ((forall x. m x -> m x) -> m a) -> (Async m a -> m b) -> m b
withAsyncOnWithUnmask i = withAsyncUnmaskUsing (forkOnWithUnmask i)
withAsyncUsing :: MonadConc m => (m () -> m (ThreadId m)) -> m a -> (Async m a -> m b) -> m b
withAsyncUsing doFork action inner = do
var <- atomically newEmptyTMVar
tid <- mask $ \restore -> doFork $ try (restore action) >>= atomically . putTMVar var
let a = Async tid (readTMVar var)
res <- inner a `catchAll` (\e -> cancel a >> throw e)
cancel a
return res
withAsyncUnmaskUsing :: MonadConc m => (((forall x. m x -> m x) -> m ()) -> m (ThreadId m)) -> ((forall x. m x -> m x) -> m a) -> (Async m a -> m b) -> m b
withAsyncUnmaskUsing doFork action inner = do
var <- atomically newEmptyTMVar
tid <- doFork $ \restore -> try (action restore) >>= atomically . putTMVar var
let a = Async tid (readTMVar var)
res <- inner a `catchAll` (\e -> cancel a >> throw e)
cancel a
return res
catchAll :: MonadConc m => m a -> (SomeException -> m a) -> m a
catchAll = catch
wait :: MonadConc m => Async m a -> m a
wait = atomically . waitSTM
waitSTM :: MonadConc m => Async m a -> STM m a
waitSTM a = do
r <- waitCatchSTM a
either throwSTM return r
poll :: MonadConc m => Async m a -> m (Maybe (Either SomeException a))
poll = atomically . pollSTM
pollSTM :: MonadConc m => Async m a -> STM m (Maybe (Either SomeException a))
pollSTM (Async _ w) = (Just <$> w) `orElse` return Nothing
waitCatch :: MonadConc m => Async m a -> m (Either SomeException a)
waitCatch = tryAgain . atomically . waitCatchSTM where
tryAgain f = f `catch` \BlockedIndefinitelyOnSTM -> f
waitCatchSTM :: MonadConc m => Async m a -> STM m (Either SomeException a)
waitCatchSTM (Async _ w) = w
cancel :: MonadConc m => Async m a -> m ()
cancel (Async tid _) = throwTo tid ThreadKilled
cancelWith :: (MonadConc m, Exception e) => Async m a -> e -> m ()
cancelWith (Async tid _) = throwTo tid
waitAny :: MonadConc m => [Async m a] -> m (Async m a, a)
waitAny = atomically . waitAnySTM
waitAnySTM :: MonadConc m => [Async m a] -> STM m (Async m a, a)
waitAnySTM = foldr (orElse . (\a -> do r <- waitSTM a; return (a, r))) retry
waitAnyCatch :: MonadConc m => [Async m a] -> m (Async m a, Either SomeException a)
waitAnyCatch = atomically . waitAnyCatchSTM
waitAnyCatchSTM :: MonadConc m => [Async m a] -> STM m (Async m a, Either SomeException a)
waitAnyCatchSTM = foldr (orElse . (\a -> do r <- waitCatchSTM a; return (a, r))) retry
waitAnyCancel :: MonadConc m => [Async m a] -> m (Async m a, a)
waitAnyCancel asyncs = waitAny asyncs `finally` mapM_ cancel asyncs
waitAnyCatchCancel :: MonadConc m => [Async m a] -> m (Async m a, Either SomeException a)
waitAnyCatchCancel asyncs = waitAnyCatch asyncs `finally` mapM_ cancel asyncs
waitEither :: MonadConc m => Async m a -> Async m b -> m (Either a b)
waitEither left right = atomically $ waitEitherSTM left right
waitEitherSTM :: MonadConc m => Async m a -> Async m b -> STM m (Either a b)
waitEitherSTM left right =
(Left <$> waitSTM left) `orElse` (Right <$> waitSTM right)
waitEitherCatch :: MonadConc m => Async m a -> Async m b
-> m (Either (Either SomeException a) (Either SomeException b))
waitEitherCatch left right = atomically $ waitEitherCatchSTM left right
waitEitherCatchSTM :: MonadConc m => Async m a -> Async m b
-> STM m (Either (Either SomeException a) (Either SomeException b))
waitEitherCatchSTM left right =
(Left <$> waitCatchSTM left) `orElse` (Right <$> waitCatchSTM right)
waitEitherCancel :: MonadConc m => Async m a -> Async m b -> m (Either a b)
waitEitherCancel left right =
waitEither left right `finally` (cancel left >> cancel right)
waitEitherCatchCancel :: MonadConc m => Async m a -> Async m b
-> m (Either (Either SomeException a) (Either SomeException b))
waitEitherCatchCancel left right =
waitEitherCatch left right `finally` (cancel left >> cancel right)
waitEither_ :: MonadConc m => Async m a -> Async m b -> m ()
waitEither_ left right = atomically $ waitEitherSTM_ left right
waitEitherSTM_:: MonadConc m => Async m a -> Async m b -> STM m ()
waitEitherSTM_ left right = void $ waitEitherSTM left right
waitBoth :: MonadConc m => Async m a -> Async m b -> m (a, b)
waitBoth left right = atomically $ waitBothSTM left right
waitBothSTM :: MonadConc m => Async m a -> Async m b -> STM m (a, b)
waitBothSTM left right = do
a <- waitSTM left `orElse` (waitSTM right >> retry)
b <- waitSTM right
return (a, b)
link :: MonadConc m => Async m a -> m ()
link (Async _ w) = do
me <- myThreadId
void $ forkRepeat $ do
r <- atomically w
case r of
Left e -> throwTo me e
_ -> return ()
link2 :: MonadConc m => Async m a -> Async m b -> m ()
link2 left@(Async tl _) right@(Async tr _) =
void $ forkRepeat $ do
r <- waitEitherCatch left right
case r of
Left (Left e) -> throwTo tr e
Right (Left e) -> throwTo tl e
_ -> return ()
forkRepeat :: MonadConc m => m a -> m (ThreadId m)
forkRepeat action = mask $ \restore ->
let go = do
r <- (try :: MonadConc m => m a -> m (Either SomeException a)) $ restore action
case r of
Left _ -> go
_ -> return ()
in fork go
race :: MonadConc m => m a -> m b -> m (Either a b)
race left right = concurrently' left right collect where
collect m = do
e <- takeMVar m
case e of
Left ex -> throw ex
Right r -> return r
race_ :: MonadConc m => m a -> m b -> m ()
race_ a b = void $ race a b
concurrently :: MonadConc m => m a -> m b -> m (a, b)
concurrently left right = concurrently' left right (collect []) where
collect [Left a, Right b] _ = return (a, b)
collect [Right b, Left a] _ = return (a, b)
collect xs m = do
e <- takeMVar m
case e of
Left ex -> throw ex
Right r -> collect (r:xs) m
concurrently' :: MonadConc m => m a -> m b
-> (MVar m (Either SomeException (Either a b)) -> m r)
-> m r
concurrently' left right collect = do
done <- newEmptyMVar
mask $ \restore -> do
lid <- fork $ restore (left >>= putMVar done . Right . Left)
`catch` (putMVar done . Left)
rid <- fork $ restore (right >>= putMVar done . Right . Right)
`catch` (putMVar done . Left)
let stop = killThread rid >> killThread lid
r <- restore (collect done) `onException` stop
stop
return r
mapConcurrently :: (Traversable t, MonadConc m) => (a -> m b) -> t a -> m (t b)
mapConcurrently f = runConcurrently . traverse (Concurrently . f)
forConcurrently :: (Traversable t, MonadConc m) => t a -> (a -> m b)-> m (t b)
forConcurrently = flip mapConcurrently