{-# LANGUAGE BangPatterns #-} -- | Functions for finding /lowest common ancestors/ in binary trees in /O(1)/ time, with /O(n)/ preprocessing. module Data.RangeMin.LCA.Binary (Index, BinTree(..), quickLCABinary, lcaBinary) where import Data.RangeMin import Data.RangeMin.Common.Vector.Utils import Control.Monad import Data.RangeMin.Common.ST import Data.RangeMin.LCA.IndexM import qualified Data.Vector.Unboxed as UV import qualified Data.Vector.Primitive as PV import qualified Data.Vector as V -- | A generic binary tree. data BinTree a = Tip | BinTree a (BinTree a) (BinTree a) unfoldBin :: BinTree a -> Maybe (a, BinTree a, BinTree a) unfoldBin (BinTree x l r) = Just (x, l, r) unfoldBin Tip = Nothing type Depth = Int data Trav a = D {-# UNPACK #-} !Depth {-# UNPACK #-} !Index a inorderBin :: BinTree a -> (BinTree (Index, a), Int) inorderBin = execIndexM . trav where trav Tip = return Tip trav (BinTree x l r) = do l' <- trav l i <- getIndex r' <- trav r return (BinTree (i, x) l' r') {-# INLINE inorderD #-} inorderD :: (Int -> a -> b) -> BinTree a -> [b] -> [b] inorderD f = inorderD' 0 where inorderD' !d t xs = case unfoldBin t of Just (x, l, r) -> inorderD'' l (f d x:inorderD'' r xs) Nothing -> xs where inorderD'' = inorderD' (d+1) -- | Takes a binary tree and indexes it inorder, returning the number of nodes, the indexed -- tree, and the lowest common ancestor function. quickLCABinary :: BinTree a -> (Int, BinTree (Index, a), Index -> Index -> (Index, a)) quickLCABinary tree = case inorderBin tree of (iTree, n) -> (n, iTree, lcaBinary n fst iTree) -- | Similar to 'LCA.lowestCommonAncestor', but optimized for binary trees. This method can reasonably -- be expected to run twice as fast as 'lowestCommonAncestor'. lcaBinary :: Int -> (a -> Index) -> BinTree a -> Index -> Index -> a lcaBinary !n ix tree = lca where depths :: PV.Vector Int ixs :: UV.Vector Int -- vals :: V.Vector a !(!depths, !ixs, !vals) = inlineRunST $ do !depthsM <- new n !ixsM <- new n !valsM <- new n let go !i (D d j x:ts) = do write depthsM i d write ixsM j i write valsM i x go (i+1) ts go _ [] = return () go 0 (inorderD (\ d x -> D d (ix x) x) tree []) liftM3 (,,) (unsafeFreeze depthsM) (unsafeFreeze ixsM) ((unsafeFreeze :: V.MVector s a -> ST s (V.Vector a)) valsM) rM :: Int -> Int -> Int !rM = intRangeMin depths {-# NOINLINE lca #-} lca !i !j = let iIx = ixs ! i jIx = ixs ! j in vals ! case compare iIx jIx of EQ -> iIx LT -> rM iIx (jIx - iIx + 1) GT -> rM jIx (iIx - jIx + 1)