-- | Trees, forests, etc. See:
--   Donald E. Knuth: The Art of Computer Programming, vol 4, pre-fascicle 4A.

module Math.Combinat.Trees 
  ( -- * Types
  , leaf
  , BinTree'(..)
  , forgetNodeDecorations
  , module Data.Tree 
  , Paren(..)
  , parenthesesToString
  , stringToParentheses
    -- * Bijections
  , forestToNestedParentheses
  , forestToBinaryTree
  , nestedParenthesesToForest
  , nestedParenthesesToForestUnsafe
  , nestedParenthesesToBinaryTree
  , nestedParenthesesToBinaryTreeUnsafe
  , binaryTreeToForest
  , binaryTreeToNestedParentheses
    -- * Nested parentheses
  , nestedParentheses 
  , randomNestedParentheses
  , nthNestedParentheses
  , countNestedParentheses
  , fasc4A_algorithm_P
  , fasc4A_algorithm_W
  , fasc4A_algorithm_U
    -- * Binary trees
  , binaryTrees
  , countBinaryTrees
  , binaryTreesNaive
  , randomBinaryTree
  , fasc4A_algorithm_R

import Control.Monad
import Control.Monad.ST

import Data.Array
import Data.Array.ST

import Data.List
import Data.Tree (Tree(..),Forest(..))

import System.Random

import Math.Combinat.Helper

-- * Types

-- | A binary tree with leaves decorated with type @a@.
data BinTree a
  = Branch (BinTree a) (BinTree a)
  | Leaf a
  deriving (Eq,Ord,Show,Read)

leaf :: BinTree ()
leaf = Leaf ()

-- | A binary tree with leaves and internal nodes decorated 
-- with types @a@ and @b@, respectively.
data BinTree' a b
  = Branch' (BinTree' a b) b (BinTree' a b)
  | Leaf' a
  deriving (Eq,Ord,Show,Read)

forgetNodeDecorations :: BinTree' a b -> BinTree a
forgetNodeDecorations (Branch' left _ right) = 
  Branch (forgetNodeDecorations left) (forgetNodeDecorations right)
forgetNodeDecorations (Leaf' decor) = Leaf decor 
instance Functor BinTree where
  fmap f (Branch left right) = Branch (fmap f left) (fmap f right)
  fmap f (Leaf x) = Leaf (f x)

data Paren = LeftParen | RightParen deriving (Eq,Ord,Show,Read)

parenToChar :: Paren -> Char
parenToChar LeftParen = '('
parenToChar RightParen = ')'

parenthesesToString :: [Paren] -> String
parenthesesToString = map parenToChar

stringToParentheses :: String -> [Paren]
stringToParentheses [] = []
stringToParentheses (x:xs) = p : stringToParentheses xs where
  p = case x of
    '(' -> LeftParen
    ')' -> RightParen
    _ -> error "stringToParentheses: invalid character"

-- * Bijections

forestToNestedParentheses :: Forest a -> [Paren]
forestToNestedParentheses = forest where
  -- forest :: Forest a -> [Paren]
  forest = concatMap tree 
  -- tree :: Tree a -> [Paren]
  tree (Node _ sf) = LeftParen : forest sf ++ [RightParen]

forestToBinaryTree :: Forest a -> BinTree ()
forestToBinaryTree = forest where
  -- forest :: Forest a -> BinTree ()
  forest = foldr Branch leaf . map tree 
  -- tree :: Tree a -> BinTree ()
  tree (Node _ sf) = case sf of
    [] -> leaf
    _  -> forest sf 
nestedParenthesesToForest :: [Paren] -> Maybe (Forest ())
nestedParenthesesToForest ps = 
  case parseForest ps of 
    (rest,forest) -> case rest of
      [] -> Just forest
      _  -> Nothing
    parseForest :: [Paren] -> ( [Paren] , Forest () )
    parseForest ps = unfoldEither parseTree ps
    parseTree :: [Paren] -> Either [Paren] ( [Paren] , Tree () )  
    parseTree orig@(LeftParen:ps) = let (rest,ts) = parseForest ps in case rest of
      (RightParen:qs) -> Right (qs, Node () ts)
      _ -> Left orig
    parseTree qs = Left qs

nestedParenthesesToForestUnsafe :: [Paren] -> Forest ()
nestedParenthesesToForestUnsafe = fromJust . nestedParenthesesToForest

nestedParenthesesToBinaryTree :: [Paren] -> Maybe (BinTree ())
nestedParenthesesToBinaryTree ps = 
  case parseForest ps of 
    (rest,forest) -> case rest of
      [] -> Just forest
      _  -> Nothing
    parseForest :: [Paren] -> ( [Paren] , BinTree () )
    parseForest ps = let (rest,ts) = unfoldEither parseTree ps in (rest , foldr Branch leaf ts)
    parseTree :: [Paren] -> Either [Paren] ( [Paren] , BinTree () )  
    parseTree orig@(LeftParen:ps) = let (rest,ts) = parseForest ps in case rest of
      (RightParen:qs) -> Right (qs, ts)
      _ -> Left orig
    parseTree qs = Left qs
nestedParenthesesToBinaryTreeUnsafe :: [Paren] -> BinTree ()
nestedParenthesesToBinaryTreeUnsafe = fromJust . nestedParenthesesToBinaryTree

binaryTreeToNestedParentheses :: BinTree a -> [Paren]
binaryTreeToNestedParentheses = worker where
  worker (Branch l r) = LeftParen : worker l ++ RightParen : worker r
  worker (Leaf _) = []

binaryTreeToForest :: BinTree a -> Forest ()
binaryTreeToForest = worker where
  worker (Branch l r) = Node () (worker l) : worker r
  worker (Leaf _) = []

-- * Nested parentheses

-- | Synonym for 'fasc4A_algorithm_P'.
nestedParentheses :: Int -> [[Paren]]
nestedParentheses = fasc4A_algorithm_P

-- | Synonym for 'fasc4A_algorithm_W'.
randomNestedParentheses :: RandomGen g => Int -> g -> ([Paren],g)
randomNestedParentheses = fasc4A_algorithm_W

-- | Synonym for 'fasc4A_algorithm_U'.
nthNestedParentheses :: Int -> Integer -> [Paren]
nthNestedParentheses = fasc4A_algorithm_U

countNestedParentheses :: Int -> Integer
countNestedParentheses = countBinaryTrees

-- | Generates all sequences of nested parentheses of length 2n.
-- Order is lexigraphic (when right parentheses are considered 
-- smaller then left ones).
-- Based on \"Algorithm P\" in Knuth, but less efficient because of
-- the \"idiomatic\" code.
fasc4A_algorithm_P :: Int -> [[Paren]]
fasc4A_algorithm_P 0 = []
fasc4A_algorithm_P 1 = [[LeftParen,RightParen]]
fasc4A_algorithm_P n = unfold next ( start , [] ) where 
  start = concat $ replicate n [RightParen,LeftParen]  -- already reversed!
  next :: ([Paren],[Paren]) -> ( [Paren] , Maybe ([Paren],[Paren]) )
  next ( (a:b:ls) , [] ) = next ( ls , b:a:[] )
  next ( lls@(l:ls) , rrs@(r:rs) ) = ( visit , new ) where
    visit = reverse lls ++ rrs
    new = 
      {- debug (reverse ls,l,r,rs) $ -} 
      case l of 
	      RightParen -> Just ( ls , LeftParen:RightParen:rs )
	      LeftParen  -> 
	        {- debug ("---",reverse ls,l,r,rs) $ -}
	        findj ( lls , [] ) ( reverse (RightParen:rs) , [] ) 

  findj :: ([Paren],[Paren]) -> ([Paren],[Paren]) -> Maybe ([Paren],[Paren])
  findj ( [] , _ ) _ = Nothing
  findj ( lls@(l:ls) , rs) ( xs , ys ) = 
    {- debug ((reverse ls,l,rs),(reverse xs,ys)) $ -}
    case l of
	    LeftParen  -> case xs of
	      (a:_:as) -> findj ( ls, RightParen:rs ) ( as , LeftParen:a:ys )
	      _ -> findj ( lls, [] ) ( reverse rs ++ xs , ys) 
	    RightParen -> Just ( reverse ys ++ xs ++ reverse (LeftParen:rs) ++ ls , [] )
-- | Generates a uniformly random sequence of nested parentheses of length 2n.    
-- Based on \"Algorithm W\" in Knuth.
fasc4A_algorithm_W :: RandomGen g => Int -> g -> ([Paren],g)
fasc4A_algorithm_W n' rnd = worker (rnd,n,n,[]) where
  n = fromIntegral n' :: Integer  
  -- the numbers we use are of order n^2, so for n >> 2^16 
  -- on a 32 bit machine, we need big integers.
  worker :: RandomGen g => (g,Integer,Integer,[Paren]) -> ([Paren],g)
  worker (rnd,_,0,parens) = (parens,rnd)
  worker (rnd,p,q,parens) = 
    if x<(q+1)*(q-p) 
      then worker (rnd' , p   , q-1 , LeftParen :parens)
      else worker (rnd' , p-1 , q   , RightParen:parens)
      (x,rnd') = randomR ( 0 , (q+p)*(q-p+1)-1 ) rnd

-- | Nth sequence of nested parentheses of length 2n. 
-- The order is the same as in 'fasc4A_algorithm_P'.
-- Based on \"Algorithm U\" in Knuth.
  :: Int               -- ^ n
  -> Integer           -- ^ N; should satisfy 1 <= N <= C(n) 
  -> [Paren]
fasc4A_algorithm_U n' bign0 = reverse $ worker (bign0,c0,n,n,[]) where
  n = fromIntegral n' :: Integer
  c0 = foldl f 1 [2..n]  
  f c p = ((4*p-2)*c) `div` (p+1) 
  worker :: (Integer,Integer,Integer,Integer,[Paren]) -> [Paren]
  worker (_   ,_,_,0,parens) = parens
  worker (bign,c,p,q,parens) = 
    if bign <= c' 
      then worker (bign    , c'   , p   , q-1 , RightParen:parens)
      else worker (bign-c' , c-c' , p-1 , q   , LeftParen :parens)
      c' = ((q+1)*(q-p)*c) `div` ((q+p)*(q-p+1))
-- * Binary trees

-- | Generates all binary trees with n nodes. 
--   At the moment just a synonym for 'binaryTreesNaive'.
binaryTrees :: Int -> [BinTree ()]
binaryTrees = binaryTreesNaive

-- | # = Catalan(n) = \\frac { 1 } { n+1 } \\binom { 2n } { n }.
-- This is also the counting function for forests and nested parentheses.
countBinaryTrees :: Int -> Integer
countBinaryTrees n = binomial (2*n) n `div` (1 + fromIntegral n)
-- | Generates all binary trees with n nodes. The naive algorithm.
binaryTreesNaive :: Int -> [BinTree ()]
binaryTreesNaive 0 = [ leaf ]
binaryTreesNaive n = 
  [ Branch l r 
  | i <- [0..n-1] 
  , l <- binaryTreesNaive i 
  , r <- binaryTreesNaive (n-1-i) 

-- | Generates an uniformly random binary tree, using 'fasc4A_algorithm_R'.
randomBinaryTree :: RandomGen g => Int -> g -> (BinTree (), g)
randomBinaryTree n rnd = (tree,rnd') where
  (decorated,rnd') = fasc4A_algorithm_R n rnd      
  tree = fmap (const ()) $ forgetNodeDecorations decorated

-- | Grows a uniformly random binary tree. 
-- \"Algorithm R\" (Remy's procudere) in Knuth.
-- Nodes are decorated with odd numbers, leaves with even numbers (from the
-- set @[0..2n]@). Uses mutable arrays internally.
fasc4A_algorithm_R :: RandomGen g => Int -> g -> (BinTree' Int Int, g)
fasc4A_algorithm_R n0 rnd = res where
  res = runST $ do
    ar <- newArray (0,2*n0) 0
    rnd' <- worker rnd 1 ar
    links <- unsafeFreeze ar
    return (toTree links, rnd')
  toTree links = f (links!0) where
    f i = if odd i 
      then Branch' (f $ links!i) i (f $ links!(i+1)) 
      else Leaf' i  
  worker :: RandomGen g => g -> Int -> STUArray s Int Int -> ST s g
  worker rnd n ar = do 
    if n > n0
      then return rnd
      else do
        writeArray ar (n2-b)   n2
        lk <- readArray ar k
        writeArray ar (n2-1+b) lk
        writeArray ar k        (n2-1)
        worker rnd' (n+1) ar      
      n2 = n+n
      (x,rnd') = randomR (0,4*n-3) rnd
      (k,b) = x `divMod` 2