{-|
Module      : Control.Concurrent.Bag.Basic
Description : Low-level bag of tasks implementation using STM
Copyright   : (c) Bastian Holst, 2014
License     : BSD3
Maintainer  : bastianholst@gmx.de
Stability   : experimental
Portability : POSIX

This is a low-level interface to the bag of tasks, which should only be
used if it is not possible to use "Control.Concurrent.Bag.Safe" for your
task. This module allows more control but it also requires a deeper
knowledge about the implementation.
-}
module Control.Concurrent.Bag.Basic
  ( Bag ()
  , newBag
  , newBag_
  , addEval
  , addTask
  , getResult
  , writeResult
  , terminateBag
  , noMoreTasks
  , module Control.Concurrent.Bag.TaskBufferSTM )
where

import Data.Maybe (fromJust, isJust)
import Control.Concurrent
  ( ThreadId
  , forkIO
  , forkIOWithUnmask
  , killThread
  , getNumCapabilities
  , myThreadId )
import qualified Control.Concurrent.STM as STM (atomically)
import Control.Concurrent.STM
  ( STM
  , TVar
  , TChan
  , newTVar
  , newTChan
  , writeTChan
  , readTChan
  , writeTVar
  , tryReadTChan
  , isEmptyTChan
  , retry
  , readTVar )
import Control.Monad.Writer
import qualified Data.Map.Strict as Map

import Control.Concurrent.Bag.TaskBufferSTM

-- | The bag of tasks type.
data Bag r = Bag {
    workerStates :: Map.Map ThreadId (TVar Bool)
  , taskBuffers  :: Either (TaskBufferSTM (IO (Maybe r))) (Map.Map ThreadId (TaskBufferSTM (IO (Maybe r))))
  , resultChan   :: TChan r
  , terminateVar :: TVar Bool
  , moreTasksVar :: TVar Bool
  }

-- | Get a list of all task buffers from a 'Bag'
bufferList :: Bag r -> [TaskBufferSTM (IO (Maybe r))]
bufferList bag =
  case taskBuffers bag of
    Left b   -> [b]
    Right bs -> Map.elems bs

-- | Build and return a new bag of tasks with a default number of worker
--   threads.
newBag_ :: MonadIO m =>
           BufferType              -- ^ type of the used buffer(s)
        -> Maybe (SplitFunction r) -- ^ Possible split function
                                   --
                                   --   If the function is given, we will create
                                   --   a bag with one buffer per worker
                                   --   reducing the communication between the
                                   --   workers.
                                   --   buffer per thread
        -> m (Bag r)
newBag_ buf split =
  (liftIO getNumCapabilities) >>= (return . (max 1))  >>= newBag buf split

-- | Build and return a new bag of tasks
newBag :: MonadIO m =>
          BufferType                -- ^ type of the used buffer(s)
       -> Maybe (SplitFunction r)   -- ^ Possible split function
                                    --
                                    --   If the function is given, we will create
                                    --   a bag with one buffer per worker
                                    --   reducing the communication between the
                                    --   workers.
       -> Int                       -- ^ number of threads
       -> m (Bag r)
newBag buf split n = do
  results <- atomically $ newTChan
  let newBuffer = case buf of { Queue -> newChanBufferSTM; Stack -> newStackBufferSTM }
  buffers <- cycle `liftM` if isJust split
    then replicateM n $ atomically $ newBuffer
    else atomically $ do
      buffer <- newBuffer
      return [buffer]
  let splitf   =
        case split of
          Nothing -> undefined
          Just s  -> s
  (states, bufferMap) <-
    execWriterT $ mapM (\i -> do
    let ownBuffer      = buffers !! i
        foreignBuffers = if isJust split
          then take n $ drop i buffers
          else []
    state <- atomically $ newTVar False
    tid   <- liftIO $ forkIOWithUnmask $ \unmask -> unmask $ worker splitf state ownBuffer foreignBuffers results
    tell $ (Map.insert tid state Map.empty, Map.insert tid ownBuffer Map.empty)) [0..n-1]
  terminated <- atomically $ newTVar False
  moreTasks  <- atomically $ newTVar True
  let eitherBuffer =
        case split of
          Nothing -> Left  $ head buffers
          Just _  -> Right   bufferMap
  return $ Bag states eitherBuffer results terminated moreTasks

