-- |
-- Module:     Data.Vector.Algorithms.Quicksort.Parameterised
-- Copyright:  (c) Sergey Vinokurov 2023
-- License:    Apache-2.0 (see LICENSE)
-- Maintainer: serg.foo@gmail.com
--
-- This module provides fully generic quicksort for now allowing
-- caller to decide how to parallelize and how to select median. More
-- things may be parameterised in the future, likely by introducing
-- new functions taking more arguments.
--
-- === Example
-- This is how you’d define parallel sort that uses sparks on unboxed vectors of integers:
--
-- >>> import Control.Monad.ST
-- >>> import Data.Int
-- >>> import Data.Vector.Algorithms.Quicksort.Parameterised
-- >>> import Data.Vector.Unboxed qualified as U
-- >>> :{
-- let myParallelSort :: U.MVector s Int64 -> ST s ()
--     myParallelSort = sortInplaceFM defaultParStrategies (Median3or5 @Int64)
-- in U.modify myParallelSort $ U.fromList @Int64 [20, 19 .. 0]
-- :}
-- [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
--
-- === Design considerations
-- Because of reliance on specialisation, this package doesn't provide
-- sort functions that take comparator function as argument. They rely
-- on the 'Ord' instance instead. While somewhat limiting, this allows
-- to offload optimization to the @SPECIALIZE@ pragmas even if compiler
-- wasn't smart enough to monomorphise automatically.
--
-- === Performance considerations
-- Compared to the default sort this one is even more sensitive to
-- specialisation. Users caring about performance are advised to dump
-- core and ensure that sort is monomorphised. The GHC 9.6.1 was seen
-- to specialize automatically but 9.4 wasn't as good and required
-- pragmas both for the main sort function and for its helpers, like this:
--
-- > -- Either use the flag to specialize everything, ...
-- > {-# OPTIONS_GHC -fspecialise-aggressively #-}
-- >
-- > -- ... or the pragmas for specific functions
-- > import Control.Monad.ST
-- > import Data.Int
-- > import Data.Vector.Algorithms.FixedSort
-- > import Data.Vector.Algorithms.Heapsort
-- > import Data.Vector.Algorithms.Quicksort.Parameterised
-- > import Data.Vector.Unboxed qualified as U
-- >
-- > {-# SPECIALIZE heapSort      :: U.MVector s Int64 -> ST s ()        #-}
-- > {-# SPECIALIZE bitonicSort   :: Int -> U.MVector s Int64 -> ST s () #-}
-- > {-# SPECIALIZE sortInplaceFM :: Sequential -> Median3 Int64 -> U.MVector s Int64 -> ST s () #-}
--
-- === Speeding up compilation
-- In order to speed up compilations it's a good idea to introduce
-- dedicated module where all the sorts will reside and import it
-- instead of calling @sort@ or @sortInplaceFM@ in moduler with other logic.
-- This way the sort functions, which can take a while to compile, will be
-- recompiled rarely.
--
-- > module MySorts (mySequentialSort) where
-- >
-- > import Control.Monad.ST
-- > import Data.Int
-- > import Data.Vector.Unboxed qualified as U
-- >
-- > import Data.Vector.Algorithms.Quicksort.Parameterised
-- >
-- > {-# NOINLINE mySequentialSort #-}
-- > mySequentialSort :: U.MVector s Int64 -> ST s ()
-- > mySequentialSort = sortInplaceFM Sequential (Median3or5 @Int64)
--
-- === Reducing code bloat
-- Avoid using sorts with both 'ST' and 'IO' monads. Stick to the 'ST'
-- monad as much as possible because it can be easily converted to
-- 'IO' via safe 'stToIO' function. Using same sort in both 'IO' and
-- 'ST' monads will compile two versions of it along with all it’s
-- helper sorts which can be pretty big (especially the bitonic sort).

-- So that haddock will resolve references in the documentation.
{-# OPTIONS_GHC -Wno-unused-imports #-}

module Data.Vector.Algorithms.Quicksort.Parameterised
  ( sortInplaceFM
  -- * Reexports
  , module E
  ) where

import Prelude hiding (last, pi)

import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable qualified as GM

import Data.Vector.Algorithms.FixedSort
import Data.Vector.Algorithms.Heapsort
import Data.Vector.Algorithms.Quicksort.Fork2 as E
import Data.Vector.Algorithms.Quicksort.Median as E

-- For haddock
import Control.Monad.ST

{-# INLINABLE sortInplaceFM #-}
-- | Quicksort parameterised by median selection method and
-- parallelisation strategy.
sortInplaceFM
  :: forall p med x m a v.
     (Fork2 p x m, Median med a m (PrimState m), PrimMonad m, Ord a, GM.MVector v a)
  => p
  -> med
  -> v (PrimState m) a
  -> m ()
sortInplaceFM :: forall p med x (m :: * -> *) a (v :: * -> * -> *).
(Fork2 p x m, Median med a m (PrimState m), PrimMonad m, Ord a,
 MVector v a) =>
p -> med -> v (PrimState m) a -> m ()
sortInplaceFM !p
p !med
med !v (PrimState m) a
vector = do
  !x
releaseToken <- forall a x (m :: * -> *). Fork2 a x m => a -> m x
startWork p
p
  -- ParStrategies requires forcing the unit, otherwise we may return
  -- while some sparks are still working.
  () <- Int -> x -> v (PrimState m) a -> m ()
qsortLoop Int
0 x
releaseToken v (PrimState m) a
vector
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  where
    -- If we select bad median 4 times in a row then fall back to heapsort.
    !cutoffLen :: Int
cutoffLen = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
vector

    !logLen :: Int
logLen = Int -> Int
binlog2 (forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
vector)

    !threshold :: Int
threshold = Int
2 forall a. Num a => a -> a -> a
* Int
logLen

    qsortLoop :: Int -> x -> v (PrimState m) a -> m ()
    qsortLoop :: Int -> x -> v (PrimState m) a -> m ()
qsortLoop !Int
depth !x
releaseToken !v (PrimState m) a
v
      | Int
len forall a. Ord a => a -> a -> Bool
< Int
17
      = forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, Ord a, MVector v a) =>
Int -> v (PrimState m) a -> m ()
bitonicSort Int
len v (PrimState m) a
v forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall a x (m :: * -> *). Fork2 a x m => a -> x -> m ()
endWork p
p x
releaseToken

      | Int
depth forall a. Eq a => a -> a -> Bool
== Int
threshold Bool -> Bool -> Bool
|| if Int
depthDiff forall a. Ord a => a -> a -> Bool
> Int
0 then Int
len forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
depthDiff forall a. Ord a => a -> a -> Bool
> Int
cutoffLen else Bool
False
      = forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapSort v (PrimState m) a
v forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall a x (m :: * -> *). Fork2 a x m => a -> x -> m ()
endWork p
p x
releaseToken

      | Bool
otherwise = do
        let !last :: Int
last = Int
len forall a. Num a => a -> a -> a
- Int
1
            v' :: v (PrimState m) a
v'    = forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
last v (PrimState m) a
v
        MedianResult a
res <- forall a b (m :: * -> *) s (v :: * -> * -> *).
(Median a b m s, MVector v b, Ord b) =>
a -> v s b -> m (MedianResult b)
selectMedian med
med v (PrimState m) a
v'

        (!Int
pi', !a
pv) <- case MedianResult a
res of
          Guess a
pv -> do
            (a
_, !Int
pi') <- forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysGuessedPivot a
pv Int
last v (PrimState m) a
v
            forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
pi', a
pv)

          ExistingValue a
pv Int
pi -> do
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
pi forall a. Eq a => a -> a -> Bool
/= Int
last) forall a b. (a -> b) -> a -> b
$ do
              forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
pi forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v Int
last
              forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
last a
pv
            (!a
xi, !Int
pi') <- forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysPivotAtEnd a
pv (Int
last forall a. Num a => a -> a -> a
- Int
1) v (PrimState m) a
v
            forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
pi' a
pv
            forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
last a
xi
            forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
pi' forall a. Num a => a -> a -> a
+ Int
1, a
pv)

        !Int
pi'' <- forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Eq a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m Int
skipEq a
pv Int
pi' v (PrimState m) a
v

        let !left :: v (PrimState m) a
left   = forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
pi' v (PrimState m) a
v
            !right :: v (PrimState m) a
right  = forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
pi'' (Int
len forall a. Num a => a -> a -> a
- Int
pi'') v (PrimState m) a
v
            !depth' :: Int
depth' = Int
depth forall a. Num a => a -> a -> a
+ Int
1
        forall a x (m :: * -> *) b d.
(Fork2 a x m, HasLength b, HasLength d) =>
a
-> x
-> Int
-> (x -> b -> m ())
-> (x -> d -> m ())
-> b
-> d
-> m ()
fork2
          p
p
          x
releaseToken
          Int
depth
          (Int -> x -> v (PrimState m) a -> m ()
qsortLoop Int
depth')
          (Int -> x -> v (PrimState m) a -> m ()
qsortLoop Int
depth')
          v (PrimState m) a
left
          v (PrimState m) a
right
      where
        !len :: Int
len       = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v
        !depthDiff :: Int
depthDiff = Int
depth forall a. Num a => a -> a -> a
- Int
logLen

{-# INLINE partitionTwoWaysGuessedPivot #-}
partitionTwoWaysGuessedPivot
  :: (PrimMonad m, Ord a, GM.MVector v a)
  => a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysGuessedPivot :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysGuessedPivot !a
pv !Int
lastIdx !v (PrimState m) a
v =
  Int -> Int -> m (a, Int)
go Int
0 Int
lastIdx
  where
    go :: Int -> Int -> m (a, Int)
go !Int
i !Int
j = do
      !(Int
i', a
xi) <- Int -> m (Int, a)
goLT Int
i
      !(Int
j', a
xj) <- Int -> m (Int, a)
goGT Int
j
      if Int
i' forall a. Ord a => a -> a -> Bool
< Int
j'
      then do
        forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
j' a
xi
        forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
i' a
xj
        Int -> Int -> m (a, Int)
go (Int
i' forall a. Num a => a -> a -> a
+ Int
1) (Int
j' forall a. Num a => a -> a -> a
- Int
1)
      else forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
xi, Int
i')
      where
        goLT :: Int -> m (Int, a)
goLT !Int
k = do
          if Int
k forall a. Ord a => a -> a -> Bool
<= Int
j
          then do
            !a
x <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v Int
k
            if a
x forall a. Ord a => a -> a -> Bool
< a
pv
            then Int -> m (Int, a)
goLT (Int
k forall a. Num a => a -> a -> a
+ Int
1)
            else forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
k, a
x)
          -- Be careful not to write this pv into array - pv may not exsit there.
          else forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
k, a
pv)
        goGT :: Int -> m (Int, a)
goGT !Int
k = do
          !a
x <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v Int
k
          if a
x forall a. Ord a => a -> a -> Bool
>= a
pv Bool -> Bool -> Bool
&& Int
i forall a. Ord a => a -> a -> Bool
< Int
k
          then Int -> m (Int, a)
goGT (Int
k forall a. Num a => a -> a -> a
- Int
1)
          else forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
k, a
x)

{-# INLINE partitionTwoWaysPivotAtEnd #-}
partitionTwoWaysPivotAtEnd
  :: (PrimMonad m, Ord a, GM.MVector v a)
  => a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysPivotAtEnd :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m (a, Int)
partitionTwoWaysPivotAtEnd !a
pv !Int
lastIdx !v (PrimState m) a
v =
  Int -> Int -> m (a, Int)
go Int
0 Int
lastIdx
  where
    go :: Int -> Int -> m (a, Int)
go !Int
i !Int
j = do
      !(Int
i', a
xi) <- Int -> m (Int, a)
goLT Int
i
      !(Int
j', a
xj) <- Int -> m (Int, a)
goGT Int
j
      if Int
i' forall a. Ord a => a -> a -> Bool
< Int
j'
      then do
        forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
j' a
xi
        forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
i' a
xj
        Int -> Int -> m (a, Int)
go (Int
i' forall a. Num a => a -> a -> a
+ Int
1) (Int
j' forall a. Num a => a -> a -> a
- Int
1)
      else forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
xi, Int
i')
      where
        goLT :: Int -> m (Int, a)
goLT !Int
k = do
          !a
x <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v Int
k
          if a
x forall a. Ord a => a -> a -> Bool
< a
pv Bool -> Bool -> Bool
&& Int
k forall a. Ord a => a -> a -> Bool
<= Int
j
          then Int -> m (Int, a)
goLT (Int
k forall a. Num a => a -> a -> a
+ Int
1)
          else forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
k, a
x)
        goGT :: Int -> m (Int, a)
goGT !Int
k = do
          !a
x <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v Int
k
          if a
x forall a. Ord a => a -> a -> Bool
>= a
pv Bool -> Bool -> Bool
&& Int
i forall a. Ord a => a -> a -> Bool
< Int
k
          then Int -> m (Int, a)
goGT (Int
k forall a. Num a => a -> a -> a
- Int
1)
          else forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
k, a
x)

{-# INLINE skipEq #-}
-- Idetnify multiple pivots that are equal to the one we were partitioning with so that
-- whole run of equal pivots can be excluded from recursion.
skipEq :: (PrimMonad m, Eq a, GM.MVector v a) => a -> Int -> v (PrimState m) a -> m Int
skipEq :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Eq a, MVector v a) =>
a -> Int -> v (PrimState m) a -> m Int
skipEq !a
x !Int
start !v (PrimState m) a
v = Int -> m Int
go Int
start
  where
    !last :: Int
last = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v
    go :: Int -> m Int
go !Int
k
      | Int
k forall a. Ord a => a -> a -> Bool
< Int
last
      = do
        !a
y <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v Int
k
        if a
y forall a. Eq a => a -> a -> Bool
== a
x
        then Int -> m Int
go (Int
k forall a. Num a => a -> a -> a
+ Int
1)
        else forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
k
      | Bool
otherwise
      = forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
k

{-# INLINE binlog2 #-}
binlog2 :: Int -> Int
binlog2 :: Int -> Int
binlog2 Int
x = forall b. FiniteBits b => b -> Int
finiteBitSize Int
x forall a. Num a => a -> a -> a
- Int
1 forall a. Num a => a -> a -> a
- forall b. FiniteBits b => b -> Int
countLeadingZeros Int
x