module Data.DisjointSet
(
DSet
, singletons
, find
, union
, classes
, singletonsIO
, singletonsST
, sameClass
)
where
import Control.Monad
import Control.Monad.ST
import Data.Ix
import Data.Array.MArray
import Data.Array.IO (IOUArray)
import Data.Array.ST (STUArray)
data DSet a = DSet { classesAr :: !(a () Int), parents :: !(a Int Int), ranks :: !(a Int Int) }
singletons :: (MArray a Int m) => (Int, Int) -> m (DSet a)
singletons bs = liftM3 DSet (newArray ((),()) (rangeSize bs)) (newListArray bs rng) (newListArray bs rng)
where rng = range bs
singletonsST :: (Int, Int) -> ST s (DSet (STUArray s))
singletonsST = singletons
singletonsIO :: (Int, Int) -> IO (DSet IOUArray)
singletonsIO = singletons
getParent :: (MArray a Int m) => DSet a -> Int -> m Int
getParent (DSet _ ps _) = readArray ps
setParent :: (MArray a Int m) => DSet a -> Int -> Int -> m ()
setParent (DSet _ ps _) = writeArray ps
getRank :: (MArray a Int m) => DSet a -> Int -> m Int
getRank (DSet _ _ rs) = readArray rs
setRank :: (MArray a Int m) => DSet a -> Int -> Int -> m ()
setRank (DSet _ _ rs) = writeArray rs
find :: (MArray a Int m) => DSet a -> Int -> m Int
find ds = f
where
f i = do
i' <- getParent ds i
if i' == i
then return i
else do
j <- f i'
setParent ds i j
return j
sameClass :: (MArray a Int m) => DSet a -> Int -> Int -> m Bool
sameClass ds x y = liftM2 (==) (find ds x) (find ds y)
union :: (MArray a Int m) => DSet a -> Int -> Int -> m Bool
union ds x y = do
xf <- find ds x
yf <- find ds y
if xf == yf
then return False
else do
xr <- getRank ds xf
yr <- getRank ds yf
case compare xr yr of
LT -> setParent ds xf yf
GT -> setParent ds yf xf
EQ -> do
setParent ds yf xf
setRank ds xf (xr + 1)
let car = classesAr ds
liftM (subtract 1) (readArray car ()) >>= writeArray car ()
return True
classes :: (MArray a Int m) => DSet a -> m Int
classes (DSet c _ _) = readArray c ()