module Control.Concurrent.Priority.Queue
    (Queue,
     TaskHandle,
     QueueOrder(..),
     QueueConfigurationRecord(..),
     fair_queue_configuration, fast_queue_configuration,
     newQueue,
     taskPriority,
     taskQueue,
     pendingTasks,
     isTopOfQueue,
     hasCompleted,
     putTask,
     pullTask,
     pullFromTop,
     pullSpecificTasks,
     dispatchTasks,
     flushQueue,
     load)
    where

import Data.Heap as Heap
import Data.List as List (sort,sortBy,groupBy,drop)
import GHC.Conc
import Control.Monad
import Data.Unique
import Data.Ord
import Data.Maybe

-- | A prioritized 'Queue'.  Prioritization is least-first, i.e. larger values are nicer.
--
-- A 'Queue' is not associated with any working thread, therefore, it is the client\'s responsibility to make sure that every pushed
-- task is also pulled, or the 'Queue' will stall.  There are several ways to accomplish this:
--
-- * Call 'pullTask' at least once for every call to 'putTask'.
--
-- * Use 'dispatchTasks' to push every task.
--
-- * Use 'flushQueue' whenever the 'Queue' is not empty.
data (Ord a) => Queue a = Queue {
    queue_configuration :: !(QueueConfigurationRecord a),
    queue_unique :: Unique,
    pending_tasks :: TVar (MinHeap (TaskHandle a)),
    task_counter :: TVar Integer }

data QueueOrder = FIFO | FILO

-- | Configuration options for a 'Queue'.  A 'Queue' blocks on a number of predicates when dispatching a job.  Generally, 'fair_queue_configuration'
-- should work well for long-running batch jobs and 'fast_queue_configuration' should work for rapid paced jobs.
--
-- * A single STM predicate for the entire 'Queue'.  This blocks the entire 'Queue' until the predicate is satisfied.
--
-- * A STM predicate parameterized by priority.  This blocks a single priority level, and the 'Queue' will skip all tasks at that priority.
--
-- * Each task is itself an STM transaction, and can block itself.
--
-- * Pure constraints on priority and ordering inversion.
--
-- If a task is blocked for any reason, the task is skipped and the next task attempted, in priority order.

data (Ord a) => QueueConfigurationRecord a = QueueConfigurationRecord {
    -- | A predicate that must hold before any task may be pulled from a 'Queue'.
    queue_predicate :: STM (),
    -- | A predicate that must hold before any priority level may be pulled from a 'Queue'.
    priority_indexed_predicate :: (a -> STM ()),
    -- | Constrains the greatest allowed difference between the priority of the top-of-queue task and the priority of a task to be pulled.
    allowed_priority_inversion :: a -> a -> Bool,
    -- | The greatest allowed difference between the ideal prioritized FILO/FIFO ordering of tasks and the actual ordering of tasks.
    -- Setting this too high can introduce a lot of overhead in the presence of a lot of short-running tasks.
    -- Setting this to zero turns off the predicate failover feature, i.e. only the top of queue task will ever be pulled.
    allowed_ordering_inversion :: Int,
    -- | Should the 'Queue' run in FILO or FIFO order.  Ordering takes place after prioritization, and won't have much effect if priorities are very fine-grained.
    queue_order :: !QueueOrder }

-- | A queue tuned for high throughput and fairness when processing moderate to long running tasks.
fair_queue_configuration :: (Ord a) => QueueConfigurationRecord a
fair_queue_configuration = QueueConfigurationRecord {
    queue_predicate = return (),
    priority_indexed_predicate = const $ return (),
    allowed_priority_inversion = const $ const $ True,
    allowed_ordering_inversion = numCapabilities*5,
    queue_order = FIFO }

-- | A queue tuned for high responsiveness and low priority inversion, but may have poorer long-term throughput and potential to starve some tasks compared to 'fair_queue_configuration'.
fast_queue_configuration :: (Ord a) => QueueConfigurationRecord a
fast_queue_configuration = fair_queue_configuration {
    allowed_priority_inversion = (==),
    allowed_ordering_inversion = numCapabilities,
    queue_order = FILO }

