-- | This module contains the 'SegmentTree' data structure, its -- constructor and the query function. -- -- Example Usage: -- -- @ -- import Data.Monoid -- import Data.SegmentTree -- ... -- st = mkTree $ map Sum [0..10] -- ... -- queryTree st (0, 10) == Sum 55 -- queryTree st (5, 10) == Sum 45 -- queryTree st (0, 4) == Sum 10 -- @ module Data.SegmentTree ( SegmentTree(..), mkTree, queryTree ) where import Data.Monoid import Text.Printf data (Monoid a) => Tree a = Branch a (Tree a) (Tree a) | Leaf a getCargo (Branch x _ _) = x getCargo (Leaf x) = x -- | A 'SegmentTree' is a binary tree and the bounds of its -- corresponding interval. data (Monoid a) => SegmentTree a = SegmentTree (Tree a) (Int, Int) instance (Monoid a) => Show (SegmentTree a) where show (SegmentTree t (l, u)) = unlines $ go t (l, u) where go (Branch _ lc rc) (l, u) = let m = (u-l) `div` 2 (ls, rs) = (go lc (l, l+m), go rc (l+m+1, u)) (ls', rs') = (indentTree True ls, indentTree False rs) ts = printf "[%d..%d]" l u in concat [[ts], ls', rs'] go (Leaf _) (l, u) = [printf "[%d]" l] indentTree _ [] = [] indentTree True [x] = [printf "|-- %s" x] indentTree False [x] = [printf "`-- %s" x] indentTree True (x:xs) = indentTree True [x] ++ map ("| "++) xs indentTree False (x:xs) = indentTree False [x] ++ map (" "++) xs -- | Build the 'SegmentTree' for the given list. Time: O(n*log n) mkTree :: (Monoid a) => [a] -> SegmentTree a mkTree xs = SegmentTree (go xs listBounds) listBounds where listBounds = (0, length xs - 1) go ys (l, u) -- invariant: head ys == xs !! l | l == u = Leaf (head ys) | otherwise = let m = (u-l) `div` 2 leftc = go ys (l, l+m) rightc = go (drop (m+1) ys) (l+m+1, u) in Branch (getCargo leftc `mappend` getCargo rightc) leftc rightc -- | Query the 'SegmentTree' for the specified closed interval. Time: -- O(log n) queryTree :: (Monoid a) => SegmentTree a -> (Int, Int) -> a queryTree (SegmentTree t (s, e)) (l, u) = go t (s, e) where -- we're querying for (l, u) go t (s, e) | (l > e) || (u < s) = mempty | (l <= s) && (u >= e) = getCargo t | otherwise = let (Branch _ leftc rightc) = t m = (e-s) `div` 2 lv = go leftc (s, s+m) rv = go rightc (s+m+1, e) in lv `mappend` rv