{-# LANGUAGE LambdaCase #-}

-- |
-- Module      :   Grisette.Experimental.MonadParallelUnion
-- Copyright   :   (c) Sirui Lu 2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Experimental.MonadParallelUnion
  ( MonadParallelUnion (..),
  )
where

import Control.DeepSeq (NFData, force)
import Control.Monad.Except (ExceptT (ExceptT), runExceptT)
import Control.Monad.Identity (IdentityT (IdentityT, runIdentityT))
import qualified Control.Monad.RWS.Lazy as RWSLazy
import qualified Control.Monad.RWS.Strict as RWSStrict
import Control.Monad.Reader (ReaderT (ReaderT, runReaderT))
import qualified Control.Monad.State.Lazy as StateLazy
import qualified Control.Monad.State.Strict as StateStrict
import Control.Monad.Trans.Maybe (MaybeT (MaybeT, runMaybeT))
import qualified Control.Monad.Writer.Lazy as WriterLazy
import qualified Control.Monad.Writer.Strict as WriterStrict
import Control.Parallel.Strategies (rpar, rseq, runEval)
import Grisette.Internal.Core.Control.Monad.Union (MonadUnion)
import Grisette.Internal.Core.Control.Monad.UnionM (UnionM, underlyingUnion)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge, tryMerge)
import Grisette.Internal.Core.Data.Union (Union (UnionIf, UnionSingle))

-- | Parallel union monad.
--
-- With the @QualifiedDo@ extension and the "Grisette.Qualified.ParallelUnionDo"
-- module, one can execute the paths in parallel and merge the results with:
--
-- > :set -XQualifiedDo -XOverloadedStrings
-- > import Grisette
-- > import qualified Grisette.Qualified.ParallelUnionDo as P
-- > P.do
-- >   x <- mrgIf "a" (return 1) (return 2) :: UnionM Int
-- >   return $ x + 1
-- >
-- > -- {If a 2 3}
class (MonadUnion m, TryMerge m) => MonadParallelUnion m where
  parBindUnion :: (Mergeable b, NFData b) => m a -> (a -> m b) -> m b

instance (MonadParallelUnion m) => MonadParallelUnion (MaybeT m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
MaybeT m a -> (a -> MaybeT m b) -> MaybeT m b
parBindUnion (MaybeT m (Maybe a)
x) a -> MaybeT m b
f =
    m (Maybe b) -> MaybeT m b
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe b) -> MaybeT m b) -> m (Maybe b) -> MaybeT m b
forall a b. (a -> b) -> a -> b
$
      m (Maybe a)
x m (Maybe a) -> (Maybe a -> m (Maybe b)) -> m (Maybe b)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
        Maybe a
Nothing -> Maybe b -> m (Maybe b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe b
forall a. Maybe a
Nothing
        Just a
x'' -> MaybeT m b -> m (Maybe b)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m b -> m (Maybe b)) -> MaybeT m b -> m (Maybe b)
forall a b. (a -> b) -> a -> b
$ a -> MaybeT m b
f a
x''
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable e, NFData e) => MonadParallelUnion (ExceptT e m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
ExceptT e m a -> (a -> ExceptT e m b) -> ExceptT e m b
parBindUnion (ExceptT m (Either e a)
x) a -> ExceptT e m b
f =
    m (Either e b) -> ExceptT e m b
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either e b) -> ExceptT e m b)
-> m (Either e b) -> ExceptT e m b
forall a b. (a -> b) -> a -> b
$
      m (Either e a)
x m (Either e a) -> (Either e a -> m (Either e b)) -> m (Either e b)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
        Left e
e -> Either e b -> m (Either e b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either e b -> m (Either e b)) -> Either e b -> m (Either e b)
forall a b. (a -> b) -> a -> b
$ e -> Either e b
forall a b. a -> Either a b
Left e
e
        Right a
