{-# LANGUAGE RankNTypes, FlexibleInstances, GeneralizedNewtypeDeriving, MultiParamTypeClasses, UndecidableInstances #-}

module Control.Monad.Queue.Heap (HeapM, HeapT, runHeapM, runHeapMIO, runHeapT) where

import Control.Monad.ST
import Control.Monad.ST.Class
import Control.Monad.Array.ArrayT
import Control.Monad.Array.Class

import Control.Monad.State.Strict
import Control.Monad.RWS.Class
import Control.Monad.Queue.Class

import Control.Monad

-- | Monad based on an array implementation of a standard binary heap.
type HeapM s e = HeapT s e (ST s)
-- | Monad transformer based on an array implementation of a standard binary heap.
newtype HeapT s e m a = HeapT {execHeapT :: StateT Int (ArrayT s e m) a} deriving (Monad, MonadST s, MonadFix, MonadReader r, MonadWriter w)

instance MonadTrans (HeapT s e) where
	lift = HeapT . lift . lift

instance (MonadST s m, MonadState t m) => MonadState t (HeapT s e m) where
	get = lift get
	put = lift . put

-- | Runs an 'HeapM' computation starting with an empty heap.
runHeapM :: Ord e => (forall s . HeapM s e a) -> a
runHeapM m = runST $ runHeapT m

runHeapMIO :: Ord e => HeapM RealWorld e a -> IO a
runHeapMIO m = stToIO $ runHeapT m

runHeapT :: (MonadST s m, Monad m) => HeapT s e m a -> m a
runHeapT m = runArrayT_ 16 (evalStateT (execHeapT m) 0)

-- | Runs an 'HeapM' computation starting with a heap initialized to hold the specified list.  (Since this can be done with linear preprocessing, this is more efficient than inserting the elements one by one.)
-- runHeapTOn :: (Ord e) => (forall s . HeapM s e a) -- ^ The transformer operation.
-- 				-> Int -- ^ The starting size of the heap (must be equal to the length of the list)
-- 				-> [e] -- ^ The initial contents of the heap
-- 				-> a
-- runHeapTOn m n l = runArrayM_ 16 $ flip evalStateT n $ do	mapM_ (uncurry unsafeWriteAt) (zip [0..] l)
-- 								mapM_ (\ i -> unsafeReadAt i >>= heapDown n i) [0..n-1]
-- 								execHeapM m

ensureHeap :: MonadArray e m => Int -> m ()
ensureHeap n = do	cap <- getSize
			when (n - 1 >= cap) (resize (2 * n))

heapUp :: (MonadArray e m, Ord e) => Int -> e -> m ()
heapUp = let	heapUp' 0 x	= unsafeWriteAt 0 x
		heapUp' i x	= let j = (i - 1) `quot` 2 in do
			aj <- unsafeReadAt j
			if x >= aj then unsafeWriteAt i x else unsafeWriteAt i aj >> heapUp' j x
		in heapUp'

heapDown :: (MonadArray e m, Ord e) => Int -> Int -> e -> m ()
heapDown size = heapDown'
	where	heapDown' i x = let lch = 2 * i + 1; rch = lch + 1 in case compare rch size of
			LT	-> do	al <- unsafeReadAt lch
					ar <- unsafeReadAt rch
					let (ach, ch) = if al < ar then (al, lch) else (ar, rch)
					if ach < x then unsafeWriteAt i ach >> heapDown' ch x else unsafeWriteAt i x
			EQ	-> do	al <- readAt lch
					if al < x then unsafeWriteAt i al >> unsafeWriteAt lch x else unsafeWriteAt i x
			GT	-> unsafeWriteAt i x

instance (MonadST s m, Monad m, Ord e) => MonadQueue e (HeapT s e m) where
	queuePeek = HeapT $ do	
		size <- get
		if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing
	queueInsert x = HeapT $ do
		size <- get
		ensureHeap (size+1)
		put (size + 1)
		heapUp size x
	queueDelete = HeapT $ do
		size <- get
		put (size - 1)
		unsafeReadAt (size - 1) >>= heapDown (size - 1) 0 >> unsafeWriteAt (size-1) undefined
	queueSize = HeapT get