{-# LANGUAGE BangPatterns #-}
-- | Functions for finding /lowest common ancestors/ in binary trees in /O(1)/ time, with /O(n)/ preprocessing.
module Data.RangeMin.LCA.Binary (Index, BinTree(..), quickLCABinary, lcaBinary) where

import Data.RangeMin

import Data.RangeMin.Common.Vector

-- import Control.Monad
-- import Data.RangeMin.Common.ST
import Data.RangeMin.LCA.IndexM
import qualified Data.RangeMin.Fusion as F
import qualified Data.Vector.Unboxed as UV
import qualified Data.Vector.Primitive as PV
import qualified Data.Vector as V

-- | A generic binary tree.
data BinTree a = Tip | BinTree a (BinTree a) (BinTree a)

unfoldBin :: BinTree a -> Maybe (a, BinTree a, BinTree a)
unfoldBin (BinTree x l r) = Just (x, l, r)
unfoldBin Tip = Nothing

type Depth = Int
data Trav a = D {-# UNPACK #-} !Depth {-# UNPACK #-} !Index a

inorderBin :: BinTree a -> (BinTree (Index, a), Int)
inorderBin = execIndexM . trav where
	trav Tip = return Tip
	trav (BinTree x l r) = do
		l' <- trav l
		i <- getIndex
		r' <- trav r
		return (BinTree (i, x) l' r')

{-# INLINE inorderD #-}
inorderD :: (Int -> a -> b) -> BinTree a -> [b] -> [b]
inorderD f = inorderD' 0 where
	inorderD' !d t xs = case unfoldBin t of
		Just (x, l, r)	-> inorderD'' l (f d x:inorderD'' r xs)
		Nothing		-> xs
		where inorderD'' = inorderD' (d+1)

-- | Takes a binary tree and indexes it inorder, returning the number of nodes, the indexed
-- tree, and the lowest common ancestor function.
quickLCABinary :: BinTree a -> (Int, BinTree (Index, a), Index -> Index -> (Index, a))
quickLCABinary tree = case inorderBin tree of
	(iTree, n) -> (n, iTree, lcaBinary n fst iTree)

-- | Similar to 'LCA.lowestCommonAncestor', but optimized for binary trees.  This method can reasonably
-- be expected to run twice as fast as 'lowestCommonAncestor'.
lcaBinary :: Int -> (a -> Index) -> BinTree a -> Index -> Index -> a
lcaBinary !n ix tree = lca
	where	!(!depths, !ixs, !vals) = 
		  (F.unzip3 :: V.Vector (Int, Int, a) -> 
		      (PV.Vector Int, UV.Vector Int, V.Vector a))
			$ vmap (\ (D i j k) -> (i, j, k)) $ F.fromListN n (inorderD (\ d x -> D d (ix x) x) tree [])
		vmap :: (a -> b) -> V.Vector a -> V.Vector b
		vmap = F.map
		rM :: Int -> Int -> Int
		!rM = intRangeMin depths
		{-# NOINLINE lca #-}
		lca !i !j = let
		  iIx = ixs ! i
		  jIx = ixs ! j
		  in vals ! case compare iIx jIx of
		      EQ -> iIx
		      LT -> rM iIx (jIx - iIx + 1)
		      GT -> rM jIx (iIx - jIx + 1)