{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-}
module Control.Concurrent.Pool
  (Task (..),
   Pool,
   newPool, newPoolIO,
   isPoolWaiting,
   queue,
   noMoreTasks, noMoreTasksIO,
   readResult,
   resultsReader,
   waitFor, waitForIO,
   waitForTasks,
   terminatePool, terminatePoolIO
  ) where

import Control.Monad
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Trans
import Control.Concurrent
import Control.Concurrent.STM
import qualified Data.Map as M

-- | Any monadic computation that can be turned to IO
class (Monad m, MonadIO m) => Task m a r where
  -- | Run the task with given argument
  runTask :: m r -> a -> IO r

instance Task IO () r where
  runTask io () = io

instance Task (ReaderT a IO) a r where
  runTask m a = runReaderT m a

instance Task (StateT a IO) a r where
  runTask m a = evalStateT m a

-- | This stores all pool-related states
data (Task m a r) => Pool m a r = Pool {
  poolThreads :: [ThreadId],                     -- ^ List of worker threads
  threadsStates :: TVar (M.Map ThreadId Bool),   -- ^ For each worker thread: is it running some task currently?
  poolQueue :: TChan (Integer, m r, a),          -- ^ Queue of tasks
  returnResults :: Bool,                         -- ^ Should pool return tasks' results?
  poolResults :: TChan (Integer, r),             -- ^ Queue of tasks' results (not used if returnResults == False)
  poolLastTask :: TVar Integer,                  -- ^ Number of last task
  poolDone :: TVar Bool,                         -- ^ Are all tasks done?
  doneTasks :: TVar [Integer]                    -- ^ Numbers of done tasks; not used when returnResults == False
  }

-- | Create new threads pool
newPool :: Task m a r
        => Int             -- ^ Number of threads in the pool
        -> Bool            -- ^ Should pool return tasks' results?
        -> m (Pool m a r)
newPool n ret = liftIO $ newPoolIO n ret

-- | Create new threads pool in IO monad
newPoolIO :: Task m a r
        => Int             -- ^ Number of threads in the pool
        -> Bool            -- ^ Should pool return tasks' results?
        -> IO (Pool m a r)
newPoolIO n ret = do
  chan <- atomically newTChan
  results <- atomically newTChan
  counter <- atomically $ newTVar 0
  states <- atomically $ newTVar M.empty
  done <- atomically $ newTVar False
  el <- atomically $ newTVar []
  threads <- replicateM n $ forkIO (worker states chan ret results el)
  return $ Pool {
      poolThreads = threads,
      threadsStates = states,
      poolQueue = chan,
      returnResults = ret,
      poolResults = results,
      poolLastTask = counter,
      poolDone = done,
      doneTasks = el
      }

-- | Set state of current thread in the pool
setMyState :: TVar (M.Map ThreadId Bool) -> Bool -> IO ()
setMyState var st = do
  th <- myThreadId
  atomically $ do
    states <- readTVar var
    let states' = M.insert th st states
    writeTVar var states'

-- | Worker thread itself
worker :: Task m a r
                 => TVar (M.Map ThreadId Bool)  -- ^ States of all threads
                 -> TChan (Integer, m r, a)     -- ^ Channel where to read tasks from
                 -> Bool                        -- ^ Send results to output queue?
                 -> TChan (Integer, r)          -- ^ Channel where to write results to
                 -> TVar [Integer]              -- ^ List of done tasks
                 -> IO ()
worker var chan ret res doneVar = forever $ do
  setMyState var False
  (n, m, x) <- atomically $ readTChan chan
  setMyState var True
  putStrLn $ ">>> Starting task #" ++ show n
  y <- runTask m x
  putStrLn $ ">>> Task #" ++ show n ++ " done."
  when ret $
      atomically $ do
          writeTChan res (n, y)
          done <- readTVar doneVar
          writeTVar doneVar (n: done)

-- | Check if pool is waiting for new tasks
isPoolWaiting :: Task m a r
              => Pool m a r         -- ^ The pool
              -> IO Bool
isPoolWaiting pool = atomically $ do
  m <- readTVar (threadsStates pool)
  let states = M.elems m
  return $ all (== False) states

-- | Put the new task into queue
queue :: Task m a r             
      => Pool m a r             -- ^ Pool of threads
      -> m r                    -- ^ Task (monadic computation)
      -> a                      -- ^ Argument for that computation
      -> m Integer              -- ^ Returns a number of task in the pool
queue pool m x = do
  let counter = poolLastTask pool
  count <- liftIO $ atomically $ do
             was <- readTVar counter
             writeTVar counter (was + 1)
             return (was + 1)
  liftIO $ atomically $ writeTChan (poolQueue pool) (count, m, x)
  return count

-- | Tell to the pool that there will no new tasks
noMoreTasks :: Task m a r => Pool m a r -> m ()
noMoreTasks pool =
  liftIO $ noMoreTasksIO pool

-- | Tell to the pool that there will no new tasks, in IO monad
noMoreTasksIO :: Task m a r => Pool m a r -> IO ()
noMoreTasksIO pool =
  atomically $ writeTVar (poolDone pool) True

-- | Read next result from the pool.
-- This makes sense only if for pool which returns results.
readResult :: Task m a r
           => Pool m a r         -- ^ Pool
           -> m (Integer, r)     -- ^ Returns (number of task, task's result)
readResult pool = 
  liftIO $ atomically $ readTChan (poolResults pool)

-- | Read all results from pool and run given computation with each.
-- Probably you will run this in the separate thread (using forkIO).
-- This makes sense only if for pool which returns results.
resultsReader :: Task m a r => Pool m a r -> (Integer -> r -> IO b) -> IO ()
resultsReader pool fn = forever $ do
  (n, r) <- atomically $ readTChan (poolResults pool)
  fn n r

-- | Wait until all tasks will end
waitFor :: Task m a r => Pool m a r -> m ()
waitFor pool = do
  empty <- liftIO $ do
      threadDelay 1000
      atomically $ do
          m <- readTVar (threadsStates pool)
          let states = M.elems m
              allWaiting = all (== False) states
          done <- readTVar (poolDone pool)
          return (done && allWaiting)
  when (not empty) $
      waitFor pool

-- | Wait until all tasks will end, in IO monad
waitForIO :: Task m a r => Pool m a r -> IO ()
waitForIO pool = do
  threadDelay 1000
  empty <- atomically $ do
          m <- readTVar (threadsStates pool)
          let states = M.elems m
              allWaiting = all (== False) states
          done <- readTVar (poolDone pool)
          return (done && allWaiting)
  when (not empty) $
      waitForIO pool

waitForTasks :: Task m a r => Pool m a r -> [Integer] -> m ()
waitForTasks pool tasks = do
  done <- liftIO $ atomically $ readTVar (doneTasks pool)
  liftIO $ putStrLn $ ">>> Done tasks: " ++ show done
  let tasks' = filter (\t -> not (t `elem` done)) tasks
  liftIO $ putStrLn $ ">>> Tasks to wait: " ++ show tasks'
  if null tasks'
    then return ()
    else do
      liftIO $ threadDelay 1000
      waitForTasks pool tasks'

-- | Terminate all threads in the pool
terminatePool :: Task m a r => Pool m a r -> m ()
terminatePool pool = liftIO $ terminatePoolIO pool

-- | Terminate all threads in the pool, in IO monad
terminatePoolIO :: Task m a r => Pool m a r -> IO ()
terminatePoolIO pool =
  mapM_ killThread (poolThreads pool)