{-# LANGUAGE BangPatterns #-}
-- | Functions for finding /lowest common ancestors/ in trees in /O(1)/ time, with /O(n)/ preprocessing.
module Data.RangeMin.LCA (Index, lowestCommonAncestor, quickLCA) where

import Data.RangeMin.LCA.IndexM
import Control.Monad
import Data.RangeMin
import Data.RangeMin.Common.ST
import Data.RangeMin.Common.Vector
-- import qualified Data.RangeMin.Mixed as Mix
-- import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as UV
import qualified Data.Vector.Primitive as PV
import qualified Data.Vector as V
-- import qualified Data.Vector.Fusion.Stream as S
-- import Data.Vector.Fusion.Stream.Size

import Data.Tree

-- | Labels a tree in depth-first order.
dfOrder :: Tree a -> (Tree (Index, a), Int)
dfOrder tree = execIndexM $ indexer tree where
	indexer (Node a ts) = do
		i <- getIndex
		ts' <- mapM indexer ts
		return (Node (i, a) ts')

type Depth = Int
data Trav a = Trav {-# UNPACK #-} !Depth {-# UNPACK #-} !Index a

travel :: (a -> Index) -> Tree a -> [Trav a] -> [Trav a]
travel f = trav' 0 where
	trav' !d (Node x ts) zs = let me = Trav d (f x) x in
		me:foldr (\ t -> trav' (d+1) t . (me:)) zs ts

-- | Takes a tree and indexes it in depth-first order, returning the number of nodes, the indexed
-- tree, and the lowest common ancestor function.
quickLCA :: Tree a -> (Int, Tree (Index, a), Index -> Index -> (Index, a))
quickLCA tree = case dfOrder tree of
	(iTree, n) -> (n, iTree, lowestCommonAncestor n fst iTree)

-- | @'lowestCommonAncestor' n ix tree@ takes a tree whose nodes are mapped by
-- @ix@ to a unique index in the range @0..n-1@, and returns a function
-- which takes two indices (corresponding to two nodes in the tree) and returns
-- the label of their /lowest common ancestor/.
-- 
-- This takes /O(n)/ preprocessing and answers queries in /O(1)/, as it is an
-- application of "Data.RangeMin".
-- 
-- For binary trees, consider using "Data.RangeMin.LCA.Binary".
lowestCommonAncestor :: Int -> (a -> Index) -> Tree a -> Index -> Index -> a
lowestCommonAncestor !n ix tree = vals `seq` lca
	where	ixs :: UV.Vector Int
		depths :: PV.Vector Int
-- 		vals :: V.Vector a
		!(!depths, !ixs, !vals) = inlineRunST $ do
			let !m = (2 * n - 1)
			!depthsM <- new m
			!ixsM <- new m
			!valsM <- new m
			let go !i (Trav d j a:ts) = do
				write depthsM i d
				write ixsM j i
				write valsM i a
				go (i+1) ts
			    go _ [] = return ()
			go 0 (travel ix tree [])
			liftM3 (,,) (unsafeFreeze depthsM) (unsafeFreeze ixsM) 
				((unsafeFreeze :: V.MVector s a -> ST s (V.Vector a)) valsM)
		rM :: Int -> Int -> Int
		!rM = intRangeMin depths
		{-# NOINLINE lca #-}
		lca !i !j = let
			iIx = ixs ! i
			jIx = ixs ! j
			in vals ! case compare iIx jIx of
				EQ	-> iIx
				LT	-> rM iIx (jIx - iIx + 1)
				GT	-> rM jIx (iIx - jIx + 1)