{-# LANGUAGE CPP #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}

module Test.Hspec.Core.Runner.JobQueue (
  MonadIO
, Job
, Concurrency(..)
, JobQueue
, withJobQueue
, enqueueJob
) where

import           Prelude ()
import           Test.Hspec.Core.Compat hiding (Monad)
import qualified Test.Hspec.Core.Compat as M

import           Control.Concurrent
import           Control.Concurrent.Async (Async, AsyncCancelled(..), async, waitCatch, asyncThreadId)

import           Control.Monad.IO.Class (liftIO)
import qualified Control.Monad.IO.Class as M

-- for compatibility with GHC < 7.10.1
type Monad m = (Functor m, Applicative m, M.Monad m)
type MonadIO m = (Monad m, M.MonadIO m)

type Job m progress a = (progress -> m ()) -> m a

data Concurrency = Sequential | Concurrent

data JobQueue = JobQueue {
  JobQueue -> Semaphore
_semaphore :: Semaphore
, JobQueue -> CancelQueue
_cancelQueue :: CancelQueue
}

data Semaphore = Semaphore {
  Semaphore -> IO ()
_wait :: IO ()
, Semaphore -> IO ()
_signal :: IO ()
}

type CancelQueue = IORef [Async ()]

withJobQueue :: Int -> (JobQueue -> IO a) -> IO a
withJobQueue :: forall a. Int -> (JobQueue -> IO a) -> IO a
withJobQueue Int
concurrency = IO JobQueue -> (JobQueue -> IO ()) -> (JobQueue -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO JobQueue
new JobQueue -> IO ()
cancelAll
  where
    new :: IO JobQueue
    new :: IO JobQueue
new = Semaphore -> CancelQueue -> JobQueue
JobQueue (Semaphore -> CancelQueue -> JobQueue)
-> IO Semaphore -> IO (CancelQueue -> JobQueue)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO Semaphore
newSemaphore Int
concurrency IO (CancelQueue -> JobQueue) -> IO CancelQueue -> IO JobQueue
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Async ()] -> IO CancelQueue
forall a. a -> IO (IORef a)
newIORef []

    cancelAll :: JobQueue -> IO ()
    cancelAll :: JobQueue -> IO ()
cancelAll (JobQueue Semaphore
_ CancelQueue
cancelQueue) = CancelQueue -> IO [Async ()]
forall a. IORef a -> IO a
readIORef CancelQueue
cancelQueue IO [Async ()] -> ([Async ()] -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Async ()] -> IO ()
forall a. [Async a] -> IO ()
cancelMany

    cancelMany :: [Async a] -> IO ()
    cancelMany :: forall a. [Async a] -> IO ()
cancelMany [Async a]
jobs = do
      (Async a -> IO ()) -> [Async a] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Async a -> IO ()
forall a. Async a -> IO ()
notifyCancel [Async a]
jobs
      (Async a -> IO (Either SomeException a)) -> [Async a] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Async a -> IO (Either SomeException a)
forall a. Async a -> IO (Either SomeException a)
waitCatch [Async a]
jobs

    notifyCancel :: Async a -> IO ()
    notifyCancel :: forall a. Async a -> IO ()
notifyCancel = (ThreadId -> AsyncCancelled -> IO ())
-> AsyncCancelled -> ThreadId -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ThreadId -> AsyncCancelled -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo AsyncCancelled
AsyncCancelled (ThreadId -> IO ()) -> (Async a -> ThreadId) -> Async a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Async a -> ThreadId
forall a. Async a -> ThreadId
asyncThreadId

newSemaphore :: Int -> IO Semaphore
newSemaphore :: Int -> IO Semaphore
newSemaphore Int
capacity = do
  QSem
sem <- Int -> IO QSem
newQSem Int
capacity
  Semaphore -> IO Semaphore
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Semaphore -> IO Semaphore) -> Semaphore -> IO Semaphore
forall a b. (a -> b) -> a -> b
$ IO () -> IO () -> Semaphore
Semaphore (QSem -> IO ()
waitQSem QSem
sem) (QSem -> IO ()
signalQSem QSem
sem)

enqueueJob :: MonadIO m => JobQueue -> Concurrency -> Job IO progress a -> IO (Job m progress (Either SomeException a))
enqueueJob :: forall (m :: * -> *) progress a.
MonadIO m =>
JobQueue
-> Concurrency
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
enqueueJob (JobQueue Semaphore
sem CancelQueue
cancelQueue) Concurrency
concurrency = case Concurrency
concurrency of
  Concurrency
Sequential -> CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
forall (m :: * -> *) progress a.
MonadIO m =>
CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
runSequentially CancelQueue
cancelQueue
  Concurrency
Concurrent -> Semaphore
-> CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
forall (m :: * -> *) progress a.
MonadIO m =>
Semaphore
-> CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
runConcurrently Semaphore
sem CancelQueue
cancelQueue

runSequentially :: forall m progress a. MonadIO m => CancelQueue -> Job IO progress a -> IO (Job m progress (Either SomeException a))
runSequentially :: forall (m :: * -> *) progress a.
MonadIO m =>
CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
runSequentially CancelQueue
cancelQueue Job IO progress a
action = do
  MVar ()
