{-|
  Purely functional weight balanced trees, aka trees of bounded balance.

    * J. Nievergelt and E.M. Reingold, \"Binary search trees of
      bounded balance\", Proceedings of the fourth annual ACM symposium on
      Theory of computing, pp 137-142, 1972.

    * S. Adams, \"Implementing sets efficiently in a functional language\",
      Technical Report CSTR 92-10, University of Southampton, 1992.
      <http://groups.csail.mit.edu/mac/users/adams/BB/>

    * S. Adam, \"Efficient sets: a balancing act\", 
      Journal of Functional Programming, Vol 3, Issue 4, pp 553-562. 

    * Y. Hirai and K. Yamamoto,
      \"Balancing Weight-Balanced Trees\",
      Journal of Functional Programming. Vol 21, Issue 03, pp 287-307. 
      <http://mew.org/~kazu/proj/weight-balanced-tree/>
    
    * M. Strake, \"Adams' Trees Revisited - Correct and Efficient Implementation\",
      TFP 2011.
      <http://fox.ucw.cz/papers/bbtree/>
-}

module Data.WBTree (
  -- * Data structures
    WBTree(..)
  , Size
  , size
  -- * Creating red-black trees
  , empty
  , singleton
  , insert
  , fromList
  -- * Converting a list
  , toList
  -- * Membership
  , member
  -- * Deleting
  , delete
  , deleteMin
  , deleteMax
  -- * Checking
  , null
  -- * Set operations
  , union
  , intersection
  , difference
  -- * Helper functions
  , join
  , merge
  , split
  , minimum
  , maximum
  , valid
--  , showTree
--  , printTree
  ) where

