module Heart.Core.UnionFind ( PartitionAlloc , Term , newPartitionAlloc , newTerm , findTermRank , findTerm , unionTerm , decideEqTerm , hashSplitClasses , ordSplitClasses ) where import Data.Function (on) import Heart.Core.Alloc import Heart.Core.MultiMap import Heart.Core.Prelude import UnliftIO.IORef (IORef, newIORef, readIORef, writeIORef) newtype Partition = Partition { unPartition :: Int } deriving (Eq, Ord, Enum, Show, Hashable) -- From Content/Term encoding simplified from --- https://github.com/ekmett/guanxi/blob/master/src/Equality.hs data Content = Root !Int | Child !Term data Term = Term { _termPartition :: !Partition , _termParent :: !(IORef Content) } instance Eq Term where (==) = (==) `on` _termPartition instance Ord Term where (<=) = (<=) `on` _termPartition instance Hashable Term where hash = hash . _termPartition hashWithSalt d = hashWithSalt d . _termPartition newtype PartitionAlloc = PartitionAlloc { unPartitionAlloc :: Alloc Partition } newPartitionAlloc :: MonadIO m => m PartitionAlloc newPartitionAlloc = fmap PartitionAlloc newEnumAlloc newTerm :: MonadIO m => PartitionAlloc -> m Term newTerm (PartitionAlloc alloc) = do p <- incAlloc alloc r <- newIORef (Root 0) pure (Term p r) findTermRank :: MonadIO m => Term -> m (Int, Term) findTermRank t@(Term _ r) = do y <- readIORef r case y of Root i -> pure (i, t) Child s -> do z@(_, q) <- findTermRank s writeIORef r (Child q) pure z findTerm :: MonadIO m => Term -> m Term findTerm = fmap snd . findTermRank unionTerm :: MonadIO m => Term -> Term -> m () unionTerm m n = do (mrank, mroot) <- findTermRank m (nrank, nroot) <- findTermRank n let mref = _termParent mroot nref = _termParent nroot case compare mrank nrank of LT -> do writeIORef mref (Child nroot) writeIORef nref (Root nrank) GT -> do writeIORef nref (Child mroot) writeIORef mref (Root mrank) EQ -> do writeIORef mref (Child nroot) writeIORef nref (Root (nrank + 1)) decideEqTerm :: MonadIO m => Term -> Term -> m Bool decideEqTerm m n = (==) <$> findTerm m <*> findTerm n hashSplitClasses :: (Eq k, Hashable k, MonadIO m) => HashMap k Term -> m (HashMultiMap Term k) hashSplitClasses = invertHashMapWith findTerm ordSplitClasses :: (Ord k, MonadIO m) => Map k Term -> m (OrdMultiMap Term k) ordSplitClasses = invertOrdMapWith findTerm