Copyright (C) 2011 CE Matthew Farkas-Dyck

This library is free software; it may be modified and/or redistributed under the terms of the GNU Lesser General Public License, as published by the Free Software Foundation, either version 3 of the License, or, optionally, any later version.

This library is distributed in the hope that it may be useful, but WITH NO WARRANTY, not even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.

A copy of the License ought to have come along with this program. Otherwise, see <http://www.gnu.org/licenses/>.

{-# LANGUAGE UndecidableInstances #-}

module Data.Tree (Tree (..), flatten, levels, unfoldTree, unfoldTreeM) where

import Prelude (Ord, Read, Show);

import Control.Applicative;
import Control.Monad hiding (mapM);
import Data.Bool;
import Data.Data;
import Data.Eq;
import Data.Foldable;
import Data.Function;
import Data.Functor;
import Data.Monoid;
import Data.Traversable;
import Util ((>>*));

-- |Multary (Rose) Tree
data Tree v a = Node a (v (Tree v a));

instance (Eq a, Eq (v (Tree v a))) => Eq (Tree v a) where {
  Node x ss == Node y ts = x == y && ss == ts;

instance (Functor v) => Functor (Tree v) where {
  fmap f (Node x ts) = Node (f x) (fmap (fmap f) ts);

instance (Functor v, Alternative v) => Applicative (Tree v) where {
  pure x = Node x empty;
  (<*>) (Node f ss) t@(Node x ts) = Node (f x) (fmap (fmap f) ts <|> fmap (<*> t) ss);

instance (Foldable v) => Foldable (Tree v) where {
  foldMap f (Node x ts) = f x `mappend` foldMap (foldMap f) ts;

instance (Traversable v) => Traversable (Tree v) where {
  traverse f (Node x ts) = (<*>) (fmap Node (f x)) (traverse (traverse f) ts);

flatten :: (Foldable v) => Tree v a -> [a];
flatten (Node x ts) = x : concatMap flatten ts;

-- |flatten tree, unique nodes only
-- |WARNING: assumption that (x == y) => (Node x _) == (Node y _)
flattenub :: (Eq a, Foldable v) => Tree v a -> [a];
flattenub t = flattenub' t []
  where flattenub' (Node x ts) xs = if x `elem` xs
                                    then xs
                                    else x : foldr flattenub' xs ts;

data Zip a = Zip { unZip :: [a] };
instance Monoid a => Monoid (Zip a) where {
  mempty = Zip [];
  mappend (Zip []) y = y;
  mappend x (Zip []) = x;
  mappend (Zip (x:xs))
          (Zip (y:ys)) = Zip $ x `mappend` y : unZip (Zip xs `mappend` Zip ys);

-- |Lists of nodes at each level of tree
levels :: (Applicative v, Foldable v, Monoid (v a)) => Tree v a -> [v a];
levels (Node x ts) = pure x : unZip (foldMap (fmap Zip levels) ts);

-- |Build tree from seed value
unfoldTree :: (Functor v) => (b -> (a, v b)) -> b -> Tree v a;
unfoldTree f y = let (x, ts) = f y in Node x (fmap (unfoldTree f) ts);

-- |Build a tree from seed value, monadically
unfoldTreeM :: (Monad m, Traversable v) => (b -> m (a, v b)) -> b -> m (Tree v a);
unfoldTreeM f y = f y >>= \ (x, ts) -> mapM (unfoldTreeM f) ts >>* (Node x);