module Data.Tree.Splay
(SplayTree, head, tail, singleton, empty, null, fromList, fromAscList, toList, toAscList, insert, lookup, (!!), splay, size, delete)
where
import Prelude hiding (head, tail, lookup, null, (!!))
data (Ord k) => SplayTree k v =
Leaf
| SplayTree k v Int (SplayTree k v) (SplayTree k v) deriving (Ord, Eq)
singleton :: (Ord k) => (k,v) -> SplayTree k v
singleton (k,v) = SplayTree k v 0 Leaf Leaf
empty :: (Ord k) => SplayTree k v
empty = Leaf
null :: (Ord k) => SplayTree k v -> Bool
null Leaf = True
null _ = False
size :: (Ord k) => SplayTree k v -> Int
size Leaf = 0
size (SplayTree _ _ d _ _) = d
lookup :: (Ord k) => k -> SplayTree k v -> SplayTree k v
lookup _ Leaf = Leaf
lookup k' t@(SplayTree k _ _ l r)
| k' == k = t
| k > k' =
case lookup k' l of
Leaf -> t
lt -> zig lt t
| otherwise =
case lookup k' r of
Leaf -> t
rt -> zag t rt
(!!) :: (Ord k) => SplayTree k v -> Int -> (k,v)
(!!) Leaf _ = error "index out of bounds"
(!!) (SplayTree k v d l r) n =
if n > d
then error "index out of bounds"
else
let l' = size l in
if n == l'
then (k,v)
else if n <= l'
then l !! n
else r !! (n l')
splay :: (Ord k) => SplayTree k v -> Int -> SplayTree k v
splay Leaf _ = error "index out of bounds"
splay t@(SplayTree _ _ d l r) n =
if n > d
then error "index out of bounds"
else
let l' = size l in
if n == l'
then t
else if n <= l'
then case splay l n of
Leaf -> error "index out of bounds"
lt -> zig lt t
else case splay r (n l') of
Leaf -> error "index out of bounds"
rt -> zag t rt
zig :: (Ord k) => SplayTree k v -> SplayTree k v -> SplayTree k v
zig Leaf _ = error "tree corruption"
zig _ Leaf = error "tree corruption"
zig (SplayTree k1 v1 _ l1 r1) (SplayTree k v d _ r) =
SplayTree k1 v1 d l1 (SplayTree k v (d size l1 1) r1 r)
zag :: (Ord k) => SplayTree k v -> SplayTree k v -> SplayTree k v
zag Leaf _ = error "tree corruption"
zag _ Leaf = error "tree corruption"
zag (SplayTree k v d l _) (SplayTree k1 v1 _ l1 r1) =
SplayTree k1 v1 d (SplayTree k v (d size r1 1) l l1) r1
insert :: (Ord k) => k -> v -> SplayTree k v -> SplayTree k v
insert k v t =
case lookup k t of
Leaf -> (SplayTree k v 0 Leaf Leaf)
(SplayTree k1 v1 d l r) ->
if k1 < k
then SplayTree k v (d + 1) (SplayTree k1 v1 (d size r + 1) l Leaf) r
else SplayTree k v (d + 1) l (SplayTree k1 v1 (d size l + 1) Leaf r)
head :: (Ord k) => SplayTree k v -> (k,v)
head Leaf = error "head of empty tree"
head (SplayTree k v _ _ _) = (k,v)
tail :: (Ord k) => SplayTree k v -> SplayTree k v
tail Leaf = error "tail of empty tree"
tail (SplayTree _ _ _ Leaf r) = r
tail (SplayTree _ _ _ l Leaf) = l
tail (SplayTree _ _ _ l r) =
case splayRight l of
(SplayTree k v d l1 Leaf) -> (SplayTree k v (d + size r) l1 r)
_ -> error "splay tree corruption"
delete :: (Ord k) => k -> SplayTree k v -> SplayTree k v
delete _ Leaf = Leaf
delete k t =
case lookup k t of
t'@(SplayTree k1 _ _ _ _) ->
if k == k1
then tail t'
else t'
Leaf -> error "splay tree corruption"
splayRight :: (Ord k) => SplayTree k v -> SplayTree k v
splayRight Leaf = Leaf
splayRight h@(SplayTree _ _ _ _ Leaf) = h
splayRight (SplayTree k1 v1 d1 l1 (SplayTree k2 v2 _ l2 r2)) =
splayRight (SplayTree k2 v2 d1 (SplayTree k1 v1 (d1 size r2) l1 l2) r2)
splayLeft :: (Ord k) => SplayTree k v -> SplayTree k v
splayLeft Leaf = Leaf
splayLeft h@(SplayTree _ _ _ Leaf _) = h
splayLeft (SplayTree k1 v1 d1 (SplayTree k2 v2 _ l2 r2) r1) =
splayLeft (SplayTree k2 v2 d1 l2 (SplayTree k1 v1 (d1 size l2) r2 r1))
fromList :: (Ord k) => [(k,v)] -> SplayTree k v
fromList [] = Leaf
fromList l = foldl (\ acc (k,v) -> insert k v acc) Leaf l
fromAscList :: (Ord k) => [(k,v)] -> SplayTree k v
fromAscList = fromList
toList :: (Ord k) => SplayTree k v -> [(k,v)]
toList = toAscList
toAscList :: (Ord k) => SplayTree k v -> [(k,v)]
toAscList h@(SplayTree _ _ _ Leaf _) = head h : toAscList (tail h)
toAscList Leaf = []
toAscList h = toAscList $ splayLeft h