{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Control.Massiv.Scheduler
  ( 
    Comp(..)
  , Scheduler(..)
  
  , withScheduler
  , withScheduler_
  
  , fromWorkerAsyncException
  , traverseConcurrently
  , traverseConcurrently_
  , traverse_
  ) where
import Control.Concurrent
import Control.Exception
import Control.Massiv.Scheduler.Computation
import Control.Massiv.Scheduler.Queue
import Control.Monad
import Control.Monad.IO.Unlift
import Data.Atomics (atomicModifyIORefCAS, atomicModifyIORefCAS_)
import Data.Foldable as F (foldl')
import Data.IORef
import Data.Traversable
data Jobs m a = Jobs
  { jobsNumWorkers :: {-# UNPACK #-} !Int
  , jobsQueue      :: !(JQueue m a)
  , jobsCountRef   :: !(IORef Int)
  }
  
                       
                       
  
data Scheduler m a = Scheduler
  { numWorkers   :: {-# UNPACK #-} !Int
  
  , scheduleWork :: m a -> m ()
  
  }
traverse_ :: (Applicative f, Foldable t) => (a -> f ()) -> t a -> f ()
traverse_ f = F.foldl' (\c a -> c *> f a) (pure ())
traverseConcurrently :: (MonadUnliftIO m, Traversable t) => Comp -> (a -> m b) -> t a -> m (t b)
traverseConcurrently comp f xs = do
  ys <- withScheduler comp $ \s -> traverse_ (scheduleWork s . f) xs
  pure $ transList ys xs
transList :: Traversable t => [a] -> t b -> t a
transList xs' = snd . mapAccumL withR xs'
  where
    withR (x:xs) _ = (xs, x)
    withR _      _ = error "Impossible<traverseConcurrently> - Mismatched sizes"
traverseConcurrently_ :: (MonadUnliftIO m, Foldable t) => Comp -> (a -> m b) -> t a -> m ()
traverseConcurrently_ comp f xs = withScheduler_ comp $ \s -> traverse_ (scheduleWork s . f) xs
scheduleJobs :: MonadIO m => Jobs m a -> m a -> m ()
scheduleJobs = scheduleJobsWith mkJob
scheduleJobs_ :: MonadIO m => Jobs m a -> m b -> m ()
scheduleJobs_ = scheduleJobsWith (return . Job_ . void)
scheduleJobsWith :: MonadIO m => (m b -> m (Job m a)) -> Jobs m a -> m b -> m ()
scheduleJobsWith mkJob' Jobs {jobsQueue, jobsCountRef, jobsNumWorkers} action = do
  liftIO $ atomicModifyIORefCAS_ jobsCountRef (+ 1)
  job <-
    mkJob' $ do
      res <- action
      res `seq` dropCounterOnZero jobsCountRef $ retireWorkersN jobsQueue jobsNumWorkers
      return res
  pushJQueue jobsQueue job
retireWorkersN :: MonadIO m => JQueue m a -> Int -> m ()
retireWorkersN jobsQueue n = traverse_ (pushJQueue jobsQueue) $ replicate n Retire
dropCounterOnZero :: MonadIO m => IORef Int -> m () -> m ()
dropCounterOnZero counterRef onZero = do
  jc <-
    liftIO $ atomicModifyIORefCAS
      counterRef
      (\ !i' ->
         let !i = i' - 1
          in (i, i))
  when (jc == 0) onZero
runWorker :: MonadIO m =>
             JQueue m a
          -> m () 
          -> m ()
runWorker jQueue onRetire = go
  where
    go =
      popJQueue jQueue >>= \case
        Just job -> job >> go
        Nothing -> onRetire
withScheduler ::
     MonadUnliftIO m
  => Comp 
  -> (Scheduler m a -> m b)
     
  -> m [a]
withScheduler comp = withSchedulerInternal comp scheduleJobs flushResults
withScheduler_ ::
     MonadUnliftIO m
  => Comp 
  -> (Scheduler m a -> m b)
     
  -> m ()
withScheduler_ comp = withSchedulerInternal comp scheduleJobs_ (const (pure ()))
withSchedulerInternal ::
     MonadUnliftIO m
  => Comp 
  -> (Jobs m a -> m a -> m ()) 
  -> (JQueue m a -> m c) 
  -> (Scheduler m a -> m b)
     
  -> m c
withSchedulerInternal comp submitWork collect onScheduler = do
  jobsNumWorkers <-
    case comp of
      Seq      -> return 1
      Par      -> liftIO getNumCapabilities
      ParOn ws -> return $ length ws
      ParN 0   -> liftIO getNumCapabilities
      ParN n   -> return $ fromIntegral n
  sWorkersCounterRef <- liftIO $ newIORef jobsNumWorkers
  jobsQueue <- newJQueue
  jobsCountRef <- liftIO $ newIORef 0
  workDoneMVar <- liftIO newEmptyMVar
  let jobs = Jobs {..}
      scheduler = Scheduler {numWorkers = jobsNumWorkers, scheduleWork = submitWork jobs}
      onRetire = dropCounterOnZero sWorkersCounterRef $ liftIO (putMVar workDoneMVar Nothing)
  
  
  _ <- onScheduler scheduler
  
  jc <- liftIO $ readIORef jobsCountRef
  when (jc == 0) $ scheduleJobs_ jobs (pure ())
  let spawnWorkersWith fork ws =
        withRunInIO $ \run ->
          forM ws $ \w ->
            fork w $ \unmask ->
              catch
                (unmask $ run $ runWorker jobsQueue onRetire)
                (run . handleWorkerException jobsQueue workDoneMVar jobsNumWorkers)
      {-# INLINE spawnWorkersWith #-}
      spawnWorkers =
        case comp of
          Seq -> return []
            
          Par -> spawnWorkersWith forkOnWithUnmask [1 .. jobsNumWorkers]
          ParOn ws -> spawnWorkersWith forkOnWithUnmask ws
          ParN _ -> spawnWorkersWith (\_ -> forkIOWithUnmask) [1 .. jobsNumWorkers]
      {-# INLINE spawnWorkers #-}
      doWork = do
        when (comp == Seq) $ runWorker jobsQueue onRetire
        mExc <- liftIO $ readMVar workDoneMVar
        
        
        case mExc of
          Nothing                    -> collect jobsQueue
          
          
          Just (WorkerException exc) -> liftIO $ throwIO exc
          
          
      {-# INLINE doWork #-}
  safeBracketOnError
    spawnWorkers
    (liftIO . traverse_ (`throwTo` SomeAsyncException WorkerTerminateException))
    (const doWork)
handleWorkerException ::
  MonadIO m => JQueue m a -> MVar (Maybe WorkerException) -> Int -> SomeException -> m ()
handleWorkerException jQueue workDoneMVar nWorkers exc =
  case asyncExceptionFromException exc of
    Just WorkerTerminateException -> return ()
      
    _ -> do
      _ <- liftIO $ tryPutMVar workDoneMVar $ Just $ WorkerException exc
      
      
      retireWorkersN jQueue (nWorkers - 1)
newtype WorkerException =
  WorkerException SomeException
  
  deriving (Show)
instance Exception WorkerException where
  displayException workerExc =
    case workerExc of
      WorkerException exc ->
        "A worker handled a job that ended with exception: " ++ displayException exc
data WorkerTerminateException =
  WorkerTerminateException
  
  
  deriving (Show)
instance Exception WorkerTerminateException where
  displayException WorkerTerminateException = "A worker was terminated by the scheduler"
fromWorkerAsyncException :: Exception e => SomeException -> Maybe e
fromWorkerAsyncException = asyncExceptionFromException
{-# DEPRECATED fromWorkerAsyncException "In favor of `asyncExceptionFromException`" #-}
safeBracketOnError :: MonadUnliftIO m => m a -> (a -> m b) -> (a -> m c) -> m c
safeBracketOnError before after thing = withRunInIO $ \run -> mask $ \restore -> do
  x <- run before
  res1 <- try $ restore $ run $ thing x
  case res1 of
    Left (e1 :: SomeException) -> do
      _ :: Either SomeException b <-
        try $ uninterruptibleMask_ $ run $ after x
      throwIO e1
    Right y -> return y