module Control.Monad.Queue.Heap (HeapM, HeapT, runHeapM, runHeapMOn, runHeapT, runHeapTOn) where
import Control.Monad.Array.ArrayT
import Control.Monad.Array.Class
import Control.Monad.ST
import Control.Monad.ST.Class
import Control.Monad.State.Strict
import Control.Monad.RWS.Class
import Control.Monad.Queue.Class
import Control.Monad
type HeapM s e = HeapT e (ST s)
newtype HeapT e m a = HeapT {execHeapT :: StateT Int (ArrayT e m) a} deriving (Monad, MonadPlus, MonadFix, MonadReader r, MonadWriter w)
instance MonadTrans (HeapT e) where
lift = HeapT . lift . lift
instance MonadState s m => MonadState s (HeapT e m) where
get = lift get
put = lift . put
runHeapM :: Ord e => (forall s . HeapM s e a) -> a
runHeapM m = runST $ runHeapT m
runHeapMOn :: Ord e => (forall s . HeapM s e a) -> Int -> [e] -> a
runHeapMOn m n l = runST $ runHeapTOn m n l
runHeapT :: (MonadST m, Monad m) => HeapT e m a -> m a
runHeapT m = runArrayT_ 16 (evalStateT (execHeapT m) 0)
runHeapTOn :: (MonadST m, Monad m, Ord e) =>
HeapT e m a
-> Int
-> [e]
-> m a
runHeapTOn m n l = runArrayT_ n $ flip evalStateT n $ do mapM_ (uncurry unsafeWriteAt) (zip [0..n1] l)
mapM_ (\ i -> unsafeReadAt i >>= heapDown n i) [n1,n2..0]
execHeapT m
instance (MonadST m, Monad m, Ord e) => MonadQueue (HeapT e m) where
type QKey (HeapT e m) = e
queuePeek = HeapT $ do
size <- get
if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing
queueInsert x = HeapT $ do
size <- get
ensureHeap (size+1)
put (size + 1)
heapUp size x
queueDelete = HeapT $ do
size <- get
put (size 1)
unsafeReadAt (size 1) >>= heapDown (size 1) 0 >> unsafeWriteAt (size1) undefined
queueSize = HeapT get
ensureHeap :: MonadArray m => Int -> m ()
ensureHeap n = do cap <- askSize
when (n 1 >= cap) (resize (2 * n))
heapUp :: (MonadArray m, e ~ ArrayElem m, Ord e) => Int -> e -> m ()
heapUp = let heapUp' 0 x = unsafeWriteAt 0 x
heapUp' i x = let j = (i 1) `quot` 2 in do
aj <- unsafeReadAt j
if x >= aj then unsafeWriteAt i x else unsafeWriteAt i aj >> heapUp' j x
in heapUp'
heapDown :: (MonadArray m, e ~ ArrayElem m, Ord e) => Int -> Int -> e -> m ()
heapDown size = heapDown'
where heapDown' i x = let lch = 2 * i + 1; rch = lch + 1 in case compare rch size of
LT -> do al <- unsafeReadAt lch
ar <- unsafeReadAt rch
let (ach, ch) = if al < ar then (al, lch) else (ar, rch)
if ach < x then unsafeWriteAt i ach >> heapDown' ch x else unsafeWriteAt i x
EQ -> do al <- readAt lch
if al < x then unsafeWriteAt i al >> unsafeWriteAt lch x else unsafeWriteAt i x
GT -> unsafeWriteAt i x