-- | This is a low-level interface to the bag of tasks. This implementation does
--   not use stm in contrast to the other implementation. This implementation
--   is provided for performance comparision. This implementation should only be
--   used if it is not possible to use /Control.Concurrent.Bag.SafeConcurrent/
--   for your task. This module allows more control but it also requires a deeper
--   knowledge about the implementation.
module Control.Concurrent.Bag.Concurrent
  ( Bag
  , newBag_
  , addEval
  , addTask
  , getResult
  , writeResult
  , terminateBag
  , noMoreTasks
  , module Control.Concurrent.Bag.TaskBuffer )
where

import Data.Maybe (fromJust, isJust)
import Control.Concurrent
  ( ThreadId
  , forkIO
  , forkIOWithUnmask
  , killThread
  , getNumCapabilities
  , myThreadId )
import Control.Concurrent.Chan
import Control.Concurrent.MVar
import Control.Monad.Writer
import Control.Concurrent.Bag.TaskBuffer

data Result a = NoResult | Result a | NoMoreTasks

data Bag r = Bag {
    workers      :: [ThreadId]
  , taskBuffer   :: TaskBuffer (IO (Maybe r))
  , resultChan   :: Chan (Result r)
  , taskCountVar :: MVar Int
  , moreTasksVar :: MVar Bool
  , waitingOneMoreTasksVar :: MVar Int
  }

-- | Create a new bag and use the number of capabilities as the thread count.
newBag_ :: (MonadIO m) =>
           BufferType
        -> m (Bag r)
newBag_ buf =
  (liftIO getNumCapabilities) >>= (return . (max 1))  >>= newBag buf

-- | Create a new bag of tasks
newBag :: (MonadIO m) =>
          BufferType
       -> Int                       -- ^ Number of threads
       -> m (Bag r)
newBag buf n = do
  results <- liftIO newChan
  buffer  <- liftIO $ case buf of { Queue -> newChanBuffer; Stack -> newStackBuffer }
  workers <-
    execWriterT $ mapM (\i -> do
    tid   <- liftIO $ forkIOWithUnmask $ \unmask -> unmask $ worker buffer results
    tell $ [tid]) [0..n-1]
  taskCount          <- liftIO $ newMVar 0
  moreTasks          <- liftIO $ newMVar True
  waitingOnMoreTasks <- liftIO $ newMVar 0
  return $ Bag workers buffer results taskCount moreTasks waitingOnMoreTasks

worker :: TaskBuffer (IO (Maybe r)) -- ^ own buffer
       -> Chan (Result r)           -- ^ result channel
       -> IO ()
worker ownBuffer results = forever $ do
  task   <- readBuffer ownBuffer
  result <- task
  case result of
    Nothing -> writeChan results NoResult
    Just r  -> writeChan results (Result r)

addTask :: (MonadIO m) => Bag r -> IO (Maybe r) -> m ()
addTask bag task = liftIO $ do
  modifyMVar_ (taskCountVar bag) (return . (+1))
  writeBuffer (taskBuffer bag) task

addEval :: (MonadIO m) => Bag r -> r -> m ()
addEval bag e =
  addTask bag (e `seq` return $ Just e)

-- | Tell the bag that there will be no more tasks from the outside,
--   however, queued tasks may still add new tasks.
noMoreTasks :: MonadIO m => Bag r -> m ()
noMoreTasks bag = liftIO $ do
  modifyMVar_ (moreTasksVar bag) (return . const False)
  writeChan (resultChan bag) NoMoreTasks

writeResult :: MonadIO m => Bag r -> r -> m ()
writeResult bag x = liftIO $ do
  modifyMVar_ (taskCountVar bag) (return . (+1))
  writeChan (resultChan bag) (Result x)

getResult :: (MonadIO m) => Bag r -> m (Maybe r)
getResult bag = liftIO $ do
  moreTasks <- readMVar (moreTasksVar bag)
  if moreTasks
    then do
      taskCount <- modifyMVar (taskCountVar bag) (\c -> return (c-1, c))
      modifyMVar_ (waitingOneMoreTasksVar bag) (return . (+1))
      result <- readChan (resultChan bag)
      waiting <- modifyMVar (waitingOneMoreTasksVar bag) (\c -> return (c-1,c-1))
      case result of
        NoMoreTasks -> do
          -- when other threads wait on this message, send it again
          when (waiting > 0) (writeChan (resultChan bag) NoMoreTasks)
          if taskCount <= 0
            then modifyMVar_ (taskCountVar bag) (return . (+1)) >> getResult bag
            else readOneResult
        NoResult    -> getResult bag
        Result v    -> return $ Just v
    else do
      taskCount <- takeMVar (taskCountVar bag)
      if taskCount <= 0
        then do
          putMVar (taskCountVar bag) taskCount
          return Nothing
        else do
          putMVar (taskCountVar bag) (taskCount - 1)
          readOneResult
 where
  readOneResult = do
    result <- readChan (resultChan bag)
    case result of
      NoMoreTasks -> do
        -- when other threads wait on this message, send it again
        -- the value in the MVar does not increase anymore, because the 'moreTasksVar'
        -- is set to False before this message is sent. After this, no new threads
        -- will start to wait on this message.
        waiting <- readMVar (waitingOneMoreTasksVar bag)
        when (waiting > 0) (writeChan (resultChan bag) NoMoreTasks)
        readOneResult
      NoResult    -> getResult bag
      Result v    -> return $ Just v

-- | Terminates all threads running in the thread pool. 'terminateBag' is
--   non-blocking and therefore does not guarantee that all threads are
--   terminated after its executions. Additionally it is not guaranteed
--   that the treads are terminated promptly. It is implemented using
--   asynchronous exceptions and therefore it terminates a thread once it uses
--   a /safe point/. A safe point is where memory allocation occurs. Although
--   most tasks will have some point of memory allocation now and the, there
--   are loops that will never reach a safe point.
terminateBag :: MonadIO m => Bag r -> m ()
terminateBag bag = do
  noMoreTasks bag
  liftIO $ mapM_ terminateThread $ workers bag
 where
  terminateThread :: ThreadId -> IO ()
  terminateThread tid =
    -- killThread is blocking and exceptions are only received on safe points,
    -- memory allocations. However some calculations never see a safe point,
    -- but we want to continue here.
    void (forkIO $ killThread tid)