{-# LANGUAGE Safe, Rank2Types, FlexibleInstances, DeriveFunctor, TypeOperators #-}

-- | A construction combining two monads, based on the work of Luth and Ghani, "Composing Monads Using Coproducts."
module Control.Monad.PlusMonad (Composition, (::+), Dist(..), leftMap, rightMap, inl, inr, sym, commute, mapPlus, refl,
-- * Example
File, runFile, readLine) where

import qualified Control.Monad.State.Strict as Strict
import Control.Monad.State
import Control.Monad.Writer hiding (Sum)
import Control.Monad.Error
import Control.Monad.Identity
import Control.Monad.Morph
import Control.Monad.Codensity
import Control.Exception
import Control.Applicative
import Data.Functor.Compose
import Data.Functor.Sum
import Data.Functor.Yoneda
import System.IO

data Composition m n t = Composition (m (n (Composition m n t))) | Rtn t deriving Functor

-- | The following construction on two monads is a monad provided the two monads
--   have extended distributive laws, defined below.
type (m ::+ n) = Yoneda (Composition m n)

-- | An extended distributive law allows one to permute two layers.
--
-- Laws are:
--
-- >>> join . T dist = dist . join :: TTS -> TST
-- >>> TS join . dist . dist = dist :: TS -> TST
class Dist n where
	dist :: (Applicative m) => n (m t) -> n (m (n t))

-- Extended distributed laws for common monads.
instance Dist (StateT s Identity) where
	dist m = do
		n <- m
		s <- get
		return (fmap (\x -> put s >> return x) n)

instance (Monoid s) => Dist (WriterT s Identity) where
	dist m =
		let (n, w) = runWriter m in
			return (fmap (\x -> tell w >> return x) n)

instance Dist [] where
	dist ls = return (sequenceA ls)

-- I/O is equipped with a trivial distributive law.
instance Dist IO where
	dist m = fmap (fmap return) m

instance Dist Identity where
	dist m = fmap (fmap return) m

instance Dist Maybe where
	dist m = fmap (fmap return) m

instance Dist (Either t) where
	dist m = fmap (fmap return) m

instance (Error e) => Dist (ErrorT e Identity) where
	dist m = fmap (fmap return) m

_hoist :: (forall u. m u -> n u) -> Yoneda m t -> Yoneda n t
_hoist f (Yoneda g) = Yoneda (f . g)

_leftMap :: (Functor n, Functor x) => (forall u. m u -> n u) -> Composition m x t -> Composition n x t
_leftMap f (Composition m) = Composition (fmap (fmap (_leftMap f)) (f m))
_leftMap _ (Rtn x) = Rtn x

_rightMap :: (Functor n, Functor x) => (forall u. m u -> n u) -> Composition x m t -> Composition x n t
_rightMap f (Composition m) = Composition (fmap (fmap (_rightMap f) . f) m)
_rightMap _ (Rtn x) = Rtn x

-- | Left and right maps...
leftMap :: (Monad m, Functor n, Functor x) => (forall u. m u -> n u) -> (m ::+ x) t -> (n ::+ x) t
leftMap f m = _hoist (_leftMap f) m

rightMap :: (Monad x, Monad m, Functor n) => (forall u. m u -> n u) -> (x ::+ m) t -> (x ::+ n) t
rightMap f m = _hoist (_rightMap f) m

-- Distribute over three layers.
distributive1 :: (Dist m, Monad m, Applicative n, Applicative x, Applicative y) => m (n (x (y (m t)))) -> m (n (x (y (m t))))
distributive1 m = (fmap (fmap (fmap (fmap join) . getCompose) . getCompose) . dist . fmap (Compose . fmap Compose)) m

-- Each layer is distributed over the inner layer in sequence, from inside to outside.
distributive2 :: (Dist m, Dist n, Monad m, Monad n, Applicative x) => Composition m n (x (m (n t))) -> Composition m n (x (m (n t)))
distributive2 (Composition m) = (
	Composition
	. fmap (fmap distributive2)
	. fmap distributive1
	. distributive1)
	m
distributive2 (Rtn x) = Rtn x

-- These two instances are needed to use '::+' in a nested manner.
instance (Dist m, Dist n, Monad m, Monad n) => Dist (Composition m n) where
	dist = fmap (fmap (Composition . fmap (fmap Rtn))) . distributive2 . fmap (fmap (return . return))

instance (Dist m, Functor m) => Dist (Yoneda m) where
	dist = liftYoneda . fmap (fmap liftYoneda) . dist . lowerYoneda

distributive :: (Dist m, Monad m, Applicative n) => m (n (m t)) -> m (n (m t))
distributive x = (fmap (fmap join) . dist) x

bringDown :: (Monad m, Monad n) => Composition m n t -> m (n (Composition m n t))
bringDown (Composition m) = m
bringDown (Rtn x) = return (return (Rtn x))

instance (Dist m, Dist n, Monad m, Monad n) => Monad (Composition m n) where
	return = Rtn
	Composition m >>= f = Composition ((fmap (fmap Composition)
		. distributive
		. fmap distributive
		. fmap (fmap (bringDown . (>>= f))))
		m)
	Rtn x >>= f = f x
	fail = Composition . fail

instance (Dist m, Dist n, Monad m, MonadPlus n) => MonadPlus (Composition m n) where
	mzero = Composition (return mzero)
	mplus (Composition m) (Composition n) = Composition (liftM2 (liftM2 mplus) m n)
	mplus (Rtn x) (Composition n) = Composition (liftM (return (Rtn x) `mplus`) n)
	mplus (Composition m) (Rtn x) = Composition (liftM (mplus (return (Rtn x))) m)
	mplus (Rtn x) (Rtn y) = Composition (return (return (Rtn x) `mplus` return (Rtn y)))

instance (Dist m, Dist n, Monad m, Monad n) => Applicative (Composition m n) where
	pure = return
	(<*>) = ap

instance (Dist m, Dist n, Monad m, MonadPlus n) => Alternative (Composition m n) where
	empty = mzero
	(<|>) = mplus

instance (Monad m) => MonadTrans (Composition m) where
	lift = Composition . return . fmap Rtn

instance (Dist m, Dist n, Monad m, Monad n, MonadIO n) => MonadIO (Composition m n) where
	liftIO = lift . liftIO

-- | Injections into the '::+' type.
inl :: (Dist m, Dist n, Monad m, Monad n) => m t -> (m ::+ n) t
inl m = lift (Composition (fmap (return . Rtn) m))

inr :: (Dist m, Dist n, Monad m, Monad n) => n t -> (m ::+ n) t
inr m = lift (Composition (return (fmap Rtn m)))

_sym :: (Monad m) => Composition m m t -> m t
_sym (Composition m) = m >>= (>>= _sym)
_sym (Rtn x) = return x

-- | If you have a '::+' over a monad, you can extract the underlying action.
sym :: (Monad m) => (m ::+ m) t -> m t
sym m = _sym (lowerYoneda m)

_commute :: (Monad m, Functor n) => n (Composition m n t) -> Composition n m t
_commute n = Composition (fmap (\m -> case m of
	Composition m -> fmap _commute m
	Rtn x -> return (Rtn x)) n)

-- | '::+' is commutative.
commute :: (Monad m, Monad n) => (m ::+ n) t -> (n ::+ m) t
commute m = _hoist (_commute . return) m

mapPlus :: (Monad m, Monad n, Functor m1, Functor n1) => (forall u. m u -> m1 u) -> (forall u. n u -> n1 u) -> (m ::+ n) t -> (m1 ::+ n1) t
mapPlus f g = leftMap f . rightMap g

refl :: (MonadPlus m) => (m ::+ m) t -> m t
refl = sym

---------------------------------------

-- | Example of an IO-performing ADT.
newtype File t = File (StateT Handle IO t) deriving Functor

runFile (File m) path = do
	hdl <- openFile path ReadMode
	finally (evalStateT m hdl) (hClose hdl)

readLine = File (do
	hdl <- get
	lift (hGetLine hdl))

instance Monad File where
	return = File . return
	File m >>= f = File (m >>= \x -> case f x of File m -> m)
	fail = File . fail

instance Applicative File where
	pure = return
	(<*>) = ap

instance Dist File where
	dist m = do
		n <- m
		s <- File get
		return (fmap (\x -> File (put s) >> return x) n)