-- | Binary trees, forests, etc. See:
--   Donald E. Knuth: The Art of Computer Programming, vol 4, pre-fascicle 4A.

module Math.Combinat.Trees.Binary
( -- * Types
BinTree(..)
, leaf
, BinTree'(..)
, forgetNodeDecorations
, module Data.Tree
, Paren(..)
, parenthesesToString
, stringToParentheses
-- * Bijections
, forestToNestedParentheses
, forestToBinaryTree
, nestedParenthesesToForest
, nestedParenthesesToForestUnsafe
, nestedParenthesesToBinaryTree
, nestedParenthesesToBinaryTreeUnsafe
, binaryTreeToForest
, binaryTreeToNestedParentheses
-- * Nested parentheses
, nestedParentheses
, randomNestedParentheses
, nthNestedParentheses
, countNestedParentheses
, fasc4A_algorithm_P
, fasc4A_algorithm_W
, fasc4A_algorithm_U
-- * Binary trees
, binaryTrees
, countBinaryTrees
, binaryTreesNaive
, randomBinaryTree
, fasc4A_algorithm_R
)
where

--------------------------------------------------------------------------------

import Control.Applicative

import Data.Array
import Data.Array.ST

import Data.List
import Data.Tree (Tree(..),Forest(..))

import Data.Monoid
import Data.Foldable (Foldable(foldMap))
import Data.Traversable (Traversable(traverse))

import System.Random

import Math.Combinat.Helper
import Math.Combinat.Numbers (factorial,binomial)

--------------------------------------------------------------------------------
-- * Types

-- | A binary tree with leaves decorated with type @a@.
data BinTree a
= Branch (BinTree a) (BinTree a)
| Leaf a

leaf :: BinTree ()
leaf = Leaf ()

