{-# LANGUAGE BangPatterns #-}
-- | Functions for finding /lowest common ancestors/ in binary trees in /O(1)/ time, with /O(n)/ preprocessing.
module Data.RangeMin.LCA.Binary (Index, BinTree(..), quickLCABinary, lcaBinary) where

import Data.RangeMin

import Data.RangeMin.Common.Vector

import qualified Data.RangeMin.Mixed as Mix
import qualified Data.RangeMin.LCA as LCA()
import Data.RangeMin.LCA.IndexM
import qualified Data.Vector.Unboxed as UV
import qualified Data.Vector as V
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Fusion.Stream as S

import Prelude hiding (foldr)

-- | A generic binary tree.
data BinTree a = Tip | BinTree a (BinTree a) (BinTree a)

unfoldBin :: BinTree a -> Maybe (a, BinTree a, BinTree a)
unfoldBin (BinTree x l r) = Just (x, l, r)
unfoldBin Tip = Nothing

type Depth = Int
data Trav a = Trav {-# UNPACK #-} !Depth {-# UNPACK #-} !Index a

inorderBin :: BinTree a -> (BinTree (Index, a), Int)
inorderBin = execIndexM . trav where
	trav Tip = return Tip
	trav (BinTree x l r) = do
		l' <- trav l
		i <- getIndex
		r' <- trav r
		return (BinTree (i, x) l' r')

{-# INLINE inorderD #-}
inorderD :: (a -> Index) -> BinTree a -> [Trav a] -> [Trav a]
inorderD f = inorderD' 0 where
	inorderD' !d t xs = case unfoldBin t of
		Just (x, l, r)	-> inorderD'' l (Trav d (f x) x:inorderD'' r xs)
		Nothing		-> xs
		where inorderD'' = inorderD' (d+1)

-- | Takes a binary tree and indexes it inorder, returning the number of nodes, the indexed
-- tree, and the lowest common ancestor function.
quickLCABinary :: BinTree a -> (Int, BinTree (Index, a), Index -> Index -> (Index, a))
quickLCABinary tree = case inorderBin tree of
	(iTree, n) -> (n, iTree, lcaBinary n fst iTree)

-- | Similar to 'LCA.lowestCommonAncestor', but optimized for binary trees.  This method can reasonably
-- be expected to run twice as fast as 'lowestCommonAncestor'.
lcaBinary :: Int -> (a -> Index) -> BinTree a -> Index -> Index -> a
lcaBinary !n ix tree = lca
	where	!trav = G.unstream $ S.map (\ (Trav d i a) -> ((d, i), a)) $ S.fromListN n (inorderD ix tree [])
		ixs :: UV.Vector Int
		!ixs = vec n $ S.map (\ (a, b) -> (b, a)) $ S.indexed $ S.map (snd . fst) $ G.stream trav
		!(dixs, !vals) = (Mix.unzip :: Mix.MixVector UV.Vector V.Vector (a, b) -> (UV.Vector a, V.Vector b))
					trav
		(!depths, _) = UV.unzip dixs
		rM :: Int -> Int -> Int
		!rM = vecRangeMin depths
		{-# NOINLINE lca #-}
		lca !i !j = let
			iIx = ixs ! i
			jIx = ixs ! j
			in vals ! case compare iIx jIx of
				EQ	-> iIx
				LT	-> rM iIx (jIx - iIx + 1)
				GT	-> rM jIx (iIx - jIx + 1)