{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, TypeFamilies, BangPatterns, NamedFieldPuns, RecordWildCards #-} {-# OPTIONS -fno-warn-name-shadowing #-} -- | Array-based implementation of an entirely traditional binary heap. module Data.MQueue.Heap (Heap, getSize) where import Data.MQueue.Class import Data.MQueue.MonadHelpers import Control.Monad.ST.Class import Data.Tuple.HT import Data.Array.Base import Data.Array.ST import Data.STRef import Control.Monad.ST import Control.Monad -- import Control.Arrow ((***)) data STHeap s e = STH {size :: {-# UNPACK #-} !Int, arr :: {-# UNPACK #-} !(STArray s Int e)} newtype Heap s e = H {unHeap :: STRef s (STHeap s e)} instance (Monad m, MonadST m, StateThread m ~ s, Ord e) => MQueue (Heap s e) m where {-# SPECIALIZE instance Ord e => MQueue (Heap s e) (ST s) #-} {-# SPECIALIZE instance Ord e => MQueue (Heap RealWorld e) IO #-} {-# INLINE pushAll #-} type MQueueKey (Heap s e) = e newQueue = liftST newHeap push h = liftST . pushHeap h pushAll h = liftST . pushAllHeap h pop_ = liftST . popHeap_ peek = liftST . peekHeap {-# SPECIALIZE getSize :: Heap s e -> ST s Int #-} getSize :: (Monad m, MonadST m, StateThread m ~ s) => Heap s e -> m Int getSize = liftST . getHeapSize ---------------------------------------------------------------------- {-# INLINE pushAllHeap #-} newHeap = liftM H (liftM (STH 0) (newArray_ (0, 15)) >>= newSTRef) pushHeap h = onHeap_ h . pusher pushAllHeap h ks = onHeap_ h (\ h@STH{size} -> uncurry (with . flip ensureSize h) (foldr accumulator (size, \ _ -> return ()) ks)) where accumulator k = mapPair ((+1), liftM2 (>>) (unsafePusher k)) popHeap_ h = onHeap_ h popper peekHeap h = queryHeap h (\ STH{..} -> if size > 0 then liftM Just (unsafeRead arr 0) else return Nothing) getHeapSize h = queryHeap h (return . size) ---------------------------------------------------------------------- queryHeap :: Heap s e -> (STHeap s e -> ST s a) -> ST s a queryHeap = (>>=) . readSTRef . unHeap -- onHeap :: Heap s e -> (STHeap s e -> ST s (a, STHeap s e)) -> ST s a -- onHeap = modSTRef . unHeap onHeap_ :: Heap s e -> (STHeap s e -> ST s (STHeap s e)) -> ST s () onHeap_ = modSTRef_ . unHeap -- modSTRef :: STRef s a -> (a -> ST s (b, a)) -> ST s b -- modSTRef ref f = do (ans, x') <- f =<< readSTRef ref -- writeSTRef ref x' -- return ans modSTRef_ :: STRef s a -> (a -> ST s a) -> ST s () modSTRef_ ref f = readSTRef ref >>= f >>= writeSTRef ref ensureSize :: Int -> STHeap s e -> ST s (STHeap s e) ensureSize n h@STH{..} = do cap <- getNumElements arr if cap < n then do arr' <- newArray_ (0, 5 * n `quot` 4 - 1) mapM_ (liftM2 (>>=) (unsafeRead arr) (unsafeWrite arr')) [0..size-1] return h{arr = arr'} else return h heapUp :: Ord e => STArray s Int e -> Int -> e -> ST s () heapUp !arr i ai = heapUp' i where heapUp' 0 = unsafeWrite arr 0 ai heapUp' i = let !j = (i - 1) `quot` 2 in do aj <- unsafeRead arr j if aj < ai then unsafeWrite arr i ai else unsafeWrite arr i aj >> heapUp' j unsafePusher :: Ord e => e -> STHeap s e -> ST s (STHeap s e) unsafePusher k h@STH{..} = do heapUp arr size k return h{size = size + 1} pusher :: Ord e => e -> STHeap s e -> ST s (STHeap s e) pusher k h@STH{size} = ensureSize (size+1) h >>= unsafePusher k heapDown :: Ord e => Int -> STArray s Int e -> Int -> e -> ST s () heapDown n !arr i ai = lt `seq` heapDown' i where lt = (<) -- hack to minimize polymorphism overhead heapDown' i = let rchild = 2 * i + 2; lchild = rchild - 1 in case compare rchild n of LT -> do al <- unsafeRead arr lchild ar <- unsafeRead arr rchild let (ach, ch) = if al < ar then (al, lchild) else (ar, rchild) if ach `lt` ai then unsafeWrite arr i ach >> heapDown' ch else unsafeWrite arr i ai EQ -> do al <- unsafeRead arr lchild if al `lt` ai then unsafeWrite arr i al >> unsafeWrite arr lchild ai else unsafeWrite arr i ai GT -> unsafeWrite arr i ai popper :: Ord e => STHeap s e -> ST s (STHeap s e) popper h@STH{..} = let s' = size - 1 in do ai <- unsafeRead arr s' unsafeWrite arr s' undefined heapDown s' arr 0 ai return h{size = s'}