{-|
Module      : Control.Concurrent.Bag.Concurrent
Description : Low-level bag of tasks implementation without using STM
Copyright   : (c) Bastian Holst, 2014
License     : BSD3
Maintainer  : bastianholst@gmx.de
Stability   : experimental
Portability : POSIX

This is a low-level interface to the bag of tasks. This implementation does
not use stm in contrast to the other implementation. This implementation
is provided for performance comparision. This implementation should only be
used if it is not possible to use /Control.Concurrent.Bag.SafeConcurrent/
for your task. This module allows more control but it also requires a deeper
knowledge about the implementation.
-}
module Control.Concurrent.Bag.Concurrent
  ( Bag
  , newBag
  , newBag_
  , addEval
  , addTask
  , getResult
  , writeResult
  , terminateBag
  , noMoreTasks
  , module Control.Concurrent.Bag.TaskBuffer )
where

import Data.Maybe (fromJust, isJust)
import Control.Concurrent
  ( ThreadId
  , forkIO
  , forkIOWithUnmask
  , killThread
  , getNumCapabilities
  , myThreadId )
import Control.Concurrent.Chan
import Control.Concurrent.MVar
import Control.Monad.Writer
import Control.Concurrent.Bag.TaskBuffer

-- | Result type
--
--   Indicates whether or not a task returned a result.
data Result a = NoResult | Result a | NoMoreTasks

-- | The bag of tasks type.
data Bag r = Bag {
    workers      :: [ThreadId]
  , taskBuffer   :: TaskBuffer (IO (Maybe r))
  , resultChan   :: Chan (Result r)
  , taskCountVar :: MVar Int
  , moreTasksVar :: MVar Bool
  , waitingOneMoreTasksVar :: MVar Int
  }

-- | Create a new bag and use the number of capabilities as the thread count.
newBag_ :: (MonadIO m) =>
           BufferType -- ^ type of the used buffer(s)
        -> m (Bag r)
newBag_ buf =
  (liftIO getNumCapabilities) >>= (return . (max 1))  >>= newBag buf

-- | Build and return a new bag of tasks.
newBag :: (MonadIO m) =>
          BufferType -- ^ type of the used buffer(s)
       -> Int        -- ^ Number of threads
       -> m (Bag r)
newBag buf n = do
  results <- liftIO newChan
  buffer  <- liftIO $ case buf of { Queue -> newChanBuffer; Stack -> newStackBuffer }
  workers <-
    execWriterT $ mapM (\i -> do
    tid   <- liftIO $ forkIOWithUnmask $ \unmask -> unmask $ worker buffer results
    tell $ [tid]) [0..n-1]
  taskCount          <- liftIO $ newMVar 0
  moreTasks          <- liftIO $ newMVar True
  waitingOnMoreTasks <- liftIO $ newMVar 0
  return $ Bag workers buffer results taskCount moreTasks waitingOnMoreTasks

-- | The main function of the worker thread.
worker :: TaskBuffer (IO (Maybe r)) -- ^ own buffer
       -> Chan (Result r)           -- ^ result channel
       -> IO ()
worker ownBuffer results = forever $ do
  task   <- readBuffer ownBuffer
  result <- task
  case result of
    Nothing -> writeChan results NoResult
    Just r  -> writeChan results (Result r)

-- | Add a task to the given bag of tasks.
addTask :: (MonadIO m) => Bag r -> IO (Maybe r) -> m ()
addTask bag task = liftIO $ do
  modifyMVar_ (taskCountVar bag) (return . (+1))
  writeBuffer (taskBuffer bag) task

-- | Add the evaluation of a haskell expression.
--
--   The given expression will be evaluated to weak head normal form.
addEval :: (MonadIO m) => Bag r -> r -> m ()
addEval bag e =
  addTask bag (e `seq` return $ Just e)

-- | Tell the bag that there will be no more tasks from the outside,
--   however, queued tasks may still add new tasks.
noMoreTasks :: MonadIO m => Bag r -> m ()
noMoreTasks bag = liftIO $ do
  modifyMVar_ (moreTasksVar bag) (return . const False)
  writeChan (resultChan bag) NoMoreTasks

-- | Write back a result to be read by 'getResult'.
writeResult :: MonadIO m => Bag r -> r -> m ()
writeResult bag x = liftIO $ do
  modifyMVar_ (taskCountVar bag) (return . (+1))
  writeChan (resultChan bag) (Result x)

-- | Get the next result written by 'writeResult'.
getResult :: (MonadIO m) => Bag r -> m (Maybe r)
getResult bag = liftIO $ do
  moreTasks <- readMVar (moreTasksVar bag)
  if moreTasks
    then do
      taskCount <- modifyMVar (taskCountVar bag) (\c -> return (c-1, c))
      modifyMVar_ (waitingOneMoreTasksVar bag) (return . (+1))
      result <- readChan (resultChan bag)
      waiting <- modifyMVar (waitingOneMoreTasksVar bag) (\c -> return (c-1,c-1))
      case result of
        NoMoreTasks -> do
          -- when other threads wait on this message, send it again
          when (waiting > 0) (writeChan (resultChan bag) NoMoreTasks)
          if taskCount <= 0
            then modifyMVar_ (taskCountVar bag) (return . (+1)) >> getResult bag
            else readOneResult
        NoResult    -> getResult bag
        Result v    -> return $ Just v
    else do
      taskCount <- takeMVar (taskCountVar bag)
      if taskCount <= 0
        then do
          putMVar (taskCountVar bag) taskCount
          return Nothing
        else do
          putMVar (taskCountVar bag) (taskCount - 1)
          readOneResult
 where
  readOneResult = do
    result <- readChan (resultChan bag)
    case result of
      NoMoreTasks -> do
        -- when other threads wait on this message, send it again
        -- the value in the MVar does not increase anymore, because the 'moreTasksVar'
        -- is set to False before this message is sent. After this, no new threads
        -- will start to wait on this message.
        waiting <- readMVar (waitingOneMoreTasksVar bag)
        when (waiting > 0) (writeChan (resultChan bag) NoMoreTasks)
        readOneResult
      NoResult    -> getResult bag
      Result v    -> return $ Just v

-- | Terminates all threads running in the thread pool. 'terminateBag' is
--   non-blocking and therefore does not guarantee that all threads are
--   terminated after its executions. Additionally it is not guaranteed
--   that the treads are terminated promptly. It is implemented using
--   asynchronous exceptions and therefore it terminates a thread once it uses
--   a /safe point/. A safe point is where memory allocation occurs. Although
--   most tasks will have some point of memory allocation now and the, there
--   are loops that will never reach a safe point.
terminateBag :: MonadIO m => Bag r -> m ()
terminateBag bag = do
  noMoreTasks bag
  liftIO $ mapM_ terminateThread $ workers bag
 where
  terminateThread :: ThreadId -> IO ()
  terminateThread tid =
    -- killThread is blocking and exceptions are only received on safe points,
    -- memory allocations. However some calculations never see a safe point,
    -- but we want to continue here.
    void (forkIO $ killThread tid)