-- |
-- Module:     Data.Vector.Algorithms.Heapsort
-- Copyright:  (c) Sergey Vinokurov 2023
-- License:    Apache-2.0 (see LICENSE)
-- Maintainer: serg.foo@gmail.com

module Data.Vector.Algorithms.Heapsort
  ( heapSort
  ) where

import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable qualified as GM

{-# INLINABLE shiftDown #-}
shiftDown :: (PrimMonad m, Ord a, GM.MVector v a) => v (PrimState m) a -> Int -> m ()
shiftDown :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown !v (PrimState m) a
v = Int -> m ()
go
  where
    !end :: Int
end = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v
    go :: Int -> m ()
go !Int
p
      | Int
c1 forall a. Ord a => a -> a -> Bool
< Int
end
      = do
        let !c2 :: Int
c2 = Int
c1 forall a. Num a => a -> a -> a
+ Int
1
        a
c1Val <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v Int
c1
        (Int
maxIdx, a
maxVal) <-
          if Int
c2 forall a. Ord a => a -> a -> Bool
< Int
end
          then do
            a
c2Val <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v Int
c2
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ if a
c1Val forall a. Ord a => a -> a -> Bool
> a
c2Val then (Int
c1, a
c1Val) else (Int
c2, a
c2Val)
          else forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
c1, a
c1Val)
        a
pVal <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead v (PrimState m) a
v Int
p
        if a
maxVal forall a. Ord a => a -> a -> Bool
> a
pVal
        then do
          forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
p a
maxVal
          forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite v (PrimState m) a
v Int
maxIdx a
pVal
          Int -> m ()
go Int
maxIdx
        else
          forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      | Bool
otherwise
      = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      where
        !c1 :: Int
c1 = Int
p forall a. Num a => a -> a -> a
* Int
2 forall a. Num a => a -> a -> a
+ Int
1

{-# INLINABLE heapify #-}
heapify :: (PrimMonad m, Ord a, GM.MVector v a) => v (PrimState m) a -> m ()
heapify :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapify !v (PrimState m) a
v =
  Int -> m ()
go (forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1)
  where
    go :: Int -> m ()
go Int
0 = forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown v (PrimState m) a
v Int
0
    go Int
n = forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown v (PrimState m) a
v Int
n forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Int -> m ()
go (Int
n forall a. Num a => a -> a -> a
- Int
1)

{-# INLINABLE heapSort #-}
-- | O(N * log(N)) regular heapsort (with 2-way heap, whereas vector-algorithm's is 4-way).
-- Can be used as a standalone sort but main purpose is fallback sort for quicksort.
--
-- Depending on GHC may be good candidate for SPECIALIZE pragma.
heapSort :: (PrimMonad m, Ord a, GM.MVector v a) => v (PrimState m) a -> m ()
heapSort :: forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapSort !v (PrimState m) a
v = do
  forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> m ()
heapify v (PrimState m) a
v
  Int -> m ()
go (forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length v (PrimState m) a
v)
  where
    go :: Int -> m ()
go Int
0 = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    go Int
n = do
      let !k :: Int
k = Int
n forall a. Num a => a -> a -> a
- Int
1
      forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
GM.unsafeSwap v (PrimState m) a
v Int
0 Int
k
      forall (m :: * -> *) a (v :: * -> * -> *).
(PrimMonad m, Ord a, MVector v a) =>
v (PrimState m) a -> Int -> m ()
shiftDown (forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
0 Int
k v (PrimState m) a
v) Int
0
      Int -> m ()
go Int
k