{-# LANGUAGE BangPatterns #-}
module Data.RangeMin.Cartesian (equivVectorBy) where

import Control.Monad
import Data.RangeMin.Common.Types
import Data.RangeMin.Common.Vector
import qualified Data.RangeMin.Fusion as F
import Prelude hiding (read)

import qualified Data.Vector as V
import qualified Data.Vector.Primitive as PV
import qualified Data.Vector.Storable as SV
import qualified Data.Vector.Unboxed as UV

-- | A 'CartesianTree' is a tree, specified in the following format: the
-- @i@th entry of the vector is the parent of node @i@, or is @-1@ for the
-- root.
type CartesianTree = PVector Value

data IL = IL {-# UNPACK #-} !Index IL | Nil

{-# SPECIALIZE equivVectorBy :: LEq a -> V.Vector a -> PV.Vector Int #-}
{-# SPECIALIZE INLINE equivVectorBy :: UV.Unbox a => LEq a -> UV.Vector a -> PV.Vector Int #-}
{-# SPECIALIZE INLINE [1] equivVectorBy :: PV.Prim a => LEq a -> PV.Vector a -> PV.Vector Int #-}
{-# SPECIALIZE INLINE [1] equivVectorBy :: SV.Storable a => LEq a -> SV.Vector a -> PV.Vector Int #-}
-- | /O(n)/.  Given a comparison function and a lookup function, this function constructs a
-- @'PV.Vector' 'Int'@ with the property that the minimum index between @i@ and @j@ 
-- in the result vector is the same as the minimum index between @i@ and @j@ from the original
-- elements.  (In both cases, ties are broken by which index comes first.)
-- 
-- This allows us to use the specialized range-min implementation on @'PV.Vector' 'Int'@,
-- even for other 'Vector' implementations, other element types, and other comparison 
-- functions.
-- 
-- Internally, this function constructs the Cartesian tree of the input vector 
-- (implicitly, to save memory and stack space), and returns the vector of the 
-- depth of each element in the tree.
equivVectorBy :: Vector v a => LEq a -> v a -> PV.Vector Int
equivVectorBy (<=?) !xs = makeDepths (makeTree (<=?) (vlength xs) (xs !))

-- | This method takes as input a tree, specified as an array in which the
-- @i@th entry is the parent of node @i@, or @-1@ for the root.  It returns
-- the vector of the depths of each node.
makeDepths :: CartesianTree -> PV.Vector Value
makeDepths !parents = create $ do
	let !n = PV.length parents
	!dest <- newWith n (-1)
	let depth !i = do
	      d0 <- read dest i
	      when (d0 == -1) $ case parents ! i of -- this node has not been visited
		  -1 -> write dest i 0	-- this is the root
		  p  -> do	-- recurse to this node's parent
		      depth p
		      dp <- read dest p
		      write dest i (dp + 1)
	F.mapM_ depth (F.enumN n)
	return dest

data S = S {-# UNPACK #-} !Int IL

{-# INLINE makeTree #-}
-- | This method constructs the cartesian tree of the input.
makeTree :: LEq a -> Length -> (Index -> a) -> CartesianTree
makeTree (<=?) !n look = create $ do
	!dest <- new n
	let suc stk !i xi = let
	      run Nil = S (-1) (IL i Nil)
	      run stk0@(IL j stk)
		| xi <=? look j	= run stk
		| otherwise	= S j (IL i stk0)
	      in run stk
	let goR !stk !i = when (i >= 0) $ do
	    	let !(S j stk') = suc stk i (look i)
	    	write dest i j
		goR stk' (i - 1)
	goR Nil (n-1)
	let goL !stk !i = when (i < n) $ do
		r <- read dest i
		let !(S l stk') = suc stk i (look i)
		write dest i $ case (l, r) of
		  (-1, -1)  -> -1
		  (-1, r)   -> r
		  (l, -1)   -> l
		  (l, r)    -> if look l <=? look r then r else l
		goL stk' (i + 1)
	goL Nil 0
	return dest