module Data.Set.Splay (
Splay(..)
, empty
, singleton
, insert
, fromList
, toList
, member
, delete
, deleteMin
, deleteMax
, null
, union
, intersection
, difference
, split
, minimum
, maximum
, valid
, (===)
, showSet
, printSet
) where
import Data.List (foldl')
import Prelude hiding (minimum, maximum, null)
data Splay a = Leaf | Node (Splay a) a (Splay a) deriving Show
instance (Eq a) => Eq (Splay a) where
t1 == t2 = toList t1 == toList t2
(===) :: Eq a => Splay a -> Splay a -> Bool
Leaf === Leaf = True
(Node l1 x1 r1) === (Node l2 x2 r2) = x1 == x2 && l1 === l2 && r1 === r2
_ === _ = False
split :: Ord a => a -> Splay a -> (Splay a, Bool, Splay a)
split _ Leaf = (Leaf,False,Leaf)
split k x@(Node xl xk xr) = case compare k xk of
EQ -> (xl, True, xr)
GT -> case xr of
Leaf -> (x, False, Leaf)
Node yl yk yr -> case compare k yk of
EQ -> (Node xl xk yl, True, yr)
GT -> let (lt, b, gt) = split k yr
in (Node (Node xl xk yl) yk lt, b, gt)
LT -> let (lt, b, gt) = split k yl
in (Node xl xk lt, b, Node gt yk yr)
LT -> case xl of
Leaf -> (Leaf, False, x)
Node yl yk yr -> case compare k yk of
EQ -> (yl, True, Node yr xk xr)
GT -> let (lt, b, gt) = split k yr
in (Node yl yk lt, b, Node gt xk xr)
LT -> let (lt, b, gt) = split k yl
in (lt, b, Node gt yk (Node yr xk xr))
empty :: Splay a
empty = Leaf
null :: Splay a -> Bool
null Leaf = True
null _ = False
singleton :: a -> Splay a
singleton x = Node Leaf x Leaf
insert :: Ord a => a -> Splay a -> Splay a
insert x t = Node l x r
where
(l,_,r) = split x t
fromList :: Ord a => [a] -> Splay a
fromList = foldl' (flip insert) empty
toList :: Splay a -> [a]
toList t = inorder t []
where
inorder Leaf xs = xs
inorder (Node l x r) xs = inorder l (x : inorder r xs)
member :: Ord a => a -> Splay a -> (Bool, Splay a)
member x t = case split x t of
(l,True,r) -> (True, Node l x r)
(Leaf,_,r) -> (False, r)
(l,_,Leaf) -> (False, l)
(l,_,r) -> let (m,l') = deleteMax l
in (False, Node l' m r)
minimum :: Splay a -> (a, Splay a)
minimum Leaf = error "minimum"
minimum t = let (x,mt) = deleteMin t in (x, Node Leaf x mt)
maximum :: Splay a -> (a, Splay a)
maximum Leaf = error "maximum"
maximum t = let (x,mt) = deleteMax t in (x, Node mt x Leaf)
deleteMin :: Splay a -> (a, Splay a)
deleteMin Leaf = error "deleteMin"
deleteMin (Node Leaf x r) = (x,r)
deleteMin (Node (Node Leaf lx lr) x r) = (lx, Node lr x r)
deleteMin (Node (Node ll lx lr) x r) = let (k,mt) = deleteMin ll
in (k, Node mt lx (Node lr x r))
deleteMax :: Splay a -> (a, Splay a)
deleteMax Leaf = error "deleteMax"
deleteMax (Node l x Leaf) = (x,l)
deleteMax (Node l x (Node rl rx Leaf)) = (rx, Node l x rl)
deleteMax (Node l x (Node rl rx rr)) = let (k,mt) = deleteMax rr
in (k, Node (Node l x rl) rx mt)
delete :: Ord a => a -> Splay a -> Splay a
delete x t = case split x t of
(l, True, r) -> union l r
_ -> t
union :: Ord a => Splay a -> Splay a -> Splay a
union Leaf t = t
union (Node a x b) t = Node (union ta a) x (union tb b)
where
(ta,_,tb) = split x t
intersection :: Ord a => Splay a -> Splay a -> Splay a
intersection Leaf _ = Leaf
intersection _ Leaf = Leaf
intersection t1 (Node l x r) = case split x t1 of
(l', True, r') -> Node (intersection l' l) x (intersection r' r)
(l', False, r') -> union (intersection l' l) (intersection r' r)
difference :: Ord a => Splay a -> Splay a -> Splay a
difference Leaf _ = Leaf
difference t1 Leaf = t1
difference t1 (Node l x r) = union (difference l' l) (difference r' r)
where
(l',_,r') = split x t1
valid :: Ord a => Splay a -> Bool
valid t = isOrdered t
isOrdered :: Ord a => Splay a -> Bool
isOrdered t = ordered $ toList t
where
ordered [] = True
ordered [_] = True
ordered (x:y:xys) = x < y && ordered (y:xys)
showSet :: Show a => Splay a -> String
showSet = showSet' ""
showSet' :: Show a => String -> Splay a -> String
showSet' _ Leaf = "\n"
showSet' pref (Node l x r) = show x ++ "\n"
++ pref ++ "+ " ++ showSet' pref' l
++ pref ++ "+ " ++ showSet' pref' r
where
pref' = " " ++ pref
printSet :: Show a => Splay a -> IO ()
printSet = putStr . showSet