module Wobsurv.Util.MasterThread where

import BasePrelude hiding (forkFinally)
import Control.Monad.Trans.Reader
import qualified BasePrelude
import qualified STMContainers.Set as Set
import qualified Wobsurv.Util.PartialHandler as H

-- |
-- A monad, which adds a functionality of forking of slave threads,
-- while binding them to their master thread in such a manner
-- that when the master is killed, they get killed too.
-- It also rethrows exceptions from the slave threads in the main thread,
-- so they don't get lost.
type MasterThread =
  ReaderT Context IO

type Context =
  (Set.Set ThreadId)

type MT = 
  MasterThread

run :: MT a -> IO a
run mt =
  do
    context <- atomically $ Set.new
    catch (runReaderT mt context) $ \(e :: SomeException) -> do
      -- Kill all slaves
      traverse_ killThread =<< do
        atomically $ Set.foldM (\l -> return . (: l)) [] context
      -- Wait for all slaves to die
      atomically $ Set.null context >>= bool retry (return ())
      throwIO e

forkFinally :: MT () -> IO () -> MT ThreadId
forkFinally main finalizer =
  ReaderT $ \context -> do
    thread <- myThreadId
    slaveContext <- atomically $ Set.new
    let
      onDeath r =
        do
          -- Finalization and rethrowing of exceptions into the master thread:
          do
            r' <- try $ finalizer
            forM_ (left r <|> left r') $ 
              H.toTotal $ H.onThreadKilled (return ()) <> H.rethrowTo thread
          -- Context management:
          do
            traverse_ killThread =<< do
              atomically $ Set.foldM (\l -> return . (: l)) [] slaveContext
            slaveThread <- myThreadId
            -- Ensures that it waits for all slaves to die 
            -- before informing the master that it died itself.
            -- And so on recursively.
            atomically $ do
              Set.null slaveContext >>= \case
                True -> Set.delete slaveThread context
                False -> retry
        where
          left = either Just (const Nothing)
    slaveThread <- BasePrelude.forkFinally (runReaderT main slaveContext) onDeath
    atomically $ Set.insert slaveThread context
    return slaveThread

fork :: MT () -> MT ThreadId
fork main =
  forkFinally main (return ())

-- | 
-- Run the 'MasterThread' monad, which performs no subforking.
runWithoutForking :: MT a -> IO a
runWithoutForking mt =
  runReaderT mt (error "Attempt to fork when run with 'runWithoutForking'")