module Data.TreeSeq.Strict where

import Control.Applicative (Applicative(..))
import Control.Monad (Monad(..))
import Data.Eq (Eq(..))
import Data.Foldable (Foldable(..))
import Data.Function (($), (.))
import Data.Functor (Functor(..), (<$>))
import Data.Monoid (Monoid(..))
import Data.Ord (Ord(..))
import Data.Semigroup (Semigroup(..))
import Data.Sequence (Seq, ViewL(..))
import Data.Traversable (Traversable(..))
import Text.Show (Show(..))
import qualified Data.List as List
import qualified Data.Sequence as Seq

-- * Type 'Tree'
data Tree a
 =   Tree { unTree   :: !a
          , subTrees :: !(Trees a)
          }
 deriving (Eq, Ord, Show)
instance Functor Tree where
	fmap f (Tree a ts) = Tree (f a) (fmap (fmap f) ts)
instance Applicative Tree where
	pure a = Tree a mempty
	Tree f tfs <*> ta@(Tree a tas) =
		Tree (f a) (fmap (f <$>) tas <> fmap (<*> ta) tfs)
instance Monad Tree where
	return = pure
	Tree a ts >>= f =
		Tree a' (ts' <> fmap (>>= f) ts)
		where Tree a' ts' = f a
instance Foldable Tree where
	foldMap f (Tree a ts) = f a `mappend` foldMap (foldMap f) ts
instance Traversable Tree where
	traverse f (Tree a ts) = Tree <$> f a <*> traverse (traverse f) ts
	sequenceA  (Tree a ts) = Tree <$>   a <*> traverse sequenceA ts

tree0 :: a -> Tree a
tree0 a = Tree a mempty

isTree0 :: Tree a -> Bool
isTree0 (Tree _ ts) = null ts

isTreeN :: Tree a -> Bool
isTreeN (Tree _ ts) = not (null ts)

-- * Type 'Trees'
type Trees a = Seq (Tree a)

prettyTree :: Show a => Tree a -> String
prettyTree = List.unlines . pretty

prettyTrees :: Show a => Trees a -> String
prettyTrees = foldr (\t acc -> prettyTree t <> "\n" <> acc) ""

pretty :: Show a => Tree a -> [String]
pretty (Tree a ts0) = show a : prettySubTrees ts0
	where
	shift first other = List.zipWith (<>) $ first : List.repeat other
	prettySubTrees s =
		case Seq.viewl s of
		 Seq.EmptyL -> []
		 t:<ts | Seq.null ts -> "|" : shift "`- " "   " (pretty t)
		       | otherwise   -> "|" : shift "+- " "|  " (pretty t) <> prettySubTrees ts