{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP          #-}

-- |
-- Module      : Data.Select.Mutable.Median
-- Description : median-of-medians internals on mutable, generic vectors.
-- Copyright   : (c) Donnacha Oisín Kidney, 2018
-- License     : MIT
-- Maintainer  : mail@doisinkidney.com
-- Stability   : experimental
-- Portability : portable
--
-- Median-of-medians internals on mutable, boxed vectors.
module Data.Select.Mutable.Median
  (select)
  where

import           Data.Vector.Generic.Mutable          (MVector)
import qualified Data.Vector.Generic.Mutable          as Vector

import           Data.Vector.Mutable.Partition

import           Data.Median.Optimal
import           Data.Select.Optimal

import           Control.Monad.ST
import           Control.Applicative.LiftMany

#if !MIN_VERSION_base(4,8,0)
import           Data.Functor ((<$>))
import           Control.Applicative (pure)
#endif

-- | @'select' ('<=') xs lb ub n@ returns the 'n'th item in the
-- indices in the inclusive range ['lb','ub'].
select
    :: MVector v a
    => (a -> a -> Bool) -> v s a -> Int -> Int -> Int -> ST s Int
select lte !xs !l' !r' !n' = go l' r' n'
  where
    go l r n =
        case r - l of
            0 -> pure l
            1 ->
                (l +) <$>
                liftA2
                    (select2 lte (n - l))
                    (Vector.unsafeRead xs l)
                    (Vector.unsafeRead xs (l + 1))
            2 ->
                (l +) <$>
                liftA3
                    (select3 lte (n - l))
                    (Vector.unsafeRead xs l)
                    (Vector.unsafeRead xs (l + 1))
                    (Vector.unsafeRead xs (l + 2))
            3 ->
                (l +) <$>
                liftA4
                    (select4 lte (n - l))
                    (Vector.unsafeRead xs l)
                    (Vector.unsafeRead xs (l + 1))
                    (Vector.unsafeRead xs (l + 2))
                    (Vector.unsafeRead xs (l + 3))
            4 ->
                (l +) <$>
                liftA5
                    (select5 lte (n - l))
                    (Vector.unsafeRead xs l)
                    (Vector.unsafeRead xs (l + 1))
                    (Vector.unsafeRead xs (l + 2))
                    (Vector.unsafeRead xs (l + 3))
                    (Vector.unsafeRead xs (l + 4))
            _ -> do
                i <- partition lte xs l r =<< pivot l
                case compare n i of
                    EQ -> pure n
                    LT -> go l (i - 1) n
                    GT -> go (i + 1) r n
      where
        pivot = pgo
          where
            pgo !i =
                case r - i of
                    0 -> do
                        Vector.unsafeSwap xs i j
                        end
                    1 -> do
                        Vector.unsafeSwap xs i j
                        end
                    2 -> do
                        m <-
                            liftA3
                                (median3 lte)
                                (Vector.unsafeRead xs i)
                                (Vector.unsafeRead xs (i + 1))
                                (Vector.unsafeRead xs (i + 2))
                        Vector.unsafeSwap xs (i + m) j
                        end
                    3 -> do
                        m <-
                            liftA4
                                (median4 lte)
                                (Vector.unsafeRead xs i)
                                (Vector.unsafeRead xs (i + 1))
                                (Vector.unsafeRead xs (i + 2))
                                (Vector.unsafeRead xs (i + 3))
                        Vector.unsafeSwap xs (i + m) j
                        end
                    4 -> do
                        m <-
                            liftA5
                                (median5 lte)
                                (Vector.unsafeRead xs i)
                                (Vector.unsafeRead xs (i + 1))
                                (Vector.unsafeRead xs (i + 2))
                                (Vector.unsafeRead xs (i + 3))
                                (Vector.unsafeRead xs (i + 4))
                        Vector.unsafeSwap xs (i + m) j
                        end
                    _ -> do
                        m <-
                            liftA5
                                (median5 lte)
                                (Vector.unsafeRead xs i)
                                (Vector.unsafeRead xs (i + 1))
                                (Vector.unsafeRead xs (i + 2))
                                (Vector.unsafeRead xs (i + 3))
                                (Vector.unsafeRead xs (i + 4))
                        Vector.unsafeSwap xs (i + m) j
                        pgo (i + 5)
              where
                !j = l + ((i - l) `div` 5)
            end = go l (l + ((r - l) `div` 5)) (l + ((r - l) `div` 10))
{-# INLINE select #-}