instance (Ord a) => Eq (Queue a) where
    (==) l r = queue_unique l == queue_unique r

instance (Ord a) => Ord (Queue a) where
    compare l r = compare (queue_unique l) (queue_unique r)

data TaskHandle a = TaskHandle {
    task_action :: STM (),
    is_top_of_queue :: TVar Bool,
    has_completed :: TVar Bool,
    task_counter_index :: !Integer,
    task_priority :: !a,
    task_queue :: Queue a }

instance (Ord a,Eq a) => Eq (TaskHandle a) where
    (==) l r = (==) (taskOrd l) (taskOrd r)

instance (Ord a) => Ord (TaskHandle a) where
    compare l r = compare (taskOrd l) (taskOrd r)

taskOrd :: TaskHandle a -> (a,Integer,Queue a)
taskOrd t = (task_priority t,task_counter_index t,task_queue t)

-- | True iff this task is poised at the top of it's 'Queue'.
isTopOfQueue :: TaskHandle a -> STM Bool
isTopOfQueue task = readTVar (is_top_of_queue task)

hasCompleted :: TaskHandle a -> STM Bool
hasCompleted task = readTVar (has_completed task)

taskPriority :: TaskHandle a -> a
taskPriority = task_priority

taskQueue :: TaskHandle a -> Queue a
taskQueue = task_queue

pendingTasks :: (Ord a) => Queue a -> STM [TaskHandle a]
pendingTasks = liftM Heap.toList . readTVar . pending_tasks

-- | Create a new 'Queue'.  
newQueue :: (Ord a) => QueueConfigurationRecord a -> IO (Queue a)
newQueue config = 
    do pending_tasks_var <- newTVarIO empty
       counter <- newTVarIO 0
       uniq <- newUnique
       return Queue {
           queue_configuration = config,
           queue_unique = uniq,
           pending_tasks = pending_tasks_var,
           task_counter = counter }

-- | Put a task with it's priority value onto this queue.  Returns a handle to the task.
putTask :: (Ord a) => Queue a -> a -> STM () -> STM (TaskHandle a)
putTask q prio actionSTM = 
    do count <- readTVar (task_counter q)
       writeTVar (task_counter q) $ (case (queue_order $ queue_configuration q) of FIFO -> (+ 1); FILO -> (subtract 1)) count
       false_top_of_queue <- newTVar False
       false_has_completed <- newTVar False
       let task = TaskHandle {
               task_action = actionSTM,
               is_top_of_queue = false_top_of_queue,
               has_completed = false_has_completed,
               task_counter_index = count,
               task_priority = prio,
               task_queue = q }
       watchingTopOfQueue q $ writeTVar (pending_tasks q) . insert task =<< readTVar (pending_tasks q)
       return task

-- | The number of tasks pending on this Queue.
load :: (Ord a) => Queue a -> STM Int 
load q = liftM size $ readTVar (pending_tasks q)

-- | Pull and commit a task from this 'Queue'.
pullTask :: (Ord a) => Queue a -> STM (TaskHandle a)
pullTask q = watchingTopOfQueue q $ 
    do queue_predicate $ queue_configuration q
       (task,rest) <- pullTask_ (queue_configuration q) empty =<< readTVar (pending_tasks q)
       writeTVar (pending_tasks q) rest
       writeTVar (has_completed task) True
       return task

pullTask_ :: (Ord a) => QueueConfigurationRecord a -> MinHeap (TaskHandle a) -> MinHeap (TaskHandle a) -> STM (TaskHandle a,MinHeap (TaskHandle a))
pullTask_ config faltered_tasks untried_tasks =
    do when (Heap.size faltered_tasks > allowed_ordering_inversion config) retry
       (task,rest) <- maybe retry return $ view untried_tasks
       let top_prio = taskPriority $ maybe task fst $ view $ faltered_tasks
       unless (allowed_priority_inversion config top_prio (taskPriority task)) retry
       let predicateFailed = do let (same_prios,remaining_prios) = Heap.span ((== (task_priority task)) . task_priority) rest
                                pullTask_ config (insert task faltered_tasks `union` fromList same_prios) remaining_prios
       let taskFailed = do pullTask_ config (insert task faltered_tasks) rest
       prio_ok <- ((priority_indexed_predicate config $ task_priority task) >> return True) `orElse` (return False)
       case prio_ok of
           False -> predicateFailed
           True -> (task_action task >> return (task,faltered_tasks `union` rest)) `orElse` taskFailed