-- | The main function of the worker thread.
worker :: SplitFunction r                -- ^ split function
       -> TVar Bool                      -- ^ worker state, running?
       -> TaskBufferSTM (IO (Maybe r))   -- ^ own buffer
       -> [TaskBufferSTM (IO (Maybe r))] -- ^ foreign buffers
       -> TChan r                        -- ^ result channel
       -> IO ()
worker split state ownBuffer foreignBuffers results = forever $ do
  atomically $ writeTVar state False
  task   <- atomically $ do
    mTask <- tryReadBufferSTM ownBuffer
    writeTVar state True
    case mTask of
      Just task -> return task
      Nothing   -> splitBuffers split ownBuffer foreignBuffers
  result <- task
  when (isJust result) $ atomically $ writeTChan results $ fromJust result
 where
  -- | Find a full buffer to split.
  splitBuffers :: SplitFunction r
               -> TaskBufferSTM (IO (Maybe r))
               -> [TaskBufferSTM (IO (Maybe r))]
               -> STM (IO (Maybe r))
  splitBuffers _     _   []     = retry
  splitBuffers split own (f:fs) = do
    isEmpty <- isEmptyBufferSTM f
    if isEmpty
      then splitBuffers split own fs
      else split own f

-- | Add a task to the given bag of tasks.
addTask :: MonadIO m => Bag r -> IO (Maybe r) -> m ()
addTask bag task = do
  buffer <- case taskBuffers bag of
    Left buffer -> return buffer
    Right bufferMap -> do
      tid <- liftIO myThreadId
      return $ Map.findWithDefault (head $ Map.elems bufferMap) tid bufferMap
  atomically $ writeBufferSTM buffer task

-- | Add the evaluation of a haskell expression.
--
--   The given expression will be evaluated to weak head normal form.
addEval :: MonadIO m => Bag r -> r -> m ()
addEval bag e =
  addTask bag (e `seq` return $ Just e)

-- | Write back a result to be read by 'getResult'.
writeResult :: MonadIO m => Bag r -> r -> m ()
writeResult bag x =
  atomically $ writeTChan (resultChan bag) x

-- | Tell the bag that there will be no more tasks from the outside,
--   however, queued tasks may still add new tasks.
noMoreTasks :: MonadIO m => Bag r -> m ()
noMoreTasks bag =
  atomically $ writeTVar (moreTasksVar bag) False

-- | Get the next result written by 'writeResult'.
getResult :: MonadIO m => Bag r -> m (Maybe r)
getResult bag = atomically $ do
  result <- tryReadTChan (resultChan bag)
  case result of
    Just r  -> return $ Just r
    Nothing -> do
      terminated <- readTVar $ terminateVar bag
      unless terminated $ do
        moreTasks <- readTVar $ moreTasksVar bag
        when moreTasks retry
        noTasks <- and `liftM` (mapM isEmptyBufferSTM (bufferList bag))
        unless noTasks retry
        mapM_ (\tvar -> readTVar tvar >>= (flip when) retry)
              (Map.elems $ workerStates bag)
      return Nothing

-- | Terminates all threads running in the thread pool. 'terminateBag' is
--   non-blocking and therefore does not guarantee that all threads are
--   terminated after its executions. Additionally it is not guaranteed
--   that the treads are terminated promptly. It is implemented using
--   asynchronous exceptions and therefore it terminates a thread once it uses
--   a /safe point/. A safe point is where memory allocation occurs. Although
--   most tasks will have some point of memory allocation now and the, there
--   are loops that will never reach a safe point.
terminateBag :: MonadIO m => Bag r -> m ()
terminateBag bag = do
  atomically $ writeTVar (terminateVar bag) True
  atomically $ writeTVar (moreTasksVar bag) False
  Map.foldrWithKey terminateThread (return ()) $ workerStates bag
 where
  terminateThread :: MonadIO m => ThreadId -> a -> m () -> m ()
  terminateThread tid =
    -- killThread is blocking and exceptions are only received on safe points,
    -- memory allocations. However some calculations never see a safe point,
    -- but we want to continue here.
    const (liftIO (forkIO $ killThread tid) >>)

-- Helper functions --
atomically :: MonadIO m => STM a -> m a
atomically = liftIO . STM.atomically