-- | You can imagine the workers as a room full of translators, who you pass
--   pages in to in one language and expect translated pages to come out the
--   other side. You'll get faster results if there's more than one translator
--   in the room, but you still need the pages to come out the other side in
--   the order their equivalents came in.
-- 
--   You want the benefits of multiple translators working at the same time,
--   but you don't want to or have the ability to hold all the pages in memory
--   and sort them after the fact.
-- 
--   'forkOrderlyWorkers' could be compared to 'mapConcurrently' from 'async'
--   except you don't need to hold all input or output values in memory as a
--   list. In addition, there don't need to be a finite number of elements.

{-# LANGUAGE
     LambdaCase
   #-}
   -- , ScopedTypeVariables

module Control.Concurrent.OrderlyWorkers (
     forkOrderlyWorkers
   ) where

import Control.Concurrent (forkIO)
import Control.Concurrent.MVar
import Control.Concurrent.STM
-- We don't really 'need' STM here probably, but this is a robust bounded queue implementation (also consider unagi-chan):
-- import Control.Concurrent.STM.TBQueue
import Control.Monad (replicateM_)


data Finished = Finished

-- | 'Nothing' for the producer indicates there's no more input.
-- 
--   'forkOrderlyWorkers' doesn't return until the producer has returned a
--   'Nothing' and all the remaining 'Just' values have been processed and
--   their values given to the consumer.
--   In other words, until all the work has been done.
forkOrderlyWorkers :: Int -> (input -> IO output) -> IO (Maybe input) -> (output -> IO ()) -> IO ()
forkOrderlyWorkers numWorkerThreads workerFunction aProducer bConsumer = do

   finishedMVar <- newEmptyMVar :: IO (MVar ())

   -- TODO: how large should the queues be?
   aQueue <- newTBQueueIO (toEnum numWorkerThreads)
   bQueue <- newTBQueueIO (toEnum numWorkerThreads)

   _ <- forkIO $ producerLoop aQueue bQueue
   _ <- forkIO $ consumerLoop bQueue finishedMVar

   replicateM_ numWorkerThreads $ forkIO $ workerLoop aQueue

   -- Wait for all threads to finish before returning:
   takeMVar finishedMVar

 where
   -- producerLoop :: TBQueue (Either Finished (input, MVar output)) -> TBQueue (Either Finished (MVar output)) -> IO ()
   producerLoop aQueue bQueue =
      aProducer >>= \case
         Nothing -> do
            atomically $ do
               replicateM_ numWorkerThreads $
                  writeTBQueue aQueue (Left Finished)
               writeTBQueue bQueue (Left Finished)
         Just a -> do
            resultVar <- newEmptyMVar
            atomically $ do
               writeTBQueue aQueue $ Right (a, resultVar)
               writeTBQueue bQueue $ Right resultVar
            producerLoop aQueue bQueue

   -- consumerLoop :: TBQueue (Either Finished (MVar output)) -> MVar () -> IO ()
   consumerLoop bQueue finishedMVar =
      atomically (readTBQueue bQueue) >>= \case
         Left Finished -> putMVar finishedMVar ()
         Right bVar -> do
            result <- takeMVar bVar -- This might block till something's put in the MVar
            bConsumer result
            consumerLoop bQueue finishedMVar

   -- workerLoop :: TBQueue (input, MVar output) -> IO ()
   workerLoop aQueue = do -- bQueue =
      atomically (readTBQueue aQueue) >>= \case
         Left Finished -> pure ()
         Right (val, var) -> do
            result <- workerFunction val
            putMVar var $! result
            workerLoop aQueue

--    todo: what's the best behavior if we encounter exceptions in the worker function?