{-# OPTIONS_HADDOCK not-home #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Hedgehog.Internal.Distributive (
    MonadTransDistributive(..)
  ) where

import           Control.Monad (join)
import           Control.Monad.Morph (MFunctor(..))
import           Control.Monad.Trans.Class (MonadTrans(..))
import           Control.Monad.Trans.Identity (IdentityT(..))
import           Control.Monad.Trans.Except (ExceptT(..), runExceptT)
import           Control.Monad.Trans.Maybe (MaybeT(..))
import qualified Control.Monad.Trans.RWS.Lazy as Lazy (RWST(..))
import qualified Control.Monad.Trans.RWS.Strict as Strict (RWST(..))
import           Control.Monad.Trans.Reader (ReaderT(..))
import qualified Control.Monad.Trans.State.Lazy as Lazy
import qualified Control.Monad.Trans.State.Strict as Strict
import qualified Control.Monad.Trans.Writer.Lazy as Lazy
import qualified Control.Monad.Trans.Writer.Strict as Strict

import           Data.Kind (Type)
import           GHC.Exts (Constraint)

------------------------------------------------------------------------
-- * MonadTransDistributive

class MonadTransDistributive g where
  type Transformer
    (f :: (Type -> Type) -> Type -> Type)
    (g :: (Type -> Type) -> Type -> Type)
    (m :: Type -> Type) :: Constraint

  type Transformer f g m = (
      Monad m
    , Monad (f m)
    , Monad (g m)
    , Monad (f (g m))
    , MonadTrans f
    , MFunctor f
    )

  -- | Distribute one monad transformer over another.
  --
  distributeT :: Transformer f g m => g (f m) a -> f (g m) a

instance MonadTransDistributive IdentityT where
  distributeT :: IdentityT (f m) a -> f (IdentityT m) a
distributeT IdentityT (f m) a
m =
    IdentityT m a -> f (IdentityT m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IdentityT m a -> f (IdentityT m) a)
-> (a -> IdentityT m a) -> a -> f (IdentityT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> IdentityT m a
forall k (f :: k -> *) (a :: k). f a -> IdentityT f a
IdentityT (m a -> IdentityT m a) -> (a -> m a) -> a -> IdentityT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> f (IdentityT m) a) -> f (IdentityT m) a -> f (IdentityT m) a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall a. m a -> IdentityT m a) -> f m a -> f (IdentityT m) a
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. m a -> IdentityT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IdentityT (f m) a -> f m a
forall k (f :: k -> *) (a :: k). IdentityT f a -> f a
runIdentityT IdentityT (f m) a
m)

instance MonadTransDistributive MaybeT where
  distributeT :: MaybeT (f m) a -> f (MaybeT m) a
distributeT MaybeT (f m) a
m =
    MaybeT m a -> f (MaybeT m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (MaybeT m a -> f (MaybeT m) a)
-> (Maybe a -> MaybeT m a) -> Maybe a -> f (MaybeT m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (Maybe a) -> MaybeT m a
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe a) -> MaybeT m a)
-> (Maybe a -> m (Maybe a)) -> Maybe a -> MaybeT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> m (Maybe a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe a -> f (MaybeT m) a)
-> f (MaybeT m) (Maybe a) -> f (MaybeT m) a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall a. m a -> MaybeT m a)
-> f m (Maybe a) -> f (MaybeT m) (Maybe a)
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. m a -> MaybeT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (MaybeT (f m) a -> f m (Maybe a)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT MaybeT (f m) a
m)

instance MonadTransDistributive (ExceptT x) where
  distributeT :: ExceptT x (f m) a -> f (ExceptT x m) a
distributeT ExceptT x (f m) a
m =
    ExceptT x m a -> f (ExceptT x m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ExceptT x m a -> f (ExceptT x m) a)
-> (Either x a -> ExceptT x m a) -> Either x a -> f (ExceptT x m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (Either x a) -> ExceptT x m a
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either x a) -> ExceptT x m a)
-> (Either x a -> m (Either x a)) -> Either x a -> ExceptT x m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either x a -> m (Either x a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either x a -> f (ExceptT x m) a)
-> f (ExceptT x m) (Either x a) -> f (ExceptT x m) a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall a. m a -> ExceptT x m a)
-> f m (Either x a) -> f (ExceptT x m) (Either x a)
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. m a -> ExceptT x m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ExceptT x (f m) a -> f m (Either x a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT x (f m) a
m)

instance MonadTransDistributive (ReaderT r) where
  distributeT :: ReaderT r (f m) a -> f (ReaderT r m) a
