-- | Miscellaneous vector methods.
--
-- @since 1.2.2.0
module AtCoder.Extra.Vector
  ( argsort,
    unsafePermuteInPlace,
    unsafePermuteInPlaceST,
  )
where

import AtCoder.Internal.Assert qualified as ACIA
import Control.Monad (unless)
import Control.Monad.Primitive (PrimMonad, PrimState, stToPrim)
import Control.Monad.ST (ST)
import Data.Vector.Algorithms.Intro qualified as VAI
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU

-- TODO: test `unsafePermuteInPlace`
-- TODO: is `unsafePermuteInPlace` fast enough as specialized one?

-- | \(O(n \log n)\) Returns indices of the vector, stably sorted by their value.
--
-- ==== Example
-- >>> import Data.Vector.Algorithms.Intro qualified as VAI
-- >>> import Data.Vector.Unboxed qualified as VU
-- >>> argsort $ VU.fromList [0, 1, 0, 1, 0]
-- [0,2,4,1,3]
{-# INLINE argsort #-}
argsort :: (Ord a, VU.Unbox a) => VU.Vector a -> VU.Vector Int
argsort :: forall a. (Ord a, Unbox a) => Vector a -> Vector Int
argsort Vector a
xs =
  (forall s. MVector s Int -> ST s ()) -> Vector Int -> Vector Int
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
VU.modify
    ( Comparison Int -> MVector (PrimState (ST s)) Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
VAI.sortBy
        ( \Int
i Int
j ->
            ( a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Vector a
xs Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i) (Vector a
xs Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
j) Ordering -> Ordering -> Ordering
forall a. Semigroup a => a -> a -> a
<> Comparison Int
forall a. Ord a => a -> a -> Ordering
compare Int
i Int
j
            )
        )
    )
    (Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate (Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
xs) Int -> Int
forall a. a -> a
id

-- | \(O(n)\) Applies a permutation to a mutable vector in-place.
--
-- ==== Constraints
-- - The index array must be a permutation (0-based).
{-# INLINE unsafePermuteInPlace #-}
unsafePermuteInPlace :: (PrimMonad m, VGM.MVector v a) => v (PrimState m) a -> VU.Vector Int -> m ()
unsafePermuteInPlace :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Vector Int -> m ()
unsafePermuteInPlace v (PrimState m) a
vec Vector Int
is = ST (PrimState m) () -> m ()
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) () -> m ()) -> ST (PrimState m) () -> m ()
forall a b. (a -> b) -> a -> b
$ v (PrimState m) a -> Vector Int -> ST (PrimState m) ()
forall (v :: * -> * -> *) a s.
MVector v a =>
v s a -> Vector Int -> ST s ()
unsafePermuteInPlaceST v (PrimState m) a
vec Vector Int
is

-- | \(O(n)\) Applies a permutation to a mutable vector in-place.
--
-- ==== Constraints
-- - The index array must be a permutation (0-based).
{-# INLINEABLE unsafePermuteInPlaceST #-}
unsafePermuteInPlaceST :: (VGM.MVector v a) => v s a -> VU.Vector Int -> ST s ()
unsafePermuteInPlaceST :: forall (v :: * -> * -> *) a s.
MVector v a =>
v s a -> Vector Int -> ST s ()
unsafePermuteInPlaceST v s a
vec Vector Int
is = do
  let !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (v s a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
VGM.length v s a
vec Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
is) String
"AtCoder.Extra.Vector.unsafePermuteInPlaceST: the length of the index array must be equal to the length of the permuted vector"
  let inner :: Int -> a -> m ()
inner Int
i a
lastX = do
        v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite v s a
v (PrimState m) a
vec Int
i a
lastX
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          let i0' :: Int
i0' = Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
is Int
i
          a
lastX' <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead v s a
v (PrimState m) a
vec Int
i
          Int -> a -> m ()
inner Int
i0' a
lastX'

  let i0' :: Int
i0' = Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
is Int
0
  a
x0' <- v (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead v s a
v (PrimState (ST s)) a
vec Int
0
  Int -> a -> ST s ()
forall {m :: * -> *}.
(PrimState m ~ s, PrimMonad m) =>
Int -> a -> m ()
inner Int
i0' a
x0'