{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Experimenter.ConcurrentIO
    ( doFork
    , collectForkResult
    , mapConurrentIO
    ) where

import           Control.Concurrent     (forkIO, yield)
import           Control.Concurrent.STM
import           Control.DeepSeq
import           Control.Monad          (void)
import           Data.IORef
import           Data.Maybe             (fromJust)


mapConurrentIO :: (NFData b) => Int -> (a -> IO b) -> [a] -> IO [b]
mapConurrentIO :: Int -> (a -> IO b) -> [a] -> IO [b]
mapConurrentIO Int
maxNr a -> IO b
f [a]
xs = do
  TMVar Int
nr <- Int -> IO (TMVar Int)
forall a. a -> IO (TMVar a)
newTMVarIO Int
0
  TMVar Int -> Int -> (a -> IO b) -> [a] -> IO [b]
forall b a.
NFData b =>
TMVar Int -> Int -> (a -> IO b) -> [a] -> IO [b]
mapConurrentIO' TMVar Int
nr Int
maxNr a -> IO b
f [a]
xs

mapConurrentIO' :: (NFData b) => TMVar Int -> Int -> (a -> IO b) -> [a] -> IO [b]
mapConurrentIO' :: TMVar Int -> Int -> (a -> IO b) -> [a] -> IO [b]
mapConurrentIO' TMVar Int
_ Int
_ a -> IO b
_ [] = [b] -> IO [b]
forall (m :: * -> *) a. Monad m => a -> m a
return []
mapConurrentIO' TMVar Int
tmVar Int
maxNr a -> IO b
f (a
x:[a]
xs) = do
  Int
nr <- (Maybe Int -> Int) -> IO (Maybe Int) -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust (IO (Maybe Int) -> IO Int) -> IO (Maybe Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ STM (Maybe Int) -> IO (Maybe Int)
forall a. STM a -> IO a
atomically (STM (Maybe Int) -> IO (Maybe Int))
-> STM (Maybe Int) -> IO (Maybe Int)
forall a b. (a -> b) -> a -> b
$ TMVar Int -> STM (Maybe Int)
forall a. TMVar a -> STM (Maybe a)
tryReadTMVar TMVar Int
tmVar
  -- putStrLn ("Nr: " ++ show nr) >> hFlush stdout
  if Int
nr Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxNr
    then STM Int -> IO Int
forall a. STM a -> IO a
atomically (TMVar Int -> STM Int
forall a. TMVar a -> STM a
readTMVar TMVar Int
tmVar) IO Int -> IO [b] -> IO [b]
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TMVar Int -> Int -> (a -> IO b) -> [a] -> IO [b]
forall b a.
NFData b =>
TMVar Int -> Int -> (a -> IO b) -> [a] -> IO [b]
mapConurrentIO' TMVar Int
tmVar Int
maxNr a -> IO b
f (a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs)
    else do
      IO ()
increase
      !IORef (ThreadState b)
xThread <- IO b -> IO (IORef (ThreadState b))
forall a. NFData a => IO a -> IO (IORef (ThreadState a))
doFork (IO b -> IO (IORef (ThreadState b)))
-> IO b -> IO (IORef (ThreadState b))
forall a b. (a -> b) -> a -> b
$ a -> IO b
f a
x IO b -> (b -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (\b
v -> IO ()
decrease IO () -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
v)
      [b]
xs' <- TMVar Int -> Int -> (a -> IO b) -> [a] -> IO [b]
forall b a.
NFData b =>
TMVar Int -> Int -> (a -> IO b) -> [a] -> IO [b]
mapConurrentIO' TMVar Int
tmVar Int
maxNr a -> IO b
f [a]
xs
      b
x' <- IORef (ThreadState b) -> IO b
forall a. IORef (ThreadState a) -> IO a
collectForkResult IORef (ThreadState b)
xThread
      [b] -> IO [b]
forall (m :: * -> *) a. Monad m => a -> m a
return (b
x' b -> [b] -> [b]
forall a. a -> [a] -> [a]
: [b]
xs')
  where
    increase :: IO ()
increase = (Int -> Int) -> IO ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    decrease :: IO ()
decrease = (Int -> Int) -> IO ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1)
    modify :: (Int -> Int) -> IO ()
modify Int -> Int
g =
      IO Int -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int -> IO ()) -> IO Int -> IO ()
forall a b. (a -> b) -> a -> b
$
      STM Int -> IO Int
forall a. STM a -> IO a
atomically (STM Int -> IO Int) -> STM Int -> IO Int
forall a b. (a -> b) -> a -> b
$ do
        Int
nr <- Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Int -> Int) -> STM (Maybe Int) -> STM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TMVar Int -> STM (Maybe Int)
forall a. TMVar a -> STM (Maybe a)
tryReadTMVar TMVar Int
tmVar
        TMVar Int -> Int -> STM Int
forall a. TMVar a -> a -> STM a
swapTMVar TMVar Int
tmVar (Int -> Int
g Int
nr)

doFork :: NFData a => IO a -> IO (IORef (ThreadState a))
doFork :: IO a -> IO (IORef (ThreadState a))
doFork IO a
f = do
  IORef (ThreadState a)
ref <- ThreadState a -> IO (IORef (ThreadState a))
forall a. a -> IO (IORef a)
newIORef ThreadState a
forall a. ThreadState a
NotReady
  IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO (IO a
f IO a -> (a -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef (ThreadState a) -> ThreadState a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (ThreadState a)
ref (ThreadState a -> IO ()) -> (a -> ThreadState a) -> a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> ThreadState a
forall a. a -> ThreadState a
Ready (a -> ThreadState a) -> (a -> a) -> a -> ThreadState a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. NFData a => a -> a
force)
  IORef (ThreadState a) -> IO (IORef (ThreadState a))
forall (m :: * -> *) a. Monad m => a -> m a
return IORef (ThreadState a)
ref

collectForkResult :: IORef (ThreadState a) -> IO a
collectForkResult :: IORef (ThreadState a) -> IO a
collectForkResult IORef (ThreadState a)
ref = do
  ThreadState a
mRes <- IORef (ThreadState a) -> IO (ThreadState a)
forall a. IORef a -> IO a
readIORef IORef (ThreadState a)
ref
  case ThreadState a
mRes of
    ThreadState a
NotReady -> IO ()
yield IO () -> IO a -> IO a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IORef (ThreadState a) -> IO a
forall a. IORef (ThreadState a) -> IO a
collectForkResult IORef (ThreadState a)
ref
    Ready a
a  -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

data ThreadState a = NotReady | Ready !a