-- | N-ary trees.

{-# LANGUAGE FlexibleInstances, TypeSynonymInstances #-}
module Math.Combinat.Trees.Nary 
  (      
    -- * Types
    module Data.Tree
  , Tree(..)
    -- * Regular trees  
  , ternaryTrees
  , regularNaryTrees
  , semiRegularTrees
  , countTernaryTrees
  , countRegularNaryTrees
    -- * \"derivation trees\"
  , derivTrees
    -- * ASCII drawings
  , asciiTreeVertical_
  , asciiTreeVertical
  , asciiTreeVerticalLeavesOnly
    -- * Graphviz drawing
  , Dot
  , graphvizDotTree  
  , graphvizDotForest
    -- * Classifying nodes
  , classifyTreeNode
  , isTreeLeaf  , isTreeNode
  , isTreeLeaf_ , isTreeNode_
  , treeNodeNumberOfChildren 
    -- * Counting nodes
  , countTreeNodes
  , countTreeLeaves
  , countTreeLabelsWith
  , countTreeNodesWith 
    -- * Left and right spines
  , leftSpine  , leftSpine_
  , rightSpine , rightSpine_
  , leftSpineLength , rightSpineLength
    -- * Unique labels
  , addUniqueLabelsTree
  , addUniqueLabelsForest
  , addUniqueLabelsTree_
  , addUniqueLabelsForest_
    -- * Labelling by depth
  , labelDepthTree
  , labelDepthForest
  , labelDepthTree_
  , labelDepthForest_
    -- * Labelling by number of children
  , labelNChildrenTree
  , labelNChildrenForest
  , labelNChildrenTree_
  , labelNChildrenForest_
    
  ) where


--------------------------------------------------------------------------------

import Data.Tree
import Data.List

import Control.Applicative

--import Control.Monad.State
import Control.Monad.Trans.State
import Data.Traversable (traverse)

import Math.Combinat.Sets                  ( listTensor )
import Math.Combinat.Partitions.Multiset   ( partitionMultiset )
import Math.Combinat.Compositions          ( compositions )
import Math.Combinat.Numbers               ( factorial, binomial )

import Math.Combinat.Trees.Graphviz ( Dot , graphvizDotForest , graphvizDotTree )

import Math.Combinat.Classes
import Math.Combinat.ASCII as ASCII
import Math.Combinat.Helper

--------------------------------------------------------------------------------

instance HasNumberOfNodes (Tree a) where
  numberOfNodes :: Tree a -> Int
numberOfNodes = forall {a} {a}. Num a => Tree a -> a
go where
    go :: Tree a -> a
go (Node a
label [Tree a]
subforest) = if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Tree a]
subforest 
      then a
0 
      else a
1 forall a. Num a => a -> a -> a
+ forall a. Num a => [a] -> a
sum' (forall a b. (a -> b) -> [a] -> [b]
map Tree a -> a
go [Tree a]
subforest)

instance HasNumberOfLeaves (Tree a) where
  numberOfLeaves :: Tree a -> Int
numberOfLeaves = forall {a} {a}. Num a => Tree a -> a
go where
    go :: Tree a -> a
go (Node a
label [Tree a]
subforest) = if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Tree a]
subforest 
      then a
1
      else forall a. Num a => [a] -> a
sum' (forall a b. (a -> b) -> [a] -> [b]
map Tree a -> a
go [Tree a]
subforest)

--------------------------------------------------------------------------------

-- | @regularNaryTrees d n@ returns the list of (rooted) trees on @n@ nodes where each
-- node has exactly @d@ children. Note that the leaves do not count in @n@.
-- Naive algorithm.
regularNaryTrees 
  :: Int         -- ^ degree = number of children of each node
  -> Int         -- ^ number of nodes
  -> [Tree ()]
regularNaryTrees :: Int -> Int -> [Tree ()]
regularNaryTrees Int
d = Int -> [Tree ()]
go where
  go :: Int -> [Tree ()]
go Int
0 = [ forall a. a -> [Tree a] -> Tree a
Node () [] ]
  go Int
