module DataSketches.Quantiles.RelativeErrorQuantile.Internal.InequalitySearch where

import Control.Monad.Primitive
import Data.Vector.Generic.Mutable (MVector)
import qualified Data.Vector.Generic.Mutable as MV

data (:<) = (:<)
data (:<=) = (:<=)
data (:>) = (:>)
data (:>=) = (:>=)
-- JavaDoc copypasta
-- 
-- This provides efficient, unique and unambiguous binary searching for inequality comparison criteria
--  for ordered arrays of values that may include duplicate values. The inequality criteria include
--  <, >, ==, >=, <=. All the inequality criteria use the same search algorithm.
--  (Although == is not an inequality, it is included for convenience.)

--  In order to make the searching unique and unambiguous, we modified the traditional binary
--  search algorithm to search for adjacent pairs of values <i>{A, B}</i> in the values array
--  instead of just a single value, where <i>A</i> and <i>B</i> are the array indicies of two
--  adjacent values in the array. For all the search criteria, if the algorithm reaches the ends of
--  the search range, the algorithm calls the <i>resolve()</i> method to determine what to
--   return to the caller. If the key value cannot be resolved, it returns a -1 to the caller.

--  Given an array of values <i>arr[]</i> and the search key value <i>v</i>, the algorithms for
--  the searching criteria are as follows:</p>
-- 
--  <li><b>LT:</b> Find the highest ranked adjacent pair <i>{A, B}</i> such that:<br>
--  <i>arr[A] &lt; v &le; arr[B]</i>. The normal return is the index <i>A</i>.
--  </li>
--  <li><b>LE:</b>  Find the highest ranked adjacent pair <i>{A, B}</i> such that:<br>
--  <i>arr[A] &le; v &lt; arr[B]</i>. The normal return is the index <i>A</i>.
--  </li>
--  <li><b>EQ:</b>  Find the adjacent pair <i>{A, B}</i> such that:<br>
--  <i>arr[A] &le; v &le; arr[B]</i>. The normal return is the index <i>A</i> or <i>B</i> whichever
--  equals <i>v</i>, otherwise it returns -1.
--  </li>
--  <li><b>GE:</b>  Find the lowest ranked adjacent pair <i>{A, B}</i> such that:<br>
--  <i>arr[A] &lt; v &le; arr[B]</i>. The normal return is the index <i>B</i>.
--  </li>
--  <li><b>GT:</b>  Find the lowest ranked adjacent pair <i>{A, B}</i> such that:<br>
--  <i>arr[A] &le; v &lt; arr[B]</i>. The normal return is the index <i>B</i>.
--  </li>
--  </ul>
class InequalitySearch s where
  inequalityCompare :: Ord a
    => s
    -> a 
    -- ^ V
    -> a
    -- ^ A
    -> a
    -- ^ B
    -> Ordering
    -- ^ 'GT' means we must search higher in the array, 'LT' means we must
    -- search lower in the array, or `EQ`, which means we have found 
    -- the correct bounding pair.
  getIndex 
    :: (PrimMonad m, MVector v a, Ord a) 
    => s 
    -> v (PrimState m) a 
    -> Int 
    -> Int 
    -> a 
    -> m Int
  resolve 
    :: s 
    -> Int -- Vector length
    -> (Int, Int) 
    -- ^ Final low index, high index (lo, hi)
    -> (Int, Int) 
    -- ^ Initial search region (low, high)
    -> Int
    -- ^ A thing

instance InequalitySearch (:<) where
  inequalityCompare :: (:<) -> a -> a -> a -> Ordering
inequalityCompare (:<)
_ a
v a
a a
b 
    | a
v a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
a = Ordering
LT
    | a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
v = Ordering
GT
    | Bool
otherwise = Ordering
EQ
  getIndex :: (:<) -> v (PrimState m) a -> Int -> Int -> a -> m Int
getIndex (:<)
_ v (PrimState m) a
_ Int
a Int
_ a
_ = Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
a
  resolve :: (:<) -> Int -> (Int, Int) -> (Int, Int) -> Int
resolve (:<)
_ Int
vl (Int
lo, Int
hi) (Int
low, Int
high) = if Int
lo Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
high then Int
high else Int
vl

