{-|
Module      : Control.Concurrent.Bag.Task
Description : Task implementations
Copyright   : (c) Bastian Holst, 2014
License     : BSD3
Maintainer  : bastianholst@gmx.de
Stability   : experimental
Portability : POSIX

Task implementations 'Interruptible' and 'TaskIO'. These can be use by some of
the bag of tasks interfaces such as "Control.Concurrent.Bag.Implicit"
and "Control.Concurrent.Bag.ImplicitConcurrent".
-}
{-# LANGUAGE MultiParamTypeClasses, TypeSynonymInstances, FlexibleInstances #-}
module Control.Concurrent.Bag.Task
  ( TaskIO
  , runTaskIO
  , addTaskIO
  , writeResult
  , Interruptible (..)
  , runInterrupted
  , runInterruptible
  , WriteResult (..)
  , AddTask (..) )
where

import Control.Applicative
import Control.Monad.Reader
import Control.Monad.Writer
import Control.Concurrent (forkIOWithUnmask, forkIO, threadDelay)
import Control.Concurrent.Bag.TaskBuffer
import Control.Concurrent.MVar
import Control.Concurrent (killThread)
import Control.Exception
import qualified Control.Concurrent.Bag.Basic as Basic

-- | A monad in which tasks can be specified.
--   Task is instancing 'MonadIO' and it therefore has the function 'liftIO' to
--   perform arbitrary IO actions. Tasks may or may not return a value. If it
--   returns a value, this value is written back as a result.
--   Additionally there is a function 'addTask' to
--   add new tasks to the bag.
--   The parameter /r/ is the result type of the corresponding bag.
--   In contrast to 'Interruptible' the evaluation order is simililar to that of
--   the IO monad and tasks added by addTaskIO are added immediately.
newtype TaskIO r a = TaskIO { getTaskReader :: ReaderT (IO (Maybe r) -> IO (), r -> IO ()) IO a }

instance Functor (TaskIO r) where
  fmap = liftM

instance Applicative (TaskIO r) where
  pure  = return
  (<*>) = ap

instance Monad (TaskIO r) where
  return = TaskIO . return
  (TaskIO a) >>= b = TaskIO $ a >>= getTaskReader . b

instance MonadIO (TaskIO r) where
  liftIO act = TaskIO $ lift act

-- | Function to write back a result.
type WriteResult r = r -> IO ()
-- | Function to add a task to the bag of tasks.
type AddTask     r = IO (Maybe r) -> IO ()

-- | Run a task as an 'IO' action.
runTaskIO :: TaskIO r (Maybe r) -- ^ The task to be run
          -> AddTask r          -- ^ Function to add a new task
          -> WriteResult r      -- ^ Function to write back a result
          -> IO (Maybe r)       -- ^ Returns a value if the task did
runTaskIO tio addTask addResult = runReaderT (getTaskReader tio) (addTask, addResult)

-- | Add a task to the bag of tasks from another task.
--   The task will be added immediately.
addTaskIO :: TaskIO r (Maybe r) -> TaskIO r ()
addTaskIO task =
  TaskIO $ do
    (add, addR) <- ask
    liftIO $ add (runTaskIO task add addR)

-- | Write back a result from a task.
writeResult :: r -> TaskIO r ()
writeResult x =
  TaskIO $ do
    (_, add) <- ask
    liftIO $ add x

-- | A type to specify interruptible tasks. Interruptible tasks are tasks that
--   can be interrupted and resumed later. Basically this means that the
--   evaluating thread may be killed in between. Side-effects in this code are
--   not allowed, and therefore all interrupted tasks have to be pure
--   functional in contrast to 'TaskIO' tasks. Otherwise this is similar to
--   'TaskIO'.
data Interruptible r = NoResult
                     | OneResult r
                     | AddInterruptibles [Interruptible r]

-- | Run the given interruptible task.
--
--   Run with this function, the task will not be interrupted.
runInterruptible :: Interruptible r -> TaskIO r (Maybe r)
runInterruptible cur = do
  case cur of
    NoResult    -> return   Nothing
    OneResult r ->
      liftIO $ evaluate r >>= return . Just
    AddInterruptibles inters -> do
      liftIO (evaluateList inters) >>= mapM (\i -> addTaskIO $ runInterruptible i)
      return Nothing

-- | Run the given interruptible task with interruptions.
--
--   Note: When using this function, it may be possible that some tasks are
--   never completed, because the time to produce an intermediate result is
--   longer than the interruption frequency.
runInterrupted :: Interruptible r
                 -> TaskIO r (Maybe r)
runInterrupted cur = do
  resultVar <- liftIO newEmptyMVar
  tid <- liftIO $ uninterruptibleMask_ $ forkIOWithUnmask $ \restore -> do
    r <- restore (do
        case cur of
          NoResult    -> return NoResult
          OneResult r ->
            evaluate r >>= return . OneResult
          AddInterruptibles inters ->
            evaluateList inters >>= return . AddInterruptibles)
      `onException`
        (putMVar resultVar Nothing)
    putMVar resultVar $ Just r
  stopper <- liftIO $ forkIO $ (threadDelay 1000 >> killThread tid)
  rs <- liftIO $ takeMVar resultVar
  liftIO $ killThread stopper
  case rs of
    Nothing               -> do
      addTaskIO $ runInterrupted cur
      return Nothing
    Just  NoResult     -> return   Nothing
    Just (OneResult r) -> return $ Just r
    Just (AddInterruptibles inters) -> do
      mapM (\i -> addTaskIO $ runInterrupted i) inters
      return Nothing

-- | Evaluate the complete structure of a list.
evaluateList [] = return []
evaluateList (x:xs) = do
  xs' <- evaluateList xs
  return $ x:xs'