{-# LANGUAGE NondecreasingIndentation #-} -- | A simple mutable union-find data structure. -- -- It is used in a unification algorithm for backpack mix-in linking. -- -- This implementation is based off of the one in \"The Essence of ML Type -- Inference\". (N.B. the union-find package is also based off of this.) -- module Distribution.Utils.UnionFind ( Point, fresh, find, union, equivalent, ) where import Data.STRef import Control.Monad import Control.Monad.ST -- | A variable which can be unified; alternately, this can be thought -- of as an equivalence class with a distinguished representative. newtype Point s a = Point (STRef s (Link s a)) deriving (Eq) -- | Mutable write to a 'Point' writePoint :: Point s a -> Link s a -> ST s () writePoint (Point v) = writeSTRef v -- | Read the current value of 'Point'. readPoint :: Point s a -> ST s (Link s a) readPoint (Point v) = readSTRef v -- | The internal data structure for a 'Point', which either records -- the representative element of an equivalence class, or a link to -- the 'Point' that actually stores the representative type. data Link s a -- NB: it is too bad we can't say STRef Int#; the weights remain boxed = Info {-# UNPACK #-} !(STRef s Int) {-# UNPACK #-} !(STRef s a) | Link {-# UNPACK #-} !(Point s a) -- | Create a fresh equivalence class with one element. fresh :: a -> ST s (Point s a) fresh desc = do weight <- newSTRef 1 descriptor <- newSTRef desc Point `fmap` newSTRef (Info weight descriptor) -- | Flatten any chains of links, returning a 'Point' -- which points directly to the canonical representation. repr :: Point s a -> ST s (Point s a) repr point = readPoint point >>= \r -> case r of Link point' -> do point'' <- repr point' when (point'' /= point') $ do writePoint point =<< readPoint point' return point'' Info _ _ -> return point -- | Return the canonical element of an equivalence -- class 'Point'. find :: Point s a -> ST s a find point = -- Optimize length 0 and 1 case at expense of -- general case readPoint point >>= \r -> case r of Info _ d_ref -> readSTRef d_ref Link point' -> readPoint point' >>= \r' -> case r' of Info _ d_ref -> readSTRef d_ref Link _ -> repr point >>= find -- | Unify two equivalence classes, so that they share -- a canonical element. Keeps the descriptor of point2. union :: Point s a -> Point s a -> ST s () union refpoint1 refpoint2 = do point1 <- repr refpoint1 point2 <- repr refpoint2 when (point1 /= point2) $ do l1 <- readPoint point1 l2 <- readPoint point2 case (l1, l2) of (Info wref1 dref1, Info wref2 dref2) -> do weight1 <- readSTRef wref1 weight2 <- readSTRef wref2 -- Should be able to optimize the == case separately if weight1 >= weight2 then do writePoint point2 (Link point1) -- The weight calculation here seems a bit dodgy writeSTRef wref1 (weight1 + weight2) writeSTRef dref1 =<< readSTRef dref2 else do writePoint point1 (Link point2) writeSTRef wref2 (weight1 + weight2) _ -> error "UnionFind.union: repr invariant broken" -- | Test if two points are in the same equivalence class. equivalent :: Point s a -> Point s a -> ST s Bool equivalent point1 point2 = liftM2 (==) (repr point1) (repr point2)