{-| Purely functional red-black trees. * Chris Okasaki, \"Red-Black Trees in a Functional Setting\", Journal of Functional Programming, 9(4), pp 471-477, July 1999 * Stefan Kahrs, \"Red-black trees with types\", Journal of functional programming, 11(04), pp 425-432, July 2001 -} module Data.RBTree ( -- * Data structures RBTree(..) , Color(..) , BlackHeight -- * Creating red-black trees , empty , singleton , insert , fromList -- * Converting a list , toList -- * Membership , member -- * Deleting , delete , deleteMin , deleteMax -- * Checking , null -- * Set operations , union , intersection , difference -- * Helper functions , join , merge , split , valid , minimum , maximum , showTree , printTree ) where import Data.List (foldl') import Prelude hiding (minimum, maximum, null) ---------------------------------------------------------------- -- Part to be shared ---------------------------------------------------------------- data RBTree a = Leaf -- color is Black | Node Color !BlackHeight !(RBTree a) a !(RBTree a) deriving (Show) data Color = B -- ^ Black | R -- ^ Red deriving (Eq,Show) {-| Red nodes have the same BlackHeight of their parent. -} type BlackHeight = Int ---------------------------------------------------------------- instance (Eq a) => Eq (RBTree a) where t1 == t2 = toList t1 == toList t2 ---------------------------------------------------------------- height :: RBTree a -> BlackHeight height Leaf = 0 height (Node _ h _ _ _) = h ---------------------------------------------------------------- {-| See if the red black tree is empty. >>> Data.RBTree.null empty True >>> Data.RBTree.null (singleton 1) False -} null :: Eq a => RBTree a -> Bool null t = t == Leaf ---------------------------------------------------------------- {-| Empty tree. >>> height empty 0 -} empty :: RBTree a empty = Leaf {-| Singleton tree. >>> height (singleton 'a') 1 -} singleton :: Ord a => a -> RBTree a singleton x = Node B 1 Leaf x Leaf ---------------------------------------------------------------- {-| Creating a tree from a list. O(N log N) >>> empty == fromList [] True >>> singleton 'a' == fromList ['a'] True >>> fromList [5,3,5] == fromList [5,3] True -} fromList :: Ord a => [a] -> RBTree a fromList = foldl' (flip insert) empty ---------------------------------------------------------------- {-| Creating a list from a tree. O(N) >>> toList (fromList [5,3]) [3,5] >>> toList empty [] -} toList :: RBTree a -> [a] toList t = inorder t [] where inorder Leaf xs = xs inorder (Node _ _ l x r) xs = inorder l (x : inorder r xs) ---------------------------------------------------------------- {-| Checking if this element is a member of a tree? >>> member 5 (fromList [5,3]) True >>> member 1 (fromList [5,3]) False -} member :: Ord a => a -> RBTree a -> Bool member _ Leaf = False member x (Node _ _ l y r) = case compare x y of LT -> member x l GT -> member x r EQ -> True ---------------------------------------------------------------- isBalanced :: RBTree a -> Bool isBalanced t = isBlackSame t && isRedSeparate t isBlackSame :: RBTree a -> Bool isBlackSame t = all (n==) ns where n:ns = blacks t blacks :: RBTree a -> [Int] blacks = blacks' 0 where blacks' n Leaf = [n+1] blacks' n (Node R _ l _ r) = blacks' n l ++ blacks' n r blacks' n (Node B _ l _ r) = blacks' n' l ++ blacks' n' r where n' = n + 1 isRedSeparate :: RBTree a -> Bool isRedSeparate = reds B reds :: Color -> RBTree t -> Bool reds _ Leaf = True reds R (Node R _ _ _ _) = False reds _ (Node c _ l _ r) = reds c l && reds c r isOrdered :: Ord a => RBTree a -> Bool isOrdered t = ordered $ toList t where ordered [] = True ordered [_] = True ordered (x:y:xys) = x < y && ordered (y:xys) blackHeight :: RBTree a -> Bool blackHeight Leaf = True blackHeight t@(Node B i _ _ _) = bh i t where bh n Leaf = n == 0 bh n (Node R h l _ r) = n == h' && bh n l && bh n r where h' = h - 1 bh n (Node B h l _ r) = n == h && bh n' l && bh n' r where n' = n - 1 blackHeight _ = error "blackHeight" ---------------------------------------------------------------- turnR :: RBTree a -> RBTree a turnR Leaf = error "turnR" turnR (Node _ h l x r) = Node R h l x r turnB :: RBTree a -> RBTree a turnB Leaf = error "turnB" turnB (Node _ h l x r) = Node B h l x r turnB' :: RBTree a -> RBTree a turnB' Leaf = Leaf turnB' (Node _ h l x r) = Node B h l x r ---------------------------------------------------------------- minimum :: RBTree a -> a minimum (Node _ _ Leaf x _) = x minimum (Node _ _ l _ _) = minimum l minimum _ = error "minimum" maximum :: RBTree a -> a maximum (Node _ _ _ x Leaf) = x maximum (Node _ _ _ _ r) = maximum r maximum _ = error "maximum" ---------------------------------------------------------------- showTree :: Show a => RBTree a -> String showTree = showTree' "" showTree' :: Show a => String -> RBTree a -> String showTree' _ Leaf = "\n" showTree' pref (Node k h l x r) = show k ++ " " ++ show x ++ " (" ++ show h ++ ")\n" ++ pref ++ "+ " ++ showTree' pref' l ++ pref ++ "+ " ++ showTree' pref' r where pref' = " " ++ pref printTree :: Show a => RBTree a -> IO () printTree = putStr . showTree ---------------------------------------------------------------- isRed :: RBTree a -> Bool isRed (Node R _ _ _ _ ) = True isRed _ = False ---------------------------------------------------------------- -- Basic operations ---------------------------------------------------------------- valid :: Ord a => RBTree a -> Bool valid t = isBalanced t && blackHeight t && isOrdered t ---------------------------------------------------------------- -- Chris Okasaki -- {-| Insertion. O(log N) >>> insert 5 (fromList [5,3]) == fromList [3,5] True >>> insert 7 (fromList [5,3]) == fromList [3,5,7] True >>> insert 5 empty == singleton 5 True -} insert :: Ord a => a -> RBTree a -> RBTree a insert kx t = turnB (insert' kx t) insert' :: Ord a => a -> RBTree a -> RBTree a insert' kx Leaf = Node R 1 Leaf kx Leaf insert' kx s@(Node c h l x r) = case compare kx x of LT -> balanceL c h (insert' kx l) x r GT -> balanceR c h l x (insert' kx r) EQ -> s balanceL :: Color -> BlackHeight -> RBTree a -> a -> RBTree a -> RBTree a balanceL B h (Node R _ (Node R _ a x b) y c) z d = Node R (h+1) (Node B h a x b) y (Node B h c z d) balanceL B h (Node R _ a x (Node R _ b y c)) z d = Node R (h+1) (Node B h a x b) y (Node B h c z d) balanceL k h l x r = Node k h l x r balanceR :: Color -> BlackHeight -> RBTree a -> a -> RBTree a -> RBTree a balanceR B h a x (Node R _ b y (Node R _ c z d)) = Node R (h+1) (Node B h a x b) y (Node B h c z d) balanceR B h a x (Node R _ (Node R _ b y c) z d) = Node R (h+1) (Node B h a x b) y (Node B h c z d) balanceR k h l x r = Node k h l x r ---------------------------------------------------------------- type RBTreeBDel a = (RBTree a, Bool) unbalancedL :: Color -> BlackHeight -> RBTree a -> a -> RBTree a -> RBTreeBDel a unbalancedL c h l@(Node B _ _ _ _) x r = (balanceL B h (turnR l) x r, c == B) unbalancedL B h (Node R lh ll lx lr@(Node B _ _ _ _)) x r = (Node B lh ll lx (balanceL B h (turnR lr) x r), False) unbalancedL _ _ _ _ _ = error "unbalancedL" -- The left tree lacks one Black node unbalancedR :: Color -> BlackHeight -> RBTree a -> a -> RBTree a -> (RBTree a, Bool) -- Decreasing one Black node in the right unbalancedR c h l x r@(Node B _ _ _ _) = (balanceR B h l x (turnR r), c == B) -- Taking one Red node from the right and adding it to the right as Black unbalancedR B h l x (Node R rh rl@(Node B _ _ _ _) rx rr) = (Node B rh (balanceR B h l x (turnR rl)) rx rr, False) unbalancedR _ _ _ _ _ = error "unbalancedR" ---------------------------------------------------------------- {-| Deleting the minimum element. O(log N) >>> deleteMin (fromList [5,3,7]) == fromList [5,7] True >>> deleteMin empty == empty True -} deleteMin :: RBTree a -> RBTree a deleteMin Leaf = empty deleteMin t = turnB' s where ((s, _), _) = deleteMin' t deleteMin' :: RBTree a -> (RBTreeBDel a, a) deleteMin' Leaf = error "deleteMin'" deleteMin' (Node B _ Leaf x Leaf) = ((Leaf, True), x) deleteMin' (Node B _ Leaf x r@(Node R _ _ _ _)) = ((turnB r, False), x) deleteMin' (Node R _ Leaf x r) = ((r, False), x) deleteMin' (Node c h l x r) = if d then (tD, m) else (tD', m) where ((l',d),m) = deleteMin' l tD = unbalancedR c (h-1) l' x r tD' = (Node c h l' x r, False) ---------------------------------------------------------------- {-| Deleting the maximum >>> deleteMax (fromList [(5,"a"), (3,"b"), (7,"c")]) == fromList [(3,"b"), (5,"a")] True >>> deleteMax empty == empty True -} deleteMax :: RBTree a -> RBTree a deleteMax Leaf = empty deleteMax t = turnB' s where ((s, _), _) = deleteMax' t deleteMax' :: RBTree a -> (RBTreeBDel a, a) deleteMax' Leaf = error "deleteMax'" deleteMax' (Node B _ Leaf x Leaf) = ((Leaf, True), x) deleteMax' (Node B _ l@(Node R _ _ _ _) x Leaf) = ((turnB l, False), x) deleteMax' (Node R _ l x Leaf) = ((l, False), x) deleteMax' (Node c h l x r) = if d then (tD, m) else (tD', m) where ((r',d),m) = deleteMax' r tD = unbalancedL c (h-1) l x r' tD' = (Node c h l x r', False) ---------------------------------------------------------------- blackify :: RBTree a -> RBTreeBDel a blackify s@(Node R _ _ _ _) = (turnB s, False) blackify s = (s, True) {-| Deleting this element from a tree. O(log N) >>> delete 5 (fromList [5,3]) == singleton 3 True >>> delete 7 (fromList [5,3]) == fromList [3,5] True >>> delete 5 empty == empty True -} delete :: Ord a => a -> RBTree a -> RBTree a delete x t = turnB' s where (s,_) = delete' x t delete' :: Ord a => a -> RBTree a -> RBTreeBDel a delete' _ Leaf = (Leaf, False) delete' x (Node c h l y r) = case compare x y of LT -> let (l',d) = delete' x l t = Node c h l' y r in if d then unbalancedR c (h-1) l' y r else (t, False) GT -> let (r',d) = delete' x r t = Node c h l y r' in if d then unbalancedL c (h-1) l y r' else (t, False) EQ -> case r of Leaf -> if c == B then blackify l else (l, False) _ -> let ((r',d),m) = deleteMin' r t = Node c h l m r' in if d then unbalancedL c (h-1) l m r' else (t, False) ---------------------------------------------------------------- -- Set operations ---------------------------------------------------------------- {- Each element of t1 < g. Each element of t2 > g. Both t1 and t2 must be Black. -} join :: Ord a => RBTree a -> a -> RBTree a -> RBTree a join Leaf g t2 = insert g t2 join t1 g Leaf = insert g t1 join t1 g t2 = case compare h1 h2 of LT -> turnB $ joinLT t1 g t2 h1 GT -> turnB $ joinGT t1 g t2 h2 EQ -> Node B (h1+1) t1 g t2 where h1 = height t1 h2 = height t2 -- The root of result must be red. joinLT :: Ord a => RBTree a -> a -> RBTree a -> BlackHeight -> RBTree a joinLT t1 g t2@(Node c h l x r) h1 | h == h1 = Node R (h+1) t1 g t2 | otherwise = balanceL c h (joinLT t1 g l h1) x r joinLT _ _ _ _ = error "joinLT" -- The root of result must be red. joinGT :: Ord a => RBTree a -> a -> RBTree a -> BlackHeight -> RBTree a joinGT t1@(Node c h l x r) g t2 h2 | h == h2 = Node R (h+1) t1 g t2 | otherwise = balanceR c h l x (joinGT r g t2 h2) joinGT _ _ _ _ = error "joinGT" ---------------------------------------------------------------- {- Each element of t1 < each element of t2 Both t1 and t2 must be Black. -} merge :: Ord a => RBTree a -> RBTree a -> RBTree a merge Leaf t2 = t2 merge t1 Leaf = t1 merge t1 t2 = case compare h1 h2 of LT -> turnB $ mergeLT t1 t2 h1 GT -> turnB $ mergeGT t1 t2 h2 EQ -> turnB $ mergeEQ t1 t2 where h1 = height t1 h2 = height t2 mergeLT :: Ord a => RBTree a -> RBTree a -> BlackHeight -> RBTree a mergeLT t1 t2@(Node c h l x r) h1 | h == h1 = mergeEQ t1 t2 | otherwise = balanceL c h (mergeLT t1 l h1) x r mergeLT _ _ _ = error "mergeLT" mergeGT :: Ord a => RBTree a -> RBTree a -> BlackHeight -> RBTree a mergeGT t1@(Node c h l x r) t2 h2 | h == h2 = mergeEQ t1 t2 | otherwise = balanceR c h l x (mergeGT r t2 h2) mergeGT _ _ _ = error "mergeGT" {- Merging two trees whose heights are the same. The root must be either a red with height + 1 for a black with height -} mergeEQ :: Ord a => RBTree a -> RBTree a -> RBTree a mergeEQ Leaf Leaf = Leaf mergeEQ t1@(Node _ h l x r) t2 | h == h2' = Node R (h+1) t1 m t2' | isRed l = Node R (h+1) (turnB l) x (Node B h r m t2') -- unnecessary for LL | isRed r = Node B h (Node R h l x rl) rx (Node R h rr m t2') | otherwise = Node B h (turnR t1) m t2' where m = minimum t2 t2' = deleteMin t2 h2' = height t2' Node R _ rl rx rr = r mergeEQ _ _ = error "mergeEQ" ---------------------------------------------------------------- split :: Ord a => a -> RBTree a -> (RBTree a, RBTree a) split _ Leaf = (Leaf,Leaf) split kx (Node _ _ l x r) = case compare kx x of LT -> (lt, join gt x (turnB' r)) where (lt,gt) = split kx l GT -> (join (turnB' l) x lt, gt) where (lt,gt) = split kx r EQ -> (turnB' l, turnB' r) {- LL split :: Ord a => a -> RBTree a -> (RBTree a, RBTree a) split _ Leaf = (Leaf,Leaf) split kx (Node _ _ l x r) = case compare kx x of LT -> (lt, join gt x r) where (lt,gt) = split kx l GT -> (join l x lt, gt) where (lt,gt) = split kx r EQ -> (turnB' l, r) -} ---------------------------------------------------------------- {-| Creating a union tree from two trees. O(N + M) >>> union (fromList [5,3]) (fromList [5,7]) == fromList [3,5,7] True -} union :: Ord a => RBTree a -> RBTree a -> RBTree a union t1 Leaf = t1 -- ensured Black thanks to split union Leaf t2 = turnB' t2 union t1 (Node _ _ l x r) = join (union l' l) x (union r' r) where (l',r') = split x t1 {-| Creating a intersection tree from trees. O(N + N) >>> intersection (fromList [5,3]) (fromList [5,7]) == singleton 5 True -} intersection :: Ord a => RBTree a -> RBTree a -> RBTree a intersection Leaf _ = Leaf intersection _ Leaf = Leaf intersection t1 (Node _ _ l x r) | member x t1 = join (intersection l' l) x (intersection r' r) | otherwise = merge (intersection l' l) (intersection r' r) where (l',r') = split x t1 {-| Creating a difference tree from trees. O(N + N) >>> difference (fromList [5,3]) (fromList [5,7]) == singleton 3 True -} difference :: Ord a => RBTree a -> RBTree a -> RBTree a difference Leaf _ = Leaf difference t1 Leaf = t1 -- ensured Black thanks to split difference t1 (Node _ _ l x r) = merge (difference l' l) (difference r' r) where (l',r') = split x t1