n = [ forall a. a -> [Tree a] -> Tree a
Node () [Tree ()]
cs
         | [Int]
is <- forall a. Integral a => a -> a -> [[Int]]
compositions Int
d (Int
nforall a. Num a => a -> a -> a
-Int
1) 
         , [Tree ()]
cs <- forall a. [[a]] -> [[a]]
listTensor [ Int -> [Tree ()]
go Int
i | Int
i<-[Int]
is ] 
         ]
  
-- | Ternary trees on @n@ nodes (synonym for @regularNaryTrees 3@)
ternaryTrees :: Int -> [Tree ()]  
ternaryTrees :: Int -> [Tree ()]
ternaryTrees = Int -> Int -> [Tree ()]
regularNaryTrees Int
3

-- | We have 
--
-- > length (regularNaryTrees d n) == countRegularNaryTrees d n == \frac {1} {(d-1)n+1} \binom {dn} {n} 
--
countRegularNaryTrees :: (Integral a, Integral b) => a -> b -> Integer
countRegularNaryTrees :: forall a b. (Integral a, Integral b) => a -> b -> Integer
countRegularNaryTrees a
d b
n = forall a. Integral a => a -> a -> Integer
binomial (Integer
ddforall a. Num a => a -> a -> a
*Integer
nn) Integer
nn forall a. Integral a => a -> a -> a
`div` ((Integer
ddforall a. Num a => a -> a -> a
-Integer
1)forall a. Num a => a -> a -> a
*Integer
nnforall a. Num a => a -> a -> a
+Integer
1) where
  dd :: Integer
dd = forall a b. (Integral a, Num b) => a -> b
fromIntegral a
d :: Integer
  nn :: Integer
nn = forall a b. (Integral a, Num b) => a -> b
fromIntegral b
n :: Integer 

-- | @\# = \\frac {1} {(2n+1} \\binom {3n} {n}@
countTernaryTrees :: Integral a => a -> Integer  
countTernaryTrees :: forall a. Integral a => a -> Integer
countTernaryTrees = forall a b. (Integral a, Integral b) => a -> b -> Integer
countRegularNaryTrees (Int
3::Int)

--------------------------------------------------------------------------------

-- | All trees on @n@ nodes where the number of children of all nodes is
-- in element of the given set. Example:
--
-- > autoTabulate RowMajor (Right 5) $ map asciiTreeVertical 
-- >                                 $ map labelNChildrenTree_ 
-- >                                 $ semiRegularTrees [2,3] 2
-- >
-- > [ length $ semiRegularTrees [2,3] n | n<-[0..] ] == [1,2,10,66,498,4066,34970,312066,2862562,26824386,...]
--
-- The latter sequence is A027307 in OEIS: <https://oeis.org/A027307>
--
-- Remark: clearly, we have
--
-- > semiRegularTrees [d] n == regularNaryTrees d n
--
-- 
semiRegularTrees 
  :: [Int]         -- ^ set of allowed number of children
  -> Int           -- ^ number of nodes
  -> [Tree ()]
semiRegularTrees :: [Int] -> Int -> [Tree ()]
semiRegularTrees []    Int
n = if Int
nforall a. Eq a => a -> a -> Bool
==Int
0 then [forall a. a -> [Tree a] -> Tree a
Node () []] else []
semiRegularTrees [Int]
dset_ Int
n = 
  if forall a. [a] -> a
head [Int]
dset forall a. Ord a => a -> a -> Bool
>=Int
1 
    then Int -> [Tree ()]
go Int
n
    else forall a. HasCallStack => [Char] -> a
error [Char]
"semiRegularTrees: expecting a list of positive integers"
  where
    dset :: [Int]
dset = forall a b. (a -> b) -> [a] -> [b]
map forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall a. Eq a => [a] -> [[a]]
group forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> [a]
sort forall a b. (a -> b) -> a -> b
$ [Int]
dset_
    
    go :: Int -> [Tree ()]
