{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# language Safe #-}
-- |
-- Module       : Control.Monad.Trans.Wedge
-- Copyright    : (c) 2020-2021 Emily Pillmore
-- License      : BSD-3-Clause
--
-- Maintainer   : Emily Pillmore <emilypi@cohomolo.gy>
-- Stability    : Experimental
-- Portability  : Non-portable
--
-- This module contains utilities for the monad transformer
-- for the pointed coproduct.
--
module Control.Monad.Trans.Wedge
( -- * Monad transformer
  WedgeT(runWedgeT)
  -- ** Combinators
, mapWedgeT
) where


import Data.Wedge
import Control.Applicative (liftA2)
import Control.Monad.Writer
import Control.Monad.Reader
import Control.Monad.State.Class
import Control.Monad.Except
import Control.Monad.RWS

-- | A monad transformer for the pointed product,
-- parameterized by:
--
--   * @a@ - the value on the left
--   * @b@ - the value on the right
--   * @m@ - The monad over a pointed coproduct (see: 'Wedge').
--
-- This monad transformer is similar to 'ExceptT',
-- except with the possibility of an empty unital value.
--
newtype WedgeT a m b = WedgeT { WedgeT a m b -> m (Wedge a b)
runWedgeT :: m (Wedge a b) }

-- | Map both the left and right values and output of a computation using
-- the given function.
--
-- * @'runWedgeT' ('mapWedgeT' f m) = f . 'runWedgeT' m@
--
mapWedgeT :: (m (Wedge a b) -> n (Wedge c d)) -> WedgeT a m b -> WedgeT c n d
mapWedgeT :: (m (Wedge a b) -> n (Wedge c d)) -> WedgeT a m b -> WedgeT c n d
mapWedgeT m (Wedge a b) -> n (Wedge c d)
f = n (Wedge c d) -> WedgeT c n d
forall a (m :: * -> *) b. m (Wedge a b) -> WedgeT a m b
WedgeT (n (Wedge c d) -> WedgeT c n d)
-> (WedgeT a m b -> n (Wedge c d)) -> WedgeT a m b -> WedgeT c n d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m (Wedge a b) -> n (Wedge c d)
f (m (Wedge a b) -> n (Wedge c d))
-> (WedgeT a m b -> m (Wedge a b)) -> WedgeT a m b -> n (Wedge c d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WedgeT a m b -> m (Wedge a b)
forall a (m :: * -> *) b. WedgeT a m b -> m (Wedge a b)
runWedgeT


instance Functor f => Functor (WedgeT a f) where
  fmap :: (a -> b) -> WedgeT a f a -> WedgeT a f b
fmap a -> b
f = f (Wedge a b) -> WedgeT a f b
forall a (m :: * -> *) b. m (Wedge a b) -> WedgeT a m b
WedgeT (f (Wedge a b) -> WedgeT a f b)
-> (WedgeT a f a -> f (Wedge a b)) -> WedgeT a f a -> WedgeT a f b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Wedge a a -> Wedge a b) -> f (Wedge a a) -> f (Wedge a b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b) -> Wedge a a -> Wedge a b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) (f (Wedge a a) -> f (Wedge a b))
-> (WedgeT a f a -> f (Wedge a a)) -> WedgeT a f a -> f (Wedge a b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WedgeT a f a -> f (Wedge a a)
forall a (m :: * -> *) b. WedgeT a m b -> m (Wedge a b)
runWedgeT

instance (Semigroup a, Applicative f) => Applicative (WedgeT a f) where
  pure :: a -> WedgeT a f a
pure = f (Wedge a a) -> WedgeT a f a
forall a (m :: * -> *) b. m (Wedge a b) -> WedgeT a m b
WedgeT (f (Wedge a a) -> WedgeT a f a)
-> (a -> f (Wedge a a)) -> a -> WedgeT a f a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Wedge a a -> f (Wedge a a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Wedge a a -> f (Wedge a a))
-> (a -> Wedge a a) -> a -> f (Wedge a a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Wedge a a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  WedgeT f (Wedge a (a -> b))
f <*> :: WedgeT a f (a -> b) -> WedgeT a f a -> WedgeT a f b
<*> WedgeT f (Wedge a a)
a = f (Wedge a b) -> WedgeT a f b
forall a (m :: * -> *) b. m (Wedge a b) -> WedgeT a m b
WedgeT (f (Wedge a b) -> WedgeT a f b) -> f (Wedge a b) -> WedgeT a f b
forall a b. (a -> b) -> a -> b
$ (Wedge a (a -> b) -> Wedge a a -> Wedge a b)
-> f (Wedge a (a -> b)) -> f (Wedge a a) -> f (Wedge a b)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Wedge a (a -> b) -> Wedge a a -> Wedge a b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
(<*>) f (Wedge a (a -> b))
f f (Wedge a a)
a

instance (Semigroup a, Monad m) => Monad (WedgeT a m) where
  return :: a -> WedgeT a m a
return = a -> WedgeT a m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

  WedgeT m (Wedge a a)
m >>= :: WedgeT a m a -> (a -> WedgeT a m b) -> WedgeT a m b
>>= a -> WedgeT a m b
k = m (Wedge a b) -> WedgeT a m b
forall a (m :: * -> *) b. m (Wedge a b) -> WedgeT a m b
WedgeT (m (Wedge a b) -> WedgeT a m b) -> m (Wedge a b) -> WedgeT a m b
forall a b. (a -> b) -> a -> b
$ do
    Wedge a a
c <- m (Wedge a a)
m
    case Wedge a a
c of
      Wedge a a
Nowhere -> Wedge a b -> m (Wedge a b)
forall (m :: * -> *) a. Monad m => a -> m a
return Wedge a b
forall a b. Wedge a b
Nowhere
      Here a
a -> Wedge a b -> m (Wedge a b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Wedge a b -> m (Wedge a b)) -> Wedge a b -> m (Wedge a b)
forall a b. (a -> b) -> a -> b
$ a -> Wedge a b
forall a b. a -> Wedge a b
Here a
a
      There a
a -> WedgeT a m b -> m (Wedge a b)
forall a (m :: * -> *) b. WedgeT a m b -> m (Wedge a b)
runWedgeT (WedgeT a m b -> m (Wedge a b)) -> WedgeT a m b -> m (Wedge a b)
forall a b. (a -> b) -> a -> b
$ a -> WedgeT a m b
k a
a

instance (MonadReader r m, Semigroup t) => MonadReader r (WedgeT t m) where
  ask :: WedgeT t m r
ask = m r -> WedgeT t m r
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m r
forall r (m :: * -> *). MonadReader r m => m r
ask
  local :: (r -> r) -> WedgeT t m a -> WedgeT t m a
local r -> r
f (WedgeT m (Wedge t a)
m) = m (Wedge t a) -> WedgeT t m a
forall a (m :: * -> *) b. m (Wedge a b) -> WedgeT a m b
WedgeT (m (Wedge t a) -> WedgeT t m a) -> m (Wedge t a) -> WedgeT t m a
forall a b. (a -> b) -> a -> b
$ (r -> r) -> m (Wedge t a) -> m (Wedge t a)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local r -> r
f m (Wedge t a)
m

instance (MonadWriter w m, Semigroup t) => MonadWriter w (WedgeT t m) where
  tell :: w -> WedgeT t m ()
tell = m () -> WedgeT t m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> WedgeT t m ()) -> (w -> m ()) -> w -> WedgeT t m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. w -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell

  listen :: WedgeT t m a -> WedgeT t m (a, w)
listen (WedgeT m (Wedge t a)
m) = m (Wedge t (a, w)) -> WedgeT t m (a, w)
forall a (m :: * -> *) b. m (Wedge a b) -> WedgeT a m b
WedgeT (m (Wedge t (a, w)) -> WedgeT t m (a, w))
-> m (Wedge t (a, w)) -> WedgeT t m (a, w)
forall a b. (a -> b) -> a -> b
$ (Wedge t a, w) -> Wedge t (a, w)
forall a a b. (Wedge a a, b) -> Wedge a (a, b)
go ((Wedge t a, w) -> Wedge t (a, w))
-> m (Wedge t a, w) -> m (Wedge t (a, w))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Wedge t a) -> m (Wedge t a, w)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen m (Wedge t a)
m where
    go :: (Wedge a a, b) -> Wedge a (a, b)
