{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE ScopedTypeVariables #-}
module BroadcastChan.Conduit.Internal (parMapM, parMapM_) where

import Control.Monad ((>=>))
import Control.Monad.Trans.Resource (MonadResource)
import qualified Control.Monad.Trans.Resource as Resource
import qualified Control.Monad.Trans.Resource.Internal as ResourceI
import Control.Monad.Trans.Class (lift)
import Control.Monad.IO.Unlift (MonadUnliftIO, UnliftIO(..), askUnliftIO)
import Data.Acquire (ReleaseType(..), allocateAcquire, mkAcquireType)
import Data.Conduit (ConduitM, (.|), awaitForever, yield)
import qualified Data.Conduit.List as C
import Data.Foldable (traverse_)
import Data.Void (Void)

import BroadcastChan.Extra (BracketOnError(..), Handler, ThreadBracket(..))
import qualified BroadcastChan.Extra as Extra

bracketOnError :: MonadResource m => IO a -> (a -> IO ()) -> m r -> m r
bracketOnError :: IO a -> (a -> IO ()) -> m r -> m r
bracketOnError IO a
alloc a -> IO ()
clean m r
work =
    Acquire a -> m (ReleaseKey, a)
forall (m :: * -> *) a.
MonadResource m =>
Acquire a -> m (ReleaseKey, a)
allocateAcquire (IO a -> (a -> ReleaseType -> IO ()) -> Acquire a
forall a. IO a -> (a -> ReleaseType -> IO ()) -> Acquire a
mkAcquireType IO a
alloc a -> ReleaseType -> IO ()
cleanup) m (ReleaseKey, a) -> ((ReleaseKey, a) -> m r) -> m r
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= m r -> (ReleaseKey, a) -> m r
forall a b. a -> b -> a
const m r
work
  where
    cleanup :: a -> ReleaseType -> IO ()
cleanup a
x ReleaseType
ReleaseException = a -> IO ()
clean a
x
    cleanup a
_ ReleaseType
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Create a conduit that processes inputs in parallel.
--
-- This function does __NOT__ guarantee that input elements are processed or
-- output in a deterministic order!
parMapM
    :: (MonadResource m, MonadUnliftIO m)
    => Handler m a
    -- ^ Exception handler
    -> Int
    -- ^ Number of parallel threads to use
    -> (a -> m b)
    -- ^ Function to run in parallel
    -> ConduitM a b m ()
parMapM :: Handler m a -> Int -> (a -> m b) -> ConduitM a b m ()
parMapM Handler m a
hnd Int
threads a -> m b
workFun = do
    UnliftIO forall a. m a -> IO a
runInIO <- m (UnliftIO m) -> ConduitT a b m (UnliftIO m)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (UnliftIO m)
forall (m :: * -> *). MonadUnliftIO m => m (UnliftIO m)
askUnliftIO

    InternalState
resourceState <- ResourceT IO InternalState -> ConduitT a b m InternalState
forall (m :: * -> *) a. MonadResource m => ResourceT IO a -> m a
Resource.liftResourceT ResourceT IO InternalState
forall (m :: * -> *). Monad m => ResourceT m InternalState
Resource.getInternalState

    let threadBracket :: ThreadBracket
threadBracket = ThreadBracket :: IO () -> IO () -> IO () -> ThreadBracket
ThreadBracket
            { setupFork :: IO ()
setupFork = InternalState -> IO ()
ResourceI.stateAlloc InternalState
resourceState
            , cleanupFork :: IO ()
cleanupFork = ReleaseType -> InternalState -> IO ()
ResourceI.stateCleanup ReleaseType
ReleaseNormal InternalState
resourceState
            , cleanupForkError :: IO ()
cleanupForkError =
                ReleaseType -> InternalState -> IO ()
ResourceI.stateCleanup ReleaseType
ReleaseException InternalState
resourceState
            }

    Bracket{IO [Weak ThreadId]
allocate :: forall (m :: * -> *) r. BracketOnError m r -> IO [Weak ThreadId]
allocate :: IO [Weak ThreadId]
allocate,[Weak ThreadId] -> IO ()
cleanup :: forall (m :: * -> *) r.
BracketOnError m r -> [Weak ThreadId] -> IO ()
cleanup :: [Weak ThreadId] -> IO ()
cleanup,ConduitM a b m ()
action :: forall (m :: * -> *) r. BracketOnError m r -> m r
action :: ConduitM a b m ()
action} <- ThreadBracket
-> Either (b -> ConduitM a b m ()) (() -> b -> ConduitM a b m ())
-> Handler IO a
-> Int
-> (a -> IO b)
-> ((a -> m ()) -> (a -> m (Maybe b)) -> ConduitM a b m ())
-> ConduitT a b m (BracketOnError (ConduitT a b m) ())
forall a b (m :: * -> *) (n :: * -> *) r.
(MonadIO m, MonadIO n) =>
ThreadBracket
-> Either (b -> n r) (r -> b -> n r)
-> Handler IO a
-> Int
-> (a -> IO b)
-> ((a -> m ()) -> (a -> m (Maybe b)) -> n r)
-> n (BracketOnError n r)
Extra.runParallelWith
        ThreadBracket
threadBracket
        ((b -> ConduitM a b m ())
-> Either (b -> ConduitM a b m ()) (() -> b -> ConduitM a b m ())
forall a b. a -> Either a b
Left b -> ConduitM a b m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield)
        ((m Action -> IO Action) -> Handler m a -> Handler IO a
forall (m :: * -> *) (n :: * -> *) a.
(m Action -> n Action) -> Handler m a -> Handler n a
Extra.mapHandler m Action -> IO Action
forall a. m a -> IO a
runInIO Handler m a
hnd)
        Int
threads
        (m b -> IO b
forall a. m a -> IO a
runInIO (m b -> IO b) -> (a -> m b) -> a -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
workFun)
        (a -> m ()) -> (a -> m (Maybe b)) -> ConduitM a b m ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m ()) -> (a -> m (Maybe b)) -> ConduitM a b m ()
