-- This file is part of the 'union-find-array' library. It is licensed -- under an MIT license. See the accompanying 'LICENSE' file for details. -- -- Authors: Bertram Felgenhauer {-# LANGUAGE RankNTypes, CPP #-} -- | -- Low-level interface for managing a disjoint set data structure, based on -- 'Control.Monad.ST'. For a higher level convenience interface, look at -- 'Control.Monad.Union'. module Data.Union.ST ( UnionST, runUnionST, new, grow, copy, lookup, annotate, merge, flatten, size, unsafeFreeze, ) where import qualified Data.Union.Type as U import Prelude hiding (lookup) import Control.Monad.ST import Control.Monad import Control.Applicative import Data.Array.Base hiding (unsafeFreeze) import Data.Array.ST hiding (unsafeFreeze) import qualified Data.Array.Base as A (unsafeFreeze) -- | A disjoint set forest, with nodes numbered from 0, which can carry labels. data UnionST s l = UnionST { up :: STUArray s Int Int, rank :: STUArray s Int Int, label :: STArray s Int l, size :: !Int, def :: l } #if __GLASGOW_HASKELL__ < 702 instance Applicative (ST s) where (<*>) = ap pure = return #endif -- Use http://www.haskell.org/pipermail/libraries/2008-March/009465.html ? -- | Analogous to 'Data.Array.ST.runSTArray'. runUnionST :: (forall s. ST s (UnionST s l)) -> U.Union l runUnionST a = runST $ a >>= unsafeFreeze -- | Analogous to 'Data.Array.Base.unsafeFreeze' unsafeFreeze :: UnionST s l -> ST s (U.Union l) unsafeFreeze u = U.Union (size u) <$> A.unsafeFreeze (up u) <*> A.unsafeFreeze (label u) -- What about thawing? -- | Create a new disjoint set forest, of given capacity. new :: Int -> l -> ST s (UnionST s l) new size def = do up <- newListArray (0, size-1) [0..] rank <- newArray (0, size-1) 0 label <- newArray (0, size-1) def return UnionST{ up = up, rank = rank, label = label, size = size, def = def } -- | Grow the capacity of a disjoint set forest. Shrinking is not possible. -- Trying to shrink a disjoint set forest will return the same forest -- unmodified. grow :: UnionST s l -> Int -> ST s (UnionST s l) grow u size' | size' <= size u = return u grow u size' = grow' u size' -- | Copy a disjoint set forest. copy :: UnionST s l -> ST s (UnionST s l) copy u = grow' u (size u) grow' :: UnionST s l -> Int -> ST s (UnionST s l) grow' u size' = do up' <- newListArray (0, size'-1) [0..] rank' <- newArray (0, size'-1) 0 label' <- newArray (0, size'-1) (def u) forM_ [0..size u - 1] $ \i -> do readArray (up u) i >>= writeArray up' i readArray (rank u) i >>= writeArray rank' i readArray (label u) i >>= writeArray label' i return u{ up = up', rank = rank', label = label', size = size' } -- | Annotate a node with a new label. annotate :: UnionST s l -> Int -> l -> ST s () annotate u i v = writeArray (label u) i v -- | Look up the representative of a given node. -- -- lookup' does path compression. lookup' :: UnionST s l -> Int -> ST s Int lookup' u i = do i' <- readArray (up u) i if i == i' then return i else do i'' <- lookup' u i' writeArray (up u) i i'' return i'' -- | Look up the representative of a given node and its label. lookup :: UnionST s l -> Int -> ST s (Int, l) lookup u i = do i' <- lookup' u i l' <- readArray (label u) i' return (i', l') -- | Check whether two nodes are in the same set. equals :: UnionST s l -> Int -> Int -> ST s Bool equals u a b = do a' <- lookup' u a b' <- lookup' u b return (a' == b') -- | Merge two nodes if they are in distinct equivalence classes. The -- passed function is used to combine labels, if a merge happens. merge :: UnionST s l -> (l -> l -> (l, a)) -> Int -> Int -> ST s (Maybe a) merge u f a b = do (a', va) <- lookup u a (b', vb) <- lookup u b if a' == b' then return Nothing else do ra <- readArray (rank u) a' rb <- readArray (rank u) b' let cont x vx y vy = do writeArray (label u) y (error "invalid entry") let (v, w) = f vx vy writeArray (label u) x v return (Just w) case ra `compare` rb of LT -> do writeArray (up u) a' b' cont b' vb a' va GT -> do writeArray (up u) b' a' cont a' va b' vb EQ -> do writeArray (up u) a' b' writeArray (rank u) b' (ra + 1) cont b' vb a' va -- | Flatten a disjoint set forest, for faster lookups. flatten :: UnionST s l -> ST s () flatten u = forM_ [0..size u - 1] $ lookup' u