{-# LANGUAGE LambdaCase #-}
module Grisette.Experimental.MonadParallelUnion
( MonadParallelUnion (..),
)
where
import Control.DeepSeq (NFData, force)
import Control.Monad.Except (ExceptT (ExceptT), runExceptT)
import Control.Monad.Identity (IdentityT (IdentityT, runIdentityT))
import qualified Control.Monad.RWS.Lazy as RWSLazy
import qualified Control.Monad.RWS.Strict as RWSStrict
import Control.Monad.Reader (ReaderT (ReaderT, runReaderT))
import qualified Control.Monad.State.Lazy as StateLazy
import qualified Control.Monad.State.Strict as StateStrict
import Control.Monad.Trans.Maybe (MaybeT (MaybeT, runMaybeT))
import qualified Control.Monad.Writer.Lazy as WriterLazy
import qualified Control.Monad.Writer.Strict as WriterStrict
import Control.Parallel.Strategies (rpar, rseq, runEval)
import Grisette.Internal.Core.Control.Monad.Union (MonadUnion)
import Grisette.Internal.Core.Control.Monad.UnionM (UnionM, underlyingUnion)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge, tryMerge)
import Grisette.Internal.Core.Data.Union (Union (UnionIf, UnionSingle))
class (MonadUnion m, TryMerge m) => MonadParallelUnion m where
parBindUnion :: (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
instance (MonadParallelUnion m) => MonadParallelUnion (MaybeT m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
MaybeT m a -> (a -> MaybeT m b) -> MaybeT m b
parBindUnion (MaybeT m (Maybe a)
x) a -> MaybeT m b
f =
m (Maybe b) -> MaybeT m b
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe b) -> MaybeT m b) -> m (Maybe b) -> MaybeT m b
forall a b. (a -> b) -> a -> b
$
m (Maybe a)
x m (Maybe a) -> (Maybe a -> m (Maybe b)) -> m (Maybe b)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
Maybe a
Nothing -> Maybe b -> m (Maybe b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe b
forall a. Maybe a
Nothing
Just a
x'' -> MaybeT m b -> m (Maybe b)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m b -> m (Maybe b)) -> MaybeT m b -> m (Maybe b)
forall a b. (a -> b) -> a -> b
$ a -> MaybeT m b
f a
x''
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable e, NFData e) => MonadParallelUnion (ExceptT e m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
ExceptT e m a -> (a -> ExceptT e m b) -> ExceptT e m b
parBindUnion (ExceptT m (Either e a)
x) a -> ExceptT e m b
f =
m (Either e b) -> ExceptT e m b
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either e b) -> ExceptT e m b)
-> m (Either e b) -> ExceptT e m b
forall a b. (a -> b) -> a -> b
$
m (Either e a)
x m (Either e a) -> (Either e a -> m (Either e b)) -> m (Either e b)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
Left e
e -> Either e b -> m (Either e b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either e b -> m (Either e b)) -> Either e b -> m (Either e b)
forall a b. (a -> b) -> a -> b
$ e -> Either e b
forall a b. a -> Either a b
Left e
e
Right a
x'' -> ExceptT e m b -> m (Either e b)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT e m b -> m (Either e b))
-> ExceptT e m b -> m (Either e b)
forall a b. (a -> b) -> a -> b
$ a -> ExceptT e m b
f a
x''
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable s, NFData s) => MonadParallelUnion (StateLazy.StateT s m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
StateT s m a -> (a -> StateT s m b) -> StateT s m b
parBindUnion (StateLazy.StateT s -> m (a, s)
x) a -> StateT s m b
f = (s -> m (b, s)) -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateLazy.StateT ((s -> m (b, s)) -> StateT s m b)
-> (s -> m (b, s)) -> StateT s m b
forall a b. (a -> b) -> a -> b
$ \s
s ->
s -> m (a, s)
x s
s m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
~(a
a, s
s') -> StateT s m b -> s -> m (b, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
StateLazy.runStateT (a -> StateT s m b
f a
a) s
s'
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable s, NFData s) => MonadParallelUnion (StateStrict.StateT s m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
StateT s m a -> (a -> StateT s m b) -> StateT s m b
parBindUnion (StateStrict.StateT s -> m (a, s)
x) a -> StateT s m b
f = (s -> m (b, s)) -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateStrict.StateT ((s -> m (b, s)) -> StateT s m b)
-> (s -> m (b, s)) -> StateT s m b
forall a b. (a -> b) -> a -> b
$ \s
s ->
s -> m (a, s)
x s
s m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
(a
a, s
s') -> StateT s m b -> s -> m (b, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
StateStrict.runStateT (a -> StateT s m b
f a
a) s
s'
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable s, Monoid s, NFData s) => MonadParallelUnion (WriterLazy.WriterT s m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
WriterT s m a -> (a -> WriterT s m b) -> WriterT s m b
parBindUnion (WriterLazy.WriterT m (a, s)
x) a -> WriterT s m b
f =
m (b, s) -> WriterT s m b
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterLazy.WriterT (m (b, s) -> WriterT s m b) -> m (b, s) -> WriterT s m b
forall a b. (a -> b) -> a -> b
$
m (a, s)
x m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
~(a
a, s
w) ->
WriterT s m b -> m (b, s)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
WriterLazy.runWriterT (a -> WriterT s m b
f a
a) m (b, s) -> ((b, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
~(b
b, s
w') -> (b, s) -> m (b, s)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
w s -> s -> s
forall a. Semigroup a => a -> a -> a
<> s
w')
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable s, Monoid s, NFData s) => MonadParallelUnion (WriterStrict.WriterT s m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
WriterT s m a -> (a -> WriterT s m b) -> WriterT s m b
parBindUnion (WriterStrict.WriterT m (a, s)
x) a -> WriterT s m b
f =
m (b, s) -> WriterT s m b
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterStrict.WriterT (m (b, s) -> WriterT s m b) -> m (b, s) -> WriterT s m b
forall a b. (a -> b) -> a -> b
$
m (a, s)
x m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
(a
a, s
w) ->
WriterT s m b -> m (b, s)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
WriterStrict.runWriterT (a -> WriterT s m b
f a
a) m (b, s) -> ((b, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
(b
b, s
w') -> (b, s) -> m (b, s)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
w s -> s -> s
forall a. Semigroup a => a -> a -> a
<> s
w')
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable a, NFData a) => MonadParallelUnion (ReaderT a m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
ReaderT a m a -> (a -> ReaderT a m b) -> ReaderT a m b
parBindUnion (ReaderT a -> m a
x) a -> ReaderT a m b
f = (a -> m b) -> ReaderT a m b
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((a -> m b) -> ReaderT a m b) -> (a -> m b) -> ReaderT a m b
forall a b. (a -> b) -> a -> b
$ \a
a ->
a -> m a
x a
a m a -> (a -> m b) -> m b
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \a
a' -> ReaderT a m b -> a -> m b
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (a -> ReaderT a m b
f a
a') a
a
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m) => MonadParallelUnion (IdentityT m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
IdentityT m a -> (a -> IdentityT m b) -> IdentityT m b
parBindUnion (IdentityT m a
x) a -> IdentityT m b
f = m b -> IdentityT m b
forall {k} (f :: k -> *) (a :: k). f a -> IdentityT f a
IdentityT (m b -> IdentityT m b) -> m b -> IdentityT m b
forall a b. (a -> b) -> a -> b
$ m a
x m a -> (a -> m b) -> m b
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` (m b -> m b
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m b -> m b) -> (a -> m b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IdentityT m b -> m b
forall {k} (f :: k -> *) (a :: k). IdentityT f a -> f a
runIdentityT (IdentityT m b -> m b) -> (a -> IdentityT m b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> IdentityT m b
f)
{-# INLINE parBindUnion #-}
instance
(MonadParallelUnion m, Mergeable s, Mergeable r, Mergeable w, Monoid w, NFData r, NFData w, NFData s) =>
MonadParallelUnion (RWSStrict.RWST r w s m)
where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
RWST r w s m a -> (a -> RWST r w s m b) -> RWST r w s m b
parBindUnion RWST r w s m a
m a -> RWST r w s m b
k = (r -> s -> m (b, s, w)) -> RWST r w s m b
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
RWSStrict.RWST ((r -> s -> m (b, s, w)) -> RWST r w s m b)
-> (r -> s -> m (b, s, w)) -> RWST r w s m b
forall a b. (a -> b) -> a -> b
$ \r
r s
s ->
RWST r w s m a -> r -> s -> m (a, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSStrict.runRWST RWST r w s m a
m r
r s
s m (a, s, w) -> ((a, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
(a
a, s
s', w
w) ->
RWST r w s m b -> r -> s -> m (b, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSStrict.runRWST (a -> RWST r w s m b
k a
a) r
r s
s' m (b, s, w) -> ((b, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
(b
b, s
s'', w
w') -> (b, s, w) -> m (b, s, w)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
s'', w
w w -> w -> w
forall a. Semigroup a => a -> a -> a
<> w
w')
{-# INLINE parBindUnion #-}
instance
(MonadParallelUnion m, Mergeable s, Mergeable r, Mergeable w, Monoid w, NFData r, NFData w, NFData s) =>
MonadParallelUnion (RWSLazy.RWST r w s m)
where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
RWST r w s m a -> (a -> RWST r w s m b) -> RWST r w s m b
parBindUnion RWST r w s m a
m a -> RWST r w s m b
k = (r -> s -> m (b, s, w)) -> RWST r w s m b
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
RWSLazy.RWST ((r -> s -> m (b, s, w)) -> RWST r w s m b)
-> (r -> s -> m (b, s, w)) -> RWST r w s m b
forall a b. (a -> b) -> a -> b
$ \r
r s
s ->
RWST r w s m a -> r -> s -> m (a, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSLazy.runRWST RWST r w s m a
m r
r s
s m (a, s, w) -> ((a, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
~(a
a, s
s', w
w) ->
RWST r w s m b -> r -> s -> m (b, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSLazy.runRWST (a -> RWST r w s m b
k a
a) r
r s
s' m (b, s, w) -> ((b, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
~(b
b, s
s'', w
w') -> (b, s, w) -> m (b, s, w)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
s'', w
w w -> w -> w
forall a. Semigroup a => a -> a -> a
<> w
w')
{-# INLINE parBindUnion #-}
parBindUnion'' :: (Mergeable b, NFData b) => Union a -> (a -> UnionM b) -> UnionM b
parBindUnion'' :: forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion'' (UnionSingle a
a) a -> UnionM b
f = UnionM b -> UnionM b
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (UnionM b -> UnionM b) -> UnionM b -> UnionM b
forall a b. (a -> b) -> a -> b
$ a -> UnionM b
f a
a
parBindUnion'' Union a
u a -> UnionM b
f = Union a -> (a -> UnionM b) -> UnionM b
forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion' Union a
u a -> UnionM b
f
parBindUnion' :: (Mergeable b, NFData b) => Union a -> (a -> UnionM b) -> UnionM b
parBindUnion' :: forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion' (UnionSingle a
a') a -> UnionM b
f' = a -> UnionM b
f' a
a'
parBindUnion' (UnionIf a
_ Bool
_ SymBool
cond Union a
ifTrue Union a
ifFalse) a -> UnionM b
f' = Eval (UnionM b) -> UnionM b
forall a. Eval a -> a
runEval (Eval (UnionM b) -> UnionM b) -> Eval (UnionM b) -> UnionM b
forall a b. (a -> b) -> a -> b
$ do
UnionM b
l <- Strategy (UnionM b)
forall a. Strategy a
rpar Strategy (UnionM b) -> Strategy (UnionM b)
forall a b. (a -> b) -> a -> b
$ UnionM b -> UnionM b
forall a. NFData a => a -> a
force (UnionM b -> UnionM b) -> UnionM b -> UnionM b
forall a b. (a -> b) -> a -> b
$ Union a -> (a -> UnionM b) -> UnionM b
forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion' Union a
ifTrue a -> UnionM b
f'
UnionM b
r <- Strategy (UnionM b)
forall a. Strategy a
rpar Strategy (UnionM b) -> Strategy (UnionM b)
forall a b. (a -> b) -> a -> b
$ UnionM b -> UnionM b
forall a. NFData a => a -> a
force (UnionM b -> UnionM b) -> UnionM b -> UnionM b
forall a b. (a -> b) -> a -> b
$ Union a -> (a -> UnionM b) -> UnionM b
forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion' Union a
ifFalse a -> UnionM b
f'
UnionM b
l' <- Strategy (UnionM b)
forall a. Strategy a
rseq UnionM b
l
UnionM b
r' <- Strategy (UnionM b)
forall a. Strategy a
rseq UnionM b
r
Strategy (UnionM b)
forall a. Strategy a
rseq Strategy (UnionM b) -> Strategy (UnionM b)
forall a b. (a -> b) -> a -> b
$ SymBool -> UnionM b -> UnionM b -> UnionM b
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf SymBool
cond UnionM b
l' UnionM b
r'
{-# INLINE parBindUnion' #-}
instance MonadParallelUnion UnionM where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
UnionM a -> (a -> UnionM b) -> UnionM b
parBindUnion = Union a -> (a -> UnionM b) -> UnionM b
forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion'' (Union a -> (a -> UnionM b) -> UnionM b)
-> (UnionM a -> Union a) -> UnionM a -> (a -> UnionM b) -> UnionM b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnionM a -> Union a
forall a. UnionM a -> Union a
underlyingUnion
{-# INLINE parBindUnion #-}