{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module Data.Select.Mutable.Quick
(select)
where
import Data.Vector.Generic.Mutable (MVector)
import qualified Data.Vector.Generic.Mutable as Vector
import Data.Vector.Mutable.Partition
import Data.Median.Optimal
import Data.Select.Optimal
import Control.Applicative.LiftMany
import Control.Monad.ST
#if !MIN_VERSION_base(4,8,0)
import Data.Functor ((<$>))
import Control.Applicative (pure)
#endif
select
:: MVector v a
=> (a -> a -> Bool) -> v s a -> Int -> Int -> Int -> ST s Int
select lte !xs !l' !r' !n = go l' r'
where
go !l !r =
case r - l of
0 -> pure l
1 ->
(l +) <$>
liftA2
(select2 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
2 ->
(l +) <$>
liftA3
(select3 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
(Vector.unsafeRead xs (l + 2))
3 ->
(l +) <$>
liftA4
(select4 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
(Vector.unsafeRead xs (l + 2))
(Vector.unsafeRead xs (l + 3))
4 ->
(l +) <$>
liftA5
(select5 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
(Vector.unsafeRead xs (l + 2))
(Vector.unsafeRead xs (l + 3))
(Vector.unsafeRead xs (l + 4))
s -> do
i <-
partition lte xs l r =<<
((l +) <$>
liftA3
(median3 lte)
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + (s `div` 2)))
(Vector.unsafeRead xs r))
case compare n i of
EQ -> pure n
LT -> go l (i - 1)
GT -> go (i + 1) r
{-# INLINABLE go #-}
{-# INLINE select #-}