module Data.MQueue.Heap (Heap, getSize) where
import Data.MQueue.Class
import Data.MQueue.MonadHelpers
import Control.Monad.ST.Class
import Data.Tuple.HT
import Data.Array.Base
import Data.Array.ST
import Data.STRef
import Control.Monad.ST
import Control.Monad
data STHeap s e = STH {size :: !Int, arr :: !(STArray s Int e)}
newtype Heap s e = H {unHeap :: STRef s (STHeap s e)}
instance (Monad m, MonadST m, StateThread m ~ s, Ord e) => MQueue (Heap s e) m where
type MQueueKey (Heap s e) = e
newQueue = liftST newHeap
push h = liftST . pushHeap h
pushAll h = liftST . pushAllHeap h
pop_ = liftST . popHeap_
peek = liftST . peekHeap
getSize :: (Monad m, MonadST m, StateThread m ~ s) => Heap s e -> m Int
getSize = liftST . getHeapSize
newHeap = liftM H (liftM (STH 0) (newArray_ (0, 15)) >>= newSTRef)
pushHeap h = onHeap_ h . pusher
pushAllHeap h ks = onHeap_ h (\ h@STH{size} -> uncurry (with . flip ensureSize h) (foldr accumulator (size, \ _ -> return ()) ks))
where accumulator k = mapPair ((+1), liftM2 (>>) (unsafePusher k))
popHeap_ h = onHeap_ h popper
peekHeap h = queryHeap h (\ STH{..} -> if size > 0 then liftM Just (unsafeRead arr 0) else return Nothing)
getHeapSize h = queryHeap h (return . size)
queryHeap :: Heap s e -> (STHeap s e -> ST s a) -> ST s a
queryHeap = (>>=) . readSTRef . unHeap
onHeap_ :: Heap s e -> (STHeap s e -> ST s (STHeap s e)) -> ST s ()
onHeap_ = modSTRef_ . unHeap
modSTRef_ :: STRef s a -> (a -> ST s a) -> ST s ()
modSTRef_ ref f = readSTRef ref >>= f >>= writeSTRef ref
ensureSize :: Int -> STHeap s e -> ST s (STHeap s e)
ensureSize n h@STH{..} = do cap <- getNumElements arr
if cap < n then do arr' <- newArray_ (0, 5 * n `quot` 4 1)
mapM_ (liftM2 (>>=) (unsafeRead arr) (unsafeWrite arr')) [0..size1]
return h{arr = arr'}
else return h
heapUp :: Ord e => STArray s Int e -> Int -> e -> ST s ()
heapUp !arr i ai = heapUp' i where
heapUp' 0 = unsafeWrite arr 0 ai
heapUp' i = let !j = (i 1) `quot` 2 in do
aj <- unsafeRead arr j
if aj < ai then unsafeWrite arr i ai else unsafeWrite arr i aj >> heapUp' j
unsafePusher :: Ord e => e -> STHeap s e -> ST s (STHeap s e)
unsafePusher k h@STH{..} = do heapUp arr size k
return h{size = size + 1}
pusher :: Ord e => e -> STHeap s e -> ST s (STHeap s e)
pusher k h@STH{size} = ensureSize (size+1) h >>= unsafePusher k
heapDown :: Ord e => Int -> STArray s Int e -> Int -> e -> ST s ()
heapDown n !arr i ai = lt `seq` heapDown' i where
lt = (<)
heapDown' i = let rchild = 2 * i + 2; lchild = rchild 1 in case compare rchild n of
LT -> do al <- unsafeRead arr lchild
ar <- unsafeRead arr rchild
let (ach, ch) = if al < ar then (al, lchild) else (ar, rchild)
if ach `lt` ai then unsafeWrite arr i ach >> heapDown' ch
else unsafeWrite arr i ai
EQ -> do al <- unsafeRead arr lchild
if al `lt` ai then unsafeWrite arr i al >> unsafeWrite arr lchild ai
else unsafeWrite arr i ai
GT -> unsafeWrite arr i ai
popper :: Ord e => STHeap s e -> ST s (STHeap s e)
popper h@STH{..} = let s' = size 1 in do
ai <- unsafeRead arr s'
unsafeWrite arr s' undefined
heapDown s' arr 0 ai
return h{size = s'}