{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module Data.Select.Mutable.Median
(select)
where
import Data.Vector.Generic.Mutable (MVector)
import qualified Data.Vector.Generic.Mutable as Vector
import Data.Vector.Mutable.Partition
import Data.Median.Optimal
import Data.Select.Optimal
import Control.Monad.ST
import Control.Applicative.LiftMany
#if !MIN_VERSION_base(4,8,0)
import Data.Functor ((<$>))
import Control.Applicative (pure)
#endif
select
:: MVector v a
=> (a -> a -> Bool) -> v s a -> Int -> Int -> Int -> ST s Int
select lte !xs !l' !r' !n' = go l' r' n'
where
go l r n =
case r - l of
0 -> pure l
1 ->
(l +) <$>
liftA2
(select2 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
2 ->
(l +) <$>
liftA3
(select3 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
(Vector.unsafeRead xs (l + 2))
3 ->
(l +) <$>
liftA4
(select4 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
(Vector.unsafeRead xs (l + 2))
(Vector.unsafeRead xs (l + 3))
4 ->
(l +) <$>
liftA5
(select5 lte (n - l))
(Vector.unsafeRead xs l)
(Vector.unsafeRead xs (l + 1))
(Vector.unsafeRead xs (l + 2))
(Vector.unsafeRead xs (l + 3))
(Vector.unsafeRead xs (l + 4))
_ -> do
i <- partition lte xs l r =<< pivot l
case compare n i of
EQ -> pure n
LT -> go l (i - 1) n
GT -> go (i + 1) r n
where
pivot = pgo
where
pgo !i =
case r - i of
0 -> do
Vector.unsafeSwap xs i j
end
1 -> do
Vector.unsafeSwap xs i j
end
2 -> do
m <-
liftA3
(median3 lte)
(Vector.unsafeRead xs i)
(Vector.unsafeRead xs (i + 1))
(Vector.unsafeRead xs (i + 2))
Vector.unsafeSwap xs (i + m) j
end
3 -> do
m <-
liftA4
(median4 lte)
(Vector.unsafeRead xs i)
(Vector.unsafeRead xs (i + 1))
(Vector.unsafeRead xs (i + 2))
(Vector.unsafeRead xs (i + 3))
Vector.unsafeSwap xs (i + m) j
end
4 -> do
m <-
liftA5
(median5 lte)
(Vector.unsafeRead xs i)
(Vector.unsafeRead xs (i + 1))
(Vector.unsafeRead xs (i + 2))
(Vector.unsafeRead xs (i + 3))
(Vector.unsafeRead xs (i + 4))
Vector.unsafeSwap xs (i + m) j
end
_ -> do
m <-
liftA5
(median5 lte)
(Vector.unsafeRead xs i)
(Vector.unsafeRead xs (i + 1))
(Vector.unsafeRead xs (i + 2))
(Vector.unsafeRead xs (i + 3))
(Vector.unsafeRead xs (i + 4))
Vector.unsafeSwap xs (i + m) j
pgo (i + 5)
where
!j = l + ((i - l) `div` 5)
end = go l (l + ((r - l) `div` 5)) (l + ((r - l) `div` 10))
{-# INLINE select #-}