-- | N-ary trees.

{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Math.Combinat.Trees.Nary
  (
    -- * Types
    module Data.Tree
  , Tree(..)
    -- * Regular trees
  , ternaryTrees
  , regularNaryTrees
  , semiRegularTrees
  , countTernaryTrees
  , countRegularNaryTrees
    -- * \"derivation trees\"
  , derivTrees
    -- * ASCII drawings
  , asciiTreeVertical_
  , asciiTreeVertical
  , asciiTreeVerticalLeavesOnly
    -- * Graphviz drawing
  , Dot
  , graphvizDotTree
  , graphvizDotForest
    -- * Classifying nodes
  , classifyTreeNode
  , isTreeLeaf  , isTreeNode
  , isTreeLeaf_ , isTreeNode_
  , treeNodeNumberOfChildren
    -- * Counting nodes
  , countTreeNodes
  , countTreeLeaves
  , countTreeLabelsWith
  , countTreeNodesWith
    -- * Left and right spines
  , leftSpine  , leftSpine_
  , rightSpine , rightSpine_
  , leftSpineLength , rightSpineLength
    -- * Unique labels
  , addUniqueLabelsTree
  , addUniqueLabelsForest
  , addUniqueLabelsTree_
  , addUniqueLabelsForest_
    -- * Labelling by depth
  , labelDepthTree
  , labelDepthForest
  , labelDepthTree_
  , labelDepthForest_
    -- * Labelling by number of children
  , labelNChildrenTree
  , labelNChildrenForest
  , labelNChildrenTree_
  , labelNChildrenForest_

  ) where


--------------------------------------------------------------------------------

import           Data.List
import           Data.Tree

import           Control.Applicative

--import Control.Monad.State
import           Control.Monad.Trans.State
import           Data.Traversable                  (traverse)

import           Math.Combinat.Compositions        (compositions)
import           Math.Combinat.Numbers             (binomial, factorial)
import           Math.Combinat.Partitions.Multiset (partitionMultiset)
import           Math.Combinat.Sets                (listTensor)

import           Math.Combinat.Trees.Graphviz      (Dot, graphvizDotForest,
                                                    graphvizDotTree)

import           Math.Combinat.ASCII               as ASCII
import           Math.Combinat.Classes
import           Math.Combinat.Helper

--------------------------------------------------------------------------------

instance HasNumberOfNodes (Tree a) where
  numberOfNodes :: Tree a -> Int
numberOfNodes = Tree a -> Int
forall p a. Num p => Tree a -> p
go where
    go :: Tree a -> p
go (Node a
label Forest a
subforest) = if Forest a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Forest a
subforest
      then p
0
      else p
1 p -> p -> p
forall a. Num a => a -> a -> a
+ [p] -> p
forall a. Num a => [a] -> a
sum' ((Tree a -> p) -> Forest a -> [p]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> p
go Forest a
subforest)

instance HasNumberOfLeaves (Tree a) where
  numberOfLeaves :: Tree a -> Int
numberOfLeaves = Tree a -> Int
forall p a. Num p => Tree a -> p
go where
    go :: Tree a -> p
go (Node a
label Forest a
subforest) = if Forest a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Forest a
subforest
      then p
1
      else [p] -> p
forall a. Num a => [a] -> a
sum' ((Tree a -> p) -> Forest a -> [p]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> p
go Forest a
subforest)

--------------------------------------------------------------------------------

-- | @regularNaryTrees d n@ returns the list of (rooted) trees on @n@ nodes where each
-- node has exactly @d@ children. Note that the leaves do not count in @n@.
-- Naive algorithm.
regularNaryTrees
  :: Int         -- ^ degree = number of children of each node
  -> Int         -- ^ number of nodes
  -> [Tree ()]
regularNaryTrees :: Int -> Int -> [Tree ()]
regularNaryTrees Int
d = Int -> [Tree ()]
go where
  go :: Int -> [Tree ()]
go Int
0 = [ () -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () [] ]
  go Int
