{-# LANGUAGE BangPatterns #-} -- | Functions for finding /lowest common ancestors/ in binary trees in /O(1)/ time, with /O(n)/ preprocessing. module Data.RangeMin.LCA.Binary (Index, BinTree(..), quickLCABinary, lcaBinary) where import Data.RangeMin import Data.RangeMin.Common.Vector import Data.RangeMin.LCA.IndexM import qualified Data.RangeMin.Fusion as F import qualified Data.Vector.Primitive as PV import qualified Data.Vector as V -- | A generic binary tree. data BinTree a = Tip | BinTree a (BinTree a) (BinTree a) unfoldBin :: BinTree a -> Maybe (a, BinTree a, BinTree a) unfoldBin (BinTree x l r) = Just (x, l, r) unfoldBin Tip = Nothing type Depth = Int data Trav a = D {-# UNPACK #-} !Depth {-# UNPACK #-} !Index a inorderBin :: BinTree a -> (BinTree (Index, a), Int) inorderBin = execIndexM . trav where trav Tip = return Tip trav (BinTree x l r) = do l' <- trav l i <- getIndex r' <- trav r return (BinTree (i, x) l' r') {-# INLINE inorderD #-} inorderD :: (Int -> a -> b) -> BinTree a -> [b] -> [b] inorderD f = inorderD' 0 where inorderD' !d t xs = case unfoldBin t of Just (x, l, r) -> inorderD'' l (f d x:inorderD'' r xs) Nothing -> xs where inorderD'' = inorderD' (d+1) -- | Takes a binary tree and indexes it inorder, returning the number of nodes, the indexed -- tree, and the lowest common ancestor function. quickLCABinary :: BinTree a -> (Int, BinTree (Index, a), Index -> Index -> (Index, a)) quickLCABinary tree = case inorderBin tree of (iTree, n) -> (n, iTree, lcaBinary n fst iTree) -- | Similar to 'LCA.lowestCommonAncestor', but optimized for binary trees. This method can reasonably -- be expected to run twice as fast as 'lowestCommonAncestor'. lcaBinary :: Int -> (a -> Index) -> BinTree a -> Index -> Index -> a lcaBinary !n ix tree = lca where !(!depths, !ixs, !vals) = (F.unzip3 :: V.Vector (Int, Int, a) -> (PV.Vector Int, PV.Vector Int, V.Vector a)) $ F.map (\ (D i j k) -> (i, j, k)) $ F.fromListN n (inorderD (\ d x -> D d (ix x) x) tree []) rM :: Int -> Int -> Int !rM = intRangeMin depths {-# NOINLINE lca #-} lca !i !j = let iIx = ixs ! i jIx = ixs ! j in vals ! case compare iIx jIx of EQ -> iIx LT -> rM iIx (jIx - iIx + 1) GT -> rM jIx (iIx - jIx + 1)