go Int
0 = [ forall a. a -> [Tree a] -> Tree a
Node () [] ]
    go Int
n = [ forall a. a -> [Tree a] -> Tree a
Node () [Tree ()]
cs
           | Int
d <- [Int]
dset
           , [Int]
is <- forall a. Integral a => a -> a -> [[Int]]
compositions Int
d (Int
nforall a. Num a => a -> a -> a
-Int
1) 
           , [Tree ()]
cs <- forall a. [[a]] -> [[a]]
listTensor [ Int -> [Tree ()]
go Int
i | Int
i<-[Int]
is ]
           ]
           
{- 

NOTES:

A006318 = [ length $ semiRegularTrees [1,2] n | n<-[0..] ] == [1,2,6,22,90,394,1806,8558,41586,206098,1037718.. ]
??      = [ length $ semiRegularTrees [1,3] n | n<-[0..] ] == [1,2,8,44,280,1936,14128,107088,834912,6652608 .. ]
??      = [ length $ semiRegularTrees [1,4] n | n<-[0..] ] == [1,2,10,74,642,6082,60970,635818,6826690

A027307 = [ length $ semiRegularTrees [2,3] n | n<-[0..] ] == [1,2,10,66,498,4066,34970,312066,2862562,26824386,...]
A219534 = [ length $ semiRegularTrees [2,4] n | n<-[0..] ] == [1,2,12,100,968,10208,113792,1318832 ..]
??      = [ length $ semiRegularTrees [2,5] n | n<-[0..] ] == [1,2,14,142,1690,21994,303126,4348102 ..]

A144097 = [ length $ semiRegularTrees [3,4] n | n<-[0..] ] == [1,2,14,134,1482,17818,226214,2984206,40503890..]

A107708 = [ length $ semiRegularTrees [1,2,3]   n | n<-[0..] ] == [1,3,18,144,1323,13176,138348,1507977 .. ]
??      = [ length $ semiRegularTrees [1,2,3,4] n | n<-[0..] ] == [1,4,40,560,9120,161856,3036800,59242240 .. ] 

-}
             
--------------------------------------------------------------------------------

-- | Vertical ASCII drawing of a tree, without labels. Example:
--
-- > autoTabulate RowMajor (Right 5) $ map asciiTreeVertical_ $ regularNaryTrees 2 4 
--
-- Nodes are denoted by @\@@, leaves by @*@.
--
asciiTreeVertical_ :: Tree a -> ASCII
asciiTreeVertical_ :: forall a. Tree a -> ASCII
asciiTreeVertical_ Tree a
tree = [[Char]] -> ASCII
ASCII.asciiFromLines (forall b. Tree b -> [[Char]]
go Tree a
tree) where
  go :: Tree b -> [String]
  go :: forall b. Tree b -> [[Char]]
go (Node b
_ [Tree b]
cs) = case [Tree b]
cs of
    [] -> [[Char]
"-*"]
    [Tree b]
_  -> forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall a b. (Bool -> Bool -> a -> b) -> [a] -> [b]
mapWithFirstLast Bool -> Bool -> [[Char]] -> [[Char]]
f forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall b. Tree b -> [[Char]]
go [Tree b]
cs
    
  f :: Bool -> Bool -> [String] -> [String] 
  f :: Bool -> Bool -> [[Char]] -> [[Char]]
f Bool
bf Bool
bl ([Char]
l:[[Char]]
ls) = let indent :: [Char]
indent = if Bool
bl           then [Char]
"  "  else  [Char]
"| "
                       gap :: [[Char]]
gap    = if Bool
bl           then []    else [[Char]
"| "]
                       branch :: [Char]
branch = if Bool
bl Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
bf 
                                  then [Char]
"\\-" 
                                  else if Bool
bf then [Char]
"@-"
                                             else [Char]
"+-"
                   in  ([Char]
branchforall a. [a] -> [a] -> [a]
++[Char]
l) forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map ([Char]
indentforall a. [a] -> [a] -> [a]
++) [[Char]]
ls forall a. [a] -> [a] -> [a]
++ [[Char]]
gap

instance DrawASCII (Tree ()) where
  ascii :: Tree () -> ASCII
ascii = forall a. Tree a -> ASCII
asciiTreeVertical_

-- | Prints all labels. Example:
-- 
-- > asciiTreeVertical $ addUniqueLabelsTree_ $ (regularNaryTrees 3 9) !! 666
--
-- Nodes are denoted by @(label)@, leaves by @label@.
--
asciiTreeVertical :: Show a => Tree a -> ASCII
asciiTreeVertical :: forall a. Show a => Tree a -> ASCII
asciiTreeVertical Tree a
tree = [[Char]] -> ASCII
ASCII.asciiFromLines (forall b. Show b => Tree b -> [[Char]]
go Tree a
tree) where
  go :: Show b => Tree b -> [String]
  go :: forall b. Show b => Tree b -> [[Char]]
go (Node b
x [Tree b]
cs) = case [Tree b]
cs of
    [] -> [[Char]
"-- " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show b
x]
    [Tree b]
_  -> forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall a b. (Bool -> Bool -> a -> b) -> [a] -> [b]
mapWithFirstLast ([Char] -> Bool -> Bool -> [[Char]] -> [[Char]]
f (forall a. Show a => a -> [Char]
show b
x)) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall b. Show b => Tree b -> [[Char]]
go [Tree b]
cs
    
  f :: String -> Bool -> Bool -> [String] -> [String] 
  f :: [Char] -> Bool -> Bool -> [[Char]] -> [[Char]]
f [Char]
label Bool
bf Bool
bl ([Char]
l:[[Char]]
ls) =
        let spaces :: [Char]
spaces = (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Char
' ') [Char]
label  ) 
            dashes :: [Char]
