module Heart.Core.UnionFind
( PartitionAlloc
, Term
, newPartitionAlloc
, newTerm
, findTermRank
, findTerm
, unionTerm
, decideEqTerm
, hashSplitClasses
, ordSplitClasses
) where
import Data.Function (on)
import Heart.Core.Alloc
import Heart.Core.MultiMap
import Heart.Core.Prelude
import UnliftIO.IORef (IORef, newIORef, readIORef, writeIORef)
newtype Partition = Partition { unPartition :: Int } deriving (Eq, Ord, Enum, Show, Hashable)
data Content
= Root !Int
| Child !Term
data Term = Term
{ _termPartition :: !Partition
, _termParent :: !(IORef Content)
}
instance Eq Term where
(==) = (==) `on` _termPartition
instance Ord Term where
(<=) = (<=) `on` _termPartition
instance Hashable Term where
hash = hash . _termPartition
hashWithSalt d = hashWithSalt d . _termPartition
newtype PartitionAlloc = PartitionAlloc { unPartitionAlloc :: Alloc Partition }
newPartitionAlloc :: MonadIO m => m PartitionAlloc
newPartitionAlloc = fmap PartitionAlloc newEnumAlloc
newTerm :: MonadIO m => PartitionAlloc -> m Term
newTerm (PartitionAlloc alloc) = do
p <- incAlloc alloc
r <- newIORef (Root 0)
pure (Term p r)
findTermRank :: MonadIO m => Term -> m (Int, Term)
findTermRank t@(Term _ r) = do
y <- readIORef r
case y of
Root i -> pure (i, t)
Child s -> do
z@(_, q) <- findTermRank s
writeIORef r (Child q)
pure z
findTerm :: MonadIO m => Term -> m Term
findTerm = fmap snd . findTermRank
unionTerm :: MonadIO m => Term -> Term -> m ()
unionTerm m n = do
(mrank, mroot) <- findTermRank m
(nrank, nroot) <- findTermRank n
let mref = _termParent mroot
nref = _termParent nroot
case compare mrank nrank of
LT -> do
writeIORef mref (Child nroot)
writeIORef nref (Root nrank)
GT -> do
writeIORef nref (Child mroot)
writeIORef mref (Root mrank)
EQ -> do
writeIORef mref (Child nroot)
writeIORef nref (Root (nrank + 1))
decideEqTerm :: MonadIO m => Term -> Term -> m Bool
decideEqTerm m n = (==) <$> findTerm m <*> findTerm n
hashSplitClasses :: (Eq k, Hashable k, MonadIO m) => HashMap k Term -> m (HashMultiMap Term k)
hashSplitClasses = invertHashMapWith findTerm
ordSplitClasses :: (Ord k, MonadIO m) => Map k Term -> m (OrdMultiMap Term k)
ordSplitClasses = invertOrdMapWith findTerm