-- |
-- Makes sure that all dispatched works are done.
-- Notify when all dispatched works are done.
module Server.ResponseController
  ( ResponseController
  , new
  , dispatch
  , setCheckpointAndWait
  ) where

import           Control.Concurrent
import           Control.Concurrent.SizedChan
import           Control.Monad                  ( void
                                                , when
                                                )
import           Data.IORef

data ResponseController = ResponseController
  { -- | The number of work dispatched
    ResponseController -> IORef Int
dispatchedCount :: IORef Int
  ,
    -- | The number of work completed
    ResponseController -> IORef Int
completedCount  :: IORef Int
  ,
    -- | A channel of "Checkpoints" to be met
    ResponseController -> SizedChan Checkpoint
checkpointChan  :: SizedChan Checkpoint
  }

-- | An "Checkpoint" is just a number with a callback, the callback will be invoked once the number is "met"
type Checkpoint = (Int, () -> IO ())

-- | Constructs a new ResponseController
new :: IO ResponseController
new :: IO ResponseController
new = IORef Int
-> IORef Int -> SizedChan Checkpoint -> ResponseController
ResponseController (IORef Int
 -> IORef Int -> SizedChan Checkpoint -> ResponseController)
-> IO (IORef Int)
-> IO (IORef Int -> SizedChan Checkpoint -> ResponseController)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0 IO (IORef Int -> SizedChan Checkpoint -> ResponseController)
-> IO (IORef Int)
-> IO (SizedChan Checkpoint -> ResponseController)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0 IO (SizedChan Checkpoint -> ResponseController)
-> IO (SizedChan Checkpoint) -> IO ResponseController
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (SizedChan Checkpoint)
forall a. IO (SizedChan a)
newSizedChan

-- | Returns a callback, invoked the callback to signal completion.
-- This function and the returned callback are both non-blocking.
dispatch :: ResponseController -> IO (() -> IO ())
dispatch :: ResponseController -> IO (() -> IO ())
dispatch ResponseController
controller = do
  -- bump `dispatchedCount`
  IORef Int -> (Int -> Int) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (ResponseController -> IORef Int
dispatchedCount ResponseController
controller) Int -> Int
forall a. Enum a => a -> a
succ
  (() -> IO ()) -> IO (() -> IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((() -> IO ()) -> IO (() -> IO ()))
-> (() -> IO ()) -> IO (() -> IO ())
forall a b. (a -> b) -> a -> b
$ \() -> do
    -- work completed, bump `completedCount`
    IORef Int -> (Int -> Int) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (ResponseController -> IORef Int
completedCount ResponseController
controller) Int -> Int
forall a. Enum a => a -> a
succ

    -- see if there's any Checkpoint
    Maybe Checkpoint
result <- SizedChan Checkpoint -> IO (Maybe Checkpoint)
forall a. SizedChan a -> IO (Maybe a)
tryPeekSizedChan (ResponseController -> SizedChan Checkpoint
checkpointChan ResponseController
controller)
    case Maybe Checkpoint
result of
      -- no checkpoints, do nothing
      Maybe Checkpoint
Nothing                     -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      -- a checkpoint is set!
      Just (Int
dispatched, () -> IO ()
callback) -> do
        Int
completed <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef (ResponseController -> IORef Int
completedCount ResponseController
controller)
        -- see if the checkpoint is met
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
dispatched Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
completed) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          -- invoke the callback and remove the checkpoint
          () -> IO ()
callback ()
          IO Checkpoint -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Checkpoint -> IO ()) -> IO Checkpoint -> IO ()
forall a b. (a -> b) -> a -> b
$ SizedChan Checkpoint -> IO Checkpoint
forall a. SizedChan a -> IO a
readSizedChan (ResponseController -> SizedChan Checkpoint
checkpointChan ResponseController
controller)

-- | Expects a callback, which will be invoked once all works dispatched BEFORE have been completed
-- This function is non-blocking
setCheckpoint :: ResponseController -> (() -> IO ()) -> IO ()
setCheckpoint :: ResponseController -> (() -> IO ()) -> IO ()
setCheckpoint ResponseController
controller () -> IO ()
callback = do
  Int
dispatched <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef (ResponseController -> IORef Int
dispatchedCount ResponseController
controller)
  Int
completed  <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef (ResponseController -> IORef Int
completedCount ResponseController
controller)
  -- see if the previously dispatched works have been completed
  if Int
dispatched Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
completed
    then () -> IO ()
callback ()
    else do
      -- constructs a Checkpoint from `dispatchedCount`
      let checkpoint :: Checkpoint
checkpoint = (Int
dispatched, () -> IO ()
callback)
      -- write it to the channel
      SizedChan Checkpoint -> Checkpoint -> IO ()
forall a. SizedChan a -> a -> IO ()
writeSizedChan (ResponseController -> SizedChan Checkpoint
checkpointChan ResponseController
controller) Checkpoint
checkpoint

-- | The blocking version of `setCheckpoint`
setCheckpointAndWait :: ResponseController -> IO ()
setCheckpointAndWait :: ResponseController -> IO ()
setCheckpointAndWait ResponseController
controller = do
  MVar ()
mvar <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
  ResponseController -> (() -> IO ()) -> IO ()
setCheckpoint ResponseController
controller (MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
mvar)
  MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
mvar