n = [ () -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () [Tree ()]
cs
         | [Int]
is <- Int -> Int -> [[Int]]
forall a. Integral a => a -> a -> [[Int]]
compositions Int
d (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
         , [Tree ()]
cs <- [[Tree ()]] -> [[Tree ()]]
forall a. [[a]] -> [[a]]
listTensor [ Int -> [Tree ()]
go Int
i | Int
i<-[Int]
is ]
         ]

-- | Ternary trees on @n@ nodes (synonym for @regularNaryTrees 3@)
ternaryTrees :: Int -> [Tree ()]
ternaryTrees :: Int -> [Tree ()]
ternaryTrees = Int -> Int -> [Tree ()]
regularNaryTrees Int
3

-- | We have
--
-- > length (regularNaryTrees d n) == countRegularNaryTrees d n == \frac {1} {(d-1)n+1} \binom {dn} {n}
--
countRegularNaryTrees :: (Integral a, Integral b) => a -> b -> Integer
countRegularNaryTrees :: a -> b -> Integer
countRegularNaryTrees a
d b
n = Integer -> Integer -> Integer
forall a. Integral a => a -> a -> Integer
binomial (Integer
ddInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
nn) Integer
nn Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` ((Integer
ddInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
-Integer
1)Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
*Integer
nnInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1) where
  dd :: Integer
dd = a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
d :: Integer
  nn :: Integer
nn = b -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
n :: Integer

-- | @\# = \\frac {1} {(2n+1} \\binom {3n} {n}@
countTernaryTrees :: Integral a => a -> Integer
countTernaryTrees :: a -> Integer
countTernaryTrees = Int -> a -> Integer
forall a b. (Integral a, Integral b) => a -> b -> Integer
countRegularNaryTrees (Int
3::Int)

--------------------------------------------------------------------------------

-- | All trees on @n@ nodes where the number of children of all nodes is
-- in element of the given set. Example:
--
-- > autoTabulate RowMajor (Right 5) $ map asciiTreeVertical
-- >                                 $ map labelNChildrenTree_
-- >                                 $ semiRegularTrees [2,3] 2
-- >
-- > [ length $ semiRegularTrees [2,3] n | n<-[0..] ] == [1,2,10,66,498,4066,34970,312066,2862562,26824386,...]
--
-- The latter sequence is A027307 in OEIS: <https://oeis.org/A027307>
--
-- Remark: clearly, we have
--
-- > semiRegularTrees [d] n == regularNaryTrees d n
--
--
semiRegularTrees
  :: [Int]         -- ^ set of allowed number of children
  -> Int           -- ^ number of nodes
  -> [Tree ()]
semiRegularTrees :: [Int] -> Int -> [Tree ()]
semiRegularTrees []    Int
n = if Int
nInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0 then [() -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () []] else []
semiRegularTrees [Int]
dset_ Int
n =
  if [Int] -> Int
forall a. [a] -> a
head [Int]
dset Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>=Int
1
    then Int -> [Tree ()]
go Int
n
    else [Char] -> [Tree ()]
forall a. HasCallStack => [Char] -> a
error [Char]
"semiRegularTrees: expecting a list of positive integers"
  where
    dset :: [Int]
dset = ([Int] -> Int) -> [[Int]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [Int] -> Int
forall a. [a] -> a
head ([[Int]] -> [Int]) -> [[Int]] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [[Int]]
forall a. Eq a => [a] -> [[a]]
group ([Int] -> [[Int]]) -> [Int] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int]
dset_

    go :: Int -> [Tree ()]
go Int
0 = [ () -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () [] ]
    go Int
n = [ () -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () [Tree ()]
cs
           | Int
d <- [Int]
dset
           , [Int]
is <- Int -> Int -> [[Int]]
forall a. Integral a => a -> a -> [[Int]]
compositions Int
d (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
           , [Tree ()]
cs <- [[Tree ()]] -> [[Tree ()]]
forall a. [[a]] -> [[a]]
listTensor [ Int -> [Tree ()]
go Int
i | Int
i<-[Int]
is ]
           ]

{-

NOTES:

A006318 = [ length $ semiRegularTrees [1,2] n | n<-[0..] ] == [1,2,6,22,90,394,1806,8558,41586,206098,1037718.. ]
??      = [ length $ semiRegularTrees [1,3] n | n<-[0..] ] == [1,2,8,44,280,1936,14128,107088,834912,6652608 .. ]
??      = [ length $ semiRegularTrees [1,4] n | n<-[0..] ] == [1,2,10,74,642,6082,60970,635818,6826690

A027307 = [ length $ semiRegularTrees [2,3] n | n<-[0..] ] == [1,2,10,66,498,4066,34970,312066,2862562,26824386,...]
A219534 = [ length $ semiRegularTrees [2,4] n | n<-[0..] ] == [1,2,12,100,968,10208,113792,1318832 ..]
??      = [ length $ semiRegularTrees [2,5] n | n<-[0..] ] == [1,2,14,142,1690,21994,303126,4348102 ..]

A144097 = [ length $ semiRegularTrees [3,4] n | n<-[0..] ] == [1,2,14,134,1482,17818,226214,2984206,40503890..]

A107708 = [ length $ semiRegularTrees [1,2,3]   n | n<-[0..] ] == [1,3,18,144,1323,13176,138348,1507977 .. ]
??      = [ length $ semiRegularTrees [1,2,3,4] n | n<-[0..] ] == [1,4,40,560,9120,161856,3036800,59242240 .. ]

-}

--------------------------------------------------------------------------------

-- | Vertical ASCII drawing of a tree, without labels. Example:
--
-- > autoTabulate RowMajor (Right 5) $ map asciiTreeVertical_ $ regularNaryTrees 2 4
--
-- Nodes are denoted by @\@@, leaves by @*@.
--
asciiTreeVertical_ :: Tree a -> ASCII
asciiTreeVertical_ :: Tree a -> ASCII
asciiTreeVertical_ Tree a
tree = [[Char]] -> ASCII
ASCII.asciiFromLines (Tree a -> [[Char]]
forall b. Tree b -> [[Char]]
go Tree a
tree) where
  go :: Tree b -> [String]
  go :: Tree b -> [[Char]]
go (Node b
_ Forest b
cs) = case Forest b
cs of
    [] -> [[Char]
"-*"]
    Forest b
_  -> [[[Char]]] -> [[Char]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[Char]]] -> [[Char]]) -> [[[Char]]] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool -> [[Char]] -> [[Char]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (Bool -> Bool -> a -> b) -> [a] -> [b]
mapWithFirstLast Bool -> Bool -> [[Char]] -> [[Char]]
f ([[[Char]]] -> [[[Char]]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (a -> b) -> a -> b
$ (Tree b -> [[Char]]) -> Forest b -> [[[Char]]]
forall a b. (a -> b) -> [a] -> [b]
map Tree b -> [[Char]]
forall b. Tree b -> [[Char]]
go Forest b
cs

  f :: Bool -> Bool -> [String] -> [String]
  f :: Bool -> Bool -> [[Char]] -> [[Char]]
f Bool
bf Bool
bl ([Char]
l:[[Char]]
ls) = let indent :: [Char]
indent = if Bool
bl           then [Char]
"  "  else  [Char]
"| "
                       gap :: [[Char]]
gap    = if Bool
bl           then []    else [[Char]
"| "]
                       branch :: [Char]
branch = if Bool
bl Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
bf
                                  then [Char]
"\\-"
                                  else if Bool
bf then [Char]
"@-"
                                             else [Char]
"+-"
                   in  ([Char]
branch[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
l) [Char] -> [[Char]] -> [[Char]]
forall a. a -> [a] -> [a]
: ([Char] -> [Char]) -> [[Char]] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map ([Char]
indent[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++) [[Char]]
ls [[Char]] -> [[Char]] -> [[Char]]
forall a. [a] -> [a] -> [a]
++ [[Char]]
gap

instance DrawASCII (Tree ()) where
  ascii :: Tree () -> ASCII
ascii = Tree () -> ASCII
forall a. Tree a -> ASCII
asciiTreeVertical_

-- | Prints all labels. Example:
--
-- > asciiTreeVertical $ addUniqueLabelsTree_ $ (regularNaryTrees 3 9) !! 666
--
-- Nodes are denoted by @(label)@, leaves by @label@.
--
asciiTreeVertical :: Show a => Tree a -> ASCII
asciiTreeVertical :: Tree a -> ASCII
asciiTreeVertical Tree a
tree = [[Char]] -> ASCII
ASCII.asciiFromLines (Tree a -> [[Char]]
forall b. Show b => Tree b -> [[Char]]
go Tree a
tree) where
  go :: Show b => Tree b -> [String]
  go :: Tree b -> [[Char]]
go (Node b
x Forest b
cs) = case Forest b
cs of
    [] -> [[Char]
"-- " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ b -> [Char]
forall a. Show a => a -> [Char]
show b
x]
    Forest b
_  -> [[[Char]]] -> [[Char]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[Char]]] -> [[Char]]) -> [[[Char]]] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool -> [[Char]] -> [[Char]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (Bool -> Bool -> a -> b) -> [a] -> [b]
mapWithFirstLast ([Char] -> Bool -> Bool -> [[Char]] -> [[Char]]
f (b -> [Char]
forall a. Show a => a -> [Char]
show b
x)) ([[[Char]]] -> [[[Char]]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (a -> b) -> a -> b
$ (Tree b -> [[Char]]) -> Forest b -> [[[Char]]]
forall a b. (a -> b) -> [a] -> [b]
map Tree b -> [[Char]]
forall b. Show b => Tree b -> [[Char]]
go Forest b
cs

  f :: String -> Bool -> Bool -> [String] -> [String]
  f :: [Char] -> Bool -> Bool -> [[Char]] -> [[Char]]
f [Char]
label Bool
bf Bool
bl ([Char]
l:[[Char]]
ls) =
        let spaces :: [Char]
spaces = ((Char -> Char) -> [Char] -> [Char]
forall a b. (a -> b) -> [a] -> [b]
map (Char -> Char -> Char
forall a b. a -> b -> a
const Char
' ') [Char]
label  )
            dashes :: [Char]
dashes = ((Char -> Char) -> [Char] -> [Char]
forall a b. (a -> b) -> [a] -> [b]
map (Char -> Char -> Char
forall a b. a -> b -> a
const Char
'-') [Char]
spaces )
            indent :: [Char]
indent = if Bool
bl then [Char]
"  " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
spaces[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
"  " else  [Char]
" |" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
spaces [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"  "
            gap :: [[Char]]
gap    = if Bool
bl then []                  else [[Char]
" |" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
spaces [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"  "]
            branch :: [Char]
branch = if Bool
bl Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
bf
                           then [Char]
" \\"[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
dashes[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
"--"
                           else if Bool
bf
                             then [Char]
"-(" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
label  [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
")-"
                             else [Char]
" +" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
dashes [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"--"
        in  ([Char]
branch[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
l) [Char] -> [[Char]] -> [[Char]]
forall a. a -> [a] -> [a]
: ([Char] -> [Char]) -> [[Char]] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map ([Char]
indent[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++) [[Char]]
ls [[Char]] -> [[Char]] -> [[Char]]
forall a. [a] -> [a] -> [a]
++ [[Char]]
gap

-- | Prints the labels for the leaves, but not for the  nodes.
asciiTreeVerticalLeavesOnly :: Show a => Tree a -> ASCII
asciiTreeVerticalLeavesOnly :: Tree a -> ASCII
asciiTreeVerticalLeavesOnly Tree a
tree = [[Char]] -> ASCII
ASCII.asciiFromLines (Tree a -> [[Char]]
forall b. Show b => Tree b -> [[Char]]
go Tree a
tree) where
  go :: Show b => Tree b -> [String]
  go :: Tree b -> [[Char]]
go (Node b
x Forest b
cs) = case Forest b
cs of
    [] -> [[Char]
"- " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ b -> [Char]
forall a. Show a => a -> [Char]
show b
x]
    Forest b
_  -> [[[Char]]] -> [[Char]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[[Char]]] -> [[Char]]) -> [[[Char]]] -> [[Char]]
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool -> [[Char]] -> [[Char]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (Bool -> Bool -> a -> b) -> [a] -> [b]
mapWithFirstLast Bool -> Bool -> [[Char]] -> [[Char]]
f ([[[Char]]] -> [[[Char]]]) -> [[[Char]]] -> [[[Char]]]
forall a b. (a -> b) -> a -> b
$ (Tree b -> [[Char]]) -> Forest b -> [[[Char]]]
forall a b. (a -> b) -> [a] -> [b]
map Tree b -> [[Char]]
forall b. Show b => Tree b -> [[Char]]
go Forest b
cs

  f :: Bool -> Bool -> [String] -> [String]
  f :: Bool -> Bool -> [[Char]] -> [[Char]]
f Bool
bf Bool
bl ([Char]
l:[[Char]]
ls) = let indent :: [Char]
indent = if Bool
bl           then [Char]
"  "  else  [Char]
"| "
                       gap :: [[Char]]
gap    = if Bool
bl           then []    else [[Char]
"| "]
                       branch :: [Char]
branch = if Bool
bl Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
bf
                                  then [Char]
"\\-"
                                  else if Bool
bf then [Char]
"@-"
                                             else [Char]
"+-"
                   in  ([Char]
branch[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
l) [Char] -> [[Char]] -> [[Char]]
forall a. a -> [a] -> [a]
: ([Char] -> [Char]) -> [[Char]] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map ([Char]
indent[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++) [[Char]]
ls [[Char]] -> [[Char]] -> [[Char]]
forall a. [a] -> [a] -> [a]
++ [[Char]]
gap

--------------------------------------------------------------------------------

-- | The leftmost spine (the second element of the pair is the leaf node)
leftSpine  :: Tree a -> ([a],a)
leftSpine :: Tree a -> ([a], a)
leftSpine = Tree a -> ([a], a)
forall a. Tree a -> ([a], a)
go where
  go :: Tree a -> ([a], a)
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> ([],a
x)
    Forest a
_  -> let ([a]
xs,a
y) = Tree a -> ([a], a)
go (Forest a -> Tree a
forall a. [a] -> a
head Forest a
cs) in (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
xs,a
y)

rightSpine  :: Tree a -> ([a],a)
rightSpine :: Tree a -> ([a], a)
rightSpine = Tree a -> ([a], a)
forall a. Tree a -> ([a], a)
go where
  go :: Tree a -> ([a], a)
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> ([],a
x)
    Forest a
_  -> let ([a]
xs,a
y) = Tree a -> ([a], a)
go (Forest a -> Tree a
forall a. [a] -> a
last Forest a
cs) in (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
xs,a
y)

-- | The leftmost spine without the leaf node
leftSpine_  :: Tree a -> [a]
leftSpine_ :: Tree a -> [a]
leftSpine_ = Tree a -> [a]
forall a. Tree a -> [a]
go where
  go :: Tree a -> [a]
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> []
    Forest a
_  -> a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Tree a -> [a]
go (Forest a -> Tree a
forall a. [a] -> a
head Forest a
cs)

rightSpine_ :: Tree a -> [a]
rightSpine_ :: Tree a -> [a]
rightSpine_ = Tree a -> [a]
forall a. Tree a -> [a]
go where
  go :: Tree a -> [a]
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> []
    Forest a
_  -> a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Tree a -> [a]
go (Forest a -> Tree a
forall a. [a] -> a
last Forest a
cs)

-- | The length (number of edges) on the left spine
--
-- > leftSpineLength tree == length (leftSpine_ tree)
--
leftSpineLength  :: Tree a -> Int
leftSpineLength :: Tree a -> Int
leftSpineLength = Int -> Tree a -> Int
forall t a. Num t => t -> Tree a -> t
go Int
0 where
  go :: t -> Tree a -> t
go t
n (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> t
n
    Forest a
_  -> t -> Tree a -> t
go (t
nt -> t -> t
forall a. Num a => a -> a -> a
+t
1) (Forest a -> Tree a
forall a. [a] -> a
head Forest a
cs)

rightSpineLength :: Tree a -> Int
rightSpineLength :: Tree a -> Int
rightSpineLength = Int -> Tree a -> Int
forall t a. Num t => t -> Tree a -> t
go Int
0 where
  go :: t -> Tree a -> t
go t
n (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> t
n
    Forest a
_  -> t -> Tree a -> t
go (t
nt -> t -> t
forall a. Num a => a -> a -> a
+t
1) (Forest a -> Tree a
forall a. [a] -> a
last Forest a
cs)

--------------------------------------------------------------------------------

-- | 'Left' is leaf, 'Right' is node
classifyTreeNode :: Tree a -> Either a a
classifyTreeNode :: Tree a -> Either a a
classifyTreeNode (Node a
x Forest a
cs) = case Forest a
cs of { [] -> a -> Either a a
forall a b. a -> Either a b
Left a
x ; Forest a
_ -> a -> Either a a
forall a b. b -> Either a b
Right a
x }

isTreeLeaf :: Tree a -> Maybe a
isTreeLeaf :: Tree a -> Maybe a
isTreeLeaf (Node a
x Forest a
cs) = case Forest a
cs of { [] -> a -> Maybe a
forall a. a -> Maybe a
Just a
x ; Forest a
_ -> Maybe a
forall a. Maybe a
Nothing }

isTreeNode :: Tree a -> Maybe a
isTreeNode :: Tree a -> Maybe a
isTreeNode (Node a
x Forest a
cs) = case Forest a
cs of { [] -> Maybe a
forall a. Maybe a
Nothing ; Forest a
_ -> a -> Maybe a
forall a. a -> Maybe a
Just a
x }

isTreeLeaf_ :: Tree a -> Bool
isTreeLeaf_ :: Tree a -> Bool
isTreeLeaf_ (Node a
x Forest a
cs) = case Forest a
cs of { [] -> Bool
True ; Forest a
_ -> Bool
False }

isTreeNode_ :: Tree a -> Bool
isTreeNode_ :: Tree a -> Bool
isTreeNode_ (Node a
x Forest a
cs) = case Forest a
cs of { [] -> Bool
False ; Forest a
_ -> Bool
True }

treeNodeNumberOfChildren :: Tree a -> Int
treeNodeNumberOfChildren :: Tree a -> Int
treeNodeNumberOfChildren (Node a
_ Forest a
cs) = Forest a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Forest a
cs

--------------------------------------------------------------------------------
-- counting

countTreeNodes :: Tree a -> Int
countTreeNodes :: Tree a -> Int
countTreeNodes = Tree a -> Int
forall p a. Num p => Tree a -> p
go where
  go :: Tree a -> p
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> p
0
    Forest a
_  -> p
1 p -> p -> p
forall a. Num a => a -> a -> a
+ [p] -> p
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Tree a -> p) -> Forest a -> [p]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> p
go Forest a
cs)

countTreeLeaves :: Tree a -> Int
countTreeLeaves :: Tree a -> Int
countTreeLeaves = Tree a -> Int
forall p a. Num p => Tree a -> p
go where
  go :: Tree a -> p
go (Node a
x Forest a
cs) = case Forest a
cs of
    [] -> p
1
    Forest a
_  -> [p] -> p
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Tree a -> p) -> Forest a -> [p]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> p
go Forest a
cs)

countTreeLabelsWith :: (a -> Bool) -> Tree a -> Int
countTreeLabelsWith :: (a -> Bool) -> Tree a -> Int
countTreeLabelsWith a -> Bool
f = Tree a -> Int
forall a. Num a => Tree a -> a
go where
  go :: Tree a -> a
go (Node a
label Forest a
cs) = (if a -> Bool
f a
label then a
1 else a
0) a -> a -> a
forall a. Num a => a -> a -> a
+ [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Tree a -> a) -> Forest a -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> a
go Forest a
cs)

countTreeNodesWith :: (Tree a -> Bool) -> Tree a -> Int
countTreeNodesWith :: (Tree a -> Bool) -> Tree a -> Int
countTreeNodesWith Tree a -> Bool
f = Tree a -> Int
forall a. Num a => Tree a -> a
go where
  go :: Tree a -> a
go node :: Tree a
node@(Node a
_ Forest a
cs) = (if Tree a -> Bool
f Tree a
node then a
1 else a
0) a -> a -> a
forall a. Num a => a -> a -> a
+ [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Tree a -> a) -> Forest a -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> a
go Forest a
cs)

--------------------------------------------------------------------------------

-- | Adds unique labels to the nodes (including leaves) of a 'Tree'.
addUniqueLabelsTree :: Tree a -> Tree (a,Int)
addUniqueLabelsTree :: Tree a -> Tree (a, Int)
addUniqueLabelsTree Tree a
tree = [Tree (a, Int)] -> Tree (a, Int)
forall a. [a] -> a
head (Forest a -> [Tree (a, Int)]
forall a. Forest a -> Forest (a, Int)
addUniqueLabelsForest [Tree a
tree])

-- | Adds unique labels to the nodes (including leaves) of a 'Forest'
addUniqueLabelsForest :: Forest a -> Forest (a,Int)
addUniqueLabelsForest :: Forest a -> Forest (a, Int)
addUniqueLabelsForest Forest a
forest = State Int (Forest (a, Int)) -> Int -> Forest (a, Int)
forall s a. State s a -> s -> a
evalState ((Tree a -> StateT Int Identity (Tree (a, Int)))
-> Forest a -> State Int (Forest (a, Int))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tree a -> StateT Int Identity (Tree (a, Int))
forall (t :: * -> *) (m :: * -> *) b a.
(Traversable t, Monad m, Num b) =>
t a -> StateT b m (t (a, b))
globalAction Forest a
forest) Int
1 where
  globalAction :: t a -> StateT b m (t (a, b))
globalAction t a
tree =
    WrappedMonad (StateT b m) (t (a, b)) -> StateT b m (t (a, b))
forall (m :: * -> *) a. WrappedMonad m a -> m a
unwrapMonad (WrappedMonad (StateT b m) (t (a, b)) -> StateT b m (t (a, b)))
-> WrappedMonad (StateT b m) (t (a, b)) -> StateT b m (t (a, b))
forall a b. (a -> b) -> a -> b
$ (a -> WrappedMonad (StateT b m) (a, b))
-> t a -> WrappedMonad (StateT b m) (t (a, b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> WrappedMonad (StateT b m) (a, b)
forall (m :: * -> *) b a.
(Monad m, Num b) =>
a -> WrappedMonad (StateT b m) (a, b)
localAction t a
tree
  localAction :: a -> WrappedMonad (StateT b m) (a, b)
localAction a
x = StateT b m (a, b) -> WrappedMonad (StateT b m) (a, b)
forall (m :: * -> *) a. m a -> WrappedMonad m a
WrapMonad (StateT b m (a, b) -> WrappedMonad (StateT b m) (a, b))
-> StateT b m (a, b) -> WrappedMonad (StateT b m) (a, b)
forall a b. (a -> b) -> a -> b
$ do
    b
i <- StateT b m b
forall (m :: * -> *) s. Monad m => StateT s m s
get
    b -> StateT b m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (b
ib -> b -> b
forall a. Num a => a -> a -> a
+b
1)
    (a, b) -> StateT b m (a, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x,b
i)

addUniqueLabelsTree_ :: Tree a -> Tree Int
addUniqueLabelsTree_ :: Tree a -> Tree Int
addUniqueLabelsTree_ = ((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd (Tree (a, Int) -> Tree Int)
-> (Tree a -> Tree (a, Int)) -> Tree a -> Tree Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
addUniqueLabelsTree

addUniqueLabelsForest_ :: Forest a -> Forest Int
addUniqueLabelsForest_ :: Forest a -> Forest Int
addUniqueLabelsForest_ = (Tree (a, Int) -> Tree Int) -> [Tree (a, Int)] -> Forest Int
forall a b. (a -> b) -> [a] -> [b]
map (((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd) ([Tree (a, Int)] -> Forest Int)
-> (Forest a -> [Tree (a, Int)]) -> Forest a -> Forest Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Forest a -> [Tree (a, Int)]
forall a. Forest a -> Forest (a, Int)
addUniqueLabelsForest

--------------------------------------------------------------------------------

-- | Attaches the depth to each node. The depth of the root is 0.
labelDepthTree :: Tree a -> Tree (a,Int)
labelDepthTree :: Tree a -> Tree (a, Int)
labelDepthTree Tree a
tree = Int -> Tree a -> Tree (a, Int)
forall t a. Num t => t -> Tree a -> Tree (a, t)
worker Int
0 Tree a
tree where
  worker :: t -> Tree a -> Tree (a, t)
worker t
depth (Node a
label Forest a
subtrees) = (a, t) -> Forest (a, t) -> Tree (a, t)
forall a. a -> Forest a -> Tree a
Node (a
label,t
depth) ((Tree a -> Tree (a, t)) -> Forest a -> Forest (a, t)
forall a b. (a -> b) -> [a] -> [b]
map (t -> Tree a -> Tree (a, t)
worker (t
deptht -> t -> t
forall a. Num a => a -> a -> a
+t
1)) Forest a
subtrees)

labelDepthForest :: Forest a -> Forest (a,Int)
labelDepthForest :: Forest a -> Forest (a, Int)
labelDepthForest Forest a
forest = (Tree a -> Tree (a, Int)) -> Forest a -> Forest (a, Int)
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
labelDepthTree Forest a
forest

labelDepthTree_ :: Tree a -> Tree Int
labelDepthTree_ :: Tree a -> Tree Int
labelDepthTree_ = ((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd (Tree (a, Int) -> Tree Int)
-> (Tree a -> Tree (a, Int)) -> Tree a -> Tree Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
labelDepthTree

labelDepthForest_ :: Forest a -> Forest Int
labelDepthForest_ :: Forest a -> Forest Int
labelDepthForest_ = (Tree (a, Int) -> Tree Int) -> [Tree (a, Int)] -> Forest Int
forall a b. (a -> b) -> [a] -> [b]
map (((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd) ([Tree (a, Int)] -> Forest Int)
-> (Forest a -> [Tree (a, Int)]) -> Forest a -> Forest Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Forest a -> [Tree (a, Int)]
forall a. Forest a -> Forest (a, Int)
labelDepthForest

--------------------------------------------------------------------------------

-- | Attaches the number of children to each node.
labelNChildrenTree :: Tree a -> Tree (a,Int)
labelNChildrenTree :: Tree a -> Tree (a, Int)
labelNChildrenTree (Node a
x Forest a
subforest) =
  (a, Int) -> Forest (a, Int) -> Tree (a, Int)
forall a. a -> Forest a -> Tree a
Node (a
x, Forest a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Forest a
subforest) ((Tree a -> Tree (a, Int)) -> Forest a -> Forest (a, Int)
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
labelNChildrenTree Forest a
subforest)

labelNChildrenForest :: Forest a -> Forest (a,Int)
labelNChildrenForest :: Forest a -> Forest (a, Int)
labelNChildrenForest Forest a
forest = (Tree a -> Tree (a, Int)) -> Forest a -> Forest (a, Int)
forall a b. (a -> b) -> [a] -> [b]
map Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
labelNChildrenTree Forest a
forest

labelNChildrenTree_ :: Tree a -> Tree Int
labelNChildrenTree_ :: Tree a -> Tree Int
labelNChildrenTree_ = ((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd (Tree (a, Int) -> Tree Int)
-> (Tree a -> Tree (a, Int)) -> Tree a -> Tree Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree a -> Tree (a, Int)
forall a. Tree a -> Tree (a, Int)
labelNChildrenTree

labelNChildrenForest_ :: Forest a -> Forest Int
labelNChildrenForest_ :: Forest a -> Forest Int
labelNChildrenForest_ = (Tree (a, Int) -> Tree Int) -> [Tree (a, Int)] -> Forest Int
forall a b. (a -> b) -> [a] -> [b]
map (((a, Int) -> Int) -> Tree (a, Int) -> Tree Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Int) -> Int
forall a b. (a, b) -> b
snd) ([Tree (a, Int)] -> Forest Int)
-> (Forest a -> [Tree (a, Int)]) -> Forest a -> Forest Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Forest a -> [Tree (a, Int)]
forall a. Forest a -> Forest (a, Int)
labelNChildrenForest

--------------------------------------------------------------------------------

-- | Computes the set of equivalence classes of rooted trees (in the
-- sense that the leaves of a node are /unordered/)
-- with @n = length ks@ leaves where the set of heights of
-- the leaves matches the given set of numbers.
-- The height is defined as the number of /edges/ from the leaf to the root.
--
-- TODO: better name?
derivTrees :: [Int] -> [Tree ()]
derivTrees :: [Int] -> [Tree ()]
derivTrees [Int]
xs = [Int] -> [Tree ()]
derivTrees' ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Int]
xs)

derivTrees' :: [Int] -> [Tree ()]
derivTrees' :: [Int] -> [Tree ()]
derivTrees' [] = []
derivTrees' [Int
n] =
  if Int
nInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>=Int
1
    then [(Int -> ((), [Int])) -> Int -> Tree ()
forall b a. (b -> (a, [b])) -> b -> Tree a
unfoldTree Int -> ((), [Int])
f Int
1]
    else []
  where
    f :: Int -> ((), [Int])
f Int
k = if Int
kInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
n then ((),[Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1]) else ((),[])
derivTrees' [Int]
ks =
  if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0) [Int]
ks)
    then
      [ () -> [Tree ()] -> Tree ()
forall a. a -> Forest a -> Tree a
Node () [Tree ()]
sub
      | [[Int]]
part <- [[[Int]]]
parts
      , let subtrees :: [[Tree ()]]
subtrees = ([Int] -> [Tree ()]) -> [[Int]] -> [[Tree ()]]
forall a b. (a -> b) -> [a] -> [b]
map [Int] -> [Tree ()]
g [[Int]]
part
      , [Tree ()]
sub <- [[Tree ()]] -> [[Tree ()]]
forall a. [[a]] -> [[a]]
listTensor [[Tree ()]]
subtrees
      ]
    else []
  where
    parts :: [[[Int]]]
parts = [Int] -> [[[Int]]]
forall a. (Eq a, Ord a) => [a] -> [[[a]]]
partitionMultiset [Int]
ks
    g :: [Int] -> [Tree ()]
g [Int]
xs = [Int] -> [Tree ()]
derivTrees' ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
x->Int
xInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [Int]
xs)

--------------------------------------------------------------------------------