{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Massiv.Array.Ops.Sort
  ( quicksort
  , quicksortM_
  , unsafeUnstablePartitionRegionM
  ) where
import Control.Monad (when)
import Control.Scheduler
import Data.Massiv.Array.Mutable
import Data.Massiv.Core.Common
import System.IO.Unsafe
unsafeUnstablePartitionRegionM ::
     forall r e m. (Mutable r Ix1 e, PrimMonad m)
  => MArray (PrimState m) r Ix1 e
  -> (e -> Bool)
  -> Ix1 
  -> Ix1 
  -> m Ix1
unsafeUnstablePartitionRegionM marr f start end = fromLeft start (end + 1)
  where
    fromLeft i j
      | i == j = pure i
      | otherwise = do
        x <- unsafeRead marr i
        if f x
          then fromLeft (i + 1) j
          else fromRight i (j - 1)
    fromRight i j
      | i == j = pure i
      | otherwise = do
        x <- unsafeRead marr j
        if f x
          then do
            unsafeWrite marr j =<< unsafeRead marr i
            unsafeWrite marr i x
            fromLeft (i + 1) j
          else fromRight i (j - 1)
{-# INLINE unsafeUnstablePartitionRegionM #-}
quicksort ::
     (Mutable r Ix1 e, Ord e) => Array r Ix1 e -> Array r Ix1 e
quicksort arr = unsafePerformIO $ withMArray' arr quicksortM_
{-# INLINE quicksort #-}
withMArray' ::
     (Mutable r ix e, MonadUnliftIO m)
  => Array r ix e
  -> (Scheduler m () -> MArray RealWorld r ix e -> m a)
  -> m (Array r ix e)
withMArray' arr action = do
  marr <- thaw arr
  withScheduler_ (getComp arr) $ \scheduler -> action scheduler marr
  liftIO $ unsafeFreeze (getComp arr) marr
{-# INLINE withMArray' #-}
quicksortM_ ::
     (Ord e, Mutable r Ix1 e, PrimMonad m)
  => Scheduler m ()
  -> MArray (PrimState m) r Ix1 e
  -> m ()
quicksortM_ scheduler marr =
  scheduleWork scheduler $ qsort (numWorkers scheduler) 0 (unSz (msize marr) - 1)
  where
    leSwap i j = do
      ei <- unsafeRead marr i
      ej <- unsafeRead marr j
      if ei < ej
        then do
          unsafeWrite marr i ej
          unsafeWrite marr j ei
          pure ei
        else pure ej
    {-# INLINE leSwap #-}
    getPivot lo hi = do
      let !mid = (hi + lo) `div` 2
      _ <- leSwap mid lo
      _ <- leSwap hi lo
      leSwap mid hi
    {-# INLINE getPivot #-}
    qsort !n !lo !hi =
      when (lo < hi) $ do
        p <- getPivot lo hi
        l <- unsafeUnstablePartitionRegionM marr (< p) lo (hi - 1)
        h <- unsafeUnstablePartitionRegionM marr (== p) l hi
        if n > 0
          then do
            let !n' = n - 1
            scheduleWork scheduler $ qsort n' lo (l - 1)
            scheduleWork scheduler $ qsort n' h hi
          else do
            qsort n lo (l - 1)
            qsort n h hi
{-# INLINE quicksortM_ #-}