{-# LANGUAGE Safe #-}
module Control.Concurrent.PooledIO.Monad where

import Control.Concurrent.MVar (MVar, newEmptyMVar, takeMVar, putMVar)
import Control.Concurrent (forkIO, getNumCapabilities)
import Control.DeepSeq (NFData, ($!!))
import Control.Exception (finally)

import qualified Control.Monad.Trans.State as MS
import qualified Control.Monad.Trans.Reader as MR
import qualified Control.Monad.Trans.Class as MT
import Control.Monad.IO.Class (MonadIO, liftIO)

import Control.Monad (replicateM_)
import Control.Functor.HT (void)


type T = MR.ReaderT (MVar ()) (MS.StateT Int IO)


fork :: (NFData a) => IO a -> T (IO a)
fork act = do
   complete <- MR.ask
   initial <- MT.lift MS.get
   if initial>0
     then MT.lift $ MS.put (initial-1)
     else liftIO $ takeMVar complete
   liftIO $ do
      result <- newEmptyMVar
      forkFinally complete $ (putMVar result $!!) =<< act
      return $ takeMVar result

forkFinally :: MVar () -> IO () -> IO ()
forkFinally mvar act =
   void $ forkIO $ finally act $ putMVar mvar ()

withNumCapabilities :: (Int -> a -> IO b) -> a -> IO b
withNumCapabilities run acts = do
   numCaps <- getNumCapabilities
   run numCaps acts

runLimited :: Int -> T a -> IO a
runLimited maxThreads m = do
   complete <- newEmptyMVar
   (result, uninitialized) <-
      MS.runStateT (MR.runReaderT m complete) maxThreads
   replicateM_ (maxThreads-uninitialized) $ takeMVar complete
   return result