barrier <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
  let
    wait :: IO ()
    wait :: IO ()
wait = MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
barrier

    signal :: m ()
    signal :: m ()
signal = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
barrier ()

  Job m progress (Either SomeException a)
job <- Semaphore
-> CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
forall (m :: * -> *) progress a.
MonadIO m =>
Semaphore
-> CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
runConcurrently (IO () -> IO () -> Semaphore
Semaphore IO ()
wait IO ()
forall (m :: * -> *). Applicative m => m ()
pass) CancelQueue
cancelQueue Job IO progress a
action
  Job m progress (Either SomeException a)
-> IO (Job m progress (Either SomeException a))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Job m progress (Either SomeException a)
 -> IO (Job m progress (Either SomeException a)))
-> Job m progress (Either SomeException a)
-> IO (Job m progress (Either SomeException a))
forall a b. (a -> b) -> a -> b
$ \ progress -> m ()
notifyPartial -> m ()
signal m () -> m (Either SomeException a) -> m (Either SomeException a)
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Job m progress (Either SomeException a)
job progress -> m ()
notifyPartial

data Partial progress a = Partial progress | Done

runConcurrently :: forall m progress a. MonadIO m => Semaphore -> CancelQueue -> Job IO progress a -> IO (Job m progress (Either SomeException a))
runConcurrently :: forall (m :: * -> *) progress a.
MonadIO m =>
Semaphore
-> CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
runConcurrently (Semaphore IO ()
wait IO ()
signal) CancelQueue
cancelQueue Job IO progress a
action = do
  MVar (Partial progress a)
result :: MVar (Partial progress a) <- IO (MVar (Partial progress a))
forall a. IO (MVar a)
newEmptyMVar
  let
    worker :: IO a
    worker :: IO a
worker = IO () -> IO () -> IO a -> IO a
forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ IO ()
wait IO ()
signal (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
      IO a -> IO a
forall a. IO a -> IO a
interruptible (Job IO progress a
action progress -> IO ()
partialResult) IO a -> IO () -> IO a
forall a b. IO a -> IO b -> IO a
`finally` IO ()
done
      where
        partialResult :: progress -> IO ()
        partialResult :: progress -> IO ()
partialResult = MVar (Partial progress a) -> Partial progress a -> IO ()
forall a. MVar a -> a -> IO ()
replaceMVar MVar (Partial progress a)
result (Partial progress a -> IO ())
-> (progress -> Partial progress a) -> progress -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. progress -> Partial progress a
forall progress a. progress -> Partial progress a
Partial

        done :: IO ()
        done :: IO ()
done = MVar (Partial progress a) -> Partial progress a -> IO ()
forall a. MVar a -> a -> IO ()
replaceMVar MVar (Partial progress a)
result Partial progress a
forall progress a. Partial progress a
Done

    pushOnCancelQueue :: Async a -> IO ()
    pushOnCancelQueue :: Async a -> IO ()
pushOnCancelQueue = (CancelQueue -> ([Async ()] -> [Async ()]) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef CancelQueue
cancelQueue (([Async ()] -> [Async ()]) -> IO ())
-> (Async a -> [Async ()] -> [Async ()]) -> Async a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:) (Async () -> [Async ()] -> [Async ()])
-> (Async a -> Async ()) -> Async a -> [Async ()] -> [Async ()]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Async a -> Async ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void)

  Async a
job <- IO (Async a)
-> (Async a -> IO ()) -> (Async a -> IO (Async a)) -> IO (Async a)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (IO a -> IO (Async a)
forall a. IO a -> IO (Async a)
async IO a
worker) Async a -> IO ()
pushOnCancelQueue Async a -> IO (Async a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return

  let
    waitForResult :: (progress -> m ()) -> m (Either SomeException a)
    waitForResult :: Job m progress (Either SomeException a)
waitForResult progress -> m ()
notifyPartial = do
      Partial progress a
r <- IO (Partial progress a) -> m (Partial progress a)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (MVar (Partial progress a) -> IO (Partial progress a)
forall a. MVar a -> IO a
takeMVar MVar (Partial progress a)
result)
      case Partial progress a
r of
        Partial progress
progress -> progress -> m ()
notifyPartial progress
progress m () -> m (Either SomeException a) -> m (Either SomeException a)
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Job m progress (Either SomeException a)
waitForResult progress -> m ()
notifyPartial
        Partial progress a
Done -> IO (Either SomeException a) -> m (Either SomeException a)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SomeException a) -> m (Either SomeException a))
-> IO (Either SomeException a) -> m (Either SomeException a)
forall a b. (a -> b) -> a -> b
$ Async a -> IO (Either SomeException a)
forall a. Async a -> IO (Either SomeException a)
waitCatch Async a
job

  Job m progress (Either SomeException a)
-> IO (Job m progress (Either SomeException a))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Job m progress (Either SomeException a)
waitForResult

replaceMVar :: MVar a -> a -> IO ()
replaceMVar :: forall a. MVar a -> a -> IO ()
replaceMVar MVar a
mvar a
p = MVar a -> IO (Maybe a)
forall a. MVar a -> IO (Maybe a)
tryTakeMVar MVar a
mvar IO (Maybe a) -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar a
mvar a
p