{-# LANGUAGE RecordWildCards #-}

-- | Implementation of Strongly Connected Components calculation. Use `AtCoder.Scc` instead.
--
-- @since 1.0.0
module AtCoder.Internal.Scc
  ( -- * Internal SCC
    SccGraph (nScc),

    -- * Constructor
    new,

    -- * Modifying the graph
    addEdge,

    -- * SCC calculation
    sccIds,
    scc,
  )
where

import AtCoder.Internal.Csr qualified as ACICSR
import AtCoder.Internal.GrowVec qualified as ACIGV
import Control.Monad (unless, when)
import Control.Monad.Fix (fix)
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Foldable (for_)
import Data.Maybe (fromJust)
import Data.Vector qualified as V
import Data.Vector.Generic qualified as VG
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM

-- | Graph for collecting strongly connected components.
--
-- @since 1.0.0
data SccGraph s = SccGraph
  { -- | The number of vertices.
    --
    -- @since 1.0.0
    forall s. SccGraph s -> Int
nScc :: {-# UNPACK #-} !Int,
    forall s. SccGraph s -> GrowVec s (Int, Int)
edgesScc :: !(ACIGV.GrowVec s (Int, Int))
  }

-- | \(O(n)\) Creates `SccGraph` of \(n\) vertices.
--
-- @since 1.0.0
{-# INLINE new #-}
new :: (PrimMonad m) => Int -> m (SccGraph (PrimState m))
new :: forall (m :: * -> *).
PrimMonad m =>
Int -> m (SccGraph (PrimState m))
new Int
nScc = do
  GrowVec (PrimState m) (Int, Int)
edgesScc <- Int -> m (GrowVec (PrimState m) (Int, Int))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (GrowVec (PrimState m) a)
ACIGV.new Int
0
  SccGraph (PrimState m) -> m (SccGraph (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SccGraph {Int
GrowVec (PrimState m) (Int, Int)
nScc :: Int
edgesScc :: GrowVec (PrimState m) (Int, Int)
nScc :: Int
edgesScc :: GrowVec (PrimState m) (Int, Int)
..}

-- | \(O(1)\) amortized. Adds an edge to the graph.
--
-- @since 1.0.0
{-# INLINE addEdge #-}
addEdge :: (PrimMonad m) => SccGraph (PrimState m) -> Int -> Int -> m ()
addEdge :: forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> Int -> Int -> m ()
addEdge SccGraph {GrowVec (PrimState m) (Int, Int)
edgesScc :: forall s. SccGraph s -> GrowVec s (Int, Int)
edgesScc :: GrowVec (PrimState m) (Int, Int)
edgesScc} Int
from Int
to = do
  GrowVec (PrimState m) (Int, Int) -> (Int, Int) -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
GrowVec (PrimState m) a -> a -> m ()
ACIGV.pushBack GrowVec (PrimState m) (Int, Int)
edgesScc (Int
from, Int
to)

-- | \(O(n + m)\) Returns a pair of @(# of scc, scc id)@.
--
-- @since 1.0.0
{-# INLINE sccIds #-}
sccIds :: (PrimMonad m) => SccGraph (PrimState m) -> m (Int, VU.Vector Int)
sccIds :: forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> m (Int, Vector Int)
sccIds SccGraph {Int
GrowVec (PrimState m) (Int, Int)
nScc :: forall s. SccGraph s -> Int
edgesScc :: forall s. SccGraph s -> GrowVec s (Int, Int)
nScc :: Int
edgesScc :: GrowVec (PrimState m) (Int, Int)
..} = do
  -- see also the Wikipedia: https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm#The_algorithm_in_pseudocode
  Csr ()
g <- HasCallStack => Int -> Vector (Int, Int) -> Csr ()
Int -> Vector (Int, Int) -> Csr ()
ACICSR.build' Int
nScc (Vector (Int, Int) -> Csr ())
-> m (Vector (Int, Int)) -> m (Csr ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GrowVec (PrimState m) (Int, Int) -> m (Vector (Int, Int))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
GrowVec (PrimState m) a -> m (Vector a)
ACIGV.unsafeFreeze GrowVec (PrimState m) (Int, Int)
edgesScc
  -- next SCC ID
  MVector (PrimState m) Int
groupNum <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
1 (Int
0 :: Int)
  -- stack of vertices
  GrowVec (PrimState m) Int
visited <- Int -> m (GrowVec (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (GrowVec (PrimState m) a)
ACIGV.new Int
nScc
  -- vertex -> low-link: the smallest index of any node on the stack known to be reachable from
  -- v through v's DFS subtree, including v itself.
  MVector (PrimState m) Int
low <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nScc (Int
0 :: Int)
  -- vertex -> order of the visit (0, 1, ..)
  MVector (PrimState m) Int
ord <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nScc (-Int
1 :: Int)
  -- vertex -> scc id
  MVector (PrimState m) Int
ids <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
nScc (Int
0 :: Int)

  let dfs :: Int -> Int -> m Int
dfs Int
v Int
ord0 = do
        MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
low Int
v Int
ord0
        MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
ord Int
v Int
ord0
        GrowVec (PrimState m) Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
GrowVec (PrimState m) a -> a -> m ()
ACIGV.pushBack GrowVec (PrimState m) Int
visited Int
v
        -- look around @v@, folding their low-link onto the low-link of @v@.
        Int
ord' <-
          (Int -> Int -> m Int) -> Int -> Vector Int -> m Int
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m a
VU.foldM'
            ( \Int
curOrd Int
to -> do
                Int
ordTo <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
ord Int
to
                if Int
ordTo Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== -Int
1
                  then do
                    -- not visited yet.
                    Int
nextOrd <- Int -> Int -> m Int
dfs Int
to (Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ Int
curOrd
                    Int
lowTo <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
low Int
to
                    MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
low (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
lowTo) Int
v
                    Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
nextOrd
                  else do
                    -- lookup back and update the low-link.
                    MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
low (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
ordTo) Int
v
                    Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
curOrd
            )
            (Int
ord0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            (Csr ()
g Csr () -> Int -> Vector Int
forall w. (HasCallStack, Unbox w) => Csr w -> Int -> Vector Int
`ACICSR.adj` Int
v)

        Int
lowV <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
low Int
v
        Int
ordV <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
ord Int
v
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
lowV Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
ordV) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          -- it's the root of a SCC, no more to look back
          Int
sccId <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState m) Int
groupNum Int
0
          (m () -> m ()) -> m ()
forall a. (a -> a) -> a
fix ((m () -> m ()) -> m ()) -> (m () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \m ()
loop -> do
            Int
u <- Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Int -> Int) -> m (Maybe Int) -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GrowVec (PrimState m) Int -> m (Maybe Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
GrowVec (PrimState m) a -> m (Maybe a)
ACIGV.popBack GrowVec (PrimState m) Int
visited
            MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
ord Int
u Int
nScc
            MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
ids Int
u Int
sccId
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
u Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
v) m ()
loop
          MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.unsafeWrite MVector (PrimState m) Int
groupNum Int
0 (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
sccId Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
ord'

  (Int -> Int -> m Int) -> Int -> Vector Int -> m ()
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m ()
VU.foldM'_
    ( \Int
curOrd Int
i -> do
        Int
o <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
ord Int
i
        if Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== -Int
1
          then Int -> Int -> m Int
dfs Int
i Int
curOrd
          else Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
curOrd
    )
    (Int
0 :: Int)
    (Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate Int
nScc Int -> Int
forall a. a -> a
id)

  Int
num <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.unsafeRead MVector (PrimState m) Int
groupNum Int
0
  -- The SCCs are reverse topologically sorted, e.g., [0, 1] <- [2] <- [3]
  -- Now reverse the SCC IDs so that they will be topologically sorted: [3] -> [2] -> [0, 1]
  [Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [Int
0 .. Int
nScc Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector (PrimState m) Int
ids ((Int
num Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) -) Int
i

  Vector Int
ids' <- MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze MVector (PrimState m) Int
ids
  (Int, Vector Int) -> m (Int, Vector Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
num, Vector Int
ids')

-- | \(O(n + m)\) Returns the strongly connected components.
--
-- @since 1.0.0
{-# INLINE scc #-}
scc :: (PrimMonad m) => SccGraph (PrimState m) -> m (V.Vector (VU.Vector Int))
scc :: forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> m (Vector (Vector Int))
scc SccGraph (PrimState m)
g = do
  (!Int
groupNum, !Vector Int
ids) <- SccGraph (PrimState m) -> m (Int, Vector Int)
forall (m :: * -> *).
PrimMonad m =>
SccGraph (PrimState m) -> m (Int, Vector Int)
sccIds SccGraph (PrimState m)
g
  let counts :: Vector Int
counts = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
        MVector s Int
vec <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
groupNum (Int
0 :: Int)
        Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
VU.forM_ Vector Int
ids ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> do
          MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
VGM.modify MVector s Int
MVector (PrimState (ST s)) Int
vec (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
x
        MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MVector s Int
vec
  Vector (MVector (PrimState m) Int)
groups <- (Int -> m (MVector (PrimState m) Int))
-> Vector Int -> m (Vector (MVector (PrimState m) Int))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
VUM.unsafeNew (Vector Int -> m (Vector (MVector (PrimState m) Int)))
-> Vector Int -> m (Vector (MVector (PrimState m) Int))
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector Int
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert Vector Int
counts
  MVector (PrimState m) Int
is <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
groupNum (Int
0 :: Int)
  Vector Int -> (Int -> Int -> m ()) -> m ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (Int -> a -> m b) -> m ()
VU.iforM_ Vector Int
ids ((Int -> Int -> m ()) -> m ()) -> (Int -> Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
v Int
sccId -> do
    Int
i <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
VGM.read MVector (PrimState m) Int
is Int
sccId
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector (PrimState m) Int
is Int
sccId (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write (Vector (MVector (PrimState m) Int)
groups Vector (MVector (PrimState m) Int)
-> Int -> MVector (PrimState m) Int
forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
VG.! Int
sccId) Int
i Int
v
  (MVector (PrimState m) Int -> m (Vector Int))
-> Vector (MVector (PrimState m) Int) -> m (Vector (Vector Int))
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VU.unsafeFreeze Vector (MVector (PrimState m) Int)
groups