module DataSketches.Quantiles.RelativeErrorQuantile.Internal.InequalitySearch where
import Control.Monad.Primitive
import Data.Vector.Generic.Mutable (MVector)
import qualified Data.Vector.Generic.Mutable as MV
data (:<) = (:<)
data (:<=) = (:<=)
data (:>) = (:>)
data (:>=) = (:>=)
-- JavaDoc copypasta
--
-- This provides efficient, unique and unambiguous binary searching for inequality comparison criteria
-- for ordered arrays of values that may include duplicate values. The inequality criteria include
-- <, >, ==, >=, <=. All the inequality criteria use the same search algorithm.
-- (Although == is not an inequality, it is included for convenience.)
-- In order to make the searching unique and unambiguous, we modified the traditional binary
-- search algorithm to search for adjacent pairs of values {A, B} in the values array
-- instead of just a single value, where A and B are the array indicies of two
-- adjacent values in the array. For all the search criteria, if the algorithm reaches the ends of
-- the search range, the algorithm calls the resolve() method to determine what to
-- return to the caller. If the key value cannot be resolved, it returns a -1 to the caller.
-- Given an array of values arr[] and the search key value v, the algorithms for
-- the searching criteria are as follows:
--
-- LT: Find the highest ranked adjacent pair {A, B} such that:
-- arr[A] < v ≤ arr[B]. The normal return is the index A.
--
-- LE: Find the highest ranked adjacent pair {A, B} such that:
-- arr[A] ≤ v < arr[B]. The normal return is the index A.
--
-- EQ: Find the adjacent pair {A, B} such that:
-- arr[A] ≤ v ≤ arr[B]. The normal return is the index A or B whichever
-- equals v, otherwise it returns -1.
--
-- GE: Find the lowest ranked adjacent pair {A, B} such that:
-- arr[A] < v ≤ arr[B]. The normal return is the index B.
--
-- GT: Find the lowest ranked adjacent pair {A, B} such that:
-- arr[A] ≤ v < arr[B]. The normal return is the index B.
--
--
class InequalitySearch s where
inequalityCompare :: Ord a
=> s
-> a
-- ^ V
-> a
-- ^ A
-> a
-- ^ B
-> Ordering
-- ^ 'GT' means we must search higher in the array, 'LT' means we must
-- search lower in the array, or `EQ`, which means we have found
-- the correct bounding pair.
getIndex
:: (PrimMonad m, MVector v a, Ord a)
=> s
-> v (PrimState m) a
-> Int
-> Int
-> a
-> m Int
resolve
:: s
-> Int -- Vector length
-> (Int, Int)
-- ^ Final low index, high index (lo, hi)
-> (Int, Int)
-- ^ Initial search region (low, high)
-> Int
-- ^ A thing
instance InequalitySearch (:<) where
inequalityCompare _ v a b
| v <= a = LT
| b < v = GT
| otherwise = EQ
getIndex _ _ a _ _ = pure a
resolve _ vl (lo, hi) (low, high) = if lo >= high then high else vl
instance InequalitySearch (:<=) where
inequalityCompare _ v a b
| v < a = LT
| b <= v = GT
| otherwise = EQ
getIndex _ _ a _ _ = pure a
resolve _ vl (lo, hi) (low, high) = if lo >= high then high else vl
instance InequalitySearch (:>) where
inequalityCompare _ v a b
| v < a = LT
| b <= v = GT
| otherwise = EQ
getIndex _ _ _ b _ = pure b
resolve _ vl (lo, hi) (low, high) = if hi <= low then low else vl
instance InequalitySearch (:>=) where
inequalityCompare _ v a b
| v <= a = LT
| b < v = GT
| otherwise = EQ
getIndex _ _ _ b _ = pure b
resolve _ vl (lo, hi) (low, high) = if hi <= low then low else vl
find :: (InequalitySearch s, PrimMonad m, MVector v a, Ord a) => s -> v (PrimState m) a -> Int -> Int -> a -> m Int
find strat v low high x = go low (high - 1)
where
go lo hi
| lo <= hi && lo < high = do
let mid = lo + ((hi - lo) `div` 2)
midV <- MV.read v mid
midV' <- MV.read v (mid + 1)
case inequalityCompare strat x midV midV' of
LT -> go lo (mid - 1)
EQ -> getIndex strat v mid (mid + 1) x
GT -> go (mid + 1) hi
| otherwise = pure $! resolve strat (MV.length v) (lo, hi) (low, high)
{-# INLINE find #-}