-- | Skew heaps. {-# LANGUAGE BangPatterns, ScopedTypeVariables #-} module Data.Heap( Heap, empty, singleton, insert, removeMin, union, mapMaybe, size) where -- | A heap. -- Representation: the size of the heap, and the heap itself. data Heap a = Heap {-# UNPACK #-} !Int !(Heap1 a) deriving Show -- N.B.: arguments are not strict so code has to take care -- to force stuff appropriately. data Heap1 a = Nil | Node a (Heap1 a) (Heap1 a) deriving Show -- | Take the union of two heaps. {-# INLINEABLE union #-} union :: Ord a => Heap a -> Heap a -> Heap a union (Heap n1 h1) (Heap n2 h2) = Heap (n1+n2) (union1 h1 h2) {-# INLINEABLE union1 #-} union1 :: forall a. Ord a => Heap1 a -> Heap1 a -> Heap1 a union1 = u1 where -- The generated code is better when we do everything -- through this u1 function instead of union1... -- This is because u1 has no Ord constraint in its type. u1 :: Heap1 a -> Heap1 a -> Heap1 a u1 Nil h = h u1 h Nil = h u1 h1@(Node x1 l1 r1) h2@(Node x2 l2 r2) | x1 <= x2 = (Node x1 $! u1 r1 h2) l1 | otherwise = (Node x2 $! u1 r2 h1) l2 -- | A singleton heap. {-# INLINE singleton #-} singleton :: a -> Heap a singleton !x = Heap 1 (Node x Nil Nil) -- | The empty heap. {-# INLINE empty #-} empty :: Heap a empty = Heap 0 Nil -- | Insert an element. {-# INLINEABLE insert #-} insert :: Ord a => a -> Heap a -> Heap a insert x h = union (singleton x) h -- | Find and remove the minimum element. {-# INLINEABLE removeMin #-} removeMin :: Ord a => Heap a -> Maybe (a, Heap a) removeMin (Heap _ Nil) = Nothing removeMin (Heap n (Node x l r)) = Just (x, Heap (n-1) (union1 l r)) -- | Map a function over a heap, removing all values which -- map to 'Nothing'. May be more efficient when the function -- being mapped is mostly monotonic. {-# INLINEABLE mapMaybe #-} mapMaybe :: Ord b => (a -> Maybe b) -> Heap a -> Heap b mapMaybe f (Heap _ h) = Heap (sz 0 h') h' where -- Compute the size fairly efficiently. sz !n Nil = n sz !n (Node _ l r) = sz (sz (n+1) l) r h' = mm h mm Nil = Nil mm (Node x l r) = case f x of -- If the value maps to Nothing, get rid of it. Nothing -> union1 l' r' -- Otherwise, check if the heap invariant still holds -- and sift downwards to restore it. Just !y -> down y l' r' where !l' = mm l !r' = mm r down x l@(Node y ll lr) r@(Node z rl rr) -- Put the smallest of x, y and z at the root. | y < x && y <= z = (Node y $! down x ll lr) r | z < x && z <= y = Node z l $! down x rl rr down x Nil (Node y l r) -- Put the smallest of x and y at the root. | y < x = Node y Nil $! down x l r down x (Node y l r) Nil -- Put the smallest of x and y at the root. | y < x = (Node y $! down x l r) Nil down x l r = Node x l r -- | Return the number of elements in the heap. {-# INLINE size #-} size :: Heap a -> Int size (Heap n _) = n -- Testing code: -- import Test.QuickCheck -- import qualified Data.List as List -- import qualified Data.Maybe as Maybe -- instance (Arbitrary a, Ord a) => Arbitrary (Heap a) where -- arbitrary = sized arb -- where -- arb 0 = return empty -- arb n = -- frequency -- [(1, singleton <$> arbitrary), -- (n-1, union <$> arb' <*> arb')] -- where -- arb' = arb (n `div` 2) -- toList :: Ord a => Heap a -> [a] -- toList = List.unfoldr removeMin -- invariant :: Ord a => Heap a -> Bool -- invariant h@(Heap n h1) = -- n == length (toList h) && ord h1 -- where -- ord Nil = True -- ord (Node x l r) = ord1 x l && ord1 x r -- ord1 _ Nil = True -- ord1 x h@(Node y _ _) = x <= y && ord h -- prop_1 h = withMaxSuccess 10000 $ invariant h -- prop_2 x h = withMaxSuccess 10000 $ invariant (insert x h) -- prop_3 h = -- withMaxSuccess 1000 $ -- case removeMin h of -- Nothing -> discard -- Just (_, h) -> invariant h -- prop_4 h = withMaxSuccess 10000 $ List.sort (toList h) == toList h -- prop_5 x h = withMaxSuccess 10000 $ toList (insert x h) == List.insert x (toList h) -- prop_6 x h = -- withMaxSuccess 1000 $ -- case removeMin h of -- Nothing -> discard -- Just (x, h') -> toList h == List.insert x (toList h') -- prop_7 h1 h2 = withMaxSuccess 10000 $ -- invariant (union h1 h2) -- prop_8 h1 h2 = withMaxSuccess 10000 $ -- toList (union h1 h2) == List.sort (toList h1 ++ toList h2) -- prop_9 (Blind f) h = withMaxSuccess 10000 $ -- invariant (mapMaybe f h) -- prop_10 (Blind f) h = withMaxSuccess 1000000 $ -- toList (mapMaybe f h) == List.sort (Maybe.mapMaybe f (toList h)) -- return [] -- main = $quickCheckAll