{-# LANGUAGE DeriveDataTypeable ,DeriveFunctor ,FlexibleContexts ,FlexibleInstances ,GeneralizedNewtypeDeriving ,TypeFamilies ,GADTs ,BangPatterns ,UndecidableInstances #-} module Data.SplayTree ( SplayTree (..) ,Measured (..) ,empty ,(|>) ,(<|) ,(><) ,null ,singleton ,size ,split ,query ,memberSplay ,delete ,insert ,difference ,intersection ,balance ,deepL ,deepR ,fromList ,fromListBalance ,fmap' ,traverse' ) where import Prelude hiding (foldr, null) import Control.Applicative hiding (empty) import Control.Monad import Control.DeepSeq import Data.Data import Data.Foldable import Data.Maybe import Data.Monoid infixr 5 >< infixr 5 <| infixl 5 |> {-# INLINE (><) #-} {-# INLINE (<|) #-} {-# INLINE (|>) #-} class Monoid (Measure a) => Measured a where type Measure a :: * measure :: a -> Measure a data SplayTree a where Tip :: SplayTree a Branch :: (Measure a) -> (SplayTree a) -> !a -> (SplayTree a) -> SplayTree a deriving (Typeable) instance (NFData a, NFData (Measure a)) => NFData (SplayTree a) where rnf Tip = () rnf (Branch m l a r) = m `deepseq` l `deepseq` a `deepseq` rnf r instance (Eq a) => Eq (SplayTree a) where xs == ys = toList xs == toList ys instance (Ord a) => Ord (SplayTree a) where compare xs ys = compare (toList xs) (toList ys) instance (Show a, Show (Measure a)) => Show (SplayTree a) where show Tip = "Tip" show (Branch v l a r) = "Branch {ann {" ++ show v ++ "}, lChild {" ++ show l ++ "}, value {" ++ show a ++ "}, rChild {" ++ show r ++ "}}" instance Measured a => Monoid (SplayTree a) where mempty = Tip mappend = (><) instance (Measured a) => Measured (SplayTree a) where type Measure (SplayTree a) = Measure a measure Tip = mempty measure (Branch v _ _ _) = v leaf :: Measured a => a -> SplayTree a leaf a = Branch (measure a) Tip a Tip branch :: Measured a => SplayTree a -> a -> SplayTree a -> SplayTree a branch l a r = Branch mm l a r where mm = case (l,r) of (Tip, Tip) -> measure a (Tip, Branch rm _ _ _) -> measure a `mappend` rm (Branch lm _ _ _, Tip) -> lm `mappend` measure a (Branch lm _ _ _, Branch rm _ _ _) -> mconcat [lm, measure a, rm] instance Foldable SplayTree where foldMap _ Tip = mempty foldMap f (Branch _ l a r) = mconcat [foldMap f l, f a, foldMap f r] {-# INLINE foldMap #-} foldl = myFoldl {-# INLINE foldl #-} myFoldl :: (a -> b -> a) -> a -> SplayTree b -> a myFoldl f i0 tree = go i0 tree where go !i Tip = i go !acc (Branch _ l a r) = let a1 = go acc l a2 = a1 `seq` f a1 a in a2 `seq` go a2 r {-# INLINE myFoldl #-} -- ------------------------------------------- -- Construction empty :: SplayTree a empty = Tip singleton :: Measured a => a -> SplayTree a singleton = leaf (<|) :: (Measured a) => a -> SplayTree a -> SplayTree a a <| Tip = branch Tip a Tip a <| t@(Branch{}) = asc . desc $ descendL t [] where asc = uncurry ascendSplay desc (Just (Tip, zp)) = (leaf a, zp) desc (Just (b@(Branch {}), zp)) = desc $ descendL b zp desc Nothing = error "SplayTree.(<|): internal error" (|>) :: (Measured a) => SplayTree a -> a -> SplayTree a Tip |> b = leaf b t@(Branch{}) |> b = asc . desc $ descendR t [] where asc = uncurry ascendSplay desc (Just (Tip, zp)) = (leaf b, zp) desc (Just (b@(Branch {}), zp)) = desc $ descendR b zp desc Nothing = error "SplayTree.(|>): internal error" -- | Append two trees. (><) :: (Measured a) => SplayTree a -> SplayTree a -> SplayTree a Tip >< ys = ys xs >< Tip = xs l >< r = asc . desc $ descendL r [] where asc = uncurry ascendSplay desc (Just (Tip, zp)) = (l, zp) desc (Just (b@(Branch{}), zp)) = desc $ descendL b zp desc Nothing = error "SplayTree.(><): internal error" -- | /O(n)/. Create a Tree from a finite list of elements. fromList :: (Measured a) => [a] -> SplayTree a fromList = foldl' (|>) Tip -- | /O(n)/. Create a Tree from a finite list of elements. -- -- After the tree is created, it is balanced. This is useful with sorted data, -- which would otherwise create a completely unbalanced tree. fromListBalance :: (Measured a) => [a] -> SplayTree a fromListBalance = balance . fromList -- ------------------------------------------- -- deconstruction -- | Is the tree empty? null :: SplayTree a -> Bool null Tip = True null _ = False -- | Split a tree at the point where the predicate on the measure changes from -- False to True. split :: Measured a => (Measure a -> Bool) -> SplayTree a -> (SplayTree a, SplayTree a) split _p Tip = (Tip, Tip) split p tree = case query p tree of Just (_, Branch _ l a r) -> (l, (branch Tip a r)) _ -> (Tip, Tip) -- | find the first point where the predicate returns True. Returns a tree -- splayed with that node at the top. query :: (Measured a, Measure a ~ Measure (SplayTree a)) => (Measure a -> Bool) -> SplayTree a -> Maybe (a, SplayTree a) query p t | p (measure t) = Just . asc $ desc mempty (t, []) | otherwise = Nothing where asc (a,t',zp) = (a, ascendSplay t' zp) desc i (b@(Branch _ l a r), zp) | p ml = desc i $ fromJust (descendL b zp) | p mm = (a,b,zp) | otherwise = desc mm $ fromJust (descendR b zp) where ml = i `mappend` measure l mm = ml `mappend` measure a {-# INLINE query #-} -- -------------------------- -- Basic interface size :: SplayTree a -> Int size = foldl' (\acc _ -> acc+1) 0 memberSplay :: (Measured a, Ord (Measure a), Eq a) => a -> SplayTree a -> (Bool, SplayTree a) memberSplay a tree = case snd <$> query (>= (measure a)) tree of Nothing -> (False, tree) Just foc@(Branch _ l a' r) -> (a == a', foc) {-# INLINE memberSplay #-} delete :: (Measured a, Ord (Measure a), Eq a) => a -> SplayTree a -> SplayTree a delete a tree = case memberSplay a tree of (False, t') -> t' (True, Branch _ l _ r) -> l >< r insert :: (Measured a, Ord (Measure a), Eq a) => a -> SplayTree a -> SplayTree a insert a tree = case snd <$> query (>= measure a) tree of Nothing -> tree |> a Just t'@(Branch _ l a' r) -> if a == a' then t' else l >< (a <| a' <| r) -- -------------------------- -- Set operations difference :: (Measured a, Ord (Measure a), Eq a) => SplayTree a -> SplayTree a -> SplayTree a difference l r = foldl' (flip delete) l r intersection :: (Measured a, Ord (Measure a), Eq a) => SplayTree a -> SplayTree a -> SplayTree a intersection l r = fst $ foldl' f (empty, l) r where f (acc,testSet) x = case memberSplay x testSet of (True, t') -> (insert x acc, t') (False, t') -> (acc, t') -- -------------------------- -- Traversals -- | Like fmap, but with a more restrictive type. fmap' :: Measured b => (a -> b) -> SplayTree a -> SplayTree b fmap' f Tip = Tip fmap' f (Branch _ l a r) = branch (fmap' f l) (f a) (fmap' f r) -- | Like traverse, but with a more restrictive type. traverse' :: (Measured b, Applicative f) => (a -> f b) -> SplayTree a -> f (SplayTree b) traverse' f Tip = pure Tip traverse' f (Branch _ l a r) = branch <$> traverse' f l <*> f a <*> traverse' f r -- | descend to the deepest left-hand branch deepL :: Measured a => SplayTree a -> SplayTree a deepL = deep descendL -- | descend to the deepest right-hand branch deepR :: Measured a => SplayTree a -> SplayTree a deepR = deep descendR -- | Descend a tree using the provided `descender` descending function, -- then recreate the tree. The new focus will be the last node accessed -- in the tree. deep :: Measured a => (SplayTree a -> [Thread a] -> Maybe (SplayTree a, [Thread a])) -> SplayTree a -> SplayTree a deep descender tree = uncurry ascendSplay . desc $ descender tree [] where desc (Just (Tip, zp)) = (Tip, zp) desc (Just (b@(Branch{}), zp)) = desc $ descender b zp desc Nothing = (tree, []) {-# INLINE deep #-} -- ------------------------------------------- -- splay tree stuff... -- use a zipper so descents/splaying can be done in a single pass data Thread a = DescL !a !(SplayTree a) | DescR !a !(SplayTree a) descendL :: SplayTree a -> [Thread a] -> Maybe (SplayTree a, [Thread a]) descendL (Branch _ l a r) zp = Just (l, DescL a r : zp) descendL _ _ = Nothing descendR :: SplayTree a -> [Thread a] -> Maybe (SplayTree a, [Thread a]) descendR (Branch _ l a r) zp = Just (r, DescR a l : zp) descendR _ _ = Nothing up :: Measured a => SplayTree a -> Thread a -> SplayTree a up tree (DescL a r) = branch tree a r up tree (DescR a l) = branch l a tree {-# INLINE up #-} rotateL :: (Measured a) => SplayTree a -> SplayTree a rotateL (Branch annP (Branch annX lX aX rX) aP rP) = branch lX aX (branch rX aP rP) rotateL tree = tree -- actually a left rotation, but calling it a right rotation matches with -- the descent terminology rotateR :: (Measured a) => SplayTree a -> SplayTree a rotateR (Branch annP lP aP (Branch annX lX aX rX)) = branch (branch lP aP lX) aX rX rotateR tree = tree ascendSplay :: Measured a => SplayTree a -> [Thread a] -> SplayTree a ascendSplay x zp = go x zp where go !x [] = x go !x zp = uncurry go $ ascendSplay' x zp ascendSplay' :: Measured a => SplayTree a -> [Thread a] -> (SplayTree a, [Thread a]) ascendSplay' !x (pt@(DescL{}) : gt@(DescL{}) : zp') = let g = up (up x pt) gt in (rotateL (rotateL g), zp') ascendSplay' !x (pt@(DescR{}) : gt@(DescR{}) : zp') = let g = up (up x pt) gt in (rotateR (rotateR g), zp') ascendSplay' !x (pt@(DescR{}) : gt@(DescL{}) : zp') = (rotateL $ up (rotateR (up x pt)) gt, zp') ascendSplay' !x (pt@(DescL{}) : gt@(DescR{}) : zp') = (rotateR $ up (rotateL (up x pt)) gt, zp') ascendSplay' !x [pt@(DescL{})] = (rotateL (up x pt), []) ascendSplay' !x [pt@(DescR{})] = (rotateR (up x pt), []) ascendSplay' _ [] = error "SplayTree: internal error, ascendSplay' called past root" -- --------------------------- -- A measure of tree depth newtype ElemD a = ElemD { getElemD :: a } deriving (Show, Ord, Eq, Num, Enum) newtype Depth = Depth {getDepth :: Int} deriving (Show, Ord, Eq, Num, Enum, Real, Integral) instance Monoid Depth where mempty = 0 (Depth l) `mappend` (Depth r) = Depth (max l r) instance Measured (ElemD a) where type Measure (ElemD a) = Depth measure _ = 1 -- | rebalance a splay tree. The order of elements does not change. balance :: Measured a => SplayTree a -> SplayTree a balance = fmap' getElemD . balance' . fmap' ElemD balance' :: SplayTree (ElemD a) -> SplayTree (ElemD a) balance' Tip = Tip balance' (Branch _ l a r) = let l' = balance' l r' = balance' r diff = measure l' - measure r' numRots = fromIntegral $ diff `div` 2 b' = Branch (mconcat [1+measure l', measure a, 1+measure r']) l' a r' in case (numRots > 0, numRots < 0) of (True, _) -> iterate rotateL b' !! numRots (_, True) -> iterate rotateR b' !! abs numRots otherwise -> b'