module Control.Concurrent.AsyncManager where
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Concurrent.MVar
import Data.IORef
import Control.Monad
import Data.Foldable (toList)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as H
import Data.Monoid
import Data.Maybe
import Control.Applicative
import Control.Arrow (second)
import GHC.Conc
data AnyAsync = forall a. AnyAsync (Async a)
data AsyncManager = AsyncManager
{ childrenThreads :: MVar (HashMap ThreadId AnyAsync)
, childrenManagers :: MVar [AsyncManager]
}
newAsyncManager :: IO AsyncManager
newAsyncManager = AsyncManager <$> newMVar mempty <*> newMVar mempty
newChildManager :: AsyncManager -> IO AsyncManager
newChildManager (AsyncManager _ csRef) = modifyMVar csRef $ \cs -> do
child <- newAsyncManager
return (child : cs, child)
insert :: AsyncManager -> Async a -> IO ()
insert (AsyncManager ref _) as
= modifyMVar_ ref
$ return . (H.insert (asyncThreadId as) (AnyAsync as))
clear :: AsyncManager -> IO ()
clear (AsyncManager ref csRef) = do
modifyMVar_ ref $ \xs -> do
forM_ (toList xs) $ \(AnyAsync x) -> cancel x
return mempty
modifyMVar_ csRef $ \xs -> do
forM_ xs $ \x -> clear x
return mempty
count :: AsyncManager -> IO Int
count (AsyncManager ref csRef) = do
threadCount <- H.size <$> readMVar ref
managerCount <- length <$> readMVar csRef
return $ threadCount + managerCount
compact :: AsyncManager -> IO ()
compact (AsyncManager ref _) = modifyMVar_ ref $
fmap H.fromList . filterM ((\(AnyAsync x) -> isJust <$> poll x) . snd) . H.toList
cancelWithManager :: AsyncManager
-> Async a
-> IO ()
cancelWithManager (AsyncManager ref _) as = do
cancel as
modifyMVar_ ref $ return . H.delete (asyncThreadId as)
asyncWithManager :: AsyncManager
-> IO a
-> IO (Async a)
asyncWithManager am act = do
result <- async act
insert am result
return result
labelAsyncWithManager :: AsyncManager
-> String
-> IO a
-> IO (Async a)
labelAsyncWithManager am label act = do
as <- asyncWithManager am act
labelThread (asyncThreadId as) label
return as