{-# LANGUAGE RankNTypes, FlexibleInstances, GeneralizedNewtypeDeriving, MultiParamTypeClasses, UndecidableInstances #-}

module Control.Monad.Queue.Heap (HeapM, runHeapM) where

import Control.Monad.State.Strict
import Control.Monad.Array
import Control.Monad.ST.Class
import Control.Monad
import Control.Monad.RWS.Class
import Control.Monad.Queue.Class
import Control.Monad.Identity

-- | Monad based on an array implementation of a standard binary heap.
newtype HeapM s e a = HeapM {execHeapM :: StateT Int (ArrayM s e) a} 

instance Monad (HeapM s e) where
	return x = HeapM (return x)
	m >>= k = HeapM (execHeapM m >>= execHeapM . k)

instance MonadST s (HeapM s e) where
	liftST = HeapM . liftST

-- | Runs an 'HeapM' computation starting with an empty heap.
runHeapM :: Ord e => (forall s . HeapM s e a) -> a
runHeapM m = runArrayM_ 16 (evalStateT (execHeapM m) 0)

-- | Runs an 'HeapM' computation starting with a heap initialized to hold the specified list.  (Since this can be done with linear preprocessing, this is more efficient than inserting the elements one by one.)
runHeapTOn :: (Ord e) => (forall s . HeapM s e a) -- ^ The transformer operation.
				-> Int -- ^ The starting size of the heap (must be equal to the length of the list)
				-> [e] -- ^ The initial contents of the heap
				-> a
runHeapTOn m n l = runArrayM_ 16 $ flip evalStateT n $ do	mapM_ (uncurry unsafeWriteAt) (zip [0..] l)
								mapM_ (\ i -> unsafeReadAt i >>= heapDown n i) [0..n-1]
								execHeapM m

ensureHeap :: MonadArray e m => Int -> m ()
ensureHeap n = do	cap <- getSize
			when (n - 1 >= cap) (resize (2 * n))

heapUp :: (MonadArray e m, Ord e) => Int -> e -> m ()
heapUp = let	heapUp' 0 x	= unsafeWriteAt 0 x
		heapUp' i x	= let j = (i - 1) `quot` 2 in do
			aj <- unsafeReadAt j
			if x >= aj then unsafeWriteAt i x else unsafeWriteAt i aj >> heapUp' j x
		in heapUp'

heapDown :: (MonadArray e m, Ord e) => Int -> Int -> e -> m ()
heapDown size = heapDown'
	where	heapDown' i x = let lch = 2 * i + 1; rch = lch + 1 in case compare rch size of
			LT	-> do	al <- unsafeReadAt lch
					ar <- unsafeReadAt rch
					let (ach, ch) = if al < ar then (al, lch) else (ar, rch)
					if ach < x then unsafeWriteAt i ach >> heapDown' ch x else unsafeWriteAt i x
			EQ	-> do	al <- readAt lch
					if al < x then unsafeWriteAt i al >> unsafeWriteAt lch x else unsafeWriteAt i x
			GT	-> unsafeWriteAt i x

instance (Ord e) => MonadQueue e (HeapM s e) where
	queuePeek = HeapM $ do	
		size <- get
		if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing
	queueInsert x = HeapM $ do
		size <- get
		ensureHeap (size+1)
		put (size + 1)
		heapUp size x
	queueDelete = HeapM $ do
		size <- get
		put (size - 1)
		unsafeReadAt (size - 1) >>= heapDown (size - 1) 0 >> unsafeWriteAt (size-1) undefined
	queueSize = HeapM get