dashes = (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Char
'-') [Char]
spaces ) 
            indent :: [Char]
indent = if Bool
bl then [Char]
"  " forall a. [a] -> [a] -> [a]
++[Char]
spacesforall a. [a] -> [a] -> [a]
++[Char]
"  " else  [Char]
" |" forall a. [a] -> [a] -> [a]
++ [Char]
spaces forall a. [a] -> [a] -> [a]
++ [Char]
"  "
            gap :: [[Char]]
gap    = if Bool
bl then []                  else [[Char]
" |" forall a. [a] -> [a] -> [a]
++ [Char]
spaces forall a. [a] -> [a] -> [a]
++ [Char]
"  "]
            branch :: [Char]
branch = if Bool
bl Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
bf
                           then [Char]
" \\"forall a. [a] -> [a] -> [a]
++[Char]
dashesforall a. [a] -> [a] -> [a]
++[Char]
"--" 
                           else if Bool
bf 
                             then [Char]
"-(" forall a. [a] -> [a] -> [a]
++ [Char]
label  forall a. [a] -> [a] -> [a]
++ [Char]
")-"
                             else [Char]
" +" forall a. [a] -> [a] -> [a]
++ [Char]
dashes forall a. [a] -> [a] -> [a]
++ [Char]
"--"
        in  ([Char]
branchforall a. [a] -> [a] -> [a]
++[Char]
l) forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map ([Char]
indentforall a. [a] -> [a] -> [a]
++) [[Char]]
ls forall a. [a] -> [a] -> [a]
++ [[Char]]
gap

-- | Prints the labels for the leaves, but not for the  nodes.
asciiTreeVerticalLeavesOnly :: Show a => Tree a -> ASCII
asciiTreeVerticalLeavesOnly :: forall a. Show a => Tree a -> ASCII
asciiTreeVerticalLeavesOnly Tree a
tree = [[Char]] -> ASCII
ASCII.asciiFromLines (forall b. Show b => Tree b -> [[Char]]
go Tree a
tree) where
  go :: Show b => Tree b -> [String]
  go :: forall b. Show b => Tree b -> [[Char]]
go (Node b
x [Tree b]
cs) = case [Tree b]
cs of
    [] -> [[Char]
"- " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show b
x]
    [Tree b]