distributeT ReaderT r (f m) a
m =
    f (ReaderT r m) (f (ReaderT r m) a) -> f (ReaderT r m) a
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (f (ReaderT r m) (f (ReaderT r m) a) -> f (ReaderT r m) a)
-> ((r -> m (f (ReaderT r m) a))
    -> f (ReaderT r m) (f (ReaderT r m) a))
-> (r -> m (f (ReaderT r m) a))
-> f (ReaderT r m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReaderT r m (f (ReaderT r m) a)
-> f (ReaderT r m) (f (ReaderT r m) a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT r m (f (ReaderT r m) a)
 -> f (ReaderT r m) (f (ReaderT r m) a))
-> ((r -> m (f (ReaderT r m) a))
    -> ReaderT r m (f (ReaderT r m) a))
-> (r -> m (f (ReaderT r m) a))
-> f (ReaderT r m) (f (ReaderT r m) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (r -> m (f (ReaderT r m) a)) -> ReaderT r m (f (ReaderT r m) a)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m (f (ReaderT r m) a)) -> f (ReaderT r m) a)
-> (r -> m (f (ReaderT r m) a)) -> f (ReaderT r m) a
forall a b. (a -> b) -> a -> b
$ \r
r ->
      f (ReaderT r m) a -> m (f (ReaderT r m) a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (f (ReaderT r m) a -> m (f (ReaderT r m) a))
-> (f m a -> f (ReaderT r m) a) -> f m a -> m (f (ReaderT r m) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. m a -> ReaderT r m a) -> f m a -> f (ReaderT r m) a
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (f m a -> m (f (ReaderT r m) a)) -> f m a -> m (f (ReaderT r m) a)
forall a b. (a -> b) -> a -> b
$ ReaderT r (f m) a -> r -> f m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT r (f m) a
m r
r

instance Monoid w => MonadTransDistributive (Lazy.WriterT w) where
  distributeT :: WriterT w (f m) a -> f (WriterT w m) a
distributeT WriterT w (f m) a
m =
    WriterT w m a -> f (WriterT w m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT w m a -> f (WriterT w m) a)
-> ((a, w) -> WriterT w m a) -> (a, w) -> f (WriterT w m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (a, w) -> WriterT w m a
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
Lazy.WriterT (m (a, w) -> WriterT w m a)
-> ((a, w) -> m (a, w)) -> (a, w) -> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, w) -> m (a, w)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((a, w) -> f (WriterT w m) a)
-> f (WriterT w m) (a, w) -> f (WriterT w m) a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall a. m a -> WriterT w m a)
-> f m (a, w) -> f (WriterT w m) (a, w)
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. m a -> WriterT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT w (f m) a -> f m (a, w)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
Lazy.runWriterT WriterT w (f m) a
m)

instance Monoid w => MonadTransDistributive (Strict.WriterT w) where
  distributeT :: WriterT w (f m) a -> f (WriterT w m) a
distributeT WriterT w (f m) a
m = do
    WriterT w m a -> f (WriterT w m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT w m a -> f (WriterT w m) a)
-> ((a, w) -> WriterT w m a) -> (a, w) -> f (WriterT w m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (a, w) -> WriterT w m a
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
Strict.WriterT (m (a, w) -> WriterT w m a)
-> ((a, w) -> m (a, w)) -> (a, w) -> WriterT w m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, w) -> m (a, w)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((a, w) -> f (WriterT w m) a)
-> f (WriterT w m) (a, w) -> f (WriterT w m) a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall a. m a -> WriterT w m a)
-> f m (a, w) -> f (WriterT w m) (a, w)
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. m a -> WriterT w m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT w (f m) a -> f m (a, w)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
Strict.runWriterT WriterT w (f m) a
m)

instance MonadTransDistributive (Lazy.StateT s) where
  distributeT :: StateT s (f m) a -> f (StateT s m) a
distributeT StateT s (f m) a
m = do
    s