-- | Pull this task from the top of a 'Queue', if it is already there.
-- If this task is top-of-queue, but it's predicates fail, then 'pullFromTop' may instead pull a lower-priority 'TaskHandle'.
pullFromTop :: (Ord a) => TaskHandle a -> STM (TaskHandle a)
pullFromTop task = 
    do b <- hasCompleted task
       if b then return task else
           do flip unless retry =<< isTopOfQueue task
              pullTask (taskQueue task)

-- | Don't return until the given 'TaskHandle' has been pulled from its associated 'Queue'.
-- This doesn't guarantee that the 'TaskHandle' will ever be pulled, even when the 'TaskHandle' and 'Queue' are both viable.
-- You must concurrently arrange for every other 'TaskHandle' associated with the same 'Queue' to be pulled, or the 'Queue' will stall.
pullSpecificTask :: (Ord a) => TaskHandle a -> IO ()
pullSpecificTask task =
    do actual_task <- atomically $ pullFromTop task
       unless (actual_task == task) $ pullSpecificTask task

-- | Don't return until the given 'TaskHandle's have been pulled from their associated 'Queue's.
-- This doesn't guarantee that the 'TaskHandle' will ever be pulled, even when the 'TaskHandle' and 'Queue' are both viable.
-- You must concurrently arrange for every other 'TaskHandle' associated with the same 'Queue' to be pulled, or the 'Queue' will stall.
-- 'pullSpecificTasks' can handle lists 'TaskHandle's that are distributed among several 'Queue's, as well as a 'TaskHandle's that have
-- already completed or complete concurrently from another thread.
pullSpecificTasks :: (Ord a) => [TaskHandle a] -> IO ()
pullSpecificTasks tasks =
    do queue_groups <- mapM (\g -> liftM ((,) g) newEmptyMVar) $ map sort $ groupBy (\x y -> taskQueue x == taskQueue y) $ sortBy (comparing taskQueue) tasks
       let pullTaskGroup (g,m) = mapM pullSpecificTask g >> putMVar m ()
       mapM (forkIO . pullTaskGroup) (List.drop 1 queue_groups)
       maybe (return ()) pullTaskGroup $ listToMaybe queue_groups
       mapM_ (takeMVar . snd) queue_groups

-- | \"Fire and forget\" some tasks on a separate thread.
dispatchTasks :: (Ord a) => [(Queue a,a,STM ())] -> IO [TaskHandle a]
dispatchTasks task_records = 
    do tasks <- mapM (\(q,a,actionSTM) -> atomically $ putTask q a actionSTM) task_records
       forkIO $ pullSpecificTasks tasks
       return tasks

-- | Process a 'Queue' until it is empty.
flushQueue :: (Ord a) => Queue a -> IO ()
flushQueue q =
    do want_zero <- atomically $ 
           do l <- load q
              when (l > 0) $ pullTask q >> return ()
              return l
       unless (want_zero == 0) $ flushQueue q

setTopOfQueue :: (Ord a) => Queue a -> Bool -> STM Bool
setTopOfQueue q t =
    do m_view <- liftM view $ readTVar (pending_tasks q)
       case m_view of
           Nothing -> return True
           Just (task,_) -> 
               do previous_t <- readTVar (is_top_of_queue task)
                  writeTVar (is_top_of_queue task) t
                  return previous_t

watchingTopOfQueue :: (Ord a) => Queue a -> STM b -> STM b
watchingTopOfQueue q actionSTM =
    do should_be_true <- setTopOfQueue q False
       unless should_be_true $ error "watchingTopOfQueue: not reentrant"
       result <- actionSTM
       setTopOfQueue q True
       return result