{-# LANGUAGE BangPatterns #-}
-- | Functions for finding /lowest common ancestors/ in binary trees in /O(1)/ time, with /O(n)/ preprocessing.
module Data.RangeMin.LCA.Binary (Index, BinTree(..), quickLCABinary, lcaBinary) where

import Data.RangeMin

import Data.RangeMin.Common.Vector

import Data.RangeMin.LCA.IndexM
import qualified Data.RangeMin.Fusion as F
import qualified Data.Vector.Primitive as PV
import qualified Data.Vector as V

-- | A generic binary tree.
data BinTree a = Tip | BinTree a (BinTree a) (BinTree a)

unfoldBin :: BinTree a -> Maybe (a, BinTree a, BinTree a)
unfoldBin (BinTree x l r) = Just (x, l, r)
unfoldBin Tip = Nothing

type Depth = Int
data Trav a = D {-# UNPACK #-} !Depth {-# UNPACK #-} !Index a

inorderBin :: BinTree a -> (BinTree (Index, a), Int)
inorderBin = execIndexM . trav where
	trav Tip = return Tip
	trav (BinTree x l r) = do
		l' <- trav l
		i <- getIndex
		r' <- trav r
		return (BinTree (i, x) l' r')

{-# INLINE inorderD #-}
inorderD :: (Int -> a -> b) -> BinTree a -> [b] -> [b]
inorderD f = inorderD' 0 where
	inorderD' !d t xs = case unfoldBin t of
		Just (x, l, r)	-> inorderD'' l (f d x:inorderD'' r xs)
		Nothing		-> xs
		where inorderD'' = inorderD' (d+1)

-- | Takes a binary tree and indexes it inorder, returning the number of nodes, the indexed
-- tree, and the lowest common ancestor function.
quickLCABinary :: BinTree a -> (Int, BinTree (Index, a), Index -> Index -> (Index, a))
quickLCABinary tree = case inorderBin tree of
	(iTree, n) -> (n, iTree, lcaBinary n fst iTree)

-- | Similar to 'LCA.lowestCommonAncestor', but optimized for binary trees.  This method can reasonably
-- be expected to run twice as fast as 'lowestCommonAncestor'.
lcaBinary :: Int -> (a -> Index) -> BinTree a -> Index -> Index -> a
lcaBinary !n ix tree = lca
  where	!(!depths, !ixs, !vals) = 
	  (F.unzip3 :: V.Vector (Int, Int, a) -> 
	      (PV.Vector Int, PV.Vector Int, V.Vector a))
		$ F.map (\ (D i j k) -> (i, j, k)) $ F.fromListN n (inorderD (\ d x -> D d (ix x) x) tree [])
	rM :: Int -> Int -> Int
	!rM = intRangeMin depths
	{-# NOINLINE lca #-}
	lca !i !j = let
	  iIx = ixs ! i
	  jIx = ixs ! j
	  in vals ! case compare iIx jIx of
	      EQ -> iIx
	      LT -> rM iIx (jIx - iIx + 1)
	      GT -> rM jIx (iIx - jIx + 1)