module Control.Concurrent.HierarchyInternal where
import Control.Concurrent (ThreadId, forkIOWithUnmask,
killThread, myThreadId)
import Control.Concurrent.MVar (MVar, newEmptyMVar, newMVar,
putMVar, readMVar)
import Control.Concurrent.STM.TVar (TVar, modifyTVar', newTVarIO,
readTVarIO)
import Control.Exception (AsyncException (ThreadKilled),
catch, finally, mask_)
import Control.Monad.STM (atomically)
import Data.Foldable (for_, traverse_)
import Data.Map.Strict (Map, delete, elems, empty, insert,
keys)
newtype FinishMarker = FinishMarker (MVar ()) deriving (Eq)
newtype ThreadMap = ThreadMap (TVar (Map ThreadId FinishMarker))
newThreadMap :: IO ThreadMap
newThreadMap = ThreadMap <$> newTVarIO empty
newChild
:: ThreadMap
-> (ThreadMap -> IO ())
-> IO ThreadId
newChild brothers@(ThreadMap bMap) action = do
finishMarker <- FinishMarker <$> newEmptyMVar
children <- newThreadMap
mask_ $ do
child <- forkIOWithUnmask $ \unmask ->
unmask (action children) `finally` cleanup finishMarker brothers children
atomically $ modifyTVar' bMap (insert child finishMarker)
return child
killThreadHierarchy
:: ThreadMap
-> IO ()
killThreadHierarchy (ThreadMap children) = do
currentChildren <- readTVarIO children
traverse_ killThread $ keys currentChildren
remainingChildren <- readTVarIO children
traverse_ (\(FinishMarker marker) -> readMVar marker) $ elems remainingChildren
killThreadHierarchyInternal
:: ThreadMap
-> IO ()
killThreadHierarchyInternal (ThreadMap children) = do
currentChildren <- readTVarIO children
for_ (keys currentChildren) $ \child ->
killThread child `catch` (\ThreadKilled -> killThread child)
remainingChildren <- readTVarIO children
for_ (elems remainingChildren) $ \(FinishMarker marker) ->
readMVar marker `catch` (\ThreadKilled -> readMVar marker)
cleanup :: FinishMarker -> ThreadMap -> ThreadMap -> IO ()
cleanup (FinishMarker marker) (ThreadMap brotherMap) children = do
killThreadHierarchyInternal children
myThread <- myThreadId
atomically $ modifyTVar' brotherMap (delete myThread)
putMVar marker ()