go = \case
      (Wedge a a
Nowhere, b
_) -> Wedge a (a, b)
forall a b. Wedge a b
Nowhere
      (Here a
t, b
_) -> a -> Wedge a (a, b)
forall a b. a -> Wedge a b
Here a
t
      (There a
a, b
w) -> (a, b) -> Wedge a (a, b)
forall a b. b -> Wedge a b
There (a
a, b
w)

  pass :: WedgeT t m (a, w -> w) -> WedgeT t m a
pass (WedgeT m (Wedge t (a, w -> w))
m) = m (Wedge t a) -> WedgeT t m a
forall a (m :: * -> *) b. m (Wedge a b) -> WedgeT a m b
WedgeT (m (Wedge t a) -> WedgeT t m a) -> m (Wedge t a) -> WedgeT t m a
forall a b. (a -> b) -> a -> b
$ m (Wedge t a, w -> w) -> m (Wedge t a)
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (Wedge t (a, w -> w) -> (Wedge t a, w -> w)
forall a b a. Wedge a (b, a -> a) -> (Wedge a b, a -> a)
go (Wedge t (a, w -> w) -> (Wedge t a, w -> w))
-> m (Wedge t (a, w -> w)) -> m (Wedge t a, w -> w)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Wedge t (a, w -> w))
m) where
    go :: Wedge a (b, a -> a) -> (Wedge a b, a -> a)
