{-# LANGUAGE BangPatterns #-} -- | Functions for finding /lowest common ancestors/ in trees in /O(1)/ time, with /O(n)/ preprocessing. module Data.RangeMin.LCA (Index, lowestCommonAncestor, quickLCA) where import Data.RangeMin.LCA.IndexM import Control.Monad import Data.RangeMin import Data.RangeMin.Common.ST import Data.RangeMin.Common.Vector -- import qualified Data.RangeMin.Mixed as Mix -- import qualified Data.Vector.Generic as G import qualified Data.Vector.Unboxed as UV import qualified Data.Vector.Primitive as PV import qualified Data.Vector as V -- import qualified Data.Vector.Fusion.Stream as S -- import Data.Vector.Fusion.Stream.Size import Data.Tree -- | Labels a tree in depth-first order. dfOrder :: Tree a -> (Tree (Index, a), Int) dfOrder tree = execIndexM $ indexer tree where indexer (Node a ts) = do i <- getIndex ts' <- mapM indexer ts return (Node (i, a) ts') type Depth = Int data Trav a = Trav {-# UNPACK #-} !Depth {-# UNPACK #-} !Index a travel :: (a -> Index) -> Tree a -> [Trav a] -> [Trav a] travel f = trav' 0 where trav' !d (Node x ts) zs = let me = Trav d (f x) x in me:foldr (\ t -> trav' (d+1) t . (me:)) zs ts -- | Takes a tree and indexes it in depth-first order, returning the number of nodes, the indexed -- tree, and the lowest common ancestor function. quickLCA :: Tree a -> (Int, Tree (Index, a), Index -> Index -> (Index, a)) quickLCA tree = case dfOrder tree of (iTree, n) -> (n, iTree, lowestCommonAncestor n fst iTree) -- | @'lowestCommonAncestor' n ix tree@ takes a tree whose nodes are mapped by -- @ix@ to a unique index in the range @0..n-1@, and returns a function -- which takes two indices (corresponding to two nodes in the tree) and returns -- the label of their /lowest common ancestor/. -- -- This takes /O(n)/ preprocessing and answers queries in /O(1)/, as it is an -- application of "Data.RangeMin". -- -- For binary trees, consider using "Data.RangeMin.LCA.Binary". lowestCommonAncestor :: Int -> (a -> Index) -> Tree a -> Index -> Index -> a lowestCommonAncestor !n ix tree = vals `seq` lca where ixs :: UV.Vector Int depths :: PV.Vector Int -- vals :: V.Vector a !(!depths, !ixs, !vals) = inlineRunST $ do let !m = (2 * n - 1) !depthsM <- new m !ixsM <- new m !valsM <- new m let go !i (Trav d j a:ts) = do write depthsM i d write ixsM j i write valsM i a go (i+1) ts go _ [] = return () go 0 (travel ix tree []) liftM3 (,,) (unsafeFreeze depthsM) (unsafeFreeze ixsM) ((unsafeFreeze :: V.MVector s a -> ST s (V.Vector a)) valsM) rM :: Int -> Int -> Int !rM = intRangeMin depths {-# NOINLINE lca #-} lca !i !j = let iIx = ixs ! i jIx = ixs ! j in vals ! case compare iIx jIx of EQ -> iIx LT -> rM iIx (jIx - iIx + 1) GT -> rM jIx (iIx - jIx + 1)