{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, TypeFamilies, BangPatterns, NamedFieldPuns, RecordWildCards #-}
{-# OPTIONS -fno-warn-name-shadowing #-}

-- | Array-based implementation of an entirely traditional binary heap.
module Data.MQueue.Heap (Heap, getSize) where

import Data.MQueue.Class
import Data.MQueue.MonadHelpers

import Control.Monad.ST.Class
import Data.Tuple.HT

import Data.Array.Base

import Data.Array.ST
import Data.STRef
import Control.Monad.ST
import Control.Monad
-- import Control.Arrow ((***))

data STHeap s e = STH {size :: {-# UNPACK #-} !Int, arr :: {-# UNPACK #-} !(STArray s Int e)}
newtype Heap s e = H {unHeap :: STRef s (STHeap s e)}

instance (Monad m, MonadST m, StateThread m ~ s, Ord e) => MQueue (Heap s e) m where
	{-# SPECIALIZE instance Ord e => MQueue (Heap s e) (ST s) #-}
	{-# SPECIALIZE instance Ord e => MQueue (Heap RealWorld e) IO #-}
	{-# INLINE pushAll #-}
	type MQueueKey (Heap s e) = e
	newQueue = liftST newHeap
	push h = liftST . pushHeap h
	pushAll h = liftST . pushAllHeap h
	pop_  = liftST . popHeap_
	peek = liftST . peekHeap

{-# SPECIALIZE getSize :: Heap s e -> ST s Int #-}
getSize :: (Monad m, MonadST m, StateThread m ~ s) => Heap s e -> m Int
getSize = liftST . getHeapSize

----------------------------------------------------------------------

{-# INLINE pushAllHeap #-}
newHeap = liftM H (liftM (STH 0) (newArray_ (0, 15)) >>= newSTRef)
pushHeap h = onHeap_ h . pusher
pushAllHeap h ks = onHeap_ h (\ h@STH{size} -> uncurry (with . flip ensureSize h) (foldr accumulator (size, \ _ -> return ()) ks))
	where	accumulator k = mapPair ((+1), liftM2 (>>) (unsafePusher k))
popHeap_ h = onHeap_ h popper
peekHeap h = queryHeap h (\ STH{..} -> if size > 0 then liftM Just (unsafeRead arr 0) else return Nothing)
getHeapSize h = queryHeap h (return . size)

----------------------------------------------------------------------

queryHeap :: Heap s e -> (STHeap s e -> ST s a) -> ST s a
queryHeap = (>>=) . readSTRef . unHeap

-- onHeap :: Heap s e -> (STHeap s e -> ST s (a, STHeap s e)) -> ST s a
-- onHeap = modSTRef . unHeap

onHeap_ :: Heap s e -> (STHeap s e -> ST s (STHeap s e)) -> ST s ()
onHeap_ = modSTRef_ . unHeap

-- modSTRef :: STRef s a -> (a -> ST s (b, a)) -> ST s b
-- modSTRef ref f = do	(ans, x') <- f =<< readSTRef ref
-- 			writeSTRef ref x'
-- 			return ans

modSTRef_ :: STRef s a -> (a -> ST s a) -> ST s ()
modSTRef_ ref f = readSTRef ref >>= f >>= writeSTRef ref

ensureSize :: Int -> STHeap s e -> ST s (STHeap s e)
ensureSize n h@STH{..} = do	cap <- getNumElements arr
				if cap < n then do	arr' <- newArray_ (0, 5 * n `quot` 4 - 1)
							mapM_ (liftM2 (>>=) (unsafeRead arr) (unsafeWrite arr')) [0..size-1]
							return h{arr = arr'}
					else return h

heapUp :: Ord e => STArray s Int e -> Int -> e -> ST s ()
heapUp !arr i ai = heapUp' i where
	heapUp' 0 = unsafeWrite arr 0 ai
	heapUp' i = let !j = (i - 1) `quot` 2 in do
		aj <- unsafeRead arr j
		if aj < ai then unsafeWrite arr i ai else unsafeWrite arr i aj >> heapUp' j

unsafePusher :: Ord e => e -> STHeap s e -> ST s (STHeap s e)
unsafePusher k h@STH{..} = do	heapUp arr size k
				return h{size = size + 1}

pusher :: Ord e => e -> STHeap s e -> ST s (STHeap s e)
pusher k h@STH{size} = ensureSize (size+1) h >>= unsafePusher k

heapDown :: Ord e => Int -> STArray s Int e -> Int -> e -> ST s ()
heapDown n !arr i ai = lt `seq` heapDown' i where
	lt = (<) -- hack to minimize polymorphism overhead
	heapDown' i = let rchild = 2 * i + 2; lchild = rchild - 1 in case compare rchild n of
		LT	-> do	al <- unsafeRead arr lchild
				ar <- unsafeRead arr rchild
				let (ach, ch) = if al < ar then (al, lchild) else (ar, rchild)
				if ach `lt` ai then unsafeWrite arr i ach >> heapDown' ch
					else unsafeWrite arr i ai
		EQ	-> do	al <- unsafeRead arr lchild
				if al `lt` ai then unsafeWrite arr i al >> unsafeWrite arr lchild ai
					else unsafeWrite arr i ai
		GT	-> unsafeWrite arr i ai

popper :: Ord e => STHeap s e -> ST s (STHeap s e)
popper h@STH{..} = let s' = size - 1 in do
	ai <- unsafeRead arr s'
	unsafeWrite arr s' undefined
	heapDown s' arr 0 ai
	return h{size = s'}