{-# LANGUAGE BangPatterns, MagicHash #-}
module Data.RangeMin.Cartesian (buildDepths) where

import Data.RangeMin.Common.Types
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Primitive as PV
import qualified Data.Vector.Fusion.Stream.Monadic as SM
import Data.Vector.Fusion.Stream (Step (..), Stream)
import qualified Data.Vector.Fusion.Stream as S
import Data.RangeMin.Common.Vector
import Data.RangeMin.Cartesian.STInt
import Prelude hiding (read)

-- | A 'CartesianTree' is a tree, specified in the following format: the
-- @i@th entry of the vector is the parent of node @i@, or is @-1@ for the
-- root.
type CartesianTree = PV.Vector Int

data IL = IL {-# UNPACK #-} !Int IL | Nil

{-# INLINE mapAccumSM #-}
mapAccumSM :: Monad m => (b -> a -> m (c, b)) -> b -> SM.Stream m a -> SM.Stream m c
mapAccumSM f z0 (SM.Stream suc s0 n) = SM.Stream suc' (z0, s0) n where
	suc' (z, s) = do
	  step <- suc s
	  case step of
	    Done	-> return Done
	    Skip s'	-> return (Skip (z, s'))
	    Yield x s'	-> do
	      (y, z') <- f z x
	      return (Yield y (z', s'))

{-# INLINE mapAccumS #-}
mapAccumS :: (b -> a -> (c, b)) -> b -> Stream a -> Stream c
mapAccumS f = mapAccumSM (\ b a -> return (f b a))

{-# INLINE buildDepths #-}
-- | /O(n)/.  Given a comparison function and a vector, this function constructs a
-- @'PV.Vector' 'Int'@ with the property that the minimum index between @i@ and @j@ 
-- in the result vector is the same as the minimum index between @i@ and @j@ in the
-- original vector.  (In both cases, ties are broken by which index comes first.)
-- 
-- This allows us to use the specialized range-min implementation on @'PV.Vector' 'Int'@,
-- even for other 'Vector' implementations, other element types, and other comparison 
-- functions.
-- 
-- Internally, this function constructs the Cartesian tree of the input vector 
-- (implicitly, to save memory and stack space), and returns the vector of the 
-- depth of each element in the tree.
buildDepths :: Vector v a => LEq a -> v a -> PV.Vector Int
buildDepths (<=?) xs = makeDepths (makeTree (<=?) (G.length xs) (xs !))

-- | This method takes as input a tree, specified as an array in which the
-- @i@th entry is the parent of node @i@, or @-1@ for the root.  It returns
-- the vector of the depths of each node.
makeDepths :: CartesianTree -> PV.Vector Int
makeDepths !parents = inlineCreate $ do
	let !n = PV.length parents
	!dest <- newWith n (-1)
	let depth !i = toSTInt $ do
	      d0 <- read dest i
	      case d0 of
		-1 -> case parents ! i of -- this node has not been visited
		  -1 -> do	-- this is the root of the entire tree
		      write dest i 0
		      return 0
		  p  -> do	-- recurse to this node's parent
		      !d' <- runSTInt (depth p)
		      let !d = d' + 1
		      write dest i d
		      return d
		_  -> return d0	-- this node has been visited
	mapM_ (runSTInt . depth) [0..n-1]
	return dest

{-# INLINE makeTree #-}
-- | This method constructs the cartesian tree of the input.
makeTree :: LEq a -> Int -> (Int -> a) -> CartesianTree
makeTree (<=?) !n look = inlineCreate $ do
	!dest <- new n
	S.mapM_ (\ (IP i j) -> write dest i j) (neighbors $ 
		S.unfoldr (\ !i -> if i == 0 then Nothing else let
			!i' = i - 1
			xi' = look i'
			in Just ((i', xi'), i')) n)
	let parent (IP i l) = do
		r <- read dest i
		write dest i $ case (l, r) of
		  (-1, -1)  -> -1
		  (-1, r)   -> r
		  (l, -1)   -> l
		  (l, r)    -> if look l <=? look r then r else l
	S.mapM_ parent (neighbors (S.generate n (\ !i -> (i, look i))))
	return dest
 where	{-# INLINE neighbors #-}
	neighbors xs = mapAccumS suc Nil xs
	  where	suc stk (!i, xi) = let
		    run Nil = (IP i (-1), IL i Nil)
		    run stk0@(IL j stk)
		      | not (xi <=? look j)
			  = (IP i j, IL i stk0)
		      | otherwise
			  = run stk
		    in run stk

{-# RULES
	"buildDepths" buildDepths = \ (<=?) xs -> makeDepths (makeTree (<=?) (G.length xs) (xs !));
	#-}