x'' -> ExceptT e m b -> m (Either e b)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT e m b -> m (Either e b))
-> ExceptT e m b -> m (Either e b)
forall a b. (a -> b) -> a -> b
$ a -> ExceptT e m b
f a
x''
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable s, NFData s) => MonadParallelUnion (StateLazy.StateT s m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
StateT s m a -> (a -> StateT s m b) -> StateT s m b
parBindUnion (StateLazy.StateT s -> m (a, s)
x) a -> StateT s m b
f = (s -> m (b, s)) -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateLazy.StateT ((s -> m (b, s)) -> StateT s m b)
-> (s -> m (b, s)) -> StateT s m b
forall a b. (a -> b) -> a -> b
$ \s
s ->
    s -> m (a, s)
x s
s m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
      ~(a
a, s
s') -> StateT s m b -> s -> m (b, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
StateLazy.runStateT (a -> StateT s m b
f a
a) s
s'
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable s, NFData s) => MonadParallelUnion (StateStrict.StateT s m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
StateT s m a -> (a -> StateT s m b) -> StateT s m b
parBindUnion (StateStrict.StateT s -> m (a, s)
x) a -> StateT s m b
f = (s -> m (b, s)) -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateStrict.StateT ((s -> m (b, s)) -> StateT s m b)
-> (s -> m (b, s)) -> StateT s m b
forall a b. (a -> b) -> a -> b
$ \s
s ->
    s -> m (a, s)
x s
s m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
      (a
a, s
s') -> StateT s m b -> s -> m (b, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
StateStrict.runStateT (a -> StateT s m b
f a
a) s
s'
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable s, Monoid s, NFData s) => MonadParallelUnion (WriterLazy.WriterT s m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
WriterT s m a -> (a -> WriterT s m b) -> WriterT s m b
parBindUnion (WriterLazy.WriterT m (a, s)
x) a -> WriterT s m b
f =
    m (b, s) -> WriterT s m b
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterLazy.WriterT (m (b, s) -> WriterT s m b) -> m (b, s) -> WriterT s m b
forall a b. (a -> b) -> a -> b
$
      m (a, s)
x m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
        ~(a
a, s
w) ->
          WriterT s m b -> m (b, s)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
WriterLazy.runWriterT (a -> WriterT s m b
f a
a) m (b, s) -> ((b, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
            ~(b
b, s
w') -> (b, s) -> m (b, s)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
w s -> s -> s
forall a. Semigroup a => a -> a -> a
<> s
w')
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable s, Monoid s, NFData s) => MonadParallelUnion (WriterStrict.WriterT s m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
WriterT s m a -> (a -> WriterT s m b) -> WriterT s m b
parBindUnion (WriterStrict.WriterT m (a, s)
x) a -> WriterT s m b
f =
    m (b, s) -> WriterT s m b
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterStrict.WriterT (m (b, s) -> WriterT s m b) -> m (b, s) -> WriterT s m b
forall a b. (a -> b) -> a -> b
$
      m (a, s)
x m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
        (a
a, s
w) ->
          WriterT s m b -> m (b, s)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
WriterStrict.runWriterT (a -> WriterT s m b
f a
a) m (b, s) -> ((b, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
            (b
b, s
w') -> (b, s) -> m (b, s)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
w s -> s -> s
forall a. Semigroup a => a -> a -> a
<> s
w')
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m, Mergeable a, NFData a) => MonadParallelUnion (ReaderT a m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
ReaderT a m a -> (a -> ReaderT a m b) -> ReaderT a m b
parBindUnion (ReaderT a -> m a
x) a -> ReaderT a m b
f = (a -> m b) -> ReaderT a m b
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((a -> m b) -> ReaderT a m b) -> (a -> m b) -> ReaderT a m b
forall a b. (a -> b) -> a -> b
$ \a
a ->
    a -> m a
x a
a m a -> (a -> m b) -> m b
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \a
a' -> ReaderT a m b -> a -> m b
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (a -> ReaderT a m b
f a
a') a
a
  {-# INLINE parBindUnion #-}

instance (MonadParallelUnion m) => MonadParallelUnion (IdentityT m) where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
IdentityT m a -> (a -> IdentityT m b) -> IdentityT m b
parBindUnion (IdentityT m a
x) a -> IdentityT m b
f = m b -> IdentityT m b
forall {k} (f :: k -> *) (a :: k). f a -> IdentityT f a
IdentityT (m b -> IdentityT m b) -> m b -> IdentityT m b
forall a b. (a -> b) -> a -> b
$ m a
x m a -> (a -> m b) -> m b
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` (m b -> m b
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m b -> m b) -> (a -> m b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IdentityT m b -> m b
forall {k} (f :: k -> *) (a :: k). IdentityT f a -> f a
runIdentityT (IdentityT m b -> m b) -> (a -> IdentityT m b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> IdentityT m b
f)
  {-# INLINE parBindUnion #-}

instance
  (MonadParallelUnion m, Mergeable s, Mergeable r, Mergeable w, Monoid w, NFData r, NFData w, NFData s) =>
  MonadParallelUnion (RWSStrict.RWST r w s m)
  where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
RWST r w s m a -> (a -> RWST r w s m b) -> RWST r w s m b
parBindUnion RWST r w s m a
m a -> RWST r w s m b
k = (r -> s -> m (b, s, w)) -> RWST r w s m b
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
RWSStrict.RWST ((r -> s -> m (b, s, w)) -> RWST r w s m b)
-> (r -> s -> m (b, s, w)) -> RWST r w s m b
forall a b. (a -> b) -> a -> b
$ \r
r s
s ->
    RWST r w s m a -> r -> s -> m (a, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSStrict.runRWST RWST r w s m a
m r
r s
s m (a, s, w) -> ((a, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
      (a
a, s
s', w
w) ->
        RWST r w s m b -> r -> s -> m (b, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSStrict.runRWST (a -> RWST r w s m b
k a
a) r
r s
s' m (b, s, w) -> ((b, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
          (b
b, s
s'', w
w') -> (b, s, w) -> m (b, s, w)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
s'', w
w w -> w -> w
forall a. Semigroup a => a -> a -> a
<> w
w')
  {-# INLINE parBindUnion #-}

instance
  (MonadParallelUnion m, Mergeable s, Mergeable r, Mergeable w, Monoid w, NFData r, NFData w, NFData s) =>
  MonadParallelUnion (RWSLazy.RWST r w s m)
  where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
RWST r w s m a -> (a -> RWST r w s m b) -> RWST r w s m b
parBindUnion RWST r w s m a
m a -> RWST r w s m b
k = (r -> s -> m (b, s, w)) -> RWST r w s m b
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
RWSLazy.RWST ((r -> s -> m (b, s, w)) -> RWST r w s m b)
-> (r -> s -> m (b, s, w)) -> RWST r w s m b
forall a b. (a -> b) -> a -> b
$ \r
r s
s ->
    RWST r w s m a -> r -> s -> m (a, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSLazy.runRWST RWST r w s m a
m r
r s
s m (a, s, w) -> ((a, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
      ~(a
a, s
s', w
w) ->
        RWST r w s m b -> r -> s -> m (b, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSLazy.runRWST (a -> RWST r w s m b
k a
a) r
r s
s' m (b, s, w) -> ((b, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
          ~(b
b, s
s'', w
w') -> (b, s, w) -> m (b, s, w)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
s'', w
w w -> w -> w
forall a. Semigroup a => a -> a -> a
<> w
w')
  {-# INLINE parBindUnion #-}

parBindUnion'' :: (Mergeable b, NFData b) => Union a -> (a -> UnionM b) -> UnionM b
parBindUnion'' :: forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion'' (UnionSingle a
a) a -> UnionM b
f = UnionM b -> UnionM b
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (UnionM b -> UnionM b) -> UnionM b -> UnionM b
forall a b. (a -> b) -> a -> b
$ a -> UnionM b
f a
a
parBindUnion'' Union a
u a -> UnionM b
f = Union a -> (a -> UnionM b) -> UnionM b
forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion' Union a
u a -> UnionM b
f

parBindUnion' :: (Mergeable b, NFData b) => Union a -> (a -> UnionM b) -> UnionM b
parBindUnion' :: forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion' (UnionSingle a
a') a -> UnionM b
f' = a -> UnionM b
f' a
a'
parBindUnion' (UnionIf a
_ Bool
_ SymBool
cond Union a
ifTrue Union a
ifFalse) a -> UnionM b
f' = Eval (UnionM b) -> UnionM b
forall a. Eval a -> a
runEval (Eval (UnionM b) -> UnionM b) -> Eval (UnionM b) -> UnionM b
forall a b. (a -> b) -> a -> b
$ do
  UnionM b
l <- Strategy (UnionM b)
forall a. Strategy a
rpar Strategy (UnionM b) -> Strategy (UnionM b)
forall a b. (a -> b) -> a -> b
$ UnionM b -> UnionM b
forall a. NFData a => a -> a
force (UnionM b -> UnionM b) -> UnionM b -> UnionM b
forall a b. (a -> b) -> a -> b
$ Union a -> (a -> UnionM b) -> UnionM b
forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion' Union a
ifTrue a -> UnionM b
f'
  UnionM b
r <- Strategy (UnionM b)
forall a. Strategy a
rpar Strategy (UnionM b) -> Strategy (UnionM b)
forall a b. (a -> b) -> a -> b
$ UnionM b -> UnionM b
forall a. NFData a => a -> a
force (UnionM b -> UnionM b) -> UnionM b -> UnionM b
forall a b. (a -> b) -> a -> b
$ Union a -> (a -> UnionM b) -> UnionM b
forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion' Union a
ifFalse a -> UnionM b
f'
  UnionM b
l' <- Strategy (UnionM b)
forall a. Strategy a
rseq UnionM b
l
  UnionM b
r' <- Strategy (UnionM b)
forall a. Strategy a
rseq UnionM b
r
  Strategy (UnionM b)
forall a. Strategy a
rseq Strategy (UnionM b) -> Strategy (UnionM b)
forall a b. (a -> b) -> a -> b
$ SymBool -> UnionM b -> UnionM b -> UnionM b
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf SymBool
cond UnionM b
l' UnionM b
r'
{-# INLINE parBindUnion' #-}

instance MonadParallelUnion UnionM where
  parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
UnionM a -> (a -> UnionM b) -> UnionM b
parBindUnion = Union a -> (a -> UnionM b) -> UnionM b
forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> UnionM b) -> UnionM b
parBindUnion'' (Union a -> (a -> UnionM b) -> UnionM b)
-> (UnionM a -> Union a) -> UnionM a -> (a -> UnionM b) -> UnionM b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnionM a -> Union a
forall a. UnionM a -> Union a
underlyingUnion
  {-# INLINE parBindUnion #-}