module Control.Concurrent.PooledIO.Monad where

import Control.Concurrent.MVar (MVar, newEmptyMVar, takeMVar, putMVar)
import Control.Concurrent (forkIO, getNumCapabilities)
import Control.DeepSeq (NFData, deepseq)
import Control.Exception (finally)

import qualified Control.Monad.Trans.State as MS
import qualified Control.Monad.Trans.Reader as MR
import qualified Control.Monad.Trans.Class as MT
import Control.Monad.IO.Class (MonadIO, liftIO)

import Control.Monad (replicateM_)
import Control.Functor.HT (void)


type T = MR.ReaderT (MVar ()) (MS.StateT Int IO)


fork :: (NFData a) => IO a -> T (IO a)
fork act = do
   complete <- MR.ask
   initial <- MT.lift MS.get
   if initial>0
     then MT.lift $ MS.put (initial-1)
     else liftIO $ takeMVar complete
   liftIO $ do
      result <- newEmptyMVar
      forkFinally complete $ do
         r <- act
         deepseq r $ putMVar result r
      return $ takeMVar result

forkFinally :: MVar () -> IO () -> IO ()
forkFinally mvar act =
   void $ forkIO $ finally act $ putMVar mvar ()

withNumCapabilities :: (Int -> a -> IO b) -> a -> IO b
withNumCapabilities run acts = do
   numCaps <- getNumCapabilities
   run numCaps acts

runLimited :: Int -> T a -> IO a
runLimited maxThreads m = do
   complete <- newEmptyMVar
   (result, uninitialized) <-
      MS.runStateT (MR.runReaderT m complete) maxThreads
   replicateM_ (maxThreads-uninitialized) $ takeMVar complete
   return result