module Control.Concurrent.HierarchyInternal where
import           Control.Concurrent.Lifted      (ThreadId, forkWithUnmask,
                                                 killThread, myThreadId)
import           Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar,
                                                 putMVar, readMVar, takeMVar)
import           Control.Exception.Lifted       (finally, mask_)
import           Control.Monad.Base             (MonadBase)
import           Control.Monad.Trans.Control    (MonadBaseControl)
import           Data.Map.Strict                (Map, delete, empty, insert,
                                                 toList)
newtype FinishMarker = FinishMarker (MVar ()) deriving (Eq)
newtype ThreadMap = ThreadMap (MVar (Map ThreadId FinishMarker))
newThreadMap :: MonadBase IO m => m ThreadMap
newThreadMap = ThreadMap <$> newMVar empty
newChild
    :: MonadBaseControl IO m
    => ThreadMap           
    -> (ThreadMap -> m ()) 
    -> m ThreadId          
newChild brothers@(ThreadMap bMap) action = do
    finishMarker <- newFinishMarker
    children <- newThreadMap
    mask_ $ do
        child <- forkWithUnmask $ \unmask ->
            unmask (action children) `finally` cleanup finishMarker brothers children
        takeMVar bMap >>= putMVar bMap . insert child finishMarker
        return child
killThreadHierarchy
    :: MonadBase IO m
    => ThreadMap    
    -> m ()
killThreadHierarchy (ThreadMap children) = do
    currentChildren <- readMVar children
    mapM_ (killThread . fst) $ toList currentChildren
    remainingChildren <- readMVar children
    mapM_ (waitFinish . snd) $ toList remainingChildren
newFinishMarker :: MonadBase IO m => m FinishMarker
newFinishMarker = FinishMarker <$> newEmptyMVar
markFinish :: MonadBase IO m => FinishMarker -> m ()
markFinish (FinishMarker marker) = putMVar marker ()
waitFinish :: MonadBase IO m => FinishMarker -> m ()
waitFinish (FinishMarker marker) = readMVar marker
cleanup :: MonadBase IO m => FinishMarker -> ThreadMap -> ThreadMap -> m ()
cleanup finishMarker (ThreadMap brotherMap) children = do
    killThreadHierarchy children
    myThread <- myThreadId
    takeMVar brotherMap >>= putMVar brotherMap . delete myThread
    markFinish finishMarker