s       <- StateT s m s -> f (StateT s m) s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift StateT s m s
forall (m :: * -> *) s. Monad m => StateT s m s
Lazy.get
    (a
a, s
s') <- (forall a. m a -> StateT s m a)
-> f m (a, s) -> f (StateT s m) (a, s)
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. m a -> StateT s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT s (f m) a -> s -> f m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
Lazy.runStateT StateT s (f m) a
m s
s)
    StateT s m () -> f (StateT s m) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (s -> StateT s m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
Lazy.put s
s')
    a -> f (StateT s m) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

instance MonadTransDistributive (Strict.StateT s) where
  distributeT :: StateT s (f m) a -> f (StateT s m) a
distributeT StateT s (f m) a
m = do
    s
s       <- StateT s m s -> f (StateT s m) s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift StateT s m s
forall (m :: * -> *) s. Monad m => StateT s m s
Strict.get
    (a
a, s
s') <- (forall a. m a -> StateT s m a)
-> f m (a, s) -> f (StateT s m) (a, s)
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. m a -> StateT s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT s (f m) a -> s -> f m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
Strict.runStateT StateT s (f m) a
m s
s)
    StateT s m () -> f (StateT s m) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (s -> StateT s m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
Strict.put s
s')
    a -> f (StateT s m) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

instance Monoid w => MonadTransDistributive (Lazy.RWST r w s) where
  distributeT :: RWST r w s (f m) a -> f (RWST r w s m) a
distributeT RWST r w s (f m) a
m = do
    -- ask and get combined
    (r
r, s
s0)    <- RWST r w s m (r, s) -> f (RWST r w s m) (r, s)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (RWST r w s m (r, s) -> f (RWST r w s m) (r, s))
-> ((r -> s -> m ((r, s), s, w)) -> RWST r w s m (r, s))
-> (r -> s -> m ((r, s), s, w))
-> f (RWST r w s m) (r, s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (r -> s -> m ((r, s), s, w)) -> RWST r w s m (r, s)
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Lazy.RWST ((r -> s -> m ((r, s), s, w)) -> f (RWST r w s m) (r, s))
-> (r -> s -> m ((r, s), s, w)) -> f (RWST r w s m) (r, s)
forall a b. (a -> b) -> a -> b
$ \r
r s
s -> ((r, s), s, w) -> m ((r, s), s, w)
forall (m :: * -> *) a. Monad m => a -> m a
return ((r
r, s
s), s
s, w
forall a. Monoid a => a
mempty)
    (a
a, s
s1, w
w) <- (forall a. m a -> RWST r w s m a)
-> f m (a, s, w) -> f (RWST r w s m) (a, s, w)
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. m a -> RWST r w s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (RWST r w s (f m) a -> r -> s -> f m (a, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
Lazy.runRWST RWST r w s (f m) a
m r
r s
s0)
    -- tell and put combined
    RWST r w s m a -> f (RWST r w s m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (RWST r w s m a -> f (RWST r w s m) a)
-> RWST r w s m a -> f (RWST r w s m) a
forall a b. (a -> b) -> a -> b
$ (r -> s -> m (a, s, w)) -> RWST r w s m a
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Lazy.RWST ((r -> s -> m (a, s, w)) -> RWST r w s m a)
-> (r -> s -> m (a, s, w)) -> RWST r w s m a
forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> (a, s, w) -> m (a, s, w)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, s
s1, w
w)

instance Monoid w => MonadTransDistributive (Strict.RWST r w s) where
  distributeT :: RWST r w s (f m) a -> f (RWST r w s m) a
distributeT RWST r w s (f m) a
m = do
    -- ask and get combined
    (r
r, s
s0)    <- RWST r w s m (r, s) -> f (RWST r w s m) (r, s)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (RWST r w s m (r, s) -> f (RWST r w s m) (r, s))
-> ((r -> s -> m ((r, s), s, w)) -> RWST r w s m (r, s))
-> (r -> s -> m ((r, s), s, w))
-> f (RWST r w s m) (r, s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (r -> s -> m ((r, s), s, w)) -> RWST r w s m (r, s)
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Strict.RWST ((r -> s -> m ((r, s), s, w)) -> f (RWST r w s m) (r, s))
-> (r -> s -> m ((r, s), s, w)) -> f (RWST r w s m) (r, s)
forall a b. (a -> b) -> a -> b
$ \r
r s
s -> ((r, s), s, w) -> m ((r, s), s, w)
forall (m :: * -> *) a. Monad m => a -> m a
return ((r
r, s
s), s
s, w
forall a. Monoid a => a
mempty)
    (a
a, s
s1, w
w) <- (forall a. m a -> RWST r w s m a)
-> f m (a, s, w) -> f (RWST r w s m) (a, s, w)
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. m a -> RWST r w s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (RWST r w s (f m) a -> r -> s -> f m (a, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
Strict.runRWST RWST r w s (f m) a
m r
r s
s0)
    -- tell and put combined
    RWST r w s m a -> f (RWST r w s m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (RWST r w s m a -> f (RWST r w s m) a)
-> RWST r w s m a -> f (RWST r w s m) a
forall a b. (a -> b) -> a -> b
$ (r -> s -> m (a, s, w)) -> RWST r w s m a
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Strict.RWST ((r -> s -> m (a, s, w)) -> RWST r w s m a)
-> (r -> s -> m (a, s, w)) -> RWST r w s m a
forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> (a, s, w) -> m (a, s, w)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, s
s1, w
w)