module Data.Set.WBTree (
WBTree(..)
, Size
, size
, empty
, singleton
, insert
, fromList
, toList
, member
, delete
, deleteMin
, deleteMax
, null
, union
, intersection
, difference
, join
, merge
, split
, minimum
, maximum
, valid
) where
import Data.List (foldl')
import Prelude hiding (minimum, maximum, null)
type Size = Int
data WBTree a = Leaf | Node Size (WBTree a) a (WBTree a) deriving (Show)
instance (Eq a) => Eq (WBTree a) where
t1 == t2 = toList t1 == toList t2
size :: WBTree a -> Size
size Leaf = 0
size (Node sz _ _ _) = sz
null :: Eq a => WBTree a -> Bool
null t = t == Leaf
empty :: WBTree a
empty = Leaf
singleton :: a -> WBTree a
singleton x = Node 1 Leaf x Leaf
node :: WBTree a -> a -> WBTree a -> WBTree a
node l x r = Node (size l + size r + 1) l x r
insert :: Ord a => a -> WBTree a -> WBTree a
insert k Leaf = singleton k
insert k (Node sz l x r) = case compare k x of
LT -> balanceR (insert k l) x r
GT -> balanceL l x (insert k r)
EQ -> Node sz l x r
fromList :: Ord a => [a] -> WBTree a
fromList = foldl' (flip insert) empty
toList :: WBTree a -> [a]
toList t = inorder t []
where
inorder Leaf xs = xs
inorder (Node _ l x r) xs = inorder l (x : inorder r xs)
member :: Ord a => a -> WBTree a -> Bool
member _ Leaf = False
member k (Node _ l x r) = case compare k x of
LT -> member k l
GT -> member k r
EQ -> True
balanceL :: WBTree a -> a -> WBTree a -> WBTree a
balanceL l x r
| isBalanced l r = node l x r
| otherwise = rotateL l x r
balanceR :: WBTree a -> a -> WBTree a -> WBTree a
balanceR l x r
| isBalanced r l = node l x r
| otherwise = rotateR l x r
rotateL :: WBTree a -> a -> WBTree a -> WBTree a
rotateL l x r@(Node _ rl _ rr)
| isSingle rl rr = singleL l x r
| otherwise = doubleL l x r
rotateL _ _ _ = error "rotateL"
rotateR :: WBTree a -> a -> WBTree a -> WBTree a
rotateR l@(Node _ ll _ lr) x r
| isSingle lr ll = singleR l x r
| otherwise = doubleR l x r
rotateR _ _ _ = error "rotateR"
singleL :: WBTree a -> a -> WBTree a -> WBTree a
singleL l x (Node _ rl rx rr) = node (node l x rl) rx rr
singleL _ _ _ = error "singleL"
singleR :: WBTree a -> a -> WBTree a -> WBTree a
singleR (Node _ ll lx lr) x r = node ll lx (node lr x r)
singleR _ _ _ = error "singleR"
doubleL :: WBTree a -> a -> WBTree a -> WBTree a
doubleL l x (Node _ (Node _ rll rlx rlr) rx rr) = node (node l x rll) rlx (node rlr rx rr)
doubleL _ _ _ = error "doubleL"
doubleR :: WBTree a -> a -> WBTree a -> WBTree a
doubleR (Node _ ll lx (Node _ lrl lrx lrr)) x r = node (node ll lx lrl) lrx (node lrr x r)
doubleR _ _ _ = error "doubleR"
deleteMin :: WBTree a -> WBTree a
deleteMin (Node _ Leaf _ r) = r
deleteMin (Node _ l x r) = balanceL (deleteMin l) x r
deleteMin Leaf = Leaf
deleteMax :: WBTree a -> WBTree a
deleteMax (Node _ l _ Leaf) = l
deleteMax (Node _ l x r) = balanceR l x (deleteMax r)
deleteMax Leaf = Leaf
delete :: Ord a => a -> WBTree a -> WBTree a
delete k t = case t of
Leaf -> Leaf
Node _ l x r -> case compare k x of
LT -> balanceL (delete k l) x r
GT -> balanceR l x (delete k r)
EQ -> glue l r
valid :: Ord a => WBTree a -> Bool
valid t = balanced t && ordered t && validsize t
balanced :: WBTree a -> Bool
balanced Leaf = True
balanced (Node _ l _ r) = isBalanced l r && isBalanced r l
&& balanced l && balanced r
ordered :: Ord a => WBTree a -> Bool
ordered t = bounded (const True) (const True) t
where
bounded lo hi t' = case t' of
Leaf -> True
Node _ l x r -> lo x && hi x && bounded lo (<x) l && bounded (>x) hi r
validsize :: WBTree a -> Bool
validsize t = realsize t == Just (size t)
where
realsize t' = case t' of
Leaf -> Just 0
Node s l _ r -> case (realsize l,realsize r) of
(Just n,Just m) | n+m+1 == s -> Just s
_ -> Nothing
join :: Ord a => WBTree a -> a -> WBTree a -> WBTree a
join Leaf x r = insert x r
join l x Leaf = insert x l
join l@(Node _ ll lx lr) x r@(Node _ rl rx rr)
| bal1 && bal2 = node l x r
| bal1 = balanceL ll lx (join lr x r)
| otherwise = balanceR (join l x rl) rx rr
where
bal1 = isBalanced l r
bal2 = isBalanced r l
merge :: WBTree a -> WBTree a -> WBTree a
merge Leaf r = r
merge l Leaf = l
merge l@(Node _ ll lx lr) r@(Node _ rl rx rr)
| bal1 && bal2 = glue l r
| bal1 = balanceL ll lx (merge lr r)
| otherwise = balanceR (merge l rl) rx rr
where
bal1 = isBalanced l r
bal2 = isBalanced r l
glue :: WBTree a -> WBTree a -> WBTree a
glue Leaf r = r
glue l Leaf = l
glue l r
| size l > size r = balanceL (deleteMax l) (maximum l) r
| otherwise = balanceR l (minimum r) (deleteMin r)
split :: Ord a => a -> WBTree a -> (WBTree a, WBTree a)
split _ Leaf = (Leaf,Leaf)
split k (Node _ l x r) = case compare k x of
LT -> let (lt,gt) = split k l in (lt,join gt x r)
GT -> let (lt,gt) = split k r in (join l x lt,gt)
EQ -> (l,r)
minimum :: WBTree a -> a
minimum (Node _ Leaf x _) = x
minimum (Node _ l _ _) = minimum l
minimum _ = error "minimum"
maximum :: WBTree a -> a
maximum (Node _ _ x Leaf) = x
maximum (Node _ _ _ r) = maximum r
maximum _ = error "maximum"
union :: Ord a => WBTree a -> WBTree a -> WBTree a
union t1 Leaf = t1
union Leaf t2 = t2
union t1 (Node _ l x r) = join (union l' l) x (union r' r)
where
(l',r') = split x t1
intersection :: Ord a => WBTree a -> WBTree a -> WBTree a
intersection Leaf _ = Leaf
intersection _ Leaf = Leaf
intersection t1 (Node _ l x r)
| member x t1 = join (intersection l' l) x (intersection r' r)
| otherwise = merge (intersection l' l) (intersection r' r)
where
(l',r') = split x t1
difference :: Ord a => WBTree a -> WBTree a -> WBTree a
difference Leaf _ = Leaf
difference t1 Leaf = t1
difference t1 (Node _ l x r) = merge (difference l' l) (difference r' r)
where
(l',r') = split x t1
delta :: Int
delta = 3
gamma :: Int
gamma = 2
isBalanced :: WBTree a -> WBTree a -> Bool
isBalanced a b = delta * (size a + 1) >= (size b + 1)
isSingle :: WBTree a -> WBTree a -> Bool
isSingle a b = (size a + 1) < gamma * (size b + 1)