{-# LANGUAGE RankNTypes, FlexibleInstances, GeneralizedNewtypeDeriving, MultiParamTypeClasses, UndecidableInstances #-} module Control.Monad.Queue.Heap (HeapM, HeapT, runHeapM, runHeapMIO, runHeapT) where import Control.Monad.ST import Control.Monad.ST.Class import Control.Monad.Array.ArrayT import Control.Monad.Array.Class import Control.Monad.State.Strict import Control.Monad.RWS.Class import Control.Monad.Queue.Class import Control.Monad -- | Monad based on an array implementation of a standard binary heap. type HeapM s e = HeapT s e (ST s) -- | Monad transformer based on an array implementation of a standard binary heap. newtype HeapT s e m a = HeapT {execHeapT :: StateT Int (ArrayT s e m) a} deriving (Monad, MonadST s, MonadFix, MonadReader r, MonadWriter w) instance MonadTrans (HeapT s e) where lift = HeapT . lift . lift instance (MonadST s m, MonadState t m) => MonadState t (HeapT s e m) where get = lift get put = lift . put -- | Runs an 'HeapM' computation starting with an empty heap. runHeapM :: Ord e => (forall s . HeapM s e a) -> a runHeapM m = runST $ runHeapT m runHeapMIO :: Ord e => HeapM RealWorld e a -> IO a runHeapMIO m = stToIO $ runHeapT m runHeapT :: (MonadST s m, Monad m) => HeapT s e m a -> m a runHeapT m = runArrayT_ 16 (evalStateT (execHeapT m) 0) -- | Runs an 'HeapM' computation starting with a heap initialized to hold the specified list. (Since this can be done with linear preprocessing, this is more efficient than inserting the elements one by one.) -- runHeapTOn :: (Ord e) => (forall s . HeapM s e a) -- ^ The transformer operation. -- -> Int -- ^ The starting size of the heap (must be equal to the length of the list) -- -> [e] -- ^ The initial contents of the heap -- -> a -- runHeapTOn m n l = runArrayM_ 16 $ flip evalStateT n $ do mapM_ (uncurry unsafeWriteAt) (zip [0..] l) -- mapM_ (\ i -> unsafeReadAt i >>= heapDown n i) [0..n-1] -- execHeapM m ensureHeap :: MonadArray e m => Int -> m () ensureHeap n = do cap <- getSize when (n - 1 >= cap) (resize (2 * n)) heapUp :: (MonadArray e m, Ord e) => Int -> e -> m () heapUp = let heapUp' 0 x = unsafeWriteAt 0 x heapUp' i x = let j = (i - 1) `quot` 2 in do aj <- unsafeReadAt j if x >= aj then unsafeWriteAt i x else unsafeWriteAt i aj >> heapUp' j x in heapUp' heapDown :: (MonadArray e m, Ord e) => Int -> Int -> e -> m () heapDown size = heapDown' where heapDown' i x = let lch = 2 * i + 1; rch = lch + 1 in case compare rch size of LT -> do al <- unsafeReadAt lch ar <- unsafeReadAt rch let (ach, ch) = if al < ar then (al, lch) else (ar, rch) if ach < x then unsafeWriteAt i ach >> heapDown' ch x else unsafeWriteAt i x EQ -> do al <- readAt lch if al < x then unsafeWriteAt i al >> unsafeWriteAt lch x else unsafeWriteAt i x GT -> unsafeWriteAt i x instance (MonadST s m, Monad m, Ord e) => MonadQueue e (HeapT s e m) where queuePeek = HeapT $ do size <- get if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing queueInsert x = HeapT $ do size <- get ensureHeap (size+1) put (size + 1) heapUp size x queueDelete = HeapT $ do size <- get put (size - 1) unsafeReadAt (size - 1) >>= heapDown (size - 1) 0 >> unsafeWriteAt (size-1) undefined queueSize = HeapT get