import Data.List (foldl')
import Prelude hiding (minimum, maximum, null)

----------------------------------------------------------------

type Size = Int
data WBTree a = Leaf | Node Size (WBTree a) a (WBTree a) deriving (Show)


instance (Eq a) => Eq (WBTree a) where
    t1 == t2 = toList t1 == toList t2

size :: WBTree a -> Size
size Leaf            = 0
size (Node sz _ _ _) = sz

----------------------------------------------------------------

{-|
See if the red black tree is empty.

>>> Data.WBTree.null empty
True
>>> Data.WBTree.null (singleton 1)
False
-}

null :: Eq a => WBTree a -> Bool
null t = t == Leaf

----------------------------------------------------------------

{-| Empty tree.

>>> size empty
0
-}

empty :: WBTree a
empty = Leaf

{-| Singleton tree.

>>> size (singleton 'a')
1
-}

singleton :: a -> WBTree a
singleton x = Node 1 Leaf x Leaf

----------------------------------------------------------------

node :: WBTree a -> a -> WBTree a -> WBTree a
node l x r = Node (size l + size r + 1) l x r

----------------------------------------------------------------

{-| Insertion. O(log N)

>>> insert 5 (fromList [5,3]) == fromList [3,5]
True
>>> insert 7 (fromList [5,3]) == fromList [3,5,7]
True
>>> insert 5 empty            == singleton 5
True
-}

insert :: Ord a => a -> WBTree a -> WBTree a
insert k Leaf = singleton k
insert k (Node sz l x r) = case compare k x of
    LT -> balanceR (insert k l) x r
    GT -> balanceL l x (insert k r)
    EQ -> Node sz l x r

{-| Creating a tree from a list. O(N log N)

>>> empty == fromList []
True
>>> singleton 'a' == fromList ['a']
True
>>> fromList [5,3,5] == fromList [5,3]
True
-}

fromList :: Ord a => [a] -> WBTree a
fromList = foldl' (flip insert) empty

----------------------------------------------------------------

{-| Creating a list from a tree. O(N)

>>> toList (fromList [5,3])
[3,5]
>>> toList empty
[]
-}

toList :: WBTree a -> [a]
toList t = inorder t []
  where
    inorder Leaf xs           = xs
    inorder (Node _ l x r) xs = inorder l (x : inorder r xs)

----------------------------------------------------------------

{-| Checking if this element is a member of a tree?

>>> member 5 (fromList [5,3])
True
>>> member 1 (fromList [5,3])
False
-}

member :: Ord a => a -> WBTree a -> Bool
member _ Leaf = False
member k (Node _ l x r) = case compare k x of
    LT -> member k l
    GT -> member k r
    EQ -> True

----------------------------------------------------------------

balanceL :: WBTree a -> a -> WBTree a -> WBTree a
balanceL l x r
  | isBalanced l r = node l x r
  | otherwise      = rotateL l x r

balanceR :: WBTree a -> a -> WBTree a -> WBTree a
balanceR l x r
  | isBalanced r l = node l x r
  | otherwise      = rotateR l x r

rotateL :: WBTree a -> a -> WBTree a -> WBTree a
rotateL l x r@(Node _ rl _ rr)
  | isSingle rl rr = singleL l x r
  | otherwise      = doubleL l x r
rotateL _ _ _      = error "rotateL"

rotateR :: WBTree a -> a -> WBTree a -> WBTree a
rotateR l@(Node _ ll _ lr) x r
  | isSingle lr ll = singleR l x r
  | otherwise      = doubleR l x r
rotateR _ _ _      = error "rotateR"

singleL :: WBTree a -> a -> WBTree a -> WBTree a
singleL l x (Node _ rl rx rr) = node (node l x rl) rx rr
singleL _ _ _                 = error "singleL"

singleR :: WBTree a -> a -> WBTree a -> WBTree a
singleR (Node _ ll lx lr) x r = node ll lx (node lr x r)
singleR _ _ _                 = error "singleR"

doubleL :: WBTree a -> a -> WBTree a -> WBTree a
doubleL l x (Node _ (Node _ rll rlx rlr) rx rr) = node (node l x rll) rlx (node rlr rx rr)
doubleL _ _ _                                   = error "doubleL"

doubleR :: WBTree a -> a -> WBTree a -> WBTree a
doubleR (Node _ ll lx (Node _ lrl lrx lrr)) x r = node (node ll lx lrl) lrx (node lrr x r)
doubleR _ _ _                                   = error "doubleR"

----------------------------------------------------------------

{-| Deleting the minimum element. O(log N)

>>> deleteMin (fromList [5,3,7]) == fromList [5,7]
True
>>> deleteMin empty == empty
True
-}

deleteMin :: WBTree a -> WBTree a
deleteMin (Node _ Leaf _ r) = r
deleteMin (Node _ l x r)    = balanceL (deleteMin l) x r
deleteMin Leaf              = Leaf

{-| Deleting the maximum

>>> deleteMax (fromList [(5,"a"), (3,"b"), (7,"c")]) == fromList [(3,"b"), (5,"a")]
True
>>> deleteMax empty == empty
True
-}

deleteMax :: WBTree a -> WBTree a
deleteMax (Node _ l _ Leaf) = l
deleteMax (Node _ l x r)    = balanceR l x (deleteMax r)
deleteMax Leaf              = Leaf

----------------------------------------------------------------

{-| Deleting this element from a tree. O(log N)

>>> delete 5 (fromList [5,3]) == singleton 3
True
>>> delete 7 (fromList [5,3]) == fromList [3,5]
True
>>> delete 5 empty            == empty
True
-}

delete :: Ord a => a -> WBTree a -> WBTree a
delete k t = case t of
    Leaf -> Leaf
    Node _ l x r -> case compare k x of
        LT -> balanceL (delete k l) x r
        GT -> balanceR l x (delete k r)
        EQ -> glue l r

----------------------------------------------------------------

{-| Checking validity of a tree.
-}

valid :: Ord a => WBTree a -> Bool
valid t = balanced t && ordered t && validsize t

balanced :: WBTree a -> Bool
balanced Leaf           = True
balanced (Node _ l _ r) = isBalanced l r && isBalanced r l
                       && balanced l     && balanced r

ordered :: Ord a => WBTree a -> Bool
ordered t = bounded (const True) (const True) t
  where
    bounded lo hi t' = case t' of
        Leaf         -> True
        Node _ l x r -> lo x && hi x && bounded lo (<x) l && bounded (>x) hi r

validsize :: WBTree a -> Bool
validsize t = realsize t == Just (size t)
  where
    realsize t' = case t' of
        Leaf            -> Just 0
        Node s l _ r -> case (realsize l,realsize r) of
            (Just n,Just m)  | n+m+1 == s -> Just s
            _                             -> Nothing

----------------------------------------------------------------

{-| Joining two trees with an element. O(log N)

    Each element of the left tree must be less than the element.
    Each element of the right tree must be greater than the element.
-}

join :: Ord a => WBTree a -> a -> WBTree a -> WBTree a
join Leaf x r   = insert x r
join l x Leaf   = insert x l
join l@(Node _ ll lx lr) x r@(Node _ rl rx rr)
  | bal1 && bal2 = node l x r
  | bal1         = balanceL ll lx (join lr x r)
  | otherwise    = balanceR (join l x rl) rx rr
  where
    bal1 = isBalanced l r
    bal2 = isBalanced r l

{-| Merging two trees. O(log N)

    Each element of the left tree must be less than each element of
    the right tree.
-}

merge :: WBTree a -> WBTree a -> WBTree a
merge Leaf r   = r
merge l Leaf   = l
merge l@(Node _ ll lx lr) r@(Node _ rl rx rr)
  | bal1 && bal2 = glue l r
  | bal1         = balanceL ll lx (merge lr r)
  | otherwise    = balanceR (merge l rl) rx rr
  where
    bal1 = isBalanced l r
    bal2 = isBalanced r l

glue :: WBTree a -> WBTree a -> WBTree a
glue Leaf r = r
glue l Leaf = l
glue l r
  | size l > size r = balanceL (deleteMax l) (maximum l) r
  | otherwise       = balanceR l (minimum r) (deleteMin r)

{-| Splitting a tree. O(log N)

>>> split 2 (fromList [5,3]) == (empty, fromList [3,5])
True
>>> split 3 (fromList [5,3]) == (empty, singleton 5)
True
>>> split 4 (fromList [5,3]) == (singleton 3, singleton 5)
True
>>> split 5 (fromList [5,3]) == (singleton 3, empty)
True
>>> split 6 (fromList [5,3]) == (fromList [3,5], empty)
True
-}

split :: Ord a => a -> WBTree a -> (WBTree a, WBTree a)
split _ Leaf           = (Leaf,Leaf)
split k (Node _ l x r) = case compare k x of
    LT -> let (lt,gt) = split k l in (lt,join gt x r)
    GT -> let (lt,gt) = split k r in (join l x lt,gt)
    EQ -> (l,r)

----------------------------------------------------------------

{-| Finding the minimum element. O(log N)

>>> minimum (fromList [3,5,1])
1
>>> minimum empty
*** Exception: minimum
-}

minimum :: WBTree a -> a
minimum (Node _ Leaf x _) = x
minimum (Node _ l _ _)    = minimum l
minimum _                 = error "minimum"

{-| Finding the maximum element. O(log N)

>>> maximum (fromList [3,5,1])
5
>>> maximum empty
*** Exception: maximum
-}

maximum :: WBTree a -> a
maximum (Node _ _ x Leaf) = x
maximum (Node _ _ _ r)    = maximum r
maximum _                 = error "maximum"

----------------------------------------------------------------

{-| Creating a union tree from two trees. O(N + M)

>>> union (fromList [5,3]) (fromList [5,7]) == fromList [3,5,7]
True
-}

union :: Ord a => WBTree a -> WBTree a -> WBTree a
union t1 Leaf = t1
union Leaf t2 = t2
union t1 (Node _ l x r) = join (union l' l) x (union r' r)
  where
    (l',r') = split x t1

{-| Creating a intersection tree from trees. O(N + N)

>>> intersection (fromList [5,3]) (fromList [5,7]) == singleton 5
True
-}

intersection :: Ord a => WBTree a -> WBTree a -> WBTree a
intersection Leaf _ = Leaf
intersection _ Leaf = Leaf
intersection t1 (Node _ l x r)
  | member x t1 = join (intersection l' l) x (intersection r' r)
  | otherwise   = merge (intersection l' l) (intersection r' r)
  where
    (l',r') = split x t1

{-| Creating a difference tree from trees. O(N + N)

>>> difference (fromList [5,3]) (fromList [5,7]) == singleton 3
True
-}

difference :: Ord a => WBTree a -> WBTree a -> WBTree a
difference Leaf _  = Leaf
difference t1 Leaf = t1
difference t1 (Node _ l x r) = merge (difference l' l) (difference r' r)
  where
    (l',r') = split x t1

----------------------------------------------------------------

delta :: Int
delta = 3

gamma :: Int
gamma = 2

isBalanced :: WBTree a -> WBTree a -> Bool
isBalanced a b = delta * (size a + 1) >= (size b + 1)

isSingle :: WBTree a -> WBTree a -> Bool
isSingle a b = (size a + 1) < gamma * (size b + 1)

{- Adams's variant
isBalanced :: WBTree a -> WBTree a -> Bool
isBalanced a b = x + y <= 1 || delta * x >= y
  where x = size a
        y = size b

isSingle :: WBTree a -> WBTree a -> Bool
isSingle a b = size a < gamma * size b
-}