{-# LANGUAGE RecordWildCards #-}
module AtCoder.Internal.Scc
(
SccGraph (nScc),
new,
addEdge,
sccIds,
scc,
)
where
import AtCoder.Internal.Csr qualified as ACICSR
import AtCoder.Internal.GrowVec qualified as ACIGV
import Control.Monad (unless, when)
import Control.Monad.Fix (fix)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Foldable (for_)
import Data.Maybe (fromJust)
import Data.Vector qualified as V
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
data SccGraph s = SccGraph
{
forall s. SccGraph s -> Int
nScc :: {-# UNPACK #-} !Int,
forall s. SccGraph s -> GrowVec s (Int, Int)
edgesScc :: !(ACIGV.GrowVec s (Int, Int))
}
{-# INLINE new #-}
new :: (PrimMonad m) => Int -> m (SccGraph (PrimState m))
new :: forall (m :: * -> *).
PrimMonad m =>
Int -> m (SccGraph (PrimState m))
new Int
nScc = do
GrowVec (PrimState m) (Int, Int)
edgesScc <- Int -> m (GrowVec (PrimState m) (Int, Int))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (GrowVec (PrimState m) a)
ACIGV.new Int
0
SccGraph (PrimState m) -> m (SccGraph (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SccGraph {Int
GrowVec (PrimState m) (Int, Int)
nScc :: Int
edgesScc :: GrowVec (PrimState m) (Int, Int)
nScc :: Int
edgesScc :: GrowVec (PrimState m) (Int, Int)
..}
{-# INLINE addEdge #-}
addEdge :: (PrimMonad m) => SccGraph (PrimState m) -> Int -> Int -> m ()
addEdge :: forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> Int -> Int -> m ()
addEdge SccGraph {GrowVec (PrimState m) (Int, Int)
edgesScc :: forall s. SccGraph s -> GrowVec s (Int, Int)
edgesScc :: GrowVec (PrimState m) (Int, Int)
edgesScc} Int
from Int
to = do
GrowVec (PrimState m) (Int, Int) -> (Int, Int) -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
GrowVec (PrimState m) a -> a -> m ()
ACIGV.pushBack GrowVec (PrimState m) (Int, Int)
edgesScc (Int
from, Int
to)
{-# INLINE sccIds #-}
sccIds :: (PrimMonad m) => SccGraph (PrimState m) -> m (Int, VU.Vector Int)
sccIds :: forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> m (Int, Vector Int)
sccIds SccGraph {Int
GrowVec (PrimState m) (Int, Int)
nScc :: forall s. SccGraph s -> Int
edgesScc :: forall s. SccGraph s -> GrowVec s (Int, Int)
nScc :: Int
edgesScc :: GrowVec (PrimState m) (Int, Int)
..} = do
Csr ()
g <- HasCallStack => Int -> Vector (Int, Int) -> Csr ()
Int -> Vector (Int, Int) -> Csr ()
ACICSR.build' Int
nScc (Vector (Int, Int) -> Csr ())
-> m (Vector (Int, Int)) -> m (Csr ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GrowVec (PrimState m) (Int, Int) -> m (Vector (Int, Int))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
GrowVec (PrimState m) a -> m (Vector a)
ACIGV.unsafeFreeze GrowVec (PrimState m) (Int, Int)
edgesScc
MVector (PrimState m) Int
groupNum <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
1 (Int
0 :: Int)
GrowVec (PrimState m) Int
visited <- Int -> m (GrowVec (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (GrowVec (PrimState m) a)
ACIGV.new Int
nScc
MVector (PrimState m) Int
low <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nScc (Int
0 :: Int)
MVector (PrimState m) Int
ord <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nScc (-Int
1 :: Int)
MVector (PrimState m) Int
ids <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nScc (Int
0 :: Int)
let dfs :: Int -> Int -> m Int
dfs Int
v Int
ord0 = do
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
low Int
v Int
ord0
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
ord Int
v Int
ord0
GrowVec (PrimState m) Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
GrowVec (PrimState m) a -> a -> m ()
ACIGV.pushBack GrowVec (PrimState m) Int
visited Int
v
Int
ord' <-
(Int -> Int -> m Int) -> Int -> Vector Int -> m Int
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m a
VU.foldM'
( \Int
curOrd Int
to -> do
Int
ordTo <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
ord Int
to
if Int
ordTo Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== -Int
1
then do
Int
nextOrd <- Int -> Int -> m Int
dfs Int
to (Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ Int
curOrd
Int
lowTo <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
low Int
to
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
low (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
lowTo) Int
v
Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
nextOrd
else do
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
low (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
ordTo) Int
v
Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
curOrd
)
(Int
ord0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
(Csr ()
g Csr () -> Int -> Vector Int
forall w. (HasCallStack, Unbox w) => Csr w -> Int -> Vector Int
`ACICSR.adj` Int
v)
Int
lowV <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
low Int
v
Int
ordV <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
ord Int
v
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
lowV Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
ordV) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Int
sccId <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState m) Int
groupNum Int
0
(m () -> m ()) -> m ()
forall a. (a -> a) -> a
fix ((m () -> m ()) -> m ()) -> (m () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \m ()
loop -> do
Int
u <- Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Int -> Int) -> m (Maybe Int) -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GrowVec (PrimState m) Int -> m (Maybe Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
GrowVec (PrimState m) a -> m (Maybe a)
ACIGV.popBack GrowVec (PrimState m) Int
visited
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
ord Int
u Int
nScc
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
ids Int
u Int
sccId
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
u Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
v) m ()
loop
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState m) Int
groupNum Int
0 (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
sccId Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
ord'
(Int -> Int -> m Int) -> Int -> Vector Int -> m ()
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m ()
VU.foldM'_
( \Int
curOrd Int
i -> do
Int
o <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
ord Int
i
if Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== -Int
1
then Int -> Int -> m Int
dfs Int
i Int
curOrd
else Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
curOrd
)
(Int
0 :: Int)
(Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate Int
nScc Int -> Int
forall a. a -> a
id)
Int
num <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState m) Int
groupNum Int
0
[Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
nScc Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
ids ((Int
num Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) -) Int
i
Vector Int
ids' <- MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) Int
ids
(Int, Vector Int) -> m (Int, Vector Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
num, Vector Int
ids')
{-# INLINE scc #-}
scc :: (PrimMonad m) => SccGraph (PrimState m) -> m (V.Vector (VU.Vector Int))
scc :: forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> m (Vector (Vector Int))
scc SccGraph (PrimState m)
g = do
(!Int
groupNum, !Vector Int
ids) <- SccGraph (PrimState m) -> m (Int, Vector Int)
forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> m (Int, Vector Int)
sccIds SccGraph (PrimState m)
g
let counts :: Vector Int
counts = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
MVector s Int
vec <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
groupNum (Int
0 :: Int)
Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
VU.forM_ Vector Int
ids ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> do
MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s Int
MVector (PrimState (ST s)) Int
vec (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
x
MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s Int
vec
Vector (MVector (PrimState m) Int)
groups <- (Int -> m (MVector (PrimState m) Int))
-> Vector Int -> m (Vector (MVector (PrimState m) Int))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Vector Int -> m (Vector (MVector (PrimState m) Int)))
-> Vector Int -> m (Vector (MVector (PrimState m) Int))
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector Int
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert Vector Int
counts
MVector (PrimState m) Int
is <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
groupNum (Int
0 :: Int)
Vector Int -> (Int -> Int -> m ()) -> m ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector Int
ids ((Int -> Int -> m ()) -> m ()) -> (Int -> Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
v Int
sccId -> do
Int
i <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
is Int
sccId
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
is Int
sccId (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write (Vector (MVector (PrimState m) Int)
groups Vector (MVector (PrimState m) Int)
-> Int -> MVector (PrimState m) Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
sccId) Int
i Int
v
(MVector (PrimState m) Int -> m (Vector Int))
-> Vector (MVector (PrimState m) Int) -> m (Vector (Vector Int))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze Vector (MVector (PrimState m) Int)
groups