module Control.TreeT (TreeT(..)) where
import Data.Tree
import Data.Monoid

import Control.Monad.Trans.Class
import Control.Monad.IO.Class
import Control.Monad
import Control.Applicative

data TreeT m a = TreeT (m (Tree a))
runTreeT (TreeT m) = m
treeT0 = TreeT
treeT1 f = treeT0 . f . runTreeT
treeT2 f = treeT1 . f . runTreeT

instance Monad m => Monad (TreeT m) where
   return m =  TreeT $ return $ Node m []
   tmb_v >>= f = TreeT $ runTreeT tmb_v >>= onone f
          where onone f (Node b_ex b_ets) = do
                   Node ex ets <- runTreeT (f b_ex)
                   ets_ <- mapM (onone f) b_ets
                   return $ Node ex (ets_ ++ ets)

instance MonadTrans TreeT where
  lift x = treeT0 $ x >>= return . flip Node []

instance Functor m => Functor (TreeT m) where
    fmap = treeT1 . fmap . fmap

instance Applicative m => Applicative (TreeT m)where
    pure = treeT0 . pure . pure
    (<*>) = treeT2 . liftA2 . liftA2 $ id

instance MonadIO m => MonadIO (TreeT m) where
    liftIO = lift . liftIO