{-# LANGUAGE UndecidableInstances #-}

{-|
Module      : Control.Monad.Tree
Description : Implementation of a non-deterministic tree monad.
Copyright   : (c) Nathan Bedell, 2021
License     : MIT
Maintainer  : nbedell@tulane.edu

This module contains the definition of a monad transformer for monadic trees.

Note that this implementation is still experimental, in the sense that the monad laws for
 these instances have _not_ been formally proven. Thus, this module should be used with caution.
-}

module Control.Monad.Trans.Tree where

import Data.Functor
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Error.Class
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Writer.Class

data TreeM n f m a = 
    Leaf a 
  | Node n (m (f (TreeM n f m a)))

newtype TreeT n f m a = TreeT { TreeT n f m a -> m (TreeM n f m a)
runTreeT :: m (TreeM n f m a) }

bind :: (Monad m, Functor f, Traversable f) => TreeT n f m a -> (a -> TreeT n f m b) -> TreeT n f m b 
bind :: TreeT n f m a -> (a -> TreeT n f m b) -> TreeT n f m b
bind TreeT n f m a
x a -> TreeT n f m b
f = m (TreeM n f m b) -> TreeT n f m b
forall n (f :: * -> *) (m :: * -> *) a.
m (TreeM n f m a) -> TreeT n f m a
TreeT (m (TreeM n f m b) -> TreeT n f m b)
-> m (TreeM n f m b) -> TreeT n f m b
forall a b. (a -> b) -> a -> b
$ do
    TreeM n f m a
x' <- TreeT n f m a -> m (TreeM n f m a)
forall n (f :: * -> *) (m :: * -> *) a.
TreeT n f m a -> m (TreeM n f m a)
runTreeT TreeT n f m a
x
    case TreeM n f m a
x' of
        Node n
l m (f (TreeM n f m a))
mas -> do
            f (TreeM n f m a)
as <- m (f (TreeM n f m a))
mas
            TreeM n f m b -> m (TreeM n f m b)
forall (m :: * -> *) a. Monad m => a -> m a
return (TreeM n f m b -> m (TreeM n f m b))
-> TreeM n f m b -> m (TreeM n f m b)
forall a b. (a -> b) -> a -> b
$ n -> m (f (TreeM n f m b)) -> TreeM n f m b
forall n (f :: * -> *) (m :: * -> *) a.
n -> m (f (TreeM n f m a)) -> TreeM n f m a
Node n
l 
              ((m (TreeM n f m a) -> m (TreeM n f m b))
-> f (m (TreeM n f m a)) -> m (f (TreeM n f m b))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\m (TreeM n f m a)
x -> TreeT n f m b -> m (TreeM n f m b)
forall n (f :: * -> *) (m :: * -> *) a.
TreeT n f m a -> m (TreeM n f m a)
runTreeT (TreeT n f m b -> m (TreeM n f m b))
-> TreeT n f m b -> m (TreeM n f m b)
forall a b. (a -> b) -> a -> b
$ TreeT n f m a -> (a -> TreeT n f m b) -> TreeT n f m b
forall (m :: * -> *) (f :: * -> *) n a b.
(Monad m, Functor f, Traversable f) =>
TreeT n f m a -> (a -> TreeT n f m b) -> TreeT n f m b
bind (m (TreeM n f m a) -> TreeT n f m a
forall n (f :: * -> *) (m :: * -> *) a.
m (TreeM n f m a) -> TreeT n f m a
TreeT m (TreeM n f m a)
x) a -> TreeT n f m b
f) 
                 ((TreeM n f m a -> m (TreeM n f m a))
-> f (TreeM n f m a) -> f (m (TreeM n f m a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TreeM n f m a -> m (TreeM n f m a)
forall (m :: * -> *) a. Monad m => a -> m a
return f (TreeM n f m a)
as))
        Leaf a
x -> TreeT n f m b -> m (TreeM n f m b)
forall n (f :: * -> *) (m :: * -> *) a.
TreeT n f m a -> m (TreeM n f m a)
runTreeT (TreeT n f m b -> m (TreeM n f m b))
-> TreeT n f m b -> m (TreeM n f m b)
forall a b. (a -> b) -> a -> b
$ a -> TreeT n f m b
f a
x

instance (Functor f, Functor m) => Functor (TreeM n f m) where
    fmap :: (a -> b) -> TreeM n f m a -> TreeM n f m b
fmap a -> b
f (Leaf a
x) = b -> TreeM n f m b
forall n (f :: * -> *) (m :: * -> *) a. a -> TreeM n f m a
Leaf (b -> TreeM n f m b) -> b -> TreeM n f m b
forall a b. (a -> b) -> a -> b
$ a -> b
f a
x
    fmap a -> b
f (Node n
l m (f (TreeM n f m a))
xs) = n -> m (f (TreeM n f m b)) -> TreeM n f m b
forall n (f :: * -> *) (m :: * -> *) a.
n -> m (f (TreeM n f m a)) -> TreeM n f m a
Node n
l ((TreeM n f m a -> TreeM n f m b)
-> f (TreeM n f m a) -> f (TreeM n f m b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> b
f (a -> b) -> TreeM n f m a -> TreeM n f m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (f (TreeM n f m a) -> f (TreeM n f m b))
-> m (f (TreeM n f m a)) -> m (f (TreeM n f m b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (f (TreeM n f m a))
xs)

instance (Functor f, Functor m) => Functor (TreeT n f m) where
    fmap :: (a -> b) -> TreeT n f m a -> TreeT n f m b
fmap a -> b
f TreeT n f m a
x = m (TreeM n f m b) -> TreeT n f m b
forall n (f :: * -> *) (m :: * -> *) a.
m (TreeM n f m a) -> TreeT n f m a
TreeT (m (TreeM n f m b) -> TreeT n f m b)
-> m (TreeM n f m b) -> TreeT n f m b
forall a b. (a -> b) -> a -> b
$ TreeT n f m a -> m (TreeM n f m a)
forall n (f :: * -> *) (m :: * -> *) a.
TreeT n f m a -> m (TreeM n f m a)
runTreeT TreeT n f m a
x m (TreeM n f m a)
-> (TreeM n f m a -> TreeM n f m b) -> m (TreeM n f m b)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (\TreeM n f m a
x' ->
            case TreeM n f m a
x' of
                Leaf a
v -> b -> TreeM n f m b
forall n (f :: * -> *) (m :: * -> *) a. a -> TreeM n f m a
Leaf (b -> TreeM n f m b) -> b -> TreeM n f m b
forall a b. (a -> b) -> a -> b
$ a -> b
f a
v
                Node n
l m (f (TreeM n f m a))
vs -> n -> m (f (TreeM n f m b)) -> TreeM n f m b
forall n (f :: * -> *) (m :: * -> *) a.
n -> m (f (TreeM n f m a)) -> TreeM n f m a
Node n
l (((a -> b
f (a -> b) -> TreeM n f m a -> TreeM n f m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (TreeM n f m a -> TreeM n f m b)
-> f (TreeM n f m a) -> f (TreeM n f m b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (f (TreeM n f m a) -> f (TreeM n f m b))
-> m (f (TreeM n f m a)) -> m (f (TreeM n f m b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (f (TreeM n f m a))
vs)
        ) 

instance (Monad m, Traversable f) => Applicative (TreeT n f m) where
    pure :: a -> TreeT n f m a
pure a
x = m (TreeM n f m a) -> TreeT n f m a
forall n (f :: * -> *) (m :: * -> *) a.
m (TreeM n f m a) -> TreeT n f m a
TreeT (m (TreeM n f m a) -> TreeT n f m a)
-> m (TreeM n f m a) -> TreeT n f m a
forall a b. (a -> b) -> a -> b
$ TreeM n f m a -> m (TreeM n f m a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TreeM n f m a -> m (TreeM n f m a))
-> TreeM n f m a -> m (TreeM n f m a)
forall a b. (a -> b) -> a -> b
$ a -> TreeM n f m a
forall n (f :: * -> *) (m :: * -> *) a. a -> TreeM n f m a
Leaf a
x
    <*> :: TreeT n f m (a -> b) -> TreeT n f m a -> TreeT n f m b
(<*>) TreeT n f m (a -> b)
fs TreeT n f m a
xs = do
        a -> b
f <- TreeT n f m (a -> b)
fs
        a
x <- TreeT n f m a
xs
        b -> TreeT n f m b
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> b
f a
x)

instance (Traversable f, Monad m) => Monad (TreeT n f m) where
    >>= :: TreeT n f m a -> (a -> TreeT n f m b) -> TreeT n f m b
(>>=) TreeT n f m a
x a -> TreeT n f m b
f = TreeT n f m a -> (a -> TreeT n f m b) -> TreeT n f m b
forall (m :: * -> *) (f :: * -> *) n a b.
(Monad m, Functor f, Traversable f) =>
TreeT n f m a -> (a -> TreeT n f m b) -> TreeT n f m b
bind TreeT n f m a
x a -> TreeT n f m b
f

instance MonadTrans (TreeT n f) where
    lift :: m a -> TreeT n f m a
lift m a
x = m (TreeM n f m a) -> TreeT n f m a
forall n (f :: * -> *) (m :: * -> *) a.
m (TreeM n f m a) -> TreeT n f m a
TreeT (m (TreeM n f m a) -> TreeT n f m a)
-> m (TreeM n f m a) -> TreeT n f m a
forall a b. (a -> b) -> a -> b
$ a -> TreeM n f m a
forall n (f :: * -> *) (m :: * -> *) a. a -> TreeM n f m a
Leaf (a -> TreeM n f m a) -> m a -> m (TreeM n f m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m a
x

instance (MonadIO m, Traversable f) => MonadIO (TreeT n f m) where
    liftIO :: IO a -> TreeT n f m a
liftIO IO a
x = m (TreeM n f m a) -> TreeT n f m a
forall n (f :: * -> *) (m :: * -> *) a.
m (TreeM n f m a) -> TreeT n f m a
TreeT (m (TreeM n f m a) -> TreeT n f m a)
-> m (TreeM n f m a) -> TreeT n f m a
forall a b. (a -> b) -> a -> b
$ IO (TreeM n f m a) -> m (TreeM n f m a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (TreeM n f m a) -> m (TreeM n f m a))
-> IO (TreeM n f m a) -> m (TreeM n f m a)
forall a b. (a -> b) -> a -> b
$ a -> TreeM n f m a
forall n (f :: * -> *) (m :: * -> *) a. a -> TreeM n f m a
Leaf (a -> TreeM n f m a) -> IO a -> IO (TreeM n f m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
x

instance (MonadState s m, Traversable f) => MonadState s (TreeT n f m) where
    get :: TreeT n f m s
get = m s -> TreeT n f m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> TreeT n f m ()
put = m () -> TreeT n f m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> TreeT n f m ()) -> (s -> m ()) -> s -> TreeT n f m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

-- TODO: Not sure how to define these instances.
{-
instance (MonadReader r m, Traversable f) => MonadReader r (TreeT n f m) where
    ask = lift ask
    local = undefined

instance (MonadWriter w m, Traversable f) => MonadWriter w (TreeT n f m) where
    tell   = lift . tell
    listen = undefined
    pass = undefined

instance (MonadError e m, Traversable f) => MonadError e (TreeT n f m) where
    throwError = lift throwError
    catchError = undefined
-}