{-# LANGUAGE Rank2Types, MultiParamTypeClasses, FlexibleInstances, GeneralizedNewtypeDeriving, TypeFamilies, UndecidableInstances #-}

{- | Safe implementation of an array-backed binary heap.  The 'HeapT' transformer requires that the underlying monad provide a 'MonadST' instance, meaning that the bottom-level monad must be 'ST'.  This critical restriction protects referential transparency, disallowing multi-threaded behavior as if the '[]' monad were at the bottom level. (The 'HeapM' monad takes care of the 'ST' bottom level automatically.)
-}
module Control.Monad.Queue.Heap (HeapM, HeapT, runHeapM, runHeapMOn, runHeapT, runHeapTOn) where

import Control.Monad.Array.ArrayT
import Control.Monad.Array.Class

import Control.Monad.ST
import Control.Monad.ST.Class

import Control.Monad.State.Strict
import Control.Monad.RWS.Class
import Control.Monad.Queue.Class

import Control.Monad

-- | Monad based on an array implementation of a standard binary heap.
type HeapM s e = HeapT e (ST s)
-- | Monad transformer based on an array implementation of a standard binary heap.
newtype HeapT e m a = HeapT {execHeapT :: StateT Int (ArrayT e m) a} deriving (Monad, MonadPlus, MonadFix, MonadReader r, MonadWriter w)

instance MonadTrans (HeapT e) where
	lift = HeapT . lift . lift

instance MonadState s m => MonadState s (HeapT e m) where
	get = lift get
	put = lift . put

-- | Runs an 'HeapM' computation starting with an empty heap.
runHeapM :: Ord e => (forall s . HeapM s e a) -> a
runHeapM m = runST $ runHeapT m

runHeapMOn :: Ord e => (forall s . HeapM s e a) -> Int -> [e] -> a
runHeapMOn m n l = runST $ runHeapTOn m n l

runHeapT :: (MonadST m, Monad m) => HeapT e m a -> m a
runHeapT m = runArrayT_ 16 (evalStateT (execHeapT m) 0)

-- | Runs an 'HeapM' computation starting with a heap initialized to hold the specified list.  (Since this can be done with linear preprocessing, this is more efficient than inserting the elements one by one.)
runHeapTOn :: (MonadST m, Monad m, Ord e) => 
				HeapT e m a -- ^ The transformer operation.
 				-> Int -- ^ The starting size of the heap (must be equal to the length of the list)
 				-> [e] -- ^ The initial contents of the heap
 				-> m a
runHeapTOn m n l = runArrayT_ n $ flip evalStateT n $ do	mapM_ (uncurry unsafeWriteAt) (zip [0..n-1] l)
								mapM_ (\ i -> unsafeReadAt i >>= heapDown n i) [n-1,n-2..0]
								execHeapT m

instance (MonadST m, Monad m, Ord e) => MonadQueue (HeapT e m) where
	type QKey (HeapT e m) = e
	queuePeek = HeapT $ do	
		size <- get
		if size > 0 then liftM Just (unsafeReadAt 0) else return Nothing
	queueInsert x = HeapT $ do
		size <- get
		ensureHeap (size+1)
		put (size + 1)
		heapUp size x
	queueDelete = HeapT $ do
		size <- get
		put (size - 1)
		unsafeReadAt (size - 1) >>= heapDown (size - 1) 0 >> unsafeWriteAt (size-1) undefined
	queueSize = HeapT get

ensureHeap :: MonadArray m => Int -> m ()
ensureHeap n = do	cap <- askSize
			when (n - 1 >= cap) (resize (2 * n))

heapUp :: (MonadArray m, e ~ ArrayElem m, Ord e) => Int -> e -> m ()
heapUp = let	heapUp' 0 x	= unsafeWriteAt 0 x
		heapUp' i x	= let j = (i - 1) `quot` 2 in do
			aj <- unsafeReadAt j
			if x >= aj then unsafeWriteAt i x else unsafeWriteAt i aj >> heapUp' j x
		in heapUp'

heapDown :: (MonadArray m, e ~ ArrayElem m, Ord e) => Int -> Int -> e -> m ()
heapDown size = heapDown'
	where	heapDown' i x = let lch = 2 * i + 1; rch = lch + 1 in case compare rch size of
			LT	-> do	al <- unsafeReadAt lch
					ar <- unsafeReadAt rch
					let (ach, ch) = if al < ar then (al, lch) else (ar, rch)
					if ach < x then unsafeWriteAt i ach >> heapDown' ch x else unsafeWriteAt i x
			EQ	-> do	al <- readAt lch
					if al < x then unsafeWriteAt i al >> unsafeWriteAt lch x else unsafeWriteAt i x
			GT	-> unsafeWriteAt i x