-- Work stealing scheduler and thread pools
--
-- Author: Patrick Maier
-----------------------------------------------------------------------------

{-# LANGUAGE GeneralizedNewtypeDeriving #-}  -- req'd for type 'RTS'
{-# LANGUAGE ScopedTypeVariables #-}         -- req'd for type annotations

module Control.Parallel.HdpH.Internal.Scheduler
  ( -- * abstract run-time system monad
    RTS,          -- instances: Monad, Functor
    run_,         -- :: RTSConf -> RTS () -> IO ()
    liftThreadM,  -- :: ThreadM RTS a -> RTS a
    liftSparkM,   -- :: SparkM RTS a -> RTS a
    liftCommM,    -- :: CommM a -> RTS a
    liftIO,       -- :: IO a -> RTS a

    -- * scheduler ID
    schedulerID,  -- :: RTS Int

    -- * converting and executing threads
    mkThread,     -- :: ParM RTS a -> Thread RTS
    execThread,   -- :: Thread RTS -> RTS ()

    -- * pushing sparks
    sendPUSH      -- :: Spark RTS -> NodeId -> RTS ()
  ) where

import Prelude hiding (error)
import Control.Concurrent (ThreadId, forkIO, killThread)
import Control.Monad (unless, replicateM)
import Data.Functor ((<$>))

import Control.Parallel.HdpH.Closure (unClosure)
import Control.Parallel.HdpH.Conf (RTSConf(scheds, wakeupDly))
import Control.Parallel.HdpH.Internal.Comm (CommM)
import qualified Control.Parallel.HdpH.Internal.Comm as Comm
       (myNode, send, receive, run_, waitShutdown)
import qualified Control.Parallel.HdpH.Internal.Data.Deque as Deque (emptyIO)
import qualified Control.Parallel.HdpH.Internal.Data.Sem as Sem
       (new, signalPeriodically)
import Control.Parallel.HdpH.Internal.Location
       (NodeId, dbgNone, dbgStats, dbgMsgSend, dbgMsgRcvd, error)
import qualified Control.Parallel.HdpH.Internal.Location as Location (debug)
import Control.Parallel.HdpH.Internal.Misc
       (encodeLazy, decodeLazy, ActionServer, newServer, killServer)
import Control.Parallel.HdpH.Internal.Sparkpool
       (SparkM, blockSched, getSpark, Msg(PUSH), dispatch, readPoolSize,
        readFishSentCtr, readSparkRcvdCtr, readSparkGenCtr, readMaxSparkCtr)
import qualified Control.Parallel.HdpH.Internal.Sparkpool as Sparkpool (run)
import Control.Parallel.HdpH.Internal.Threadpool
       (ThreadM, poolID, forkThreadM, stealThread, readMaxThreadCtrs)
import qualified Control.Parallel.HdpH.Internal.Threadpool as Threadpool
       (run, liftSparkM, liftCommM, liftIO)
import Control.Parallel.HdpH.Internal.Type.Par
       (ParM, unPar, Thread(Atom), Spark)


-----------------------------------------------------------------------------
-- RTS monad

-- The RTS monad hides monad stack (IO, CommM, SparkM, ThreadM) as abstract.
newtype RTS a = RTS { unRTS :: ThreadM RTS a }
                deriving (Functor, Monad)


-- Fork a new thread to execute the given 'RTS' action; the integer 'n'
-- dictates how much to rotate the thread pools (so as to avoid contention
-- due to concurrent access).
forkRTS :: Int -> RTS () -> RTS ThreadId
forkRTS n = liftThreadM . forkThreadM n . unRTS


-- Eliminate the whole RTS monad stack down the IO monad by running the given
-- RTS action 'main'; aspects of the RTS's behaviour are controlled by
-- the respective parameters in the given RTSConf.
-- NOTE: This function start various threads (for executing schedulers, 
--       a message handler, and various timeouts). On normal termination,
--       all these threads are killed. However, there is no cleanup in the 
--       event of aborting execution due to an exception. The functions
--       for doing so (see Control.Execption) all live in the IO monad.
--       Maybe they could be lifted to the RTS monad by using the monad-peel
--       package.
run_ :: RTSConf -> RTS () -> IO ()
run_ conf main = do
  let n = scheds conf
  unless (n > 0) $
    error "HdpH.Internal.Scheduler.run_: no schedulers"

  -- allocate n+1 empty thread pools (numbered from 0 to n)
  pools <- mapM (\ k -> do { pool <- Deque.emptyIO; return (k,pool) }) [0 .. n]

  -- fork nowork server (for clearing the "FISH outstanding" flag on NOWORK)
  noWorkServer <- newServer

  -- create semaphore for idle schedulers
  idleSem <- Sem.new

  -- fork wakeup server (periodically waking up racey sleeping scheds)
  wakeupServerTid <- forkIO $ Sem.signalPeriodically idleSem (wakeupDly conf)

  -- start the RTS
  Comm.run_ conf $
    Sparkpool.run conf noWorkServer idleSem $
      Threadpool.run pools $
        unRTS $
          rts n noWorkServer wakeupServerTid

    where
      -- RTS action
      rts :: Int -> ActionServer -> ThreadId -> RTS ()
      rts scheds noWorkServer wakeupServerTid = do

        -- fork message handler (accessing thread pool 0)
        handlerTid <- forkRTS 0 handler

        -- fork schedulers (each accessing thread pool k, 1 <= k <= scheds)
        schedulerTids <- mapM (\ k -> forkRTS k scheduler) [1 .. scheds]

        -- run main RTS action
        main

        -- block waiting for shutdown barrier
        liftCommM $ Comm.waitShutdown

        -- print stats
        printFinalStats

        -- kill nowork server
        liftIO $ killServer noWorkServer

        -- kill wakeup server
        liftIO $ killThread wakeupServerTid

        -- kill message handler
        liftIO $ killThread handlerTid

        -- kill schedulers
        liftIO $ mapM_ killThread schedulerTids


-- lifting lower layers
liftThreadM :: ThreadM RTS a -> RTS a
liftThreadM = RTS

liftSparkM :: SparkM RTS a -> RTS a
liftSparkM = liftThreadM . Threadpool.liftSparkM

liftCommM :: CommM a -> RTS a
liftCommM = liftThreadM . Threadpool.liftCommM

liftIO :: IO a -> RTS a
liftIO = liftThreadM . Threadpool.liftIO


-- Return scheduler ID, that is ID of scheduler's own thread pool.
schedulerID :: RTS Int
schedulerID = liftThreadM poolID


-----------------------------------------------------------------------------
-- cooperative scheduling

-- Execute the given thread until it blocks or terminates.
execThread :: Thread RTS -> RTS ()
execThread (Atom m) = m >>= maybe (return ()) execThread


-- Try to get a thread from a thread pool or the spark pool and execute it
-- until it blocks or terminates, whence repeat forever; if there is no
-- thread to execute then block the scheduler (ie. its underlying IO thread).
scheduler :: RTS ()
scheduler = getThread >>= scheduleThread


-- Execute given thread until it blocks or terminates, whence call 'scheduler'.
scheduleThread :: Thread RTS -> RTS ()
scheduleThread (Atom m) = m >>= maybe scheduler scheduleThread


-- Try to steal a thread from any thread pool (with own pool preferred);
-- if there is none, try to convert a spark from the spark pool;
-- if there is none too, block the scheduler such that the 'getThread'
-- action will be repeated on wake up.
-- NOTE: Sleeping schedulers should be woken up
--       * after new threads have been added to a thread pool,
--       * after new sparks have been added to the spark pool, and
--       * once the delay after a NOWORK message has expired.
getThread :: RTS (Thread RTS)
getThread = do
  schedID <- schedulerID
  maybe_thread <- liftThreadM stealThread
  case maybe_thread of
    Just thread -> return thread
    Nothing     -> do
      maybe_spark <- liftSparkM $ getSpark schedID
      case maybe_spark of
        Just spark -> return $ mkThread $ unClosure spark
        Nothing    -> liftSparkM blockSched >> getThread


-- Converts 'Par' computations into threads.
mkThread :: ParM RTS a -> Thread RTS
mkThread p = unPar p $ \ _c -> Atom (return Nothing)


-----------------------------------------------------------------------------
-- pushed sparks

-- Send a 'spark' via PUSH message to the given 'target' unless 'target'
-- is the current node (in which case 'spark' is executed immediately).
sendPUSH :: Spark RTS -> NodeId -> RTS ()
sendPUSH spark target = do
  here <- liftCommM Comm.myNode
  if target == here
    then do
      -- short cut PUSH msg locally
      execSpark spark
    else do
      -- construct and send PUSH message
      let msg = PUSH spark :: Msg RTS
      debug dbgMsgSend $
        show msg ++ " ->> " ++ show target
      liftCommM $ Comm.send target $ encodeLazy msg


-- Handle a PUSH message by converting the spark into a thread and
-- executing it immediately.
handlePUSH :: Msg RTS -> RTS ()
handlePUSH (PUSH spark) = execSpark spark


-- Execute a spark (by converting it to a thread and executing).
execSpark :: Spark RTS -> RTS ()
execSpark spark = execThread $ mkThread $ unClosure spark


-----------------------------------------------------------------------------
-- message handler; only PUSH messages are actually handled here in this
-- module, other messages are relegated to module Sparkpool.

-- Message handler, running continously (in its own thread) receiving
-- and handling messages (some of which may unblock threads or create sparks)
-- as they arrive.
handler :: RTS ()
handler = do
  msg <- decodeLazy <$> liftCommM Comm.receive
  sparks <- liftSparkM readPoolSize
  debug dbgMsgRcvd $
    ">> " ++ show msg ++ " #sparks=" ++ show sparks
  case msg of
    PUSH _ -> handlePUSH msg
    _      -> liftSparkM $ dispatch msg
  handler


-----------------------------------------------------------------------------
-- auxiliary stuff

-- Print stats (#sparks, threads, FISH, ...) at appropriate debug level.
-- TODO: Log time elapsed since RTS is up
printFinalStats :: RTS ()
printFinalStats = do
  fishes <- liftSparkM $ readFishSentCtr
  scheds <- liftSparkM $ readSparkRcvdCtr
  sparks <- liftSparkM $ readSparkGenCtr
  max_sparks  <- liftSparkM $ readMaxSparkCtr
  maxs_threads <- liftThreadM $ readMaxThreadCtrs
  debug dbgStats $ "#SPARK=" ++ show sparks ++ "   " ++
                   "max_SPARK=" ++ show max_sparks ++ "   " ++
                   "max_THREAD=" ++ show maxs_threads
  debug dbgStats $ "#FISH_sent=" ++ show fishes ++ "   " ++
                   "#SCHED_rcvd=" ++ show scheds

debug :: Int -> String -> RTS ()
debug level message = liftIO $ Location.debug level message