{-# LANGUAGE CPP #-}
module Agda.Utils.Cluster
  ( C
  , cluster
  , cluster'
  ) where
import Control.Monad
import Data.Equivalence.Monad (runEquivT, equateAll, classDesc)
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NonEmpty
import qualified Data.IntMap as IntMap
#if __GLASGOW_HASKELL__ < 804
import Data.Semigroup
#endif
import Agda.Utils.Functor
import Agda.Utils.Singleton
import Agda.Utils.Fail
type C = Int
cluster :: (a -> NonEmpty C) -> [a] -> [NonEmpty a]
cluster f as = cluster' $ map (\ a -> (a, f a)) as
cluster' :: [(a, NonEmpty C)] -> [NonEmpty a]
cluster' acs = runFail_ $ runEquivT id const $ do
  
  forM_ acs $ \ (_, c :| cs) -> equateAll $ c:cs
  
  cas <- forM acs $ \ (a, c :| _) -> classDesc c <&> \ k -> IntMap.singleton k (singleton a)
  
  let m = IntMap.unionsWith (<>) cas
  
  return $ IntMap.elems m