{-# OPTIONS_HADDOCK not-home #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Hedgehog.Internal.Distributive (
    MonadTransDistributive(..)
  ) where

import           Control.Monad (join)
import           Control.Monad.Morph (MFunctor(..))
import           Control.Monad.Trans.Class (MonadTrans(..))
import           Control.Monad.Trans.Identity (IdentityT(..))
import           Control.Monad.Trans.Except (ExceptT(..), runExceptT)
import           Control.Monad.Trans.Maybe (MaybeT(..))
import qualified Control.Monad.Trans.RWS.Lazy as Lazy (RWST(..))
import qualified Control.Monad.Trans.RWS.Strict as Strict (RWST(..))
import           Control.Monad.Trans.Reader (ReaderT(..))
import qualified Control.Monad.Trans.State.Lazy as Lazy
import qualified Control.Monad.Trans.State.Strict as Strict
import qualified Control.Monad.Trans.Writer.Lazy as Lazy
import qualified Control.Monad.Trans.Writer.Strict as Strict

import           Data.Kind (Type)
import           GHC.Exts (Constraint)

------------------------------------------------------------------------
-- * MonadTransDistributive

class MonadTransDistributive g where
  type Transformer
    (f :: (Type -> Type) -> Type -> Type)
    (g :: (Type -> Type) -> Type -> Type)
    (m :: Type -> Type) :: Constraint

  type Transformer f g m = (
      Monad m
    , Monad (f m)
    , Monad (g m)
    , Monad (f (g m))
    , MonadTrans f
    , MFunctor f
    )

  -- | Distribute one monad transformer over another.
  --
  distributeT :: Transformer f g m => g (f m) a -> f (g m) a

instance MonadTransDistributive IdentityT where
  distributeT :: forall (f :: (* -> *) -> * -> *) (m :: * -> *) a.
Transformer f IdentityT m =>
IdentityT (f m) a -> f (IdentityT m) a
distributeT IdentityT (f m) a
m =
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (f :: k -> *) (a :: k). f a -> IdentityT f a
IdentityT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall {k} (f :: k -> *) (a :: k). IdentityT f a -> f a
runIdentityT IdentityT (f m) a
m)

instance MonadTransDistributive MaybeT where
  distributeT :: forall (f :: (* -> *) -> * -> *) (m :: * -> *) a.
Transformer f MaybeT m =>
MaybeT (f m) a -> f (MaybeT m) a
distributeT MaybeT (f m) a
m =
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT MaybeT (f m) a
m)

instance MonadTransDistributive (ExceptT x) where
  distributeT :: forall (f :: (* -> *) -> * -> *) (m :: * -> *) a.
Transformer f (ExceptT x) m =>
ExceptT x (f m) a -> f (ExceptT x m) a
distributeT ExceptT x (f m) a
m =
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT x (f m) a
m)

instance MonadTransDistributive (ReaderT r) where
  distributeT :: forall (f :: (* -> *) -> * -> *) (m :: * -> *) a.
Transformer f (ReaderT r) m =>
ReaderT r (f m) a -> f (ReaderT r m) a
distributeT ReaderT r (f m) a
m =
    forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \r
r ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT r (f m) a
m r
r

instance Monoid w => MonadTransDistributive (Lazy.WriterT w) where
  distributeT :: forall (f :: (* -> *) -> * -> *) (m :: * -> *) a.
Transformer f (WriterT w) m =>
WriterT w (f m) a -> f (WriterT w m) a
distributeT WriterT w (f m) a
m =
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
Lazy.WriterT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
Lazy.runWriterT WriterT w (f m) a
m)

instance Monoid w => MonadTransDistributive (Strict.WriterT w) where
  distributeT :: forall (f :: (* -> *) -> * -> *) (m :: * -> *) a.
Transformer f (WriterT w) m =>
WriterT w (f m) a -> f (WriterT w m) a
distributeT WriterT w (f m) a
m = do
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
Strict.WriterT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
Strict.runWriterT WriterT w (f m) a
m)

instance MonadTransDistributive (Lazy.StateT s) where
  distributeT :: forall (f :: (* -> *) -> * -> *) (m :: * -> *) a.
Transformer f (StateT s) m =>
StateT s (f m) a -> f (StateT s m) a
distributeT StateT s (f m) a
m = do
    s
s       <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *) s. Monad m => StateT s m s
Lazy.get
    (a
a, s
s') <- forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
Lazy.runStateT StateT s (f m) a
m s
s)
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *) s. Monad m => s -> StateT s m ()
Lazy.put s
s')
    forall (m :: * -> *) a. Monad m => a -> m a
return a
a

instance MonadTransDistributive (Strict.StateT s) where
  distributeT :: forall (f :: (* -> *) -> * -> *) (m :: * -> *) a.
Transformer f (StateT s) m =>
StateT s (f m) a -> f (StateT s m) a
distributeT StateT s (f m) a
m = do
    s
s       <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *) s. Monad m => StateT s m s
Strict.get
    (a
a, s
s') <- forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
Strict.runStateT StateT s (f m) a
m s
s)
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *) s. Monad m => s -> StateT s m ()
Strict.put s
s')
    forall (m :: * -> *) a. Monad m => a -> m a
return a
a

instance Monoid w => MonadTransDistributive (Lazy.RWST r w s) where
  distributeT :: forall (f :: (* -> *) -> * -> *) (m :: * -> *) a.
Transformer f (RWST r w s) m =>
RWST r w s (f m) a -> f (RWST r w s m) a
distributeT RWST r w s (f m) a
m = do
    -- ask and get combined
    (r
r, s
s0)    <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Lazy.RWST forall a b. (a -> b) -> a -> b
$ \r
r s
s -> forall (m :: * -> *) a. Monad m => a -> m a
return ((r
r, s
s), s
s, forall a. Monoid a => a
mempty)
    (a
a, s
s1, w
w) <- forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
Lazy.runRWST RWST r w s (f m) a
m r
r s
s0)
    -- tell and put combined
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Lazy.RWST forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, s
s1, w
w)

instance Monoid w => MonadTransDistributive (Strict.RWST r w s) where
  distributeT :: forall (f :: (* -> *) -> * -> *) (m :: * -> *) a.
Transformer f (RWST r w s) m =>
RWST r w s (f m) a -> f (RWST r w s m) a
distributeT RWST r w s (f m) a
m = do
    -- ask and get combined
    (r
r, s
s0)    <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Strict.RWST forall a b. (a -> b) -> a -> b
$ \r
r s
s -> forall (m :: * -> *) a. Monad m => a -> m a
return ((r
r, s
s), s
s, forall a. Monoid a => a
mempty)
    (a
a, s
s1, w
w) <- forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
Strict.runRWST RWST r w s (f m) a
m r
r s
s0)
    -- tell and put combined
    forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
Strict.RWST forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, s
s1, w
w)