-- | This is a low-level interface to the bag of tasks, which should only be
--   used if it is not possible to use /Control.Concurrent.Bag.Safe/ for your
--   task. This module allows more control but it also requires a deeper
--   knowledge about the implementation.
module Control.Concurrent.Bag.Basic
  ( Bag
  , newBag_
  , addEval
  , addTask
  , getResult
  , terminateBag
  , noMoreTasks
  , module Control.Concurrent.Bag.TaskBuffer )
where

import Data.Maybe (fromJust, isJust)
import Control.Concurrent
  ( ThreadId
  , forkIO
  , forkIOWithUnmask
  , killThread
  , getNumCapabilities
  , myThreadId )
import qualified Control.Concurrent.STM as STM (atomically)
import Control.Concurrent.STM
  ( STM
  , TVar
  , TChan
  , newTVar
  , newTChan
  , writeTChan
  , readTChan
  , writeTVar
  , tryReadTChan
  , isEmptyTChan
  , retry
  , readTVar )
import Control.Monad.Writer
import qualified Data.Map.Strict as Map

import Control.Concurrent.Bag.TaskBuffer

data Bag r = Bag {
    workerStates :: Map.Map ThreadId (TVar Bool)
  , taskBuffers  :: Either (TaskBufferSTM (IO (Maybe r))) (Map.Map ThreadId (TaskBufferSTM (IO (Maybe r))))
  , resultChan   :: TChan r
  , terminateVar :: TVar Bool
  , moreTasksVar :: TVar Bool
  }

bufferList :: Bag r -> [TaskBufferSTM (IO (Maybe r))]
bufferList bag =
  case taskBuffers bag of
    Left b   -> [b]
    Right bs -> Map.elems bs

newBag_ :: MonadIO m =>
           BufferType
        -> Maybe (SplitFunction r)
        -> m (Bag r)
newBag_ buf split =
  (liftIO getNumCapabilities) >>= (return . (max 1))  >>= newBag buf split

-- | Create a new bag of tasks
newBag :: MonadIO m =>
          BufferType
       -> Maybe (SplitFunction r)   -- ^ split function
                                    --   if this function is given use one
                                    --   buffer per thread
       -> Int                       -- ^ Number of threads
       -> m (Bag r)
newBag buf split n = do
  results <- atomically $ newTChan
  let newBuffer = case buf of { Queue -> newChanBuffer; Stack -> newStackBuffer }
  buffers <- cycle `liftM` if isJust split
    then replicateM n $ atomically $ newBuffer
    else atomically $ do
      buffer <- newBuffer
      return [buffer]
  let splitf   =
        case split of
          Nothing -> undefined
          Just s  -> s
  (states, bufferMap) <-
    execWriterT $ mapM (\i -> do
    let ownBuffer      = buffers !! i
        foreignBuffers = if isJust split
          then take n $ drop i buffers
          else []
    state <- atomically $ newTVar False
    tid   <- liftIO $ forkIOWithUnmask $ \unmask -> unmask $ worker splitf state ownBuffer foreignBuffers results
    tell $ (Map.insert tid state Map.empty, Map.insert tid ownBuffer Map.empty)) [0..n-1]
  terminated <- atomically $ newTVar False
  moreTasks  <- atomically $ newTVar True
  let eitherBuffer =
        case split of
          Nothing -> Left  $ head buffers
          Just _  -> Right   bufferMap
  return $ Bag states eitherBuffer results terminated moreTasks

worker :: SplitFunction r  -- ^ split function
       -> TVar Bool          -- ^ worker state, running?
       -> TaskBufferSTM (IO (Maybe r))   -- ^ own buffer
       -> [TaskBufferSTM (IO (Maybe r))] -- ^ foreign buffers
       -> TChan r            -- ^ result channel
       -> IO ()
worker split state ownBuffer foreignBuffers results = forever $ do
  atomically $ writeTVar state False
  task   <- atomically $ do
    mTask <- tryReadBufferSTM ownBuffer
    writeTVar state True
    case mTask of
      Just task -> return task
      Nothing   -> splitBuffers split ownBuffer foreignBuffers
  result <- task
  when (isJust result) $ atomically $ writeTChan results $ fromJust result
 where
  splitBuffers :: SplitFunction r
               -> TaskBufferSTM (IO (Maybe r))
               -> [TaskBufferSTM (IO (Maybe r))]
               -> STM (IO (Maybe r))
  splitBuffers _     _   []     = retry
  splitBuffers split own (f:fs) = do
    isEmpty <- isEmptyBufferSTM f
    if isEmpty
      then splitBuffers split own fs
      else split own f

addTask :: MonadIO m => Bag r -> IO (Maybe r) -> m ()
addTask bag task = do
  buffer <- case taskBuffers bag of
    Left buffer -> return buffer
    Right bufferMap -> do
      tid <- liftIO myThreadId
      return $ Map.findWithDefault (head $ Map.elems bufferMap) tid bufferMap
  atomically $ writeBufferSTM buffer task

addEval :: MonadIO m => Bag r -> r -> m ()
addEval bag e =
  addTask bag (e `seq` return $ Just e)

-- | Tell the bag that there will be no more tasks from the outside,
--   however, queued tasks may still add new tasks.
noMoreTasks :: MonadIO m => Bag r -> m ()
noMoreTasks bag =
  atomically $ writeTVar (moreTasksVar bag) False

getResult :: MonadIO m => Bag r -> m (Maybe r)
getResult bag = atomically $ do
  result <- tryReadTChan (resultChan bag)
  case result of
    Just r  -> return $ Just r
    Nothing -> do
      terminated <- readTVar $ terminateVar bag
      unless terminated $ do
        moreTasks <- readTVar $ moreTasksVar bag
        when moreTasks retry
        noTasks <- and `liftM` (mapM isEmptyBufferSTM (bufferList bag))
        unless noTasks retry
        mapM_ (\tvar -> readTVar tvar >>= (flip when) retry)
              (Map.elems $ workerStates bag)
      return Nothing

-- | Terminates all threads running in the thread pool. 'terminateBag' is
--   non-blocking and therefore does not guarantee that all threads are
--   terminated after its executions. Additionally it is not guaranteed
--   that the treads are terminated promptly. It is implemented using
--   asynchronous exceptions and therefore it terminates a thread once it uses
--   a /safe point/. A safe point is where memory allocation occurs. Although
--   most tasks will have some point of memory allocation now and the, there
--   are loops that will never reach a safe point.
terminateBag :: MonadIO m => Bag r -> m ()
terminateBag bag = do
  atomically $ writeTVar (terminateVar bag) True
  atomically $ writeTVar (moreTasksVar bag) False
  Map.foldrWithKey terminateThread (return ()) $ workerStates bag
 where
  terminateThread :: MonadIO m => ThreadId -> a -> m () -> m ()
  terminateThread tid =
    -- killThread is blocking and exceptions are only received on safe points,
    -- memory allocations. However some calculations never see a safe point,
    -- but we want to continue here.
    const (liftIO (forkIO $ killThread tid) >>)

-- Helper functions --
atomically :: MonadIO m => STM a -> m a
atomically = liftIO . STM.atomically