{-# OPTIONS_GHC -fglasgow-exts #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Tree.AVL.Internals.HeightUtils
-- Copyright   :  (c) Adrian Hey 2004,2005
-- License     :  BSD3
--
-- Maintainer  :  http://homepages.nildram.co.uk/~ahey/em.png
-- Stability   :  stable
-- Portability :  portable
--
-- AVL tree height related utilities.
--
-- The functions defined here are not exported by the main Data.Tree.AVL module
-- because they violate the policy for AVL tree equality used elsewhere in this library.
-- You need to import this module explicitly if you want to use any of these functions.
-----------------------------------------------------------------------------
module Data.Tree.AVL.Internals.HeightUtils
        (height,addHeight,compareHeight, -- heightInt,
         fastAddSize,
        ) where 

import Data.Tree.AVL.Types(AVL(..))

#ifdef __GLASGOW_HASKELL__
import GHC.Base
#include "ghcdefs.h"
#else
#include "h98defs.h"
#endif

-- {-# INLINE heightInt #-} -- Don't want this
-- heightInt :: AVL e -> Int
-- heightInt t = ASINT(addHeight L(0) t)

-- | Determine the height of an AVL tree.
--
-- Complexity: O(log n)
{-# INLINE height #-}
height :: AVL e -> UINT
height t = addHeight L(0) t

-- | Adds the height of a tree to the first argument.
--
-- Complexity: O(log n)
addHeight :: UINT -> AVL e -> UINT
addHeight h  E        = h
addHeight h (N l _ _) = addHeight INCINT2(h) l 
addHeight h (Z l _ _) = addHeight INCINT1(h) l  
addHeight h (P _ _ r) = addHeight INCINT2(h) r

-- | A fast algorithm for comparing the heights of two trees. This algorithm avoids the need
-- to compute the heights of both trees and should offer better performance if the trees differ
-- significantly in height. But if you need the heights anyway it will be quicker to just evaluate
-- them both and compare the results.
--
-- Complexity: O(log n), where n is the size of the smaller of the two trees.
compareHeight :: AVL a -> AVL b -> Ordering
compareHeight = ch L(0) where                       -- d = hA-hB
 ch :: UINT -> AVL a -> AVL b -> Ordering
 ch d  E           E          = COMPAREUINT d L(0)
 ch d  E          (N l1 _ _ ) = chA DECINT2(d) l1
 ch d  E          (Z l1 _ _ ) = chA DECINT1(d) l1
 ch d  E          (P _  _ r1) = chA DECINT2(d) r1
 ch d (N l0 _ _ )  E          = chB INCINT2(d) l0
 ch d (N l0 _ _ ) (N l1 _ _ ) = ch          d  l0 l1 
 ch d (N l0 _ _ ) (Z l1 _ _ ) = ch  INCINT1(d) l0 l1 
 ch d (N l0 _ _ ) (P _  _ r1) = ch          d  l0 r1 
 ch d (Z l0 _ _ )  E          = chB INCINT1(d) l0
 ch d (Z l0 _ _ ) (N l1 _ _ ) = ch  DECINT1(d) l0 l1 
 ch d (Z l0 _ _ ) (Z l1 _ _ ) = ch          d  l0 l1 
 ch d (Z l0 _ _ ) (P _  _ r1) = ch  DECINT1(d) l0 r1 
 ch d (P _  _ r0)  E          = chB INCINT2(d) r0
 ch d (P _  _ r0) (N l1 _ _ ) = ch          d  r0 l1 
 ch d (P _  _ r0) (Z l1 _ _ ) = ch  INCINT1(d) r0 l1 
 ch d (P _  _ r0) (P _  _ r1) = ch          d  r0 r1 
 -- Tree A ended first, continue with Tree B until hA-hB<0, or Tree B ends
 chA d tB = case COMPAREUINT d L(0) of
            LT ->             LT
            EQ -> case tB of
                  E        -> EQ
                  _        -> LT
            GT -> case tB of
                  E        -> GT
                  N l _ _  -> chA DECINT2(d) l
                  Z l _ _  -> chA DECINT1(d) l
                  P _ _ r  -> chA DECINT2(d) r
 -- Tree B ended first, continue with Tree A until hA-hB>0, or Tree A ends
 chB d tA = case COMPAREUINT d L(0) of
            GT ->             GT
            EQ -> case tA of
                  E        -> EQ
                  _        -> GT
            LT -> case tA of
                  E        -> LT
                  N l _ _  -> chB INCINT2(d) l
                  Z l _ _  -> chB INCINT1(d) l
                  P _ _ r  -> chB INCINT2(d) r


{-----------------------------------------
Notes for fast size calculation.
 case (h,avl)
      (0,_      ) -> 0            -- Must be E
      (1,_      ) -> 1            -- Must be (Z  E        _  E       )
      (2,N _ _ _) -> 2            -- Must be (N  E        _ (Z E _ E))
      (2,Z _ _ _) -> 3            -- Must be (Z (Z E _ E) _ (Z E _ E))
      (2,P _ _ _) -> 2            -- Must be (P (Z E _ E) _  E       )
      (3,N _ _ r) -> 2 + size 2 r -- Must be (N (Z E _ E) _  r       )
      (3,P l _ _) -> 2 + size 2 l -- Must be (P  l        _ (Z E _ E))
------------------------------------------}

-- | Fast algorithm to calculate size. This avoids visiting about 50% of tree nodes
-- by using fact that trees with small heights can only have particular shapes.
-- So it's still O(n), but with substantial saving in constant factors.
--
-- Complexity: O(n) 
fastAddSize :: UINT -> AVL e -> UINT
fastAddSize n E         = n
fastAddSize n (N l _ r) = case addHeight L(2) l of
                          L(2) -> INCINT2(n)
                          h    -> fasN n h l r
fastAddSize n (Z l _ r) = case addHeight L(1) l of
                          L(1) -> INCINT1(n)
                          L(2) -> INCINT3(n)
                          h    -> fasZ n h l r
fastAddSize n (P l _ r) = case addHeight L(2) r of
                          L(2) -> INCINT2(n)
                          h    -> fasP n h l r

-- Local utilities used by fastAddSize, Only work if h >=3 !! 
fasN,fasZ,fasP :: UINT -> UINT -> AVL e -> AVL e -> UINT
fasN n L(3) _ r = fas INCINT2(n)                    L(2)       r
fasN n h    l r = fas (fas INCINT1(n) DECINT2(h) l) DECINT1(h) r -- h>=4
fasZ n h    l r = fas (fas INCINT1(n) DECINT1(h) l) DECINT1(h) r
fasP n L(3) l _ = fas INCINT2(n)                    L(2)       l
fasP n h    l r = fas (fas INCINT1(n) DECINT2(h) r) DECINT1(h) l -- h>=4

-- Local Utility used by fasN,fasZ,fasP, Only works if h >= 2 !!
fas :: UINT -> UINT -> AVL e -> UINT
fas _ L(2)  E        = error "fas: Bug0"
fas n L(2) (N _ _ _) = INCINT2(n)
fas n L(2) (Z _ _ _) = INCINT3(n)
fas n L(2) (P _ _ _) = INCINT2(n)
-- So h must be >= 3 if we get here
fas n h    (N l _ r) = fasN n h l r     
fas n h    (Z l _ r) = fasZ n h l r                        
fas n h    (P l _ r) = fasP n h l r     
--fas _ _     E        = error "fas: Bug1"