{-# LANGUAGE BangPatterns              #-}
{-# LANGUAGE RecordWildCards           #-}
{-# LANGUAGE ScopedTypeVariables       #-}
module Data.Massiv.Core.Scheduler
  ( Scheduler
  , numWorkers
  , scheduleWork
  , withScheduler
  , withScheduler'
  , withScheduler_
  , divideWork
  , divideWork_
  ) where
import           Control.Concurrent           (ThreadId, forkOnWithUnmask,
                                               getNumCapabilities, killThread)
import           Control.Concurrent.MVar
import           Control.DeepSeq
import           Control.Exception            (SomeException, catch, mask,
                                               mask_, throwIO, try,
                                               uninterruptibleMask_)
import           Control.Monad                (void, forM)
import           Control.Monad.Primitive      (RealWorld)
import           Data.IORef                   (IORef, atomicModifyIORef',
                                               newIORef, readIORef)
import           Data.Massiv.Core.Index.Class (Index (totalElem))
import           Data.Massiv.Core.Iterator    (loop)
import           Data.Primitive.Array         (Array, MutableArray, indexArray,
                                               newArray, unsafeFreezeArray,
                                               writeArray)
import           System.IO.Unsafe             (unsafePerformIO)
import           System.Mem.Weak
data Job = Job (IO ())
         | Retire
data Scheduler a = Scheduler
  { jobsCountIORef  :: !(IORef Int)
  , jobQueueMVar    :: !(MVar [Job])
  , resultsMVar     :: !(MVar (MutableArray RealWorld a))
  , workers         :: !Workers
  , numCapabilities :: {-# UNPACK #-} !Int
  }
numWorkers :: Scheduler a -> Int
numWorkers = numCapabilities
data Workers = Workers { workerThreadIds :: ![ThreadId]
                       , workerJobDone   :: !(MVar (Maybe SomeException))
                       , workerJobQueue  :: !(MVar [Job])
                       }
scheduleWork :: Scheduler a 
             -> IO a 
             -> IO ()
scheduleWork Scheduler {..} jobAction =
  modifyMVar_ jobQueueMVar $ \jobs -> do
    jix <- atomicModifyIORef' jobsCountIORef $ \jc -> (jc + 1, jc)
    let job =
          Job $ do
            jobResult <- jobAction
            withMVar resultsMVar $ \resArray -> do
              writeArray resArray jix jobResult
              putMVar (workerJobDone workers) Nothing
    return (job : jobs)
uninitialized :: a
uninitialized = error "Data.Array.Massiv.Scheduler: uncomputed job result"
bracketWithException :: forall a b c d .
  IO a 
  -> (a -> IO b) 
  -> (SomeException -> a -> IO c) 
  -> (a -> IO d) 
  -> IO d
bracketWithException before afterSuccess afterError thing = mask $ \restore -> do
  x <- before
  eRes <- try $ restore (thing x)
  case eRes of
    Left (exc :: SomeException) -> do
      _ :: Either SomeException c <- try $ uninterruptibleMask_ $ afterError exc x
      throwIO exc
    Right y -> do
      _ <- uninterruptibleMask_ $ afterSuccess x
      return y
withScheduler :: [Int] 
                       
              -> (Scheduler a -> IO b) 
                                       
              -> IO (Int, Array a)
withScheduler wss submitJobs = do
  jobsCountIORef <- newIORef 0
  jobQueueMVar <- newMVar []
  resultsMVar <- newEmptyMVar
  bracketWithException
    (do mWeakWorkers <-
          if null wss
            then tryTakeMVar globalWorkersMVar
            else return Nothing
        mGlobalWorkers <- maybe (return Nothing) deRefWeak mWeakWorkers
        let toWorkers w = return (mWeakWorkers, w)
        maybe (hireWorkers wss >>= toWorkers) toWorkers mGlobalWorkers)
    (\(mWeakWorkers, workers) ->
       case mWeakWorkers of
         Nothing ->
           putMVar (workerJobQueue workers) $
           replicate (length (workerThreadIds workers)) Retire
         Just weak -> putMVar globalWorkersMVar weak)
    (\_ (mWeakWorkers, workers) ->
       case mWeakWorkers of
         Nothing -> mapM_ killThread (workerThreadIds workers)
         Just weakWorkers -> do
           finalize weakWorkers
           newWeakWorkers <- hireWeakWorkers globalWorkersMVar
           putMVar globalWorkersMVar newWeakWorkers)
    (\(_, workers) -> do
       let scheduler =
             Scheduler {numCapabilities = length $ workerThreadIds workers, ..}
       _ <- submitJobs scheduler
       jobCount <- readIORef jobsCountIORef
       marr <- newArray jobCount uninitialized
       putMVar resultsMVar marr
       jobQueue <- takeMVar jobQueueMVar
       putMVar (workerJobQueue workers) $ reverse jobQueue
       waitTillDone scheduler
       arr <- unsafeFreezeArray marr
       return (jobCount, arr))
withScheduler' :: [Int] -> (Scheduler a -> IO b) -> IO [a]
withScheduler' wss submitJobs = do
  (jc, arr) <- withScheduler wss submitJobs
  return $
    loop (jc - 1) (>= 0) (subtract 1) [] $ \i acc -> indexArray arr i : acc
withScheduler_ :: [Int] -> (Scheduler a -> IO b) -> IO ()
withScheduler_ wss submitJobs = void $ withScheduler wss submitJobs
divideWork_ :: Index ix
            => [Int] -> ix -> (Scheduler a -> Int -> Int -> Int -> IO b) -> IO ()
divideWork_ wss sz submit = void $ divideWork wss sz submit
divideWork :: Index ix
           => [Int] 
           -> ix 
           -> (Scheduler a -> Int -> Int -> Int -> IO b) 
           -> IO [a]
divideWork wss sz submit
  | totalElem sz == 0 = return []
  | otherwise =
    withScheduler' wss $ \scheduler -> do
      let !totalLength = totalElem sz
          !chunkLength = totalLength `quot` numWorkers scheduler
          !slackStart = chunkLength * numWorkers scheduler
      submit scheduler chunkLength totalLength slackStart
waitTillDone :: Scheduler a -> IO ()
waitTillDone Scheduler {..} = readIORef jobsCountIORef >>= waitTill 0
  where
    waitTill jobsDone jobsCount
      | jobsDone == jobsCount = return ()
      | otherwise = do
          mExc <- takeMVar (workerJobDone workers)
          case mExc of
            Just exc -> throwIO exc
            Nothing  -> waitTill (jobsDone + 1) jobsCount
runWorker :: MVar [Job] -> IO ()
runWorker jobsMVar = do
  jobs <- takeMVar jobsMVar
  case jobs of
    (Job job:rest) -> putMVar jobsMVar rest >> job >> runWorker jobsMVar
    (Retire:rest)  -> putMVar jobsMVar rest
    []             -> runWorker jobsMVar
hireWorkers :: [Int] -> IO Workers
hireWorkers wss = do
  wss' <-
    if null wss
      then do
        wNum <- getNumCapabilities
        return [0 .. wNum - 1]
      else return wss
  workerJobQueue <- newEmptyMVar
  workerJobDone <- newEmptyMVar
  workerThreadIds <-
    forM wss' $ \ws ->
      mask_ $
      forkOnWithUnmask ws $ \unmask -> do
        catch
          (unmask $ runWorker workerJobQueue)
          (unmask . putMVar workerJobDone . Just)
  workerThreadIds `deepseq` return Workers {..}
globalWorkersMVar :: MVar (Weak Workers)
globalWorkersMVar = unsafePerformIO $ do
  workersMVar <- newEmptyMVar
  weakWorkers <- hireWeakWorkers workersMVar
  putMVar workersMVar weakWorkers
  return workersMVar
{-# NOINLINE globalWorkersMVar #-}
hireWeakWorkers :: key -> IO (Weak Workers)
hireWeakWorkers k = do
  workers <- hireWorkers []
  mkWeak k workers (Just (mapM_ killThread (workerThreadIds workers)))