_  -> forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall a b. (Bool -> Bool -> a -> b) -> [a] -> [b]
mapWithFirstLast Bool -> Bool -> [[Char]] -> [[Char]]
f forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall b. Show b => Tree b -> [[Char]]
go [Tree b]
cs
    
  f :: Bool -> Bool -> [String] -> [String] 
  f :: Bool -> Bool -> [[Char]] -> [[Char]]
f Bool
bf Bool
bl ([Char]
l:[[Char]]
ls) = let indent :: [Char]
indent = if Bool
bl           then [Char]
"  "  else  [Char]
"| "
                       gap :: [[Char]]
gap    = if Bool
bl           then []    else [[Char]
"| "]
                       branch :: [Char]
branch = if Bool
bl Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
bf 
                                  then [Char]
"\\-" 
                                  else if Bool
bf then [Char]
"@-"
                                             else [Char]
"+-"
                   in  ([Char]
branchforall a. [a] -> [a] -> [a]
++[Char]
l) forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map ([Char]
indentforall a. [a] -> [a] -> [a]
++) [[Char]]
ls forall a. [a] -> [a] -> [a]
++ [[Char]]
gap
  
--------------------------------------------------------------------------------
  
-- | The leftmost spine (the second element of the pair is the leaf node)
leftSpine  :: Tree a -> ([a],a)
leftSpine :: forall a. Tree a -> ([a], a)
leftSpine = forall a. Tree a -> ([a], a)
go where
  go :: Tree a -> ([a], a)
go (Node a
x [Tree a]
cs) = case [Tree a]
cs of
    [] -> ([],a
x)
    [Tree a]
_  -> let ([a]
xs,a
y) = Tree a -> ([a], a)
go (forall a. [a] -> a
head [Tree a]
cs) in (a
xforall a. a -> [a] -> [a]
:[a]
xs,a
y) 

rightSpine  :: Tree a -> ([a],a)
rightSpine :: forall a. Tree a -> ([a], a)
rightSpine = forall a. Tree a -> ([a], a)
go where
  go :: Tree a -> ([a], a)
go (Node a
x [Tree a]
cs) = case [Tree a]
cs of
    [] -> ([],a
x)
    [Tree a]
_  -> let ([a]
xs,a
y) = Tree a -> ([a], a)
go (forall a. [a] -> a
last [Tree a]
cs) in (a
xforall a. a -> [a] -> [a]
:[a]
xs,a
y) 

-- | The leftmost spine without the leaf node
leftSpine_  :: Tree a -> [a]
leftSpine_ :: forall a. Tree a -> [a]
leftSpine_ = forall a. Tree a -> [a]
go where
  go :: Tree a -> [a]
go (Node a
x [Tree a]
cs) = case [Tree a]
cs of
    [] -> []
    [Tree a]
_  -> a
x forall a. a -> [a] -> [a]
: Tree a -> [a]
go (forall a. [a] -> a
head [Tree a]
cs)

rightSpine_ :: Tree a -> [a] 
rightSpine_ :: forall a. Tree a -> [a]
rightSpine_ = forall a. Tree a -> [a]
go where
  go :: Tree a -> [a]
go (Node a
x [Tree a]
cs) = case [Tree a]
cs of
    [] -> []
    [Tree a]
_  -> a
x forall a. a -> [a] -> [a]
: Tree a -> [a]
go (forall a. [a] -> a
last [Tree a]
cs) 

-- | The length (number of edges) on the left spine 
--
-- > leftSpineLength tree == length (leftSpine_ tree)
--
leftSpineLength  :: Tree a -> Int  
leftSpineLength :: forall a. Tree a -> Int
leftSpineLength = forall {t} {a}. Num t => t -> Tree a -> t
go Int
0 where
  go :: t -> Tree a -> t
go t
n (Node a
x [Tree a]
cs) = case [Tree a]
cs of
    [] -> t
n
    [Tree a]
_  -> t -> Tree a -> t
go (t
nforall a. Num a => a -> a -> a
+t
1) (forall a. [a] -> a
head [Tree a]
cs)
  
