module Control.Monad.Queue.Heap (HeapM, runHeapM) where
import Control.Monad.State.Strict
import Control.Monad.Array
import Control.Monad.ST.Class
import Control.Monad
import Control.Monad.RWS.Class
import Control.Monad.Queue.Class
import Control.Monad.Identity
newtype HeapM s e a = HeapM {execHeapM :: StateT Int (ArrayM s e) a}
instance Monad (HeapM s e) where
return x = HeapM (return x)
m >>= k = HeapM (execHeapM m >>= execHeapM . k)
instance MonadST s (HeapM s e) where
liftST = HeapM . liftST
runHeapM :: Ord e => (forall s . HeapM s e a) -> a
runHeapM m = runArrayM_ 16 (evalStateT (execHeapM m) 0)
runHeapTOn :: (Ord e) => (forall s . HeapM s e a)
-> Int
-> [e]
-> a
runHeapTOn m n l = runArrayM_ 16 $ flip evalStateT n $ do mapM_ (uncurry unsafeWriteAt) (zip [0..] l)
mapM_ (\ i -> unsafeReadAt i >>= heapDown n i) [0..n1]
execHeapM m
ensureHeap :: MonadArray e m => Int -> m ()
ensureHeap n = do cap <- getSize
when (n 1 >= cap) (resize (2 * n))
heapUp :: (MonadArray e m, Ord e) => Int -> e -> m ()
heapUp = let heapUp' 0 x = unsafeWriteAt 0 x
heapUp' i x = let j = (i 1) `quot` 2 in do
aj <- unsafeReadAt j
if x >= aj then unsafeWriteAt i x else unsafeWriteAt i aj >> heapUp' j x
in heapUp'
heapDown :: (MonadArray e m, Ord e) => Int -> Int -> e -> m ()
heapDown size = heapDown'
where heapDown' i x = let lch = 2 * i + 1; rch = lch + 1 in case compare rch size of
LT -> do al <- unsafeReadAt lch
ar <- unsafeReadAt rch
let (ach, ch) = if al < ar then (al, lch) else (ar, rch)
if ach < x then unsafeWriteAt i ach >> heapDown' ch x else unsafeWriteAt i x
EQ -> do al <- readAt lch
if al < x then unsafeWriteAt i al >> unsafeWriteAt lch x else unsafeWriteAt i x
GT -> unsafeWriteAt i x
instance (Ord e) => MonadQueue e (HeapM s e) where
queuePeek = HeapM $ do
size <- get
if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing
queueInsert x = HeapM $ do
size <- get
ensureHeap (size+1)
put (size + 1)
heapUp size x
queueDelete = HeapM $ do
size <- get
put (size 1)
unsafeReadAt (size 1) >>= heapDown (size 1) 0 >> unsafeWriteAt (size1) undefined
queueSize = HeapM get