{-# LANGUAGE BangPatterns #-}
-- | Functions for finding /lowest common ancestors/ in binary trees in /O(1)/ time, with /O(n)/ preprocessing.
module Data.RangeMin.LCA.Binary (Index, BinTree(..), quickLCABinary, lcaBinary) where

import Data.RangeMin

import Data.RangeMin.Common.Vector.Utils

import Control.Monad
import Data.RangeMin.Common.ST
import Data.RangeMin.LCA.IndexM
import qualified Data.Vector.Unboxed as UV
import qualified Data.Vector.Primitive as PV
import qualified Data.Vector as V

-- | A generic binary tree.
data BinTree a = Tip | BinTree a (BinTree a) (BinTree a)

unfoldBin :: BinTree a -> Maybe (a, BinTree a, BinTree a)
unfoldBin (BinTree x l r) = Just (x, l, r)
unfoldBin Tip = Nothing

type Depth = Int
data Trav a = D {-# UNPACK #-} !Depth {-# UNPACK #-} !Index a

inorderBin :: BinTree a -> (BinTree (Index, a), Int)
inorderBin = execIndexM . trav where
	trav Tip = return Tip
	trav (BinTree x l r) = do
		l' <- trav l
		i <- getIndex
		r' <- trav r
		return (BinTree (i, x) l' r')

{-# INLINE inorderD #-}
inorderD :: (Int -> a -> b) -> BinTree a -> [b] -> [b]
inorderD f = inorderD' 0 where
	inorderD' !d t xs = case unfoldBin t of
		Just (x, l, r)	-> inorderD'' l (f d x:inorderD'' r xs)
		Nothing		-> xs
		where inorderD'' = inorderD' (d+1)

-- | Takes a binary tree and indexes it inorder, returning the number of nodes, the indexed
-- tree, and the lowest common ancestor function.
quickLCABinary :: BinTree a -> (Int, BinTree (Index, a), Index -> Index -> (Index, a))
quickLCABinary tree = case inorderBin tree of
	(iTree, n) -> (n, iTree, lcaBinary n fst iTree)

-- | Similar to 'LCA.lowestCommonAncestor', but optimized for binary trees.  This method can reasonably
-- be expected to run twice as fast as 'lowestCommonAncestor'.
lcaBinary :: Int -> (a -> Index) -> BinTree a -> Index -> Index -> a
lcaBinary !n ix tree = lca
	where	depths :: PV.Vector Int
		ixs :: UV.Vector Int
-- 		vals :: V.Vector a
		!(!depths, !ixs, !vals) = inlineRunST $ do
			!depthsM <- new n
			!ixsM <- new n
			!valsM <- new n
			let go !i (D d j x:ts) = do
				write depthsM i d
				write ixsM j i
				write valsM i x
				go (i+1) ts
			    go _ [] = return ()
			go 0 (inorderD (\ d x -> D d (ix x) x) tree [])
			liftM3 (,,) (unsafeFreeze depthsM) (unsafeFreeze ixsM) 
				((unsafeFreeze :: V.MVector s a -> ST s (V.Vector a)) valsM)
		rM :: Int -> Int -> Int
		!rM = intRangeMin depths
		{-# NOINLINE lca #-}
		lca !i !j = let
			iIx = ixs ! i
			jIx = ixs ! j
			in vals ! case compare iIx jIx of
				EQ	-> iIx
				LT	-> rM iIx (jIx - iIx + 1)
				GT	-> rM jIx (iIx - jIx + 1)