{-# LANGUAGE Rank2Types, MultiParamTypeClasses, FlexibleInstances, GeneralizedNewtypeDeriving, TypeFamilies, UndecidableInstances #-} {- | Safe implementation of an array-backed binary heap. The 'HeapT' transformer requires that the underlying monad provide a 'MonadST' instance, meaning that the bottom-level monad must be 'ST'. This critical restriction protects referential transparency, disallowing multi-threaded behavior as if the '[]' monad were at the bottom level. (The 'HeapM' monad takes care of the 'ST' bottom level automatically.) -} module Control.Monad.Queue.Heap (HeapM, HeapT, runHeapM, runHeapMOn, runHeapT, runHeapTOn, UHeapT, runUHeapT) where import Control.Monad.Array.ArrayT import Control.Monad.Array.Unboxed import Control.Monad.Array.Class import Control.Monad.ST import Control.Monad.ST.Class import Data.Array.Vector import Control.Monad.State.Strict import Control.Monad.RWS.Class import Control.Monad.Queue.Class import Control.Monad -- | Monad based on an array implementation of a standard binary heap. type HeapM s e = HeapT e (ST s) -- | Monad transformer based on an array implementation of a standard binary heap. newtype HeapT e m a = HeapT {execHeapT :: StateT Int (ArrayT e m) a} deriving (Monad, MonadPlus, MonadFix, MonadReader r, MonadWriter w) newtype UHeapT e m a = UHeapT {execUHeapT :: StateT Int (UArrayT e m) a} deriving (Monad, MonadFix, MonadReader r, MonadWriter w) instance MonadTrans (HeapT e) where lift = HeapT . lift . lift instance MonadState s m => MonadState s (HeapT e m) where get = lift get put = lift . put -- | Runs an 'HeapM' computation starting with an empty heap. runHeapM :: Ord e => (forall s . HeapM s e a) -> a runHeapM m = runST $ runHeapT m runHeapMOn :: Ord e => (forall s . HeapM s e a) -> Int -> [e] -> a runHeapMOn m n l = runST $ runHeapTOn m n l runHeapT :: (MonadST m, Monad m) => HeapT e m a -> m a runHeapT m = runArrayT_ 16 (evalStateT (execHeapT m) 0) runUHeapT :: (MonadST m, Monad m, UA e, Ord e) => UHeapT e m a -> m a runUHeapT m = evalUArrayT 16 (evalStateT (execUHeapT m) 0) -- | Runs an 'HeapM' computation starting with a heap initialized to hold the specified list. (Since this can be done with linear preprocessing, this is more efficient than inserting the elements one by one.) runHeapTOn :: (MonadST m, Monad m, Ord e) => HeapT e m a -- ^ The transformer operation. -> Int -- ^ The starting size of the heap (must be equal to the length of the list) -> [e] -- ^ The initial contents of the heap -> m a runHeapTOn m n l = runArrayT_ n $ flip evalStateT n $ do mapM_ (uncurry unsafeWriteAt) (zip [0..n-1] l) mapM_ (\ i -> unsafeReadAt i >>= heapDown n i) [n-1,n-2..0] execHeapT m instance (MonadST m, Monad m, Ord e) => MonadQueue (HeapT e m) where type QKey (HeapT e m) = e queuePeek = HeapT $ do size <- get if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing queueInsert x = HeapT $ do size <- get ensureHeap (size+1) put (size + 1) heapUp size x queueDelete = HeapT $ do size <- get put (size - 1) unsafeReadAt (size - 1) >>= heapDown (size - 1) 0 >> unsafeWriteAt (size-1) undefined queueSize = HeapT get instance (MonadST m, Monad m, UA e, Ord e) => MonadQueue (UHeapT e m) where type QKey (UHeapT e m) = e queuePeek = UHeapT $ do size <- get if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing queueInsert x = UHeapT $ do size <- get ensureHeap (size+1) put (size + 1) heapUp size x queueDelete = UHeapT $ do size <- get put (size - 1) unsafeReadAt (size - 1) >>= heapDown (size - 1) 0 queueSize = UHeapT get {-# INLINE ensureHeap #-} ensureHeap :: MonadArray m => Int -> m () ensureHeap n = do cap <- askSize when (n - 1 >= cap) (resize (4 * n)) {-# INLINE heapUp #-} heapUp :: (MonadArray m, e ~ ArrayElem m, Ord e) => Int -> e -> m () heapUp = let heapUp' 0 x = unsafeWriteAt 0 x heapUp' i x = let j = (i - 1) `quot` 2 in do aj <- unsafeReadAt j if x >= aj then unsafeWriteAt i x else unsafeWriteAt i aj >> heapUp' j x in heapUp' {-# INLINE heapDown #-} heapDown :: (MonadArray m, e ~ ArrayElem m, Ord e) => Int -> Int -> e -> m () heapDown size = heapDown' where heapDown' i x = let lch = 2 * i + 1; rch = lch + 1 in case compare rch size of LT -> do al <- unsafeReadAt lch ar <- unsafeReadAt rch let (ach, ch) = if al < ar then (al, lch) else (ar, rch) if ach < x then unsafeWriteAt i ach >> heapDown' ch x else unsafeWriteAt i x EQ -> do al <- readAt lch if al < x then unsafeWriteAt i al >> unsafeWriteAt lch x else unsafeWriteAt i x GT -> unsafeWriteAt i x