{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}

-- | Wrapper around 'parallel' for limiting the threads using a semaphore.

module Test.Sandwich.ParallelN (parallelN) where

import Control.Concurrent.QSem
import Control.Exception.Safe
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Control (MonadBaseControl)
import Test.Sandwich.Contexts
import Test.Sandwich.Types.Spec



-- | Wrapper around 'parallel'. Introduces a semaphore to limit the parallelism to N threads.
parallelN :: (
  MonadBaseControl IO m, MonadIO m, MonadMask m
  ) => Int -> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m () -> SpecFree context m ()
parallelN :: Int
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
parallelN Int
n SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
children = Int
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
forall (m :: * -> *) context.
(MonadIO m, MonadBaseControl IO m) =>
Int
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
introduceParallelSemaphore Int
n (SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
 -> SpecFree context m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
forall a b. (a -> b) -> a -> b
$ SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
forall context (m :: * -> *).
HasCallStack =>
SpecFree context m () -> SpecFree context m ()
parallel (SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
 -> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
forall a b. (a -> b) -> a -> b
$ String
-> (ExampleT
      (LabelValue "parallelSemaphore" QSem :> context) m [Result]
    -> ExampleT (LabelValue "parallelSemaphore" QSem :> context) m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
forall (m :: * -> *) context.
(Monad m, HasCallStack) =>
String
-> (ExampleT context m [Result] -> ExampleT context m ())
-> SpecFree context m ()
-> SpecFree context m ()
aroundEach String
"Take parallel semaphore" ExampleT
  (LabelValue "parallelSemaphore" QSem :> context) m [Result]
-> ExampleT (LabelValue "parallelSemaphore" QSem :> context) m ()
forall (m :: * -> *) context a.
(HasLabel context "parallelSemaphore" QSem, MonadReader context m,
 MonadMask m, MonadIO m) =>
m a -> m ()
claimRunSlot SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
children
  where claimRunSlot :: m a -> m ()
claimRunSlot m a
f = do
          QSem
s <- Label "parallelSemaphore" QSem -> m QSem
forall (m :: * -> *) context (l :: Symbol) a.
(Monad m, HasLabel context l a, HasCallStack,
 MonadReader context m) =>
Label l a -> m a
getContext Label "parallelSemaphore" QSem
parallelSemaphore
          m () -> m () -> m () -> m ()
forall (m :: * -> *) a b c. MonadMask m => m a -> m b -> m c -> m c
bracket_ (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ QSem -> IO ()
waitQSem QSem
s) (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ QSem -> IO ()
signalQSem QSem
s) (m a -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void m a
f)

parallelSemaphore :: Label "parallelSemaphore" QSem
parallelSemaphore :: Label "parallelSemaphore" QSem
parallelSemaphore = Label "parallelSemaphore" QSem
forall k (l :: Symbol) (a :: k). Label l a
Label

introduceParallelSemaphore :: (
  MonadIO m, MonadBaseControl IO m
  ) => Int -> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m () -> SpecFree context m ()
introduceParallelSemaphore :: Int
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
introduceParallelSemaphore Int
n = String
-> Label "parallelSemaphore" QSem
-> ExampleT context m QSem
-> (QSem -> ExampleT context m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
forall intro (l :: Symbol) context (m :: * -> *).
(HasCallStack, Typeable intro) =>
String
-> Label l intro
-> ExampleT context m intro
-> (intro -> ExampleT context m ())
-> SpecFree (LabelValue l intro :> context) m ()
-> SpecFree context m ()
introduce String
"Introduce parallel semaphore" Label "parallelSemaphore" QSem
parallelSemaphore (IO QSem -> ExampleT context m QSem
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO QSem -> ExampleT context m QSem)
-> IO QSem -> ExampleT context m QSem
forall a b. (a -> b) -> a -> b
$ Int -> IO QSem
newQSem Int
n) (ExampleT context m () -> QSem -> ExampleT context m ()
forall a b. a -> b -> a
const (ExampleT context m () -> QSem -> ExampleT context m ())
-> ExampleT context m () -> QSem -> ExampleT context m ()
forall a b. (a -> b) -> a -> b
$ () -> ExampleT context m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())