{-# LANGUAGE DeriveDataTypeable, QuasiQuotes, Rank2Types, ScopedTypeVariables #-}
-- FlexibleInstances, TypeSynonymInstances #-}

-- | Module with the internal workhorse for the library, 'parallelTasks'.  You only
-- need to use this module if you want to alter 'ExtendedParTaskOpts', which allows
-- you to redirect the logging output or store information about task timing.
module Control.Concurrent.ParallelTasks.Base (ExtendedParTaskOpts(..), ParTaskOpts(..), SimpleParTaskOpts(..), TaskOutcome(Success, TookTooLong), defaultExtendedParTaskOpts, defaultParTaskOpts, parallelTasks) where

import Control.Applicative ((<$), (<$>))
import Control.Concurrent (forkIO, getNumCapabilities, killThread, myThreadId, threadDelay)
import Control.Concurrent.STM (TVar, atomically, newTVarIO, readTVar, retry, writeTVar)
import Control.Concurrent.STM.TMVar (newTMVar, putTMVar, takeTMVar)
import Control.Exception.Base (Exception, bracket_, evaluate, handle, onException, throwTo)
import Control.Monad (replicateM_, when)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Maybe (isJust)
import Data.String.Here.Interpolated (i)
import Data.Time.Clock (UTCTime, addUTCTime, diffUTCTime, getCurrentTime)
import Data.Time.Format (formatTime)
import Data.Typeable
import qualified Data.Vector.Mutable as V (IOVector, new, write)
import System.IO (Handle, hFlush, hPutStrLn, stderr)
import System.Locale (defaultTimeLocale)

data SimpleParTaskOpts = SimpleParTaskOpts {
  -- | Number of worker threads to use.  When this is Nothing, defaults to number of capabilities (see @numCapabilities@)
  numberWorkers :: Maybe Int,
  -- | How often to print the progress of the tasks.  E.g. when Just 100, print a message roughly
  -- after the completion of every 100 tasks.
  printProgress :: Maybe Int,
  -- | How often to print an estimate of the estimated completion time.  E.g. when Just 100,
  -- print an estimate after the completion of every 100 tasks.
  printEstimate :: Maybe Int
  }

-- | Options controlling the general running of parallel tasks.  The @m@ parameter is the monad (which must be an instance
-- of 'MonadIO') in which the tasks will be run, and the @a@ parameter is the return value of the tasks.
data ParTaskOpts m a = ParTaskOpts {
  -- | The simple options.
  simpleOpts :: SimpleParTaskOpts,
  -- | Function to use to run the @m@ monad on top of IO.  The returned function is run at least once per worker, so should support 
  -- being run multiple times in parallel, and should clean up after itself.  Suitable instance for IO is simply @return id@.
  wrapWorker :: forall r. m (m r -> IO r),
  -- | When Just, the number of microseconds to let each task run for, before assuming it will
  -- not complete, and killing it off.  In the case that the task is killed off, the second
  -- part of the pair is the value that will be stored in the vector.
  timeLimit :: Maybe (Integer, a)
 }
                          
-- | Advanced options controlling the behaviour of parallel tasks.  The @m@ parameter
-- is the monad that the tasks execute in, the @a@ parameter is the output value of the
-- tasks, and the @b@ parameter is the type that is stored in the results array.  It is
-- common that either @b = a@ or @b = Maybe a@.
data ExtendedParTaskOpts m a = ExtendedParTaskOpts {
  -- | Core options
  coreOpts :: ParTaskOpts m a,
  -- | Function that supplies a handle to an inner block to write messages to.
  -- To use stdout or stderr, you can just supply @($ stdout)@.  To write to a file,
  -- use @withFile \"blah\" WriteMode@.
  printTo :: forall r. (Handle -> IO r) -> IO r,
  -- | Function used to store the outcome of the task.  Arguments are (in order):
  --
  -- * Time that the task took to complete (in seconds)
  --
  -- * Index at which to store the result (same as index of the task in the original tasks list)
  --
  -- * The outcome of the task
  --
  -- If a String is returned, it is logged
  afterFinish :: Double -> Int -> TaskOutcome -> IO (Maybe String)
  }

-- | Value indicating whether a task successfully completed, or was killed off for taking too long
data TaskOutcome = Success | TookTooLong

data TookTooLongException = TookTooLongException deriving (Show, Typeable)
instance Exception TookTooLongException

-- | A version of threadDelay that accommodates delays longer than @maxBound :: Int@.
threadDelay' :: Integer -> IO ()
threadDelay' target
  | target > mx = threadDelay maxBound >> threadDelay' (target - mx)
  | otherwise = threadDelay (fromInteger target)
  where
    mx :: Integer
    mx = toInteger (maxBound :: Int)

-- | Default extended options.  Prints messages to stderr, and writes a message when a
-- task is killed
defaultExtendedParTaskOpts :: MonadIO m => ParTaskOpts m a -> ExtendedParTaskOpts m a
defaultExtendedParTaskOpts opts = ExtendedParTaskOpts { coreOpts = opts, printTo = ($ stderr), afterFinish = printKill }
  where
    printKill _ n TookTooLong = return $ Just [i|*** Killed task ${n} for taking too long|]
    printKill _ _ _ = return Nothing

-- | Default parallel task options.  The number of workers defaults to the number of capabilities,
-- with no time limit, and printing progress every 50 tasks and an estimated time every 200
defaultParTaskOpts :: ParTaskOpts IO a
defaultParTaskOpts = ParTaskOpts { simpleOpts = SimpleParTaskOpts {numberWorkers = Nothing, printProgress = Just 50, printEstimate = Just 200 }, wrapWorker = return id, timeLimit = Nothing }

-- | Runs the given set of computations in parallel, and once they are all finished, returns their results.
-- Note that they won't all be run in parallel from the start; rather, a set of
-- workers will be spawned that work their way through the (potentially large) set of jobs.
parallelTasks :: MonadIO m => ExtendedParTaskOpts m a -> [m a] -> m (V.IOVector a)
parallelTasks _ [] = liftIO $ V.new 0
parallelTasks opts tasks = wrapWorker (coreOpts opts) >>= \run -> liftIO $ printTo opts $ \h -> do
  let numTasks = length tasks
  numWorkers <- maybe getNumCapabilities return (numberWorkers $ simpleOpts $ coreOpts opts)
  vValue <- V.new numTasks
  tvWork <- newTVarIO (numTasks, zip [0..] tasks)
  tvDone <- newTVarIO numWorkers
  startTime <- getCurrentTime  
  let printStartEnd = isJust $ printProgress $ simpleOpts $ coreOpts opts
  when printStartEnd $
    hPrintTime h [i|Total tasks: ${numTasks}, starting at: |]  
  hMutex <- atomically $ newTMVar ()
  let safeLog s = bracket_ (atomically $ takeTMVar hMutex) (atomically $ putTMVar hMutex ()) (hPutStrLnFlush h s)
  replicateM_ numWorkers $ forkIO $
    worker (timeLimit $ coreOpts opts) run
      (\t n x o -> V.write vValue n x >> afterFinish opts t n o >>= maybe (return ()) safeLog)
      tvWork tvDone
  waitForWorkers (simpleOpts $ coreOpts opts) safeLog startTime numTasks tvWork tvDone
  when printStartEnd $
    hPrintTime h "Finished at: "
  return vValue

hPrintTime :: Handle -> String -> IO ()
hPrintTime h msg = getCurrentTime >>= hPutStrLnFlush h . (msg ++) . show

-- | Primarily waits for the last TVar to hit zero, but in the mean time prints details
-- of the tasks remaining and the estimated completion time
waitForWorkers :: SimpleParTaskOpts -> (String -> IO ()) -> UTCTime -> Int -> TVar (Int, _x) -> TVar Int -> IO ()
waitForWorkers opts safeLog startTime totalTasks tvWork tvDone = go True totalTasks totalTasks
  where
    go False _ _ = return ()
    go True lastProgress lastETA = do
      (workersRemaining, tasksRemaining, timeToPrintProgress, timeToPrintETA) <- atomically $ do
        tasksRemaining <- fst <$> readTVar tvWork
        workersRemaining <- readTVar tvDone
        let timeToPrintProgress = case printProgress opts of
              Nothing -> False
              Just n -> tasksRemaining <= lastProgress - n
            timeToPrintETA = case printEstimate opts of
              Nothing -> False
              Just n -> tasksRemaining <= lastETA - n
        when (workersRemaining > 0 && not timeToPrintETA && not timeToPrintProgress) retry
        return (workersRemaining, tasksRemaining, timeToPrintProgress, timeToPrintETA)
      curTime <- getCurrentTime
      safeLog $ concat
        [if timeToPrintProgress then [i|Tasks remaining: ${tasksRemaining} |] else ""
        ,if timeToPrintETA then [i|ETA: ${eta startTime totalTasks curTime tasksRemaining}|] else ""]
      go (workersRemaining > 0)
         (if timeToPrintProgress then tasksRemaining else lastProgress)
         (if timeToPrintETA then tasksRemaining else lastETA)

eta :: UTCTime -> Int -> UTCTime -> Int -> String
eta startTime totalTasks curTime tasksRemaining
  = [i|${timeLeft} seconds (${showTime timeFinish})|]
  where
    timeSoFar = curTime `diffUTCTime` startTime
    timePerTask = timeSoFar / fromIntegral (totalTasks - tasksRemaining)
    timeLeft = timePerTask * fromIntegral tasksRemaining
    timeFinish = timeLeft `addUTCTime` curTime

    showTime :: UTCTime -> String
    showTime = formatTime defaultTimeLocale "%T %F"


-- | Repeatedly picks next job from queue and executes it.
worker :: forall m a. MonadIO m =>
          -- Limit in microseconds for the computations:
          Maybe (Integer, a) ->
          -- Function for running the monad:
          (m () -> IO ()) ->
          -- Function for storing the result of the computation:
          (Double -> Int -> a -> TaskOutcome -> IO ()) -> 
          -- Variable to grab the next work item from:
          TVar (Int, [(Int, m a)]) ->
          -- Variable to decrement when the worker finishes:
          TVar Int ->
          IO ()
worker mlimit run store tvWork tvDone = case mlimit of
    Just _limit -> handle (\TookTooLongException -> return ()) 
                          (run $ go True) `onException` finish
                          -- We do not call finish when we receive a TookTooLongException,
                          -- because another worker will have replaced us
    Nothing -> (run $ go True) `onException` finish
  where
    finish = atomically $ readTVar tvDone >>= writeTVar tvDone . pred
    
    go :: Bool -> m ()
    go False = liftIO finish
    go True = do nextWork <- liftIO $ atomically $ do
                          (count, work) <- readTVar tvWork
                          case work of
                             [] -> return Nothing
                             (w:ws) -> Just w <$ writeTVar tvWork (pred count, ws)
                 keepGoing <- case (nextWork, mlimit) of
                   (Nothing, _) -> return False
                   (Just (n, work), Nothing) -> do
                     (x, t) <- withTime work
                     liftIO $ store t n x Success
                     return True
                   (Just (n, work), Just (limit, def)) -> do
                     myId <- liftIO myThreadId
                     start <- liftIO getCurrentTime
                     -- Fork a watchdog to watch for if we have been executing too long:
                     theirId <- liftIO $ forkIO $ do
                       threadDelay' limit
                       end <- getCurrentTime
                       store (end `timeDiff` start) n def TookTooLong
                       -- We fork a new worker, because we cannot guarantee we will ever kill
                       -- the old one (e.g. if it is blocked, or not allocating memory)
                       -- TODO there is a slim chance that the worker can start and then we are killed
                       -- first:
                       _ <- forkIO $ worker mlimit run store tvWork tvDone
                       -- This call may block (effectively) forever, waiting for the other thread to die:
                       throwTo myId TookTooLongException
                     (x, t) <- withTime work
                     liftIO $ killThread theirId -- Kill the watchdog
                     liftIO $ store t n x Success
                     return True
                 go keepGoing

-- | A version of putStrLn that flushes after writing (and works in any MonadIO monad)
hPutStrLnFlush :: MonadIO m => Handle -> String -> m ()
hPutStrLnFlush h s = liftIO $ hPutStrLn h s >> hFlush h

withTime :: MonadIO m => m a -> m (a, Double)
withTime m = do
  start <- liftIO getCurrentTime
  x <- m
  end <- liftIO getCurrentTime
  t <- liftIO $ evaluate $ end `timeDiff` start
  return (x, t)
  
timeDiff :: UTCTime -> UTCTime -> Double
timeDiff end start = fromRational $ toRational (end `diffUTCTime` start)