module Control.Monad.Queue.Heap (HeapM, HeapT, runHeapM, runHeapMIO, runHeapT) where
import Control.Monad.ST
import Control.Monad.ST.Class
import Control.Monad.Array.ArrayT
import Control.Monad.Array.Class
import Control.Monad.State.Strict
import Control.Monad.RWS.Class
import Control.Monad.Queue.Class
import Control.Monad
type HeapM s e = HeapT s e (ST s)
newtype HeapT s e m a = HeapT {execHeapT :: StateT Int (ArrayT s e m) a} deriving (Monad, MonadST s, MonadFix, MonadReader r, MonadWriter w)
instance MonadTrans (HeapT s e) where
lift = HeapT . lift . lift
instance (MonadST s m, MonadState t m) => MonadState t (HeapT s e m) where
get = lift get
put = lift . put
runHeapM :: Ord e => (forall s . HeapM s e a) -> a
runHeapM m = runST $ runHeapT m
runHeapMIO :: Ord e => HeapM RealWorld e a -> IO a
runHeapMIO m = stToIO $ runHeapT m
runHeapT :: (MonadST s m, Monad m) => HeapT s e m a -> m a
runHeapT m = runArrayT_ 16 (evalStateT (execHeapT m) 0)
ensureHeap :: MonadArray e m => Int -> m ()
ensureHeap n = do cap <- getSize
when (n 1 >= cap) (resize (2 * n))
heapUp :: (MonadArray e m, Ord e) => Int -> e -> m ()
heapUp = let heapUp' 0 x = unsafeWriteAt 0 x
heapUp' i x = let j = (i 1) `quot` 2 in do
aj <- unsafeReadAt j
if x >= aj then unsafeWriteAt i x else unsafeWriteAt i aj >> heapUp' j x
in heapUp'
heapDown :: (MonadArray e m, Ord e) => Int -> Int -> e -> m ()
heapDown size = heapDown'
where heapDown' i x = let lch = 2 * i + 1; rch = lch + 1 in case compare rch size of
LT -> do al <- unsafeReadAt lch
ar <- unsafeReadAt rch
let (ach, ch) = if al < ar then (al, lch) else (ar, rch)
if ach < x then unsafeWriteAt i ach >> heapDown' ch x else unsafeWriteAt i x
EQ -> do al <- readAt lch
if al < x then unsafeWriteAt i al >> unsafeWriteAt lch x else unsafeWriteAt i x
GT -> unsafeWriteAt i x
instance (MonadST s m, Monad m, Ord e) => MonadQueue e (HeapT s e m) where
queuePeek = HeapT $ do
size <- get
if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing
queueInsert x = HeapT $ do
size <- get
ensureHeap (size+1)
put (size + 1)
heapUp size x
queueDelete = HeapT $ do
size <- get
put (size 1)
unsafeReadAt (size 1) >>= heapDown (size 1) 0 >> unsafeWriteAt (size1) undefined
queueSize = HeapT get