go = \case
     Wedge a (b, a -> a)
Nowhere -> (Wedge a b
forall a b. Wedge a b
Nowhere, a -> a
forall a. a -> a
id)
     Here a
w -> (a -> Wedge a b
forall a b. a -> Wedge a b
Here a
w, a -> a
forall a. a -> a
id)
     There (b
a,a -> a
f) -> (b -> Wedge a b
forall a b. b -> Wedge a b
There b
a, a -> a
f)

instance (MonadState s m, Semigroup t) => MonadState s (WedgeT t m) where
  get :: WedgeT t m s
get = m s -> WedgeT t m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> WedgeT t m ()
put = m () -> WedgeT t m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> WedgeT t m ()) -> (s -> m ()) -> s -> WedgeT t m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance (Semigroup t, MonadRWS r w s m) => MonadRWS r w s (WedgeT t m)

instance MonadTrans (WedgeT a) where
  lift :: m a -> WedgeT a m a
lift = m (Wedge a a) -> WedgeT a m a
forall a (m :: * -> *) b. m (Wedge a b) -> WedgeT a m b
WedgeT (m (Wedge a a) -> WedgeT a m a)
-> (m a -> m (Wedge a a)) -> m a -> WedgeT a m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Wedge a a) -> m a -> m (Wedge a a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Wedge a a
forall a b. b -> Wedge a b
There

instance (MonadError e m, Semigroup e) => MonadError e (WedgeT e m) where
  throwError :: e -> WedgeT e m a
throwError e
e = m (Wedge e a) -> WedgeT e m a
forall a (m :: * -> *) b. m (Wedge a b) -> WedgeT a m b
WedgeT (m (Wedge e a) -> WedgeT e m a) -> m (Wedge e a) -> WedgeT e m a
forall a b. (a -> b) -> a -> b
$ e -> Wedge e a
forall a b. a -> Wedge a b
Here (e -> Wedge e a) -> m e -> m (Wedge e a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> e -> m e
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError e
e
  catchError :: WedgeT e m a -> (e -> WedgeT e m a) -> WedgeT e m a
catchError (WedgeT m (Wedge e a)
m) e -> WedgeT e m a
f = m (Wedge e a) -> WedgeT e m a
forall a (m :: * -> *) b. m (Wedge a b) -> WedgeT a m b
WedgeT (m (Wedge e a) -> WedgeT e m a) -> m (Wedge e a) -> WedgeT e m a
forall a b. (a -> b) -> a -> b
$ m (Wedge e a) -> (e -> m (Wedge e a)) -> m (Wedge e a)
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError m (Wedge e a)
m (WedgeT e m a -> m (Wedge e a)
forall a (m :: * -> *) b. WedgeT a m b -> m (Wedge a b)
runWedgeT (WedgeT e m a -> m (Wedge e a))
-> (e -> WedgeT e m a) -> e -> m (Wedge e a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> WedgeT e m a
f)