instance InequalitySearch (:<=) where
  inequalityCompare :: (:<=) -> a -> a -> a -> Ordering
inequalityCompare (:<=)
_ a
v a
a a
b
    | a
v a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
a = Ordering
LT
    | a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
v = Ordering
GT
    | Bool
otherwise = Ordering
EQ
  getIndex :: (:<=) -> v (PrimState m) a -> Int -> Int -> a -> m Int
getIndex (:<=)
_ v (PrimState m) a
_ Int
a Int
_ a
_ = Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
a
  resolve :: (:<=) -> Int -> (Int, Int) -> (Int, Int) -> Int
resolve (:<=)
_ Int
vl (Int
lo, Int
hi) (Int
low, Int
high) = if Int
lo Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
high then Int
high else Int
vl 

instance InequalitySearch (:>) where
  inequalityCompare :: (:>) -> a -> a -> a -> Ordering
inequalityCompare (:>)
_ a
v a
a a
b
    | a
v a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
a = Ordering
LT
    | a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
v = Ordering
GT
    | Bool
otherwise = Ordering
EQ
  getIndex :: (:>) -> v (PrimState m) a -> Int -> Int -> a -> m Int
getIndex (:>)
_ v (PrimState m) a
_ Int
_ Int
b a
_ = Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
b
  resolve :: (:>) -> Int -> (Int, Int) -> (Int, Int) -> Int
resolve (:>)
_ Int
vl (Int
lo, Int
hi) (Int
low, Int
high) = if Int
hi Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
low then Int
low else Int
vl

instance InequalitySearch (:>=) where
  inequalityCompare :: (:>=) -> a -> a -> a -> Ordering
inequalityCompare (:>=)
_ a
v a
a a
b
    | a
v a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
a = Ordering
LT
    | a
b a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
v = Ordering
GT
    | Bool
otherwise = Ordering
EQ
  getIndex :: (:>=) -> v (PrimState m) a -> Int -> Int -> a -> m Int
getIndex (:>=)
_ v (PrimState m) a
_ Int
_ Int
b a
_ = Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
b
  resolve :: (:>=) -> Int -> (Int, Int) -> (Int, Int) -> Int
resolve (:>=)
_ Int
vl (Int
lo, Int
hi) (Int
low, Int
high) = if Int
hi Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
low then Int
low else Int
vl

find :: (InequalitySearch s, PrimMonad m, MVector v a, Ord a) => s -> v (PrimState m) a -> Int -> Int -> a -> m Int
find :: s -> v (PrimState m) a -> Int -> Int -> a -> m Int
find s
strat v (PrimState m) a
v Int
low Int
high a
x = Int -> Int -> m Int
go Int
low (Int
high Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
  where
    go :: Int -> Int -> m Int
go Int
lo Int
hi 
      | Int
lo Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
hi Bool -> Bool -> Bool
&& Int
lo Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
high = do
          let mid :: Int
mid = Int
lo Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ((Int
hi Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
lo) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)
          a
midV <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
MV.read v (PrimState m) a
v Int
mid
          a
midV' <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
MV.read v (PrimState m) a
v (Int
mid Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          case s -> a -> a -> a -> Ordering
forall s a.
(InequalitySearch s, Ord a) =>
s -> a -> a -> a -> Ordering
inequalityCompare s
strat a
x a
midV a
midV' of
            Ordering
LT -> Int -> Int -> m Int
go Int
lo (Int
mid Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
            Ordering
EQ -> s -> v (PrimState m) a -> Int -> Int -> a -> m Int
forall s (m :: * -> *) (v :: * -> * -> *) a.
(InequalitySearch s, PrimMonad m, MVector v a, Ord a) =>
s -> v (PrimState m) a -> Int -> Int -> a -> m Int
getIndex s
strat v (PrimState m) a
v Int
mid (Int
mid Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) a
x
            Ordering
GT -> Int -> Int -> m Int
go (Int
mid Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
hi
      | Bool
otherwise = Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$! s -> Int -> (Int, Int) -> (Int, Int) -> Int
forall s.
InequalitySearch s =>
s -> Int -> (Int, Int) -> (Int, Int) -> Int
resolve s
strat (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
MV.length v (PrimState m) a
v) (Int
lo, Int
hi) (Int
low, Int
high)
{-# INLINE find #-}