-- | A binary tree with leaves and internal nodes decorated
-- with types @a@ and @b@, respectively.
data BinTree' a b
= Branch' (BinTree' a b) b (BinTree' a b)
| Leaf' a

forgetNodeDecorations :: BinTree' a b -> BinTree a
forgetNodeDecorations (Branch' left _ right) =
Branch (forgetNodeDecorations left) (forgetNodeDecorations right)
forgetNodeDecorations (Leaf' decor) = Leaf decor

--------------------------------------------------------------------------------

instance Functor BinTree where
fmap f (Branch left right) = Branch (fmap f left) (fmap f right)
fmap f (Leaf x) = Leaf (f x)

instance Foldable BinTree where
foldMap f (Leaf x) = f x
foldMap f (Branch left right) = (foldMap f left) mappend (foldMap f right)

instance Traversable BinTree where
traverse f (Leaf x) = Leaf <$> f x traverse f (Branch left right) = Branch <$> traverse f left <*> traverse f right

--------------------------------------------------------------------------------

data Paren = LeftParen | RightParen deriving (Eq,Ord,Show,Read)

parenToChar :: Paren -> Char
parenToChar LeftParen = '('
parenToChar RightParen = ')'

parenthesesToString :: [Paren] -> String
parenthesesToString = map parenToChar

stringToParentheses :: String -> [Paren]
stringToParentheses [] = []
stringToParentheses (x:xs) = p : stringToParentheses xs where
p = case x of
'(' -> LeftParen
')' -> RightParen
_ -> error "stringToParentheses: invalid character"

--------------------------------------------------------------------------------
-- * Bijections

forestToNestedParentheses :: Forest a -> [Paren]
forestToNestedParentheses = forest where
-- forest :: Forest a -> [Paren]
forest = concatMap tree
-- tree :: Tree a -> [Paren]
tree (Node _ sf) = LeftParen : forest sf ++ [RightParen]

forestToBinaryTree :: Forest a -> BinTree ()
forestToBinaryTree = forest where
-- forest :: Forest a -> BinTree ()
forest = foldr Branch leaf . map tree
-- tree :: Tree a -> BinTree ()
tree (Node _ sf) = case sf of
[] -> leaf
_  -> forest sf

nestedParenthesesToForest :: [Paren] -> Maybe (Forest ())
nestedParenthesesToForest ps =
case parseForest ps of
(rest,forest) -> case rest of
[] -> Just forest
_  -> Nothing
where
parseForest :: [Paren] -> ( [Paren] , Forest () )
parseForest ps = unfoldEither parseTree ps
parseTree :: [Paren] -> Either [Paren] ( [Paren] , Tree () )
parseTree orig@(LeftParen:ps) = let (rest,ts) = parseForest ps in case rest of
(RightParen:qs) -> Right (qs, Node () ts)
_ -> Left orig
parseTree qs = Left qs

nestedParenthesesToForestUnsafe :: [Paren] -> Forest ()
nestedParenthesesToForestUnsafe = fromJust . nestedParenthesesToForest

nestedParenthesesToBinaryTree :: [Paren] -> Maybe (BinTree ())
nestedParenthesesToBinaryTree ps =
case parseForest ps of
(rest,forest) -> case rest of
[] -> Just forest
_  -> Nothing
where
parseForest :: [Paren] -> ( [Paren] , BinTree () )
parseForest ps = let (rest,ts) = unfoldEither parseTree ps in (rest , foldr Branch leaf ts)
parseTree :: [Paren] -> Either [Paren] ( [Paren] , BinTree () )
parseTree orig@(LeftParen:ps) = let (rest,ts) = parseForest ps in case rest of
(RightParen:qs) -> Right (qs, ts)
_ -> Left orig
parseTree qs = Left qs

nestedParenthesesToBinaryTreeUnsafe :: [Paren] -> BinTree ()
nestedParenthesesToBinaryTreeUnsafe = fromJust . nestedParenthesesToBinaryTree

binaryTreeToNestedParentheses :: BinTree a -> [Paren]
binaryTreeToNestedParentheses = worker where
worker (Branch l r) = LeftParen : worker l ++ RightParen : worker r
worker (Leaf _) = []

binaryTreeToForest :: BinTree a -> Forest ()
binaryTreeToForest = worker where
worker (Branch l r) = Node () (worker l) : worker r
worker (Leaf _) = []

--------------------------------------------------------------------------------
-- * Nested parentheses

-- | Synonym for 'fasc4A_algorithm_P'.
nestedParentheses :: Int -> [[Paren]]
nestedParentheses = fasc4A_algorithm_P

-- | Synonym for 'fasc4A_algorithm_W'.
randomNestedParentheses :: RandomGen g => Int -> g -> ([Paren],g)
randomNestedParentheses = fasc4A_algorithm_W

-- | Synonym for 'fasc4A_algorithm_U'.
nthNestedParentheses :: Int -> Integer -> [Paren]
nthNestedParentheses = fasc4A_algorithm_U

countNestedParentheses :: Int -> Integer
countNestedParentheses = countBinaryTrees

-- | Generates all sequences of nested parentheses of length 2n.
-- Order is lexigraphic (when right parentheses are considered
-- smaller then left ones).
-- Based on \"Algorithm P\" in Knuth, but less efficient because of
-- the \"idiomatic\" code.
fasc4A_algorithm_P :: Int -> [[Paren]]
fasc4A_algorithm_P 0 = []
fasc4A_algorithm_P 1 = [[LeftParen,RightParen]]
fasc4A_algorithm_P n = unfold next ( start , [] ) where
start = concat $replicate n [RightParen,LeftParen] -- already reversed! next :: ([Paren],[Paren]) -> ( [Paren] , Maybe ([Paren],[Paren]) ) next ( (a:b:ls) , [] ) = next ( ls , b:a:[] ) next ( lls@(l:ls) , rrs@(r:rs) ) = ( visit , new ) where visit = reverse lls ++ rrs new = {- debug (reverse ls,l,r,rs)$ -}
case l of
RightParen -> Just ( ls , LeftParen:RightParen:rs )
LeftParen  ->
{- debug ("---",reverse ls,l,r,rs) $-} findj ( lls , [] ) ( reverse (RightParen:rs) , [] ) findj :: ([Paren],[Paren]) -> ([Paren],[Paren]) -> Maybe ([Paren],[Paren]) findj ( [] , _ ) _ = Nothing findj ( lls@(l:ls) , rs) ( xs , ys ) = {- debug ((reverse ls,l,rs),(reverse xs,ys))$ -}
case l of
LeftParen  -> case xs of
(a:_:as) -> findj ( ls, RightParen:rs ) ( as , LeftParen:a:ys )
_ -> findj ( lls, [] ) ( reverse rs ++ xs , ys)
RightParen -> Just ( reverse ys ++ xs ++ reverse (LeftParen:rs) ++ ls , [] )

-- | Generates a uniformly random sequence of nested parentheses of length 2n.
-- Based on \"Algorithm W\" in Knuth.
fasc4A_algorithm_W :: RandomGen g => Int -> g -> ([Paren],g)
fasc4A_algorithm_W n' rnd = worker (rnd,n,n,[]) where
n = fromIntegral n' :: Integer
-- the numbers we use are of order n^2, so for n >> 2^16
-- on a 32 bit machine, we need big integers.
worker :: RandomGen g => (g,Integer,Integer,[Paren]) -> ([Paren],g)
worker (rnd,_,0,parens) = (parens,rnd)
worker (rnd,p,q,parens) =
if x<(q+1)*(q-p)
then worker (rnd' , p   , q-1 , LeftParen :parens)
else worker (rnd' , p-1 , q   , RightParen:parens)
where
(x,rnd') = randomR ( 0 , (q+p)*(q-p+1)-1 ) rnd

-- | Nth sequence of nested parentheses of length 2n.
-- The order is the same as in 'fasc4A_algorithm_P'.
-- Based on \"Algorithm U\" in Knuth.
fasc4A_algorithm_U
:: Int               -- ^ n
-> Integer           -- ^ N; should satisfy 1 <= N <= C(n)
-> [Paren]
fasc4A_algorithm_U n' bign0 = reverse $worker (bign0,c0,n,n,[]) where n = fromIntegral n' :: Integer c0 = foldl f 1 [2..n] f c p = ((4*p-2)*c) div (p+1) worker :: (Integer,Integer,Integer,Integer,[Paren]) -> [Paren] worker (_ ,_,_,0,parens) = parens worker (bign,c,p,q,parens) = if bign <= c' then worker (bign , c' , p , q-1 , RightParen:parens) else worker (bign-c' , c-c' , p-1 , q , LeftParen :parens) where c' = ((q+1)*(q-p)*c) div ((q+p)*(q-p+1)) -------------------------------------------------------------------------------- -- * Binary trees -- | Generates all binary trees with n nodes. -- At the moment just a synonym for 'binaryTreesNaive'. binaryTrees :: Int -> [BinTree ()] binaryTrees = binaryTreesNaive -- | # = Catalan(n) = \\frac { 1 } { n+1 } \\binom { 2n } { n }. -- -- This is also the counting function for forests and nested parentheses. countBinaryTrees :: Int -> Integer countBinaryTrees n = binomial (2*n) n div (1 + fromIntegral n) -- | Generates all binary trees with n nodes. The naive algorithm. binaryTreesNaive :: Int -> [BinTree ()] binaryTreesNaive 0 = [ leaf ] binaryTreesNaive n = [ Branch l r | i <- [0..n-1] , l <- binaryTreesNaive i , r <- binaryTreesNaive (n-1-i) ] -- | Generates an uniformly random binary tree, using 'fasc4A_algorithm_R'. randomBinaryTree :: RandomGen g => Int -> g -> (BinTree (), g) randomBinaryTree n rnd = (tree,rnd') where (decorated,rnd') = fasc4A_algorithm_R n rnd tree = fmap (const ())$ forgetNodeDecorations decorated

-- | Grows a uniformly random binary tree.
-- \"Algorithm R\" (Remy's procudere) in Knuth.
-- Nodes are decorated with odd numbers, leaves with even numbers (from the
-- set @[0..2n]@). Uses mutable arrays internally.
fasc4A_algorithm_R :: RandomGen g => Int -> g -> (BinTree' Int Int, g)
fasc4A_algorithm_R n0 rnd = res where
res = runST $do ar <- newArray (0,2*n0) 0 rnd' <- worker rnd 1 ar links <- unsafeFreeze ar return (toTree links, rnd') toTree links = f (links!0) where f i = if odd i then Branch' (f$ links!i) i (f \$ links!(i+1))
else Leaf' i
worker :: RandomGen g => g -> Int -> STUArray s Int Int -> ST s g
worker rnd n ar = do
if n > n0
then return rnd
else do
writeArray ar (n2-b)   n2
lk <- readArray ar k
writeArray ar (n2-1+b) lk
writeArray ar k        (n2-1)
worker rnd' (n+1) ar
where
n2 = n+n
(x,rnd') = randomR (0,4*n-3) rnd
(k,b) = x divMod 2

--------------------------------------------------------------------------------