{-# LANGUAGE RankNTypes, TypeOperators #-}

-- | The Plus monad - a free combination of monads. This is very similar to coproducts, but not quite the same.
--
--   Coproducts are due to Luth and Ghani, "Composing Monads Using Coproducts," http://www.informatik.uni-bremen.de/~cxl/papers/icfp02.pdf
module Control.Monad.PlusMonad where

import Control.Monad.Trans
import Control.Monad.Identity
import Control.Monad.Product
import Control.Monad.Morph
import Control.Applicative
import Control.Arrow

newtype (m ::+ n) t = Plus { unPlus :: forall x. (MonadPlus x) => (forall u. m u -> x u) -> (forall u. n u -> x u) -> x t }

instance Monad (m ::+ n) where
	return x = Plus (\_ _ -> return x)
	Plus f >>= g = Plus (\h i -> f h i >>= \x -> unPlus (g x) h i)

instance Functor (m ::+ n) where
	fmap f m = m >>= return . f

instance Applicative (m ::+ n) where
	pure = return
	(<*>) = ap

instance MonadPlus (m ::+ n) where
	mzero = Plus (\_ _ -> mzero)
	mplus (Plus f) (Plus g) = Plus (\h i -> mplus (f h i) (g h i))

instance Alternative (m ::+ n) where
	empty = mzero
	(<|>) = mplus

inl m = Plus (\h _ -> h m)

inr m = Plus (\_ i -> i m)

instance MonadTrans ((::+) m) where
	lift = inr

mapPlus :: (forall t. m t -> m1 t) -> (forall t. n t -> n1 t) -> (m ::+ n) t -> (m1 ::+ n1) t
mapPlus f g (Plus x) = Plus (\h i -> x (h . f) (i . g))

instance MFunctor ((::+) m) where
	hoist = mapPlus id

comm :: (m ::+ n) t -> (n ::+ m) t
comm (Plus f) = Plus (\h i -> f i h)

assoc (Plus f) = Plus (\h i -> f (\m -> unPlus m h (i . inl)) (i . inr))

assoc1 (Plus f) = Plus (\h i -> f (h . inl) (\m -> unPlus m (h . inr) i))

cancelLeft (Plus f) = f (return . runIdentity) id

cancelRight (Plus f) = f id (return . runIdentity)

refl (Plus f) = f id id

instance (MonadPlus m) => MMonad ((::+) m) where
	embed f = mapPlus refl id . assoc1 . mapPlus id f

-- | Distributivity with monad products.
distr pls = Product (mapPlus (fst . runProduct) (fst . runProduct) pls, mapPlus (snd . runProduct) (snd . runProduct) pls)