{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, TypeFamilies, BangPatterns, NamedFieldPuns, RecordWildCards #-}
{-# OPTIONS -fno-warn-name-shadowing #-}

-- | Array-based implementation of an entirely traditional binary heap.
module Data.MQueue.Heap (Heap, getSize) where

import Data.MQueue.Class
import Data.MQueue.MonadHelpers

import Control.Arrow((***))
import Control.Monad.ST.Class
--import Data.Tuple.HT

import Data.Array.Base

import Data.Array.ST
import Data.STRef
import Control.Monad.ST
import Control.Monad
-- import Control.Arrow ((***))

data STHeap s e = STH {size :: {-# UNPACK #-} !Int, arr :: {-# UNPACK #-} !(STArray s Int e)}
newtype Heap s e = H {unHeap :: STRef s (STHeap s e)}

instance (Monad m, MonadST m, StateThread m ~ s, Ord e) => MQueue (Heap s e) m where
	{-# SPECIALIZE instance Ord e => MQueue (Heap s e) (ST s) #-}
	{-# SPECIALIZE instance Ord e => MQueue (Heap RealWorld e) IO #-}
	{-# INLINE pushAll #-}
	type MQueueKey (Heap s e) = e
	newQueue = liftST newHeap
	push h = liftST . pushHeap h
	pushAll h = liftST . pushAllHeap h
	pop_  = liftST . popHeap_
	peek = liftST . peekHeap

{-# SPECIALIZE getSize :: Heap s e -> ST s Int #-}
getSize :: (Monad m, MonadST m, StateThread m ~ s) => Heap s e -> m Int
getSize = liftST . getHeapSize

----------------------------------------------------------------------

{-# INLINE pushAllHeap #-}
newHeap = liftM H (liftM (STH 0) (newArray_ (0, 15)) >>= newSTRef)
pushHeap h = onHeap_ h . pusher
pushAllHeap h ks = onHeap_ h (\ h@STH{size} -> uncurry (with . flip ensureSize h) (foldr accumulator (size, \ _ -> return ()) ks))
	where	accumulator k = ((+1) *** liftM2 (>>) (unsafePusher k))
popHeap_ h = onHeap_ h popper
peekHeap h = queryHeap h (\ STH{..} -> if size > 0 then liftM Just (unsafeRead arr 0) else return Nothing)
getHeapSize h = queryHeap h (return . size)

----------------------------------------------------------------------

queryHeap :: Heap s e -> (STHeap s e -> ST s a) -> ST s a
queryHeap = (>>=) . readSTRef . unHeap

-- onHeap :: Heap s e -> (STHeap s e -> ST s (a, STHeap s e)) -> ST s a
-- onHeap = modSTRef . unHeap

onHeap_ :: Heap s e -> (STHeap s e -> ST s (STHeap s e)) -> ST s ()
onHeap_ = modSTRef_ . unHeap

-- modSTRef :: STRef s a -> (a -> ST s (b, a)) -> ST s b
-- modSTRef ref f = do	(ans, x') <- f =<< readSTRef ref
-- 			writeSTRef ref x'
-- 			return ans

modSTRef_ :: STRef s a -> (a -> ST s a) -> ST s ()
modSTRef_ ref f = readSTRef ref >>= f >>= writeSTRef ref

ensureSize :: Int -> STHeap s e -> ST s (STHeap s e)
ensureSize n h@STH{..} = do	cap <- getNumElements arr
				if cap < n then do	arr' <- newArray_ (0, 5 * n `quot` 4 - 1)
							mapM_ (liftM2 (>>=) (unsafeRead arr) (unsafeWrite arr')) [0..size-1]
							return h{arr = arr'}
					else return h

heapUp :: Ord e => STArray s Int e -> Int -> e -> ST s ()
heapUp !arr i ai = heapUp' i where
	heapUp' 0 = unsafeWrite arr 0 ai
	heapUp' i = let !j = (i - 1) `quot` 2 in do
		aj <- unsafeRead arr j
		if aj < ai then unsafeWrite arr i ai else unsafeWrite arr i aj >> heapUp' j

unsafePusher :: Ord e => e -> STHeap s e -> ST s (STHeap s e)
unsafePusher k h@STH{..} = do	heapUp arr size k
				return h{size = size + 1}

pusher :: Ord e => e -> STHeap s e -> ST s (STHeap s e)
pusher k h@STH{size} = ensureSize (size+1) h >>= unsafePusher k

heapDown :: Ord e => Int -> STArray s Int e -> Int -> e -> ST s ()
heapDown n !arr i ai = lt `seq` heapDown' i where
	lt = (<) -- hack to minimize polymorphism overhead
	heapDown' i = let rchild = 2 * i + 2; lchild = rchild - 1 in case compare rchild n of
		LT	-> do	al <- unsafeRead arr lchild
				ar <- unsafeRead arr rchild
				let (ach, ch) = if al < ar then (al, lchild) else (ar, rchild)
				if ach `lt` ai then unsafeWrite arr i ach >> heapDown' ch
					else unsafeWrite arr i ai
		EQ	-> do	al <- unsafeRead arr lchild
				if al `lt` ai then unsafeWrite arr i al >> unsafeWrite arr lchild ai
					else unsafeWrite arr i ai
		GT	-> unsafeWrite arr i ai

popper :: Ord e => STHeap s e -> ST s (STHeap s e)
popper h@STH{..} = let s' = size - 1 in do
	ai <- unsafeRead arr s'
	unsafeWrite arr s' undefined
	heapDown s' arr 0 ai
	return h{size = s'}