body

    IO [Weak ThreadId]
-> ([Weak ThreadId] -> IO ())
-> ConduitM a b m ()
-> ConduitM a b m ()
forall (m :: * -> *) a r.
MonadResource m =>
IO a -> (a -> IO ()) -> m r -> m r
bracketOnError IO [Weak ThreadId]
allocate [Weak ThreadId] -> IO ()
cleanup ConduitM a b m ()
action
  where
    body :: Monad m => (a -> m ()) -> (a -> m (Maybe b)) -> ConduitM a b m ()
    body :: (a -> m ()) -> (a -> m (Maybe b)) -> ConduitM a b m ()
body a -> m ()
buffer a -> m (Maybe b)
process = do
        Int -> ConduitT a a m ()
forall (m :: * -> *) a. Monad m => Int -> ConduitT a a m ()
C.isolate Int
threads ConduitT a a m () -> ConduitM a b m () -> ConduitM a b m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| (a -> m ()) -> ConduitM a b m ()
forall (m :: * -> *) a o.
Monad m =>
(a -> m ()) -> ConduitT a o m ()
C.mapM_ a -> m ()
buffer
        (a -> ConduitM a b m ()) -> ConduitM a b m ()
forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever ((a -> ConduitM a b m ()) -> ConduitM a b m ())
-> (a -> ConduitM a b m ()) -> ConduitM a b m ()
forall a b. (a -> b) -> a -> b
$ m (Maybe b) -> ConduitT a b m (Maybe b)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Maybe b) -> ConduitT a b m (Maybe b))
-> (a -> m (Maybe b)) -> a -> ConduitT a b m (Maybe b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m (Maybe b)
process (a -> ConduitT a b m (Maybe b))
-> (Maybe b -> ConduitM a b m ()) -> a -> ConduitM a b m ()
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (b -> ConduitM a b m ()) -> Maybe b -> ConduitM a b m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ b -> ConduitM a b m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield

-- | Create a conduit sink that consumes inputs in parallel.
--
-- This function does __NOT__ guarantee that input elements are processed or
-- output in a deterministic order!
parMapM_
    :: (MonadResource m, MonadUnliftIO m)
    => Handler m a
    -- ^ Exception handler
    -> Int
    -- ^ Number of parallel threads to use
    -> (a -> m ())
    -- ^ Function to run in parallel
    -> ConduitM a Void m ()
parMapM_ :: Handler m a -> Int -> (a -> m ()) -> ConduitM a Void m ()
parMapM_ Handler m a
hnd Int
threads a -> m ()
workFun = do
    UnliftIO forall a. m a -> IO a
runInIO <- m (UnliftIO m) -> ConduitT a Void m (UnliftIO m)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (UnliftIO m)
forall (m :: * -> *). MonadUnliftIO m => m (UnliftIO m)
askUnliftIO

    InternalState
resourceState <- ResourceT IO InternalState -> ConduitT a Void m InternalState
forall (m :: * -> *) a. MonadResource m => ResourceT IO a -> m a
Resource.liftResourceT ResourceT IO InternalState
forall (m :: * -> *). Monad m => ResourceT m InternalState
Resource.getInternalState

    let threadBracket :: ThreadBracket
threadBracket = ThreadBracket :: IO () -> IO () -> IO () -> ThreadBracket
ThreadBracket
            { setupFork :: IO ()
setupFork = InternalState -> IO ()
ResourceI.stateAlloc InternalState
resourceState
            , cleanupFork :: IO ()
cleanupFork = ReleaseType -> InternalState -> IO ()
ResourceI.stateCleanup ReleaseType
ReleaseNormal InternalState
resourceState
            , cleanupForkError :: IO ()
cleanupForkError =
                ReleaseType -> InternalState -> IO ()
ResourceI.stateCleanup ReleaseType
ReleaseException InternalState
resourceState
            }

    Bracket{IO [Weak ThreadId]
allocate :: IO [Weak ThreadId]
allocate :: forall (m :: * -> *) r. BracketOnError m r -> IO [Weak ThreadId]
allocate,[Weak ThreadId] -> IO ()
cleanup :: [Weak ThreadId] -> IO ()
cleanup :: forall (m :: * -> *) r.
BracketOnError m r -> [Weak ThreadId] -> IO ()
cleanup,ConduitM a Void m ()
action :: ConduitM a Void m ()
action :: forall (m :: * -> *) r. BracketOnError m r -> m r
action} <- ThreadBracket
-> Handler IO a
-> Int
-> (a -> IO ())
-> ((a -> m ()) -> ConduitM a Void m ())
-> ConduitT a Void m (BracketOnError (ConduitT a Void m) ())
forall (m :: * -> *) (n :: * -> *) a r.
(MonadIO m, MonadIO n) =>
ThreadBracket
-> Handler IO a
-> Int
-> (a -> IO ())
-> ((a -> m ()) -> n r)
-> n (BracketOnError n r)
Extra.runParallelWith_
        ThreadBracket
threadBracket
        ((m Action -> IO Action) -> Handler m a -> Handler IO a
forall (m :: * -> *) (n :: * -> *) a.
(m Action -> n Action) -> Handler m a -> Handler n a
Extra.mapHandler m Action -> IO Action
forall a. m a -> IO a
runInIO Handler m a
hnd)
        Int
threads
        (m () -> IO ()
forall a. m a -> IO a
runInIO (m () -> IO ()) -> (a -> m ()) -> a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m ()
workFun)
        (a -> m ()) -> ConduitM a Void m ()
forall (m :: * -> *) a o.
Monad m =>
(a -> m ()) -> ConduitT a o m ()
C.mapM_

    IO [Weak ThreadId]
-> ([Weak ThreadId] -> IO ())
-> ConduitM a Void m ()
-> ConduitM a Void m ()
forall (m :: * -> *) a r.
MonadResource m =>
IO a -> (a -> IO ()) -> m r -> m r
bracketOnError IO [Weak ThreadId]
allocate [Weak ThreadId] -> IO ()
cleanup ConduitM a Void m ()
action