{-# LANGUAGE FlexibleInstances, GeneralizedNewtypeDeriving, MultiParamTypeClasses, UndecidableInstances #-} module Control.Monad.Queue.Heap (HeapT, HeapM, runHeapT, runHeapM) where import Control.Monad.State.Strict import Control.Monad.Array import Control.Monad.ST.Class import Control.Monad import Control.Monad.RWS.Class import Control.Monad.Queue.Class import Control.Monad.Identity -- | Monad transformer based on an array implementation of a standard binary heap. newtype HeapT e m a = HeapT {runHeapT :: ArrayT e (StateT Int m) a} deriving (Monad, MonadReader r, MonadST s, MonadWriter w, MonadFix, MonadIO) type HeapM e = HeapT e Identity -- | Runs an 'HeapT' transformer starting with an empty heap. runHeapT :: (Monad m, Ord e) => HeapT e m a -> m a runHeapT m = evalStateT (runArrayT_ 16 (runHeapT m)) 0 -- | Runs an 'HeapT' transformer starting with a heap initialized to hold the specified list. (Since this can be done with linear preprocessing, this is more efficient than inserting the elements one by one.) runHeapTOn :: (Monad m, Ord e) => HeapT e m a -- ^ The transformer operation. -> Int -- ^ The starting size of the heap (must be equal to the length of the list) -> [e] -- ^ The initial contents of the heap -> m a runHeapTOn m n l = flip evalStateT n $ runArrayT_ 16 $ do mapM_ (uncurry unsafeWriteAt) (zip [0..] l) mapM_ (\ i -> unsafeReadAt i >>= heapDown n i) [0..n-1] runHeapT m runHeapM :: Ord e => HeapM e a -> a runHeapM = runIdentity . runHeapT runHeapMOn :: Ord e => HeapM e a -> Int -> [e] -> a runHeapMOn m n l = runIdentity (runHeapTOn m n l) instance MonadTrans (HeapT e) where lift = HeapT . lift . lift ensureHeap :: MonadArray e m => Int -> m () ensureHeap n = do cap <- getSize when (n - 1 >= cap) (resize (2 * n)) heapUp :: (MonadArray e m, Ord e) => Int -> e -> m () heapUp = let heapUp' 0 x = unsafeWriteAt 0 x heapUp' i x = let j = (i - 1) `quot` 2 in do aj <- unsafeReadAt j if x >= aj then unsafeWriteAt i x else unsafeWriteAt i aj >> heapUp' j x in heapUp' heapDown :: (MonadArray e m, Ord e) => Int -> Int -> e -> m () heapDown size = heapDown' where heapDown' i x = let lch = 2 * i + 1; rch = lch + 1 in case compare rch size of LT -> do al <- unsafeReadAt lch ar <- unsafeReadAt rch let (ach, ch) = if al < ar then (al, lch) else (ar, rch) if ach < x then unsafeWriteAt i ach >> heapDown' ch x else unsafeWriteAt i x EQ -> do al <- readAt lch if al < x then unsafeWriteAt i al >> unsafeWriteAt lch x else unsafeWriteAt i x GT -> unsafeWriteAt i x instance (Ord e, Monad m) => MonadQueue e (HeapT e m) where {-# SPECIALIZE instance Ord e => MonadQueue e (HeapT e Identity) #-} queuePeek = HeapT $ do size <- get if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing queueInsert x = HeapT $ do size <- get ensureHeap (size+1) put (size + 1) heapUp size x queueDelete = HeapT $ do size <- get put (size - 1) unsafeReadAt (size - 1) >>= heapDown (size - 1) 0 >> unsafeWriteAt (size-1) undefined queueSize = HeapT get instance MonadState s m => MonadState s (HeapT e m) where get = lift get put = lift . put