{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, TypeFamilies, BangPatterns, NamedFieldPuns, RecordWildCards #-} {-# OPTIONS -fno-warn-name-shadowing #-} -- | Array-based implementation of an entirely traditional binary heap. module Data.MQueue.Heap (Heap, getSize) where import Data.MQueue.Class import Data.MQueue.MonadHelpers import Control.Arrow((***)) import Control.Monad.ST.Class --import Data.Tuple.HT import Data.Array.Base import Data.Array.ST import Data.STRef import Control.Monad.ST import Control.Monad -- import Control.Arrow ((***)) data STHeap s e = STH {size :: {-# UNPACK #-} !Int, arr :: {-# UNPACK #-} !(STArray s Int e)} newtype Heap s e = H {unHeap :: STRef s (STHeap s e)} instance (Monad m, MonadST m, StateThread m ~ s, Ord e) => MQueue (Heap s e) m where {-# SPECIALIZE instance Ord e => MQueue (Heap s e) (ST s) #-} {-# SPECIALIZE instance Ord e => MQueue (Heap RealWorld e) IO #-} {-# INLINE pushAll #-} type MQueueKey (Heap s e) = e newQueue = liftST newHeap push h = liftST . pushHeap h pushAll h = liftST . pushAllHeap h pop_ = liftST . popHeap_ peek = liftST . peekHeap {-# SPECIALIZE getSize :: Heap s e -> ST s Int #-} getSize :: (Monad m, MonadST m, StateThread m ~ s) => Heap s e -> m Int getSize = liftST . getHeapSize ---------------------------------------------------------------------- {-# INLINE pushAllHeap #-} newHeap = liftM H (liftM (STH 0) (newArray_ (0, 15)) >>= newSTRef) pushHeap h = onHeap_ h . pusher pushAllHeap h ks = onHeap_ h (\ h@STH{size} -> uncurry (with . flip ensureSize h) (foldr accumulator (size, \ _ -> return ()) ks)) where accumulator k = ((+1) *** liftM2 (>>) (unsafePusher k)) popHeap_ h = onHeap_ h popper peekHeap h = queryHeap h (\ STH{..} -> if size > 0 then liftM Just (unsafeRead arr 0) else return Nothing) getHeapSize h = queryHeap h (return . size) ---------------------------------------------------------------------- queryHeap :: Heap s e -> (STHeap s e -> ST s a) -> ST s a queryHeap = (>>=) . readSTRef . unHeap -- onHeap :: Heap s e -> (STHeap s e -> ST s (a, STHeap s e)) -> ST s a -- onHeap = modSTRef . unHeap onHeap_ :: Heap s e -> (STHeap s e -> ST s (STHeap s e)) -> ST s () onHeap_ = modSTRef_ . unHeap -- modSTRef :: STRef s a -> (a -> ST s (b, a)) -> ST s b -- modSTRef ref f = do (ans, x') <- f =<< readSTRef ref -- writeSTRef ref x' -- return ans modSTRef_ :: STRef s a -> (a -> ST s a) -> ST s () modSTRef_ ref f = readSTRef ref >>= f >>= writeSTRef ref ensureSize :: Int -> STHeap s e -> ST s (STHeap s e) ensureSize n h@STH{..} = do cap <- getNumElements arr if cap < n then do arr' <- newArray_ (0, 5 * n `quot` 4 - 1) mapM_ (liftM2 (>>=) (unsafeRead arr) (unsafeWrite arr')) [0..size-1] return h{arr = arr'} else return h heapUp :: Ord e => STArray s Int e -> Int -> e -> ST s () heapUp !arr i ai = heapUp' i where heapUp' 0 = unsafeWrite arr 0 ai heapUp' i = let !j = (i - 1) `quot` 2 in do aj <- unsafeRead arr j if aj < ai then unsafeWrite arr i ai else unsafeWrite arr i aj >> heapUp' j unsafePusher :: Ord e => e -> STHeap s e -> ST s (STHeap s e) unsafePusher k h@STH{..} = do heapUp arr size k return h{size = size + 1} pusher :: Ord e => e -> STHeap s e -> ST s (STHeap s e) pusher k h@STH{size} = ensureSize (size+1) h >>= unsafePusher k heapDown :: Ord e => Int -> STArray s Int e -> Int -> e -> ST s () heapDown n !arr i ai = lt `seq` heapDown' i where lt = (<) -- hack to minimize polymorphism overhead heapDown' i = let rchild = 2 * i + 2; lchild = rchild - 1 in case compare rchild n of LT -> do al <- unsafeRead arr lchild ar <- unsafeRead arr rchild let (ach, ch) = if al < ar then (al, lchild) else (ar, rchild) if ach `lt` ai then unsafeWrite arr i ach >> heapDown' ch else unsafeWrite arr i ai EQ -> do al <- unsafeRead arr lchild if al `lt` ai then unsafeWrite arr i al >> unsafeWrite arr lchild ai else unsafeWrite arr i ai GT -> unsafeWrite arr i ai popper :: Ord e => STHeap s e -> ST s (STHeap s e) popper h@STH{..} = let s' = size - 1 in do ai <- unsafeRead arr s' unsafeWrite arr s' undefined heapDown s' arr 0 ai return h{size = s'}