{-# LANGUAGE CPP #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} module UnliftIO.Internals.Async where import Control.Applicative import Control.Concurrent (threadDelay, getNumCapabilities) import qualified Control.Concurrent as C import Control.Concurrent.Async (Async) import qualified Control.Concurrent.Async as A import Control.Concurrent.STM import Control.Exception (Exception, SomeException) import Control.Monad (forever, liftM, unless, void, (>=>)) import Control.Monad.IO.Unlift import Data.Foldable (for_, traverse_) import Data.Typeable (Typeable) import Data.IORef (IORef, readIORef, atomicWriteIORef, newIORef, atomicModifyIORef') import qualified UnliftIO.Exception as UE -- For the implementation of Conc below, we do not want any of the -- smart async exception handling logic from UnliftIO.Exception, since -- (eg) we're low-level enough to need to explicit be throwing async -- exceptions synchronously. import qualified Control.Exception as E import GHC.Generics (Generic) #if MIN_VERSION_base(4,9,0) import Data.Semigroup #else import Data.Monoid hiding (Alt) #endif import Data.Foldable (Foldable, toList) import Data.Traversable (Traversable, for, traverse) -- | Unlifted 'A.async'. -- -- @since 0.1.0.0 async :: MonadUnliftIO m => m a -> m (Async a) async m = withRunInIO $ \run -> A.async $ run m -- | Unlifted 'A.asyncBound'. -- -- @since 0.1.0.0 asyncBound :: MonadUnliftIO m => m a -> m (Async a) asyncBound m = withRunInIO $ \run -> A.asyncBound $ run m -- | Unlifted 'A.asyncOn'. -- -- @since 0.1.0.0 asyncOn :: MonadUnliftIO m => Int -> m a -> m (Async a) asyncOn i m = withRunInIO $ \run -> A.asyncOn i $ run m -- | Unlifted 'A.asyncWithUnmask'. -- -- @since 0.1.0.0 asyncWithUnmask :: MonadUnliftIO m => ((forall b. m b -> m b) -> m a) -> m (Async a) asyncWithUnmask m = withRunInIO $ \run -> A.asyncWithUnmask $ \unmask -> run $ m $ liftIO . unmask . run -- | Unlifted 'A.asyncOnWithUnmask'. -- -- @since 0.1.0.0 asyncOnWithUnmask :: MonadUnliftIO m => Int -> ((forall b. m b -> m b) -> m a) -> m (Async a) asyncOnWithUnmask i m = withRunInIO $ \run -> A.asyncOnWithUnmask i $ \unmask -> run $ m $ liftIO . unmask . run -- | Unlifted 'A.withAsync'. -- -- @since 0.1.0.0 withAsync :: MonadUnliftIO m => m a -> (Async a -> m b) -> m b withAsync a b = withRunInIO $ \run -> A.withAsync (run a) (run . b) -- | Unlifted 'A.withAsyncBound'. -- -- @since 0.1.0.0 withAsyncBound :: MonadUnliftIO m => m a -> (Async a -> m b) -> m b withAsyncBound a b = withRunInIO $ \run -> A.withAsyncBound (run a) (run . b) -- | Unlifted 'A.withAsyncOn'. -- -- @since 0.1.0.0 withAsyncOn :: MonadUnliftIO m => Int -> m a -> (Async a -> m b) -> m b withAsyncOn i a b = withRunInIO $ \run -> A.withAsyncOn i (run a) (run . b) -- | Unlifted 'A.withAsyncWithUnmask'. -- -- @since 0.1.0.0 withAsyncWithUnmask :: MonadUnliftIO m => ((forall c. m c -> m c) -> m a) -> (Async a -> m b) -> m b withAsyncWithUnmask a b = withRunInIO $ \run -> A.withAsyncWithUnmask (\unmask -> run $ a $ liftIO . unmask . run) (run . b) -- | Unlifted 'A.withAsyncOnWithMask'. -- -- @since 0.1.0.0 withAsyncOnWithUnmask :: MonadUnliftIO m => Int -> ((forall c. m c -> m c) -> m a) -> (Async a -> m b) -> m b withAsyncOnWithUnmask i a b = withRunInIO $ \run -> A.withAsyncOnWithUnmask i (\unmask -> run $ a $ liftIO . unmask . run) (run . b) -- | Lifted 'A.wait'. -- -- @since 0.1.0.0 wait :: MonadIO m => Async a -> m a wait = liftIO . A.wait -- | Lifted 'A.poll'. -- -- @since 0.1.0.0 poll :: MonadIO m => Async a -> m (Maybe (Either SomeException a)) poll = liftIO . A.poll -- | Lifted 'A.waitCatch'. -- -- @since 0.1.0.0 waitCatch :: MonadIO m => Async a -> m (Either SomeException a) waitCatch = liftIO . A.waitCatch -- | Lifted 'A.cancel'. -- -- @since 0.1.0.0 cancel :: MonadIO m => Async a -> m () cancel = liftIO . A.cancel -- | Lifted 'A.uninterruptibleCancel'. -- -- @since 0.1.0.0 uninterruptibleCancel :: MonadIO m => Async a -> m () uninterruptibleCancel = liftIO . A.uninterruptibleCancel -- | Lifted 'A.cancelWith'. Additionally uses 'UE.toAsyncException' to -- ensure async exception safety. -- -- @since 0.1.0.0 cancelWith :: (Exception e, MonadIO m) => Async a -> e -> m () cancelWith a e = liftIO (A.cancelWith a (UE.toAsyncException e)) -- | Lifted 'A.waitAny'. -- -- @since 0.1.0.0 waitAny :: MonadIO m => [Async a] -> m (Async a, a) waitAny = liftIO . A.waitAny -- | Lifted 'A.waitAnyCatch'. -- -- @since 0.1.0.0 waitAnyCatch :: MonadIO m => [Async a] -> m (Async a, Either SomeException a) waitAnyCatch = liftIO . A.waitAnyCatch -- | Lifted 'A.waitAnyCancel'. -- -- @since 0.1.0.0 waitAnyCancel :: MonadIO m => [Async a] -> m (Async a, a) waitAnyCancel = liftIO . A.waitAnyCancel -- | Lifted 'A.waitAnyCatchCancel'. -- -- @since 0.1.0.0 waitAnyCatchCancel :: MonadIO m => [Async a] -> m (Async a, Either SomeException a) waitAnyCatchCancel = liftIO . A.waitAnyCatchCancel -- | Lifted 'A.waitEither'. -- -- @since 0.1.0.0 waitEither :: MonadIO m => Async a -> Async b -> m (Either a b) waitEither a b = liftIO (A.waitEither a b) -- | Lifted 'A.waitEitherCatch'. -- -- @since 0.1.0.0 waitEitherCatch :: MonadIO m => Async a -> Async b -> m (Either (Either SomeException a) (Either SomeException b)) waitEitherCatch a b = liftIO (A.waitEitherCatch a b) -- | Lifted 'A.waitEitherCancel'. -- -- @since 0.1.0.0 waitEitherCancel :: MonadIO m => Async a -> Async b -> m (Either a b) waitEitherCancel a b = liftIO (A.waitEitherCancel a b) -- | Lifted 'A.waitEitherCatchCancel'. -- -- @since 0.1.0.0 waitEitherCatchCancel :: MonadIO m => Async a -> Async b -> m (Either (Either SomeException a) (Either SomeException b)) waitEitherCatchCancel a b = liftIO (A.waitEitherCatchCancel a b) -- | Lifted 'A.waitEither_'. -- -- @since 0.1.0.0 waitEither_ :: MonadIO m => Async a -> Async b -> m () waitEither_ a b = liftIO (A.waitEither_ a b) -- | Lifted 'A.waitBoth'. -- -- @since 0.1.0.0 waitBoth :: MonadIO m => Async a -> Async b -> m (a, b) waitBoth a b = liftIO (A.waitBoth a b) -- | Lifted 'A.link'. -- -- @since 0.1.0.0 link :: MonadIO m => Async a -> m () link = liftIO . A.link -- | Lifted 'A.link2'. -- -- @since 0.1.0.0 link2 :: MonadIO m => Async a -> Async b -> m () link2 a b = liftIO (A.link2 a b) -- | Unlifted 'A.race'. -- -- @since 0.1.0.0 race :: MonadUnliftIO m => m a -> m b -> m (Either a b) race a b = withRunInIO $ \run -> A.race (run a) (run b) -- | Unlifted 'A.race_'. -- -- @since 0.1.0.0 race_ :: MonadUnliftIO m => m a -> m b -> m () race_ a b = withRunInIO $ \run -> A.race_ (run a) (run b) -- | Unlifted 'A.concurrently'. -- -- @since 0.1.0.0 concurrently :: MonadUnliftIO m => m a -> m b -> m (a, b) concurrently a b = withRunInIO $ \run -> A.concurrently (run a) (run b) -- | Unlifted 'A.concurrently_'. -- -- @since 0.1.0.0 concurrently_ :: MonadUnliftIO m => m a -> m b -> m () concurrently_ a b = withRunInIO $ \run -> A.concurrently_ (run a) (run b) -- | Unlifted 'A.Concurrently'. -- -- @since 0.1.0.0 newtype Concurrently m a = Concurrently { runConcurrently :: m a } -- | @since 0.1.0.0 instance Monad m => Functor (Concurrently m) where fmap f (Concurrently a) = Concurrently $ liftM f a -- | @since 0.1.0.0 instance MonadUnliftIO m => Applicative (Concurrently m) where pure = Concurrently . return Concurrently fs <*> Concurrently as = Concurrently $ liftM (\(f, a) -> f a) (concurrently fs as) -- | Composing two unlifted 'Concurrently' values using 'Alternative' is the -- equivalent to using a 'race' combinator, the asynchrounous sub-routine that -- returns a value first is the one that gets it's value returned, the slowest -- sub-routine gets cancelled and it's thread is killed. -- -- @since 0.1.0.0 instance MonadUnliftIO m => Alternative (Concurrently m) where -- | Care should be taken when using the 'empty' value of the 'Alternative' -- interface, as it will create a thread that delays for a long period of -- time. The reason behind this implementation is that any other computation -- will finish first than the 'empty' value. This implementation is less than -- ideal, and in a perfect world, we would have a typeclass family that allows -- '(<|>)' but not 'empty'. -- -- @since 0.1.0.0 empty = Concurrently $ liftIO (forever (threadDelay maxBound)) Concurrently as <|> Concurrently bs = Concurrently $ liftM (either id id) (race as bs) -------------------------------------------------------------------------------- #if MIN_VERSION_base(4,9,0) -------------------------------------------------------------------------------- -- | Only defined by @async@ for @base >= 4.9@. -- -- @since 0.1.0.0 instance (MonadUnliftIO m, Semigroup a) => Semigroup (Concurrently m a) where (<>) = liftA2 (<>) -- | @since 0.1.0.0 instance (Semigroup a, Monoid a, MonadUnliftIO m) => Monoid (Concurrently m a) where mempty = pure mempty mappend = (<>) -------------------------------------------------------------------------------- #else -------------------------------------------------------------------------------- -- | @since 0.1.0.0 instance (Monoid a, MonadUnliftIO m) => Monoid (Concurrently m a) where mempty = pure mempty mappend = liftA2 mappend -------------------------------------------------------------------------------- #endif -------------------------------------------------------------------------------- -- | Similar to 'mapConcurrently' but with arguments flipped -- -- @since 0.1.0.0 forConcurrently :: MonadUnliftIO m => Traversable t => t a -> (a -> m b) -> m (t b) forConcurrently = flip mapConcurrently {-# INLINE forConcurrently #-} -- | Similar to 'mapConcurrently_' but with arguments flipped -- -- @since 0.1.0.0 forConcurrently_ :: MonadUnliftIO m => Foldable f => f a -> (a -> m b) -> m () forConcurrently_ = flip mapConcurrently_ {-# INLINE forConcurrently_ #-} -- | Unlifted 'A.replicateConcurrently'. -- -- @since 0.1.0.0 #if MIN_VERSION_base(4,7,0) #else replicateConcurrently :: (Functor m, MonadUnliftIO m) => Int -> m a -> m [a] #endif replicateConcurrently cnt m = case compare cnt 1 of LT -> pure [] EQ -> (:[]) <$> m GT -> mapConcurrently id (replicate cnt m) {-# INLINE replicateConcurrently #-} -- | Unlifted 'A.replicateConcurrently_'. -- -- @since 0.1.0.0 #if MIN_VERSION_base(4,7,0) replicateConcurrently_ :: (Applicative m, MonadUnliftIO m) => Int -> m a -> m () #else replicateConcurrently_ :: (MonadUnliftIO m) => Int -> m a -> m () #endif replicateConcurrently_ cnt m = case compare cnt 1 of LT -> pure () EQ -> void m GT -> mapConcurrently_ id (replicate cnt m) {-# INLINE replicateConcurrently_ #-} -- Conc uses GHC features that are not supported in versions <= to ghc-7.10 -- so we are going to export/use it when we have a higher version only. -------------------------------------------------------------------------------- #if MIN_VERSION_base(4,8,0) -------------------------------------------------------------------------------- -- | Executes a 'Traversable' container of items concurrently, it uses the 'Flat' -- type internally. -- -- @since 0.1.0.0 mapConcurrently :: MonadUnliftIO m => Traversable t => (a -> m b) -> t a -> m (t b) mapConcurrently f t = withRunInIO $ \run -> runFlat $ traverse (FlatApp . FlatAction . run . f) t {-# INLINE mapConcurrently #-} -- | Executes a 'Traversable' container of items concurrently, it uses the 'Flat' -- type internally. This function ignores the results. -- -- @since 0.1.0.0 mapConcurrently_ :: MonadUnliftIO m => Foldable f => (a -> m b) -> f a -> m () mapConcurrently_ f t = withRunInIO $ \run -> runFlat $ traverse_ (FlatApp . FlatAction . run . f) t {-# INLINE mapConcurrently_ #-} -- More efficient Conc implementation -- | A more efficient alternative to 'Concurrently', which reduces the -- number of threads that need to be forked. For more information, see -- [this blog post](https://www.fpcomplete.com/blog/transformations-on-applicative-concurrent-computations/). -- This is provided as a separate type to @Concurrently@ as it has a slightly different API. -- -- Use the 'conc' function to construct values of type 'Conc', and -- 'runConc' to execute the composed actions. You can use the -- @Applicative@ instance to run different actions and wait for all of -- them to complete, or the @Alternative@ instance to wait for the -- first thread to complete. -- -- In the event of a runtime exception thrown by any of the children -- threads, or an asynchronous exception received in the parent -- thread, all threads will be killed with an 'A.AsyncCancelled' -- exception and the original exception rethrown. If multiple -- exceptions are generated by different threads, there are no -- guarantees on which exception will end up getting rethrown. -- -- For many common use cases, you may prefer using helper functions in -- this module like 'mapConcurrently'. -- -- There are some intentional differences in behavior to -- @Concurrently@: -- -- * Children threads are always launched in an unmasked state, not -- the inherited state of the parent thread. -- -- Note that it is a programmer error to use the @Alternative@ -- instance in such a way that there are no alternatives to an empty, -- e.g. @runConc (empty <|> empty)@. In such a case, a 'ConcException' -- will be thrown. If there was an @Alternative@ in the standard -- libraries without @empty@, this library would use it instead. -- -- @since 0.2.9.0 data Conc m a where Action :: m a -> Conc m a Apply :: Conc m (v -> a) -> Conc m v -> Conc m a LiftA2 :: (x -> y -> a) -> Conc m x -> Conc m y -> Conc m a -- Just an optimization to avoid spawning extra threads Pure :: a -> Conc m a -- I thought there would be an optimization available from having a -- data constructor that explicit doesn't care about the first -- result. Turns out it doesn't help much: we still need to keep a -- TMVar below to know when the thread completes. -- -- Then :: Conc m a -> Conc m b -> Conc m b Alt :: Conc m a -> Conc m a -> Conc m a Empty :: Conc m a deriving instance Functor m => Functor (Conc m) -- fmap f (Action routine) = Action (fmap f routine) -- fmap f (LiftA2 g x y) = LiftA2 (fmap f g) x y -- fmap f (Pure val) = Pure (f val) -- fmap f (Alt a b) = Alt (fmap f a) (fmap f b) -- fmap f Empty = Empty -- | Construct a value of type 'Conc' from an action. Compose these -- values using the typeclass instances (most commonly 'Applicative' -- and 'Alternative') and then run with 'runConc'. -- -- @since 0.2.9.0 conc :: m a -> Conc m a conc = Action {-# INLINE conc #-} -- | Run a 'Conc' value on multiple threads. -- -- @since 0.2.9.0 runConc :: MonadUnliftIO m => Conc m a -> m a runConc = flatten >=> (liftIO . runFlat) {-# INLINE runConc #-} -- | @since 0.2.9.0 instance MonadUnliftIO m => Applicative (Conc m) where pure = Pure {-# INLINE pure #-} -- | Following is an example of how an 'Applicative' expands to a Tree -- -- @@@ -- downloadA :: IO String -- downloadB :: IO String -- -- (f <$> conc downloadA <*> conc downloadB <*> pure 123) -- -- (((f <$> a) <*> b) <*> c)) -- (1) (2) (3) -- -- (1) -- Action (fmap f downloadA) -- (2) -- Apply (Action (fmap f downloadA)) (Action downloadB) -- (3) -- Apply (Apply (Action (fmap f downloadA)) (Action downloadB)) -- (Pure 123) -- @@@ -- (<*>) = Apply {-# INLINE (<*>) #-} -- See comment above on Then -- (*>) = Then #if MIN_VERSION_base(4,11,0) liftA2 = LiftA2 {-# INLINE liftA2 #-} #endif a *> b = LiftA2 (\_ x -> x) a b {-# INLINE (*>) #-} -- | @since 0.2.9.0 instance MonadUnliftIO m => Alternative (Conc m) where empty = Empty -- this is so ugly, we don't actually want to provide it! {-# INLINE empty #-} (<|>) = Alt {-# INLINE (<|>) #-} #if MIN_VERSION_base(4, 11, 0) -- | @since 0.2.9.0 instance (MonadUnliftIO m, Semigroup a) => Semigroup (Conc m a) where (<>) = liftA2 (<>) {-# INLINE (<>) #-} #endif -- | @since 0.2.9.0 instance (Monoid a, MonadUnliftIO m) => Monoid (Conc m a) where mempty = pure mempty {-# INLINE mempty #-} mappend = liftA2 mappend {-# INLINE mappend #-} ------------------------- -- Conc implementation -- ------------------------- -- Data types for flattening out the original @Conc@ into a simplified -- view. Goals: -- -- * We want to get rid of the Empty data constructor. We don't want -- it anyway, it's only there because of the Alternative typeclass. -- -- * We want to ensure that there is no nesting of Alt data -- constructors. There is a bookkeeping overhead to each time we -- need to track raced threads, and we want to minimize that -- bookkeeping. -- -- * We want to ensure that, when racing, we're always racing at least -- two threads. -- -- * We want to simplify down to IO. -- | Flattened structure, either Applicative or Alternative data Flat a = FlatApp !(FlatApp a) -- | Flattened Alternative. Has at least 2 entries, which must be -- FlatApp (no nesting of FlatAlts). | FlatAlt !(FlatApp a) !(FlatApp a) ![FlatApp a] deriving instance Functor Flat -- fmap f (FlatApp a) = -- FlatApp (fmap f a) -- fmap f (FlatAlt (FlatApp a) (FlatApp b) xs) = -- FlatAlt (FlatApp (fmap f a)) (FlatApp (fmap f b)) (map (fmap f) xs) instance Applicative Flat where pure = FlatApp . pure (<*>) f a = FlatApp (FlatLiftA2 id f a) #if MIN_VERSION_base(4,11,0) liftA2 f a b = FlatApp (FlatLiftA2 f a b) #endif -- | Flattened Applicative. No Alternative stuff directly in here, but may be in -- the children. Notice this type doesn't have a type parameter for monadic -- contexts, it hardwires the base monad to IO given concurrency relies -- eventually on that. -- -- @since 0.2.9.0 data FlatApp a where FlatPure :: a -> FlatApp a FlatAction :: IO a -> FlatApp a FlatApply :: Flat (v -> a) -> Flat v -> FlatApp a FlatLiftA2 :: (x -> y -> a) -> Flat x -> Flat y -> FlatApp a deriving instance Functor FlatApp instance Applicative FlatApp where pure = FlatPure (<*>) mf ma = FlatApply (FlatApp mf) (FlatApp ma) #if MIN_VERSION_base(4,11,0) liftA2 f a b = FlatLiftA2 f (FlatApp a) (FlatApp b) #endif -- | Things that can go wrong in the structure of a 'Conc'. These are -- /programmer errors/. -- -- @since 0.2.9.0 data ConcException = EmptyWithNoAlternative deriving (Generic, Show, Typeable, Eq, Ord) instance E.Exception ConcException -- | Simple difference list, for nicer types below type DList a = [a] -> [a] dlistConcat :: DList a -> DList a -> DList a dlistConcat = (.) {-# INLINE dlistConcat #-} dlistCons :: a -> DList a -> DList a dlistCons a as = dlistSingleton a `dlistConcat` as {-# INLINE dlistCons #-} dlistConcatAll :: [DList a] -> DList a dlistConcatAll = foldr (.) id {-# INLINE dlistConcatAll #-} dlistToList :: DList a -> [a] dlistToList = ($ []) {-# INLINE dlistToList #-} dlistSingleton :: a -> DList a dlistSingleton a = (a:) {-# INLINE dlistSingleton #-} dlistEmpty :: DList a dlistEmpty = id {-# INLINE dlistEmpty #-} -- | Turn a 'Conc' into a 'Flat'. Note that thanks to the ugliness of -- 'empty', this may fail, e.g. @flatten Empty@. -- -- @since 0.2.9.0 flatten :: forall m a. MonadUnliftIO m => Conc m a -> m (Flat a) flatten c0 = withRunInIO $ \run -> do -- why not app? let both :: forall k. Conc m k -> IO (Flat k) both Empty = E.throwIO EmptyWithNoAlternative both (Action m) = pure $ FlatApp $ FlatAction $ run m both (Apply cf ca) = do f <- both cf a <- both ca pure $ FlatApp $ FlatApply f a both (LiftA2 f ca cb) = do a <- both ca b <- both cb pure $ FlatApp $ FlatLiftA2 f a b both (Alt ca cb) = do a <- alt ca b <- alt cb case dlistToList (a `dlistConcat` b) of [] -> E.throwIO EmptyWithNoAlternative [x] -> pure $ FlatApp x x:y:z -> pure $ FlatAlt x y z both (Pure a) = pure $ FlatApp $ FlatPure a -- Returns a difference list for cheaper concatenation alt :: forall k. Conc m k -> IO (DList (FlatApp k)) alt Empty = pure dlistEmpty alt (Apply cf ca) = do f <- both cf a <- both ca pure (dlistSingleton $ FlatApply f a) alt (Alt ca cb) = do a <- alt ca b <- alt cb pure $ a `dlistConcat` b alt (Action m) = pure (dlistSingleton $ FlatAction (run m)) alt (LiftA2 f ca cb) = do a <- both ca b <- both cb pure (dlistSingleton $ FlatLiftA2 f a b) alt (Pure a) = pure (dlistSingleton $ FlatPure a) both c0 -- | Run a @Flat a@ on multiple threads. runFlat :: Flat a -> IO a -- Silly, simple optimizations runFlat (FlatApp (FlatAction io)) = io runFlat (FlatApp (FlatPure x)) = pure x -- Start off with all exceptions masked so we can install proper cleanup. runFlat f0 = E.uninterruptibleMask $ \restore -> do -- How many threads have been spawned and finished their task? We need to -- ensure we kill all child threads and wait for them to die. resultCountVar <- newTVarIO 0 -- Forks off as many threads as necessary to run the given Flat a, -- and returns: -- -- + An STM action that will block until completion and return the -- result. -- -- + The IDs of all forked threads. These need to be tracked so they -- can be killed (either when an exception is thrown, or when one -- of the alternatives completes first). -- -- It would be nice to have the returned STM action return an Either -- and keep the SomeException values somewhat explicit, but in all -- my testing this absolutely kills performance. Instead, we're -- going to use a hack of providing a TMVar to fill up with a -- SomeException when things fail. -- -- TODO: Investigate why performance degradation on Either let go :: forall a. TMVar E.SomeException -> Flat a -> IO (STM a, DList C.ThreadId) go _excVar (FlatApp (FlatPure x)) = pure (pure x, dlistEmpty) go excVar (FlatApp (FlatAction io)) = do resVar <- newEmptyTMVarIO tid <- C.forkIOWithUnmask $ \restore1 -> do res <- E.try $ restore1 io atomically $ do modifyTVar' resultCountVar (+ 1) case res of Left e -> void $ tryPutTMVar excVar e Right x -> putTMVar resVar x pure (readTMVar resVar, dlistSingleton tid) go excVar (FlatApp (FlatApply cf ca)) = do (f, tidsf) <- go excVar cf (a, tidsa) <- go excVar ca pure (f <*> a, tidsf `dlistConcat` tidsa) go excVar (FlatApp (FlatLiftA2 f a b)) = do (a', tidsa) <- go excVar a (b', tidsb) <- go excVar b pure (liftA2 f a' b', tidsa `dlistConcat` tidsb) go excVar0 (FlatAlt x y z) = do -- As soon as one of the children finishes, we need to kill the siblings, -- we're going to create our own excVar here to pass to the children, so -- we can prevent the ThreadKilled exceptions we throw to the children -- here from propagating and taking down the whole system. excVar <- newEmptyTMVarIO resVar <- newEmptyTMVarIO pairs <- traverse (go excVar . FlatApp) (x:y:z) let (blockers, workerTids) = unzip pairs -- Fork a helper thread to wait for the first child to -- complete, or for one of them to die with an exception so we -- can propagate it to excVar0. helperTid <- C.forkIOWithUnmask $ \restore1 -> do eres <- E.try $ restore1 $ atomically $ foldr (\blocker rest -> (Right <$> blocker) <|> rest) (Left <$> readTMVar excVar) blockers atomically $ do modifyTVar' resultCountVar (+ 1) case eres of -- NOTE: The child threads are spawned from @traverse go@ call above, they -- are _not_ children of this helper thread, and helper thread doesn't throw -- synchronous exceptions, so, any exception that the try above would catch -- must be an async exception. -- We were killed by an async exception, do nothing. Left (_ :: E.SomeException) -> pure () -- Child thread died, propagate it Right (Left e) -> void $ tryPutTMVar excVar0 e -- Successful result from one of the children Right (Right res) -> putTMVar resVar res -- And kill all of the threads for_ workerTids $ \tids' -> -- NOTE: Replacing A.AsyncCancelled with KillThread as the -- 'A.AsyncCancelled' constructor is not exported in older versions -- of the async package -- for_ (tids' []) $ \workerTid -> E.throwTo workerTid A.AsyncCancelled for_ (dlistToList tids') $ \workerTid -> C.killThread workerTid pure ( readTMVar resVar , helperTid `dlistCons` dlistConcatAll workerTids ) excVar <- newEmptyTMVarIO (getRes, tids0) <- go excVar f0 let tids = dlistToList tids0 tidCount = length tids allDone count = if count > tidCount then error ("allDone: count (" <> show count <> ") should never be greater than tidCount (" <> show tidCount <> ")") else count == tidCount -- Automatically retry if we get killed by a -- BlockedIndefinitelyOnSTM. For more information, see: -- -- + https:\/\/github.com\/simonmar\/async\/issues\/14 -- + https:\/\/github.com\/simonmar\/async\/pull\/15 -- let autoRetry action = action `E.catch` \E.BlockedIndefinitelyOnSTM -> autoRetry action -- Restore the original masking state while blocking and catch -- exceptions to allow the parent thread to be killed early. res <- E.try $ restore $ autoRetry $ atomically $ (Left <$> readTMVar excVar) <|> (Right <$> getRes) count0 <- atomically $ readTVar resultCountVar unless (allDone count0) $ do -- Kill all of the threads -- NOTE: Replacing A.AsyncCancelled with KillThread as the -- 'A.AsyncCancelled' constructor is not exported in older versions -- of the async package -- for_ tids $ \tid -> E.throwTo tid A.AsyncCancelled for_ tids $ \tid -> C.killThread tid -- Wait for all of the threads to die. We're going to restore the original -- masking state here, just in case there's a bug in the cleanup code of a -- child thread, so that we can be killed by an async exception. We decided -- this is a better behavior than hanging indefinitely and wait for a SIGKILL. restore $ atomically $ do count <- readTVar resultCountVar -- retries until resultCountVar has increased to the threadId count returned by go check $ allDone count -- Return the result or throw an exception. Yes, we could use -- either or join, but explicit pattern matching is nicer here. case res of -- Parent thread was killed with an async exception Left e -> E.throwIO (e :: E.SomeException) -- Some child thread died Right (Left e) -> E.throwIO e -- Everything worked! Right (Right x) -> pure x {-# INLINEABLE runFlat #-} -------------------------------------------------------------------------------- #else -------------------------------------------------------------------------------- -- | Unlifted 'A.mapConcurrently'. -- -- @since 0.1.0.0 mapConcurrently :: MonadUnliftIO m => Traversable t => (a -> m b) -> t a -> m (t b) mapConcurrently f t = withRunInIO $ \run -> A.mapConcurrently (run . f) t {-# INLINE mapConcurrently #-} -- | Unlifted 'A.mapConcurrently_'. -- -- @since 0.1.0.0 mapConcurrently_ :: MonadUnliftIO m => Foldable f => (a -> m b) -> f a -> m () mapConcurrently_ f t = withRunInIO $ \run -> A.mapConcurrently_ (run . f) t {-# INLINE mapConcurrently_ #-} -------------------------------------------------------------------------------- #endif -------------------------------------------------------------------------------- -- | Like 'mapConcurrently' from async, but instead of one thread per -- element, it does pooling from a set of threads. This is useful in -- scenarios where resource consumption is bounded and for use cases -- where too many concurrent tasks aren't allowed. -- -- === __Example usage__ -- -- @ -- import Say -- -- action :: Int -> IO Int -- action n = do -- tid <- myThreadId -- sayString $ show tid -- threadDelay (2 * 10^6) -- 2 seconds -- return n -- -- main :: IO () -- main = do -- yx \<- pooledMapConcurrentlyN 5 (\\x -\> action x) [1..5] -- print yx -- @ -- -- On executing you can see that five threads have been spawned: -- -- @ -- \$ ./pool -- ThreadId 36 -- ThreadId 38 -- ThreadId 40 -- ThreadId 42 -- ThreadId 44 -- [1,2,3,4,5] -- @ -- -- -- Let's modify the above program such that there are less threads -- than the number of items in the list: -- -- @ -- import Say -- -- action :: Int -> IO Int -- action n = do -- tid <- myThreadId -- sayString $ show tid -- threadDelay (2 * 10^6) -- 2 seconds -- return n -- -- main :: IO () -- main = do -- yx \<- pooledMapConcurrentlyN 3 (\\x -\> action x) [1..5] -- print yx -- @ -- On executing you can see that only three threads are active totally: -- -- @ -- \$ ./pool -- ThreadId 35 -- ThreadId 37 -- ThreadId 39 -- ThreadId 35 -- ThreadId 39 -- [1,2,3,4,5] -- @ -- -- @since 0.2.10 pooledMapConcurrentlyN :: (MonadUnliftIO m, Traversable t) => Int -- ^ Max. number of threads. Should not be less than 1. -> (a -> m b) -> t a -> m (t b) pooledMapConcurrentlyN numProcs f xs = withRunInIO $ \run -> pooledMapConcurrentlyIO numProcs (run . f) xs -- | Similar to 'pooledMapConcurrentlyN' but with number of threads -- set from 'getNumCapabilities'. Usually this is useful for CPU bound -- tasks. -- -- @since 0.2.10 pooledMapConcurrently :: (MonadUnliftIO m, Traversable t) => (a -> m b) -> t a -> m (t b) pooledMapConcurrently f xs = do withRunInIO $ \run -> do numProcs <- getNumCapabilities pooledMapConcurrentlyIO numProcs (run . f) xs -- | Similar to 'pooledMapConcurrentlyN' but with flipped arguments. -- -- @since 0.2.10 pooledForConcurrentlyN :: (MonadUnliftIO m, Traversable t) => Int -- ^ Max. number of threads. Should not be less than 1. -> t a -> (a -> m b) -> m (t b) pooledForConcurrentlyN numProcs = flip (pooledMapConcurrentlyN numProcs) -- | Similar to 'pooledForConcurrentlyN' but with number of threads -- set from 'getNumCapabilities'. Usually this is useful for CPU bound -- tasks. -- -- @since 0.2.10 pooledForConcurrently :: (MonadUnliftIO m, Traversable t) => t a -> (a -> m b) -> m (t b) pooledForConcurrently = flip pooledMapConcurrently pooledMapConcurrentlyIO :: Traversable t => Int -> (a -> IO b) -> t a -> IO (t b) pooledMapConcurrentlyIO numProcs f xs = if (numProcs < 1) then error "pooledMapconcurrentlyIO: number of threads < 1" else pooledMapConcurrentlyIO' numProcs f xs -- | Performs the actual pooling for the tasks. This function will -- continue execution until the task queue becomes empty. When one of -- the pooled thread finishes it's task, it will pickup the next task -- from the queue if an job is available. pooledConcurrently :: Int -- ^ Max. number of threads. Should not be less than 1. -> IORef [a] -- ^ Task queue. These are required as inputs for the jobs. -> (a -> IO ()) -- ^ The task which will be run concurrently (but -- will be pooled properly). -> IO () pooledConcurrently numProcs jobsVar f = do replicateConcurrently_ numProcs $ do let loop = do mbJob :: Maybe a <- atomicModifyIORef' jobsVar $ \x -> case x of [] -> ([], Nothing) var : vars -> (vars, Just var) case mbJob of Nothing -> return () Just x -> do f x loop in loop pooledMapConcurrentlyIO' :: Traversable t => Int -- ^ Max. number of threads. Should not be less than 1. -> (a -> IO b) -> t a -> IO (t b) pooledMapConcurrentlyIO' numProcs f xs = do -- prepare one IORef per result... jobs :: t (a, IORef b) <- for xs (\x -> (x, ) <$> newIORef (error "pooledMapConcurrentlyIO': empty IORef")) -- ...put all the inputs in a queue.. jobsVar :: IORef [(a, IORef b)] <- newIORef (toList jobs) -- ...run `numProcs` threads in parallel, each -- of them consuming the queue and filling in -- the respective IORefs. pooledConcurrently numProcs jobsVar $ \ (x, outRef) -> f x >>= atomicWriteIORef outRef -- Read all the IORefs for jobs (\(_, outputRef) -> readIORef outputRef) pooledMapConcurrentlyIO_' :: Foldable t => Int -> (a -> IO ()) -> t a -> IO () pooledMapConcurrentlyIO_' numProcs f jobs = do jobsVar :: IORef [a] <- newIORef (toList jobs) pooledConcurrently numProcs jobsVar f pooledMapConcurrentlyIO_ :: Foldable t => Int -> (a -> IO b) -> t a -> IO () pooledMapConcurrentlyIO_ numProcs f xs = if (numProcs < 1) then error "pooledMapconcurrentlyIO_: number of threads < 1" else pooledMapConcurrentlyIO_' numProcs (\x -> f x >> return ()) xs -- | Like 'pooledMapConcurrentlyN' but with the return value -- discarded. -- -- @since 0.2.10 pooledMapConcurrentlyN_ :: (MonadUnliftIO m, Foldable f) => Int -- ^ Max. number of threads. Should not be less than 1. -> (a -> m b) -> f a -> m () pooledMapConcurrentlyN_ numProcs f t = withRunInIO $ \run -> pooledMapConcurrentlyIO_ numProcs (run . f) t -- | Like 'pooledMapConcurrently' but with the return value discarded. -- -- @since 0.2.10 pooledMapConcurrently_ :: (MonadUnliftIO m, Foldable f) => (a -> m b) -> f a -> m () pooledMapConcurrently_ f t = withRunInIO $ \run -> do numProcs <- getNumCapabilities pooledMapConcurrentlyIO_ numProcs (run . f) t -- | Like 'pooledMapConcurrently_' but with flipped arguments. -- -- @since 0.2.10 pooledForConcurrently_ :: (MonadUnliftIO m, Foldable f) => f a -> (a -> m b) -> m () pooledForConcurrently_ = flip pooledMapConcurrently_ -- | Like 'pooledMapConcurrentlyN_' but with flipped arguments. -- -- @since 0.2.10 pooledForConcurrentlyN_ :: (MonadUnliftIO m, Foldable t) => Int -- ^ Max. number of threads. Should not be less than 1. -> t a -> (a -> m b) -> m () pooledForConcurrentlyN_ numProcs = flip (pooledMapConcurrentlyN_ numProcs) -- | Pooled version of 'replicateConcurrently'. Performs the action in -- the pooled threads. -- -- @since 0.2.10 pooledReplicateConcurrentlyN :: (MonadUnliftIO m) => Int -- ^ Max. number of threads. Should not be less than 1. -> Int -- ^ Number of times to perform the action. -> m a -> m [a] pooledReplicateConcurrentlyN numProcs cnt task = if cnt < 1 then return [] else pooledMapConcurrentlyN numProcs (\_ -> task) [1..cnt] -- | Similar to 'pooledReplicateConcurrentlyN' but with number of -- threads set from 'getNumCapabilities'. Usually this is useful for -- CPU bound tasks. -- -- @since 0.2.10 pooledReplicateConcurrently :: (MonadUnliftIO m) => Int -- ^ Number of times to perform the action. -> m a -> m [a] pooledReplicateConcurrently cnt task = if cnt < 1 then return [] else pooledMapConcurrently (\_ -> task) [1..cnt] -- | Pooled version of 'replicateConcurrently_'. Performs the action in -- the pooled threads. -- -- @since 0.2.10 pooledReplicateConcurrentlyN_ :: (MonadUnliftIO m) => Int -- ^ Max. number of threads. Should not be less than 1. -> Int -- ^ Number of times to perform the action. -> m a -> m () pooledReplicateConcurrentlyN_ numProcs cnt task = if cnt < 1 then return () else pooledMapConcurrentlyN_ numProcs (\_ -> task) [1..cnt] -- | Similar to 'pooledReplicateConcurrently_' but with number of -- threads set from 'getNumCapabilities'. Usually this is useful for -- CPU bound tasks. -- -- @since 0.2.10 pooledReplicateConcurrently_ :: (MonadUnliftIO m) => Int -- ^ Number of times to perform the action. -> m a -> m () pooledReplicateConcurrently_ cnt task = if cnt < 1 then return () else pooledMapConcurrently_ (\_ -> task) [1..cnt]