{-# LANGUAGE CPP #-}
module Agda.Utils.Cluster
( C
, cluster
, cluster'
) where
import Control.Monad
import Data.Equivalence.Monad (runEquivT, equateAll, classDesc)
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NonEmpty
import qualified Data.IntMap as IntMap
#if __GLASGOW_HASKELL__ < 804
import Data.Semigroup
#endif
import Agda.Utils.Functor
import Agda.Utils.Singleton
import Agda.Utils.Fail
type C = Int
cluster :: (a -> NonEmpty C) -> [a] -> [NonEmpty a]
cluster f as = cluster' $ map (\ a -> (a, f a)) as
cluster' :: [(a, NonEmpty C)] -> [NonEmpty a]
cluster' acs = runFail_ $ runEquivT id const $ do
forM_ acs $ \ (_, c :| cs) -> equateAll $ c:cs
cas <- forM acs $ \ (a, c :| _) -> classDesc c <&> \ k -> IntMap.singleton k (singleton a)
let m = IntMap.unionsWith (<>) cas
return $ IntMap.elems m