-- | Skew heaps.

{-# LANGUAGE BangPatterns, ScopedTypeVariables #-}
module Data.Heap(
  Heap, empty, singleton, insert, removeMin, union, mapMaybe, size, toList) where

-- | A heap.

-- N.B.: arguments are not strict so code has to take care
-- to force stuff appropriately.
-- The Int field is the size of the heap.
data Heap a = Nil | Node {-# UNPACK #-} !Int a (Heap a) (Heap a) deriving Int -> Heap a -> ShowS
[Heap a] -> ShowS
Heap a -> String
(Int -> Heap a -> ShowS)
-> (Heap a -> String) -> ([Heap a] -> ShowS) -> Show (Heap a)
forall a. Show a => Int -> Heap a -> ShowS
forall a. Show a => [Heap a] -> ShowS
forall a. Show a => Heap a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Heap a] -> ShowS
$cshowList :: forall a. Show a => [Heap a] -> ShowS
show :: Heap a -> String
$cshow :: forall a. Show a => Heap a -> String
showsPrec :: Int -> Heap a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Heap a -> ShowS
Show

-- | Take the union of two heaps.
{-# INLINEABLE union #-}
union :: forall a. Ord a => Heap a -> Heap a -> Heap a
union :: Heap a -> Heap a -> Heap a
union = Heap a -> Heap a -> Heap a
u
  where
    -- The generated code is better when we do everything
    -- through this u function instead of union...
    -- This is because u has no Ord constraint in its type.
    u :: Heap a -> Heap a -> Heap a
    u :: Heap a -> Heap a -> Heap a
u Heap a
Nil Heap a
h = Heap a
h
    u Heap a
h Heap a
Nil = Heap a
h
    u h1 :: Heap a
h1@(Node Int
s1 a
x1 Heap a
l1 Heap a
r1) h2 :: Heap a
h2@(Node Int
s2 a
x2 Heap a
l2 Heap a
r2)
      | a
x1 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
x2 = (Int -> a -> Heap a -> Heap a -> Heap a
forall a. Int -> a -> Heap a -> Heap a -> Heap a
Node (Int
s1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
s2) a
x1 (Heap a -> Heap a -> Heap a) -> Heap a -> Heap a -> Heap a
forall a b. (a -> b) -> a -> b
$! Heap a -> Heap a -> Heap a
u Heap a
r1 Heap a
h2) Heap a
l1
      | Bool
otherwise = (Int -> a -> Heap a -> Heap a -> Heap a
forall a. Int -> a -> Heap a -> Heap a -> Heap a
Node (Int
s1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
s2) a
x2 (Heap a -> Heap a -> Heap a) -> Heap a -> Heap a -> Heap a
forall a b. (a -> b) -> a -> b
$! Heap a -> Heap a -> Heap a
u Heap a
r2 Heap a
h1) Heap a
l2

-- | A singleton heap.
{-# INLINE singleton #-}
singleton :: a -> Heap a
singleton :: a -> Heap a
singleton !a
x = Int -> a -> Heap a -> Heap a -> Heap a
forall a. Int -> a -> Heap a -> Heap a -> Heap a
Node Int
1 a
x Heap a
forall a. Heap a
Nil Heap a
forall a. Heap a
Nil

-- | The empty heap.
{-# INLINE empty #-}
empty :: Heap a
empty :: Heap a
empty = Heap a
forall a. Heap a
Nil

-- | Insert an element.
{-# INLINEABLE insert #-}
insert :: Ord a => a -> Heap a -> Heap a
insert :: a -> Heap a -> Heap a
insert a
x Heap a
h = Heap a -> Heap a -> Heap a
forall a. Ord a => Heap a -> Heap a -> Heap a
union (a -> Heap a
forall a. a -> Heap a
singleton a
x) Heap a
h

-- | Find and remove the minimum element.
{-# INLINEABLE removeMin #-}
removeMin :: Ord a => Heap a -> Maybe (a, Heap a)
removeMin :: Heap a -> Maybe (a, Heap a)
removeMin Heap a
Nil = Maybe (a, Heap a)
forall a. Maybe a
Nothing
removeMin (Node Int
_ a
x Heap a
l Heap a
r) = (a, Heap a) -> Maybe (a, Heap a)
forall a. a -> Maybe a
Just (a
x, Heap a -> Heap a -> Heap a
forall a. Ord a => Heap a -> Heap a -> Heap a
union Heap a
l Heap a
r)

-- | Get the elements of a heap as a list, in unspecified order.
toList :: Heap a -> [a]
toList :: Heap a -> [a]
toList Heap a
h = Heap a -> [a] -> [a]
forall a. Heap a -> [a] -> [a]
tl Heap a
h []
  where
    tl :: Heap a -> [a] -> [a]
tl Heap a
Nil = [a] -> [a]
forall a. a -> a
id
    tl (Node Int
_ a
x Heap a
l Heap a
r) = (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:) ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Heap a -> [a] -> [a]
tl Heap a
l ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Heap a -> [a] -> [a]
tl Heap a
r

-- | Map a function over a heap, removing all values which
-- map to 'Nothing'. May be more efficient when the function
-- being mapped is mostly monotonic.
{-# INLINEABLE mapMaybe #-}
mapMaybe :: forall a b. Ord b => (a -> Maybe b) -> Heap a -> Heap b
mapMaybe :: (a -> Maybe b) -> Heap a -> Heap b
mapMaybe a -> Maybe b
f Heap a
h = Heap a -> Heap b
mm Heap a
h
  where
    mm :: Heap a -> Heap b
    mm :: Heap a -> Heap b
mm Heap a
Nil = Heap b
forall a. Heap a
Nil
    mm (Node Int
_ a
x Heap a
l Heap a
r) =
      case a -> Maybe b
f a
x of
        -- If the value maps to Nothing, get rid of it.
        Maybe b
Nothing -> Heap b -> Heap b -> Heap b
forall a. Ord a => Heap a -> Heap a -> Heap a
union Heap b
l' Heap b
r'
        -- If y is still the smallest in its subheap,
        -- the calls to insert and union here will work without making
        -- any recursive subcalls!
        Just !b
y -> b -> Heap b -> Heap b
forall a. Ord a => a -> Heap a -> Heap a
insert b
y Heap b
l' Heap b -> Heap b -> Heap b
forall a. Ord a => Heap a -> Heap a -> Heap a
`union` Heap b
r'
      where
        !l' :: Heap b
l' = Heap a -> Heap b
mm Heap a
l
        !r' :: Heap b
r' = Heap a -> Heap b
mm Heap a
r

-- | Return the number of elements in the heap.
{-# INLINE size #-}
size :: Heap a -> Int
size :: Heap a -> Int
size Heap a
Nil = Int
0
size (Node Int
n a
_ Heap a
_ Heap a
_) = Int
n

-- Testing code:
-- import Test.QuickCheck
-- import qualified Data.List as List
-- import qualified Data.Maybe as Maybe
-- 
-- instance (Arbitrary a, Ord a) => Arbitrary (Heap a) where
--   arbitrary = sized arb
--     where
--       arb 0 = return empty
--       arb n =
--         frequency
--           [(1, singleton <$> arbitrary),
--            (n-1, union <$> arb' <*> arb')]
--         where
--           arb' = arb (n `div` 2)
-- 
-- toSortedList :: Ord a => Heap a -> [a]
-- toSortedList = List.unfoldr removeMin
-- 
-- invariant :: Ord a => Heap a -> Bool
-- invariant h = ord h && sizeOK h
--   where
--     ord Nil = True
--     ord (Node _ x l r) = ord1 x l && ord1 x r
-- 
--     ord1 _ Nil = True
--     ord1 x h@(Node _ y _ _) = x <= y && ord h
-- 
--     sizeOK Nil = size Nil == 0
--     sizeOK (Node s _ l r) =
--       s == size l + size r + 1
-- 
-- prop_1 h = withMaxSuccess 100000 $ invariant h
-- prop_2 x h = withMaxSuccess 100000 $ invariant (insert x h)
-- prop_3 h =
--   withMaxSuccess 100000 $
--   case removeMin h of
--     Nothing -> discard
--     Just (_, h) -> invariant h
-- prop_4 h = withMaxSuccess 100000 $ List.sort (toSortedList h) == toSortedList h
-- prop_5 x h = withMaxSuccess 100000 $ toSortedList (insert x h) == List.insert x (toSortedList h)
-- prop_6 x h =
--   withMaxSuccess 100000 $
--   case removeMin h of
--     Nothing -> discard
--     Just (x, h') -> toSortedList h == List.insert x (toSortedList h')
-- prop_7 h1 h2 = withMaxSuccess 100000 $
--   invariant (union h1 h2)
-- prop_8 h1 h2 = withMaxSuccess 100000 $
--   toSortedList (union h1 h2) == List.sort (toSortedList h1 ++ toSortedList h2)
-- prop_9 (Blind f) h = withMaxSuccess 100000 $
--   invariant (mapMaybe f h)
-- prop_10 (Blind f) h = withMaxSuccess 1000000 $
--   toSortedList (mapMaybe f h) == List.sort (Maybe.mapMaybe f (toSortedList h))
-- 
-- return []
-- main = $quickCheckAll