module Util.UnionFind(
Element,
T,
find,
fromElement,
getW,
new,
new_,
putW,
union,
union_,
updateW
) where
import Control.Monad.Trans
import Data.IORef
import Data.Unique
import Control.Monad (when)
data Element w a = Element a !Unique !(IORef (Link w a))
data Link w a = Weight !Int w | Next (Element w a)
type T = Element
new :: MonadIO m => w -> a -> m (Element w a)
new w x = liftIO $ do
r <- newIORef (Weight 1 w)
n <- newUnique
return $ Element x n r
new_ :: MonadIO m => a -> m (Element () a)
new_ x = new () x
find :: MonadIO m => Element w a -> m (Element w a)
find x@(Element a _ r) = liftIO $ do
e <- readIORef r
case e of
Weight _ _ -> return x
Next next -> do
last <- Util.UnionFind.find next
when (next /= last) $ writeIORef r (Next last)
return last
getW :: MonadIO m => Element w a -> m w
getW x = liftIO $ do
Element _ _ r <- find x
Weight _ w <- readIORef r
return w
updateW :: MonadIO m => (w -> w) -> Element w a -> m ()
updateW f x = liftIO $ do
Element _ _ r <- find x
modifyIORef r (\ (Weight s w) -> Weight s (f w))
putW :: MonadIO m => Element w a -> w -> m ()
putW e w = liftIO $ do
Element _ _ r <- find e
modifyIORef r (\ (Weight s _) -> Weight s w)
union :: MonadIO m => (w -> w -> w) -> Element w a -> Element w a -> m ()
union comb e1 e2 = liftIO $ do
e1'@(Element _ _ r1) <- find e1
e2'@(Element _ _ r2) <- find e2
when (r1 /= r2) $ do
Weight w1 x1 <- readIORef r1
Weight w2 x2 <- readIORef r2
if w1 <= w2 then do
writeIORef r1 (Next e2')
writeIORef r2 $! (Weight (w1 + w2) (comb x1 x2))
else do
writeIORef r1 $! (Weight (w1 + w2) (comb x1 x2))
writeIORef r2 (Next e1')
union_ :: MonadIO m => Element () a -> Element () a -> m ()
union_ x y = union (\_ _ -> ()) x y
fromElement :: Element w a -> a
fromElement (Element a _ _) = a
instance Eq (Element w a) where
Element _ x _ == Element _ y _ = x == y
Element _ x _ /= Element _ y _ = x /= y
instance Ord (Element w a) where
Element _ x _ `compare` Element _ y _ = x `compare` y
Element _ x _ <= Element _ y _ = x <= y
Element _ x _ >= Element _ y _ = x >= y
instance Show a => Show (Element w a) where
showsPrec n (Element x _ _) = showsPrec n x