module AtCoder.Extra.Vector
( argsort,
unsafePermuteInPlace,
unsafePermuteInPlaceST,
)
where
import AtCoder.Internal.Assert qualified as ACIA
import Control.Monad (unless)
import Control.Monad.Primitive (PrimMonad, PrimState, stToPrim)
import Control.Monad.ST (ST)
import Data.Vector.Algorithms.Intro qualified as VAI
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
{-# INLINE argsort #-}
argsort :: (Ord a, VU.Unbox a) => VU.Vector a -> VU.Vector Int
argsort :: forall a. (Ord a, Unbox a) => Vector a -> Vector Int
argsort Vector a
xs =
(forall s. MVector s Int -> ST s ()) -> Vector Int -> Vector Int
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
VU.modify
( Comparison Int -> MVector (PrimState (ST s)) Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
VAI.sortBy
( \Int
i Int
j ->
( a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Vector a
xs Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
i) (Vector a
xs Vector a -> Int -> a
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
j) Ordering -> Ordering -> Ordering
forall a. Semigroup a => a -> a -> a
<> Comparison Int
forall a. Ord a => a -> a -> Ordering
compare Int
i Int
j
)
)
)
(Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate (Vector a -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector a
xs) Int -> Int
forall a. a -> a
id
{-# INLINE unsafePermuteInPlace #-}
unsafePermuteInPlace :: (PrimMonad m, VGM.MVector v a) => v (PrimState m) a -> VU.Vector Int -> m ()
unsafePermuteInPlace :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Vector Int -> m ()
unsafePermuteInPlace v (PrimState m) a
vec Vector Int
is = ST (PrimState m) () -> m ()
forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim (ST (PrimState m) () -> m ()) -> ST (PrimState m) () -> m ()
forall a b. (a -> b) -> a -> b
$ v (PrimState m) a -> Vector Int -> ST (PrimState m) ()
forall (v :: * -> * -> *) a s.
MVector v a =>
v s a -> Vector Int -> ST s ()
unsafePermuteInPlaceST v (PrimState m) a
vec Vector Int
is
{-# INLINEABLE unsafePermuteInPlaceST #-}
unsafePermuteInPlaceST :: (VGM.MVector v a) => v s a -> VU.Vector Int -> ST s ()
unsafePermuteInPlaceST :: forall (v :: * -> * -> *) a s.
MVector v a =>
v s a -> Vector Int -> ST s ()
unsafePermuteInPlaceST v s a
vec Vector Int
is = do
let !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (v s a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
VGM.length v s a
vec Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Vector Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
VG.length Vector Int
is) String
"AtCoder.Extra.Vector.unsafePermuteInPlaceST: the length of the index array must be equal to the length of the permuted vector"
let inner :: Int -> a -> m ()
inner Int
i a
lastX = do
v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite v s a
v (PrimState m) a
vec Int
i a
lastX
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
let i0' :: Int
i0' = Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
is Int
i
a
lastX' <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead v s a
v (PrimState m) a
vec Int
i
Int -> a -> m ()
inner Int
i0' a
lastX'
let i0' :: Int
i0' = Vector Int -> Int -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
VG.unsafeIndex Vector Int
is Int
0
a
x0' <- v (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead v s a
v (PrimState (ST s)) a
vec Int
0
Int -> a -> ST s ()
forall {m :: * -> *}.
(PrimState m ~ s, PrimMonad m) =>
Int -> a -> m ()
inner Int
i0' a
x0'