rightSpineLength :: Tree a -> Int  
rightSpineLength :: forall a. Tree a -> Int
rightSpineLength = forall {t} {a}. Num t => t -> Tree a -> t
go Int
0 where
  go :: t -> Tree a -> t
go t
n (Node a
x [Tree a]
cs) = case [Tree a]
cs of
    [] -> t
n
    [Tree a]
_  -> t -> Tree a -> t
go (t
nforall a. Num a => a -> a -> a
+t
1) (forall a. [a] -> a
last [Tree a]
cs)

--------------------------------------------------------------------------------

-- | 'Left' is leaf, 'Right' is node
classifyTreeNode :: Tree a -> Either a a
classifyTreeNode :: forall a. Tree a -> Either a a
classifyTreeNode (Node a
x [Tree a]
cs) = case [Tree a]
cs of { [] -> forall a b. a -> Either a b
Left a
x ; [Tree a]
_ -> forall a b. b -> Either a b
Right a
x }

isTreeLeaf :: Tree a -> Maybe a  
isTreeLeaf :: forall a. Tree a -> Maybe a
isTreeLeaf (Node a
x [Tree a]
cs) = case [Tree a]
cs of { [] -> forall a. a -> Maybe a
Just a
x ; [Tree a]
_ -> forall a. Maybe a
Nothing }  

isTreeNode :: Tree a -> Maybe a  
isTreeNode :: forall a. Tree a -> Maybe a
isTreeNode (Node a
x [Tree a]
cs) = case [Tree a]
cs of { [] -> forall a. Maybe a
Nothing ; [Tree a]
_ -> forall a. a -> Maybe a
Just a
x }  

isTreeLeaf_ :: Tree a -> Bool  
isTreeLeaf_ :: forall a. Tree a -> Bool
isTreeLeaf_ (Node a
x [Tree a]
cs) = case [Tree a]
cs of { [] -> Bool
True ; [Tree a]
_ -> Bool
False }  
  
isTreeNode_ :: Tree a -> Bool  
isTreeNode_ :: forall a. Tree a -> Bool
isTreeNode_ (Node a
x [Tree a]
cs) = case [Tree a]
cs of { [] -> Bool
False ; [Tree a]
_ -> Bool
True }  

treeNodeNumberOfChildren :: Tree a -> Int
treeNodeNumberOfChildren :: forall a. Tree a -> Int
treeNodeNumberOfChildren (Node a
_ [Tree a]
cs) = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tree a]
cs

--------------------------------------------------------------------------------
-- counting

countTreeNodes :: Tree a -> Int
countTreeNodes :: forall a. Tree a -> Int
countTreeNodes = forall {a} {a}. Num a => Tree a -> a
go where
  go :: Tree a -> a
go (Node a
x [Tree a]
cs) = case [Tree a]
cs of
    [] -> a
0
    [Tree a]
_  -> a
1 forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map Tree a -> a
go [Tree a]
cs)

countTreeLeaves :: Tree a -> Int
countTreeLeaves :: forall a. Tree a -> Int
countTreeLeaves = forall {a} {a}. Num a => Tree a -> a
go where
  go :: Tree a -> a
go (Node a
x [Tree a]
cs) = case [Tree a]
cs of
    [] -> a
1
    [Tree a]
_  -> forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map Tree a -> a
go [Tree a]
cs)

countTreeLabelsWith :: (a -> Bool) -> Tree a -> Int
countTreeLabelsWith :: forall a. (a -> Bool) -> Tree a -> Int
countTreeLabelsWith a -> Bool
f = forall {a}. Num a => Tree a -> a
go where
  go :: Tree a -> a
go (Node a
label [Tree a]
cs) = (if a -> Bool
f a
label then a
1 else a
0) forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map Tree a -> a
go [Tree a]
cs)

countTreeNodesWith :: (Tree a -> Bool) -> Tree a -> Int
countTreeNodesWith :: forall a. (Tree a -> Bool) -> Tree a -> Int
countTreeNodesWith Tree a -> Bool
f = forall {a}. Num a => Tree a -> a
go where
  go :: Tree a -> a
go node :: Tree a
node@(Node a
_ [Tree a]
cs) = (if Tree a -> Bool
f Tree a
node then a
1 else a
0) forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map Tree a -> a
go [Tree a]
cs)

--------------------------------------------------------------------------------

-- | Adds unique labels to the nodes (including leaves) of a 'Tree'.
addUniqueLabelsTree :: Tree a -> Tree (a,Int) 
addUniqueLabelsTree :: forall a. Tree a -> Tree (a, Int)
addUniqueLabelsTree Tree a
tree = forall a. [a] -> a
head (forall a. Forest a -> Forest (a, Int)
addUniqueLabelsForest [Tree a
tree])

-- | Adds unique labels to the nodes (including leaves) of a 'Forest'
addUniqueLabelsForest :: Forest a -> Forest (a,Int) 
addUniqueLabelsForest :: forall a. Forest a -> Forest (a, Int)
addUniqueLabelsForest Forest a
forest = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {t :: * -> *} {m :: * -> *} {b} {a}.
(Traversable t, Monad m, Num b) =>
t a -> StateT b m (t (a, b))
globalAction Forest a
forest) Int
1 where
  globalAction :: t a -> StateT b m (t (a, b))
globalAction t a
tree = 
    forall (m :: * -> *) a. WrappedMonad m a -> m a
unwrapMonad forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {m :: * -> *} {b} {a}.
(Monad m, Num b) =>
a -> WrappedMonad (StateT b m) (a, b)
localAction t a
tree 
  localAction :: a -> WrappedMonad (StateT b m) (a, b)
localAction a
x = forall (m :: * -> *) a. m a -> WrappedMonad m a
WrapMonad forall a b. (a -> b) -> a -> b
$ do
    b
i <- forall (m :: * -> *) s. Monad m => StateT s m s
get
    forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (b
iforall a. Num a => a -> a -> a
+b
1)
    forall (m :: * -> *) a. Monad m => a -> m a
return (a
x,b
i)

addUniqueLabelsTree_ :: Tree a -> Tree Int
addUniqueLabelsTree_ :: forall a. Tree a -> Tree Int
addUniqueLabelsTree_ = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Tree a -> Tree (a, Int)
addUniqueLabelsTree  

addUniqueLabelsForest_ :: Forest a -> Forest Int
addUniqueLabelsForest_ :: forall a. Forest a -> Forest Int
addUniqueLabelsForest_ = forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Forest a -> Forest (a, Int)
addUniqueLabelsForest

--------------------------------------------------------------------------------
    
-- | Attaches the depth to each node. The depth of the root is 0. 
labelDepthTree :: Tree a -> Tree (a,Int) 
labelDepthTree :: forall a. Tree a -> Tree (a, Int)
labelDepthTree Tree a
tree = forall {t} {a}. Num t => t -> Tree a -> Tree (a, t)
worker Int
0 Tree a
tree where
  worker :: t -> Tree a -> Tree (a, t)
worker t
depth (Node a
label [Tree a]
subtrees) = forall a. a -> [Tree a] -> Tree a
Node (a
label,t
depth) (forall a b. (a -> b) -> [a] -> [b]
map (t -> Tree a -> Tree (a, t)
worker (t
depthforall a. Num a => a -> a -> a
+t
1)) [Tree a]
subtrees)

labelDepthForest :: Forest a -> Forest (a,Int) 
labelDepthForest :: forall a. Forest a -> Forest (a, Int)
labelDepthForest Forest a
forest = forall a b. (a -> b) -> [a] -> [b]
map forall a. Tree a -> Tree (a, Int)
labelDepthTree Forest a
forest
    
labelDepthTree_ :: Tree a -> Tree Int
labelDepthTree_ :: forall a. Tree a -> Tree Int
labelDepthTree_ = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Tree a -> Tree (a, Int)
labelDepthTree

labelDepthForest_ :: Forest a -> Forest Int 
labelDepthForest_ :: forall a. Forest a -> Forest Int
labelDepthForest_ = forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Forest a -> Forest (a, Int)
labelDepthForest

--------------------------------------------------------------------------------

-- | Attaches the number of children to each node. 
labelNChildrenTree :: Tree a -> Tree (a,Int)
labelNChildrenTree :: forall a. Tree a -> Tree (a, Int)
labelNChildrenTree (Node a
x [Tree a]
subforest) = 
  forall a. a -> [Tree a] -> Tree a
Node (a
x, forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tree a]
subforest) (forall a b. (a -> b) -> [a] -> [b]
map forall a. Tree a -> Tree (a, Int)
labelNChildrenTree [Tree a]
subforest)
  
labelNChildrenForest :: Forest a -> Forest (a,Int) 
labelNChildrenForest :: forall a. Forest a -> Forest (a, Int)
labelNChildrenForest Forest a
forest = forall a b. (a -> b) -> [a] -> [b]
map forall a. Tree a -> Tree (a, Int)
labelNChildrenTree Forest a
forest

labelNChildrenTree_ :: Tree a -> Tree Int
labelNChildrenTree_ :: forall a. Tree a -> Tree Int
labelNChildrenTree_ = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Tree a -> Tree (a, Int)
labelNChildrenTree

labelNChildrenForest_ :: Forest a -> Forest Int 
labelNChildrenForest_ :: forall a. Forest a -> Forest Int
labelNChildrenForest_ = forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Forest a -> Forest (a, Int)
labelNChildrenForest
    
--------------------------------------------------------------------------------

-- | Computes the set of equivalence classes of rooted trees (in the 
-- sense that the leaves of a node are /unordered/) 
-- with @n = length ks@ leaves where the set of heights of 
-- the leaves matches the given set of numbers. 
-- The height is defined as the number of /edges/ from the leaf to the root. 
--
-- TODO: better name?
derivTrees :: [Int] -> [Tree ()]
derivTrees :: [Int] -> [Tree ()]
derivTrees [Int]
xs = [Int] -> [Tree ()]
derivTrees' (forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
+Int
1) [Int]
xs)

derivTrees' :: [Int] -> [Tree ()]
derivTrees' :: [Int] -> [Tree ()]
derivTrees' [] = []
derivTrees' [Int
n] = 
  if Int
nforall a. Ord a => a -> a -> Bool
>=Int
1 
    then [forall b a. (b -> (a, [b])) -> b -> Tree a
unfoldTree Int -> ((), [Int])
f Int
1] 
    else [] 
  where 
    f :: Int -> ((), [Int])
f Int
k = if Int
kforall a. Ord a => a -> a -> Bool
<Int
n then ((),[Int
kforall a. Num a => a -> a -> a
+Int
1]) else ((),[])
derivTrees' [Int]
ks = 
  if forall (t :: * -> *). Foldable t => t Bool -> Bool
and (forall a b. (a -> b) -> [a] -> [b]
map (forall a. Ord a => a -> a -> Bool
>Int
0) [Int]
ks)
    then
      [ forall a. a -> [Tree a] -> Tree a
Node () [Tree ()]
sub 
      | [[Int]]
part <- [[[Int]]]
parts
      , let subtrees :: [[Tree ()]]
subtrees = forall a b. (a -> b) -> [a] -> [b]
map [Int] -> [Tree ()]
g [[Int]]
part
      , [Tree ()]
sub <- forall a. [[a]] -> [[a]]
listTensor [[Tree ()]]
subtrees 
      ] 
    else []
  where
    parts :: [[[Int]]]
parts = forall a. (Eq a, Ord a) => [a] -> [[[a]]]
partitionMultiset [Int]
ks
    g :: [Int] -> [Tree ()]
g [Int]
xs = [Int] -> [Tree ()]
derivTrees' (forall a b. (a -> b) -> [a] -> [b]
map (\Int
x->Int
xforall a. Num a => a -> a -> a
-Int
1) [Int]
xs)

--------------------------------------------------------------------------------