module Data.Union.ST (
UnionST,
runUnionST,
new,
grow,
copy,
lookup,
annotate,
merge,
flatten,
size,
unsafeFreeze,
) where
import qualified Data.Union.Type as U
import Prelude hiding (lookup)
import Control.Monad.ST
import Control.Monad
import Control.Applicative
import Data.Array.Base hiding (unsafeFreeze)
import Data.Array.ST hiding (unsafeFreeze)
import qualified Data.Array.Base as A (unsafeFreeze)
data UnionST s l = UnionST {
up :: STUArray s Int Int,
rank :: STUArray s Int Int,
label :: STArray s Int l,
size :: !Int,
def :: l
}
#if __GLASGOW_HASKELL__ < 702
instance Applicative (ST s) where
(<*>) = ap
pure = return
#endif
runUnionST :: (forall s. ST s (UnionST s l)) -> U.Union l
runUnionST a = runST $ a >>= unsafeFreeze
unsafeFreeze :: UnionST s l -> ST s (U.Union l)
unsafeFreeze u =
U.Union (size u) <$> A.unsafeFreeze (up u) <*> A.unsafeFreeze (label u)
new :: Int -> l -> ST s (UnionST s l)
new size def = do
up <- newListArray (0, size1) [0..]
rank <- newArray (0, size1) 0
label <- newArray (0, size1) def
return UnionST{ up = up, rank = rank, label = label, size = size, def = def }
grow :: UnionST s l -> Int -> ST s (UnionST s l)
grow u size' | size' <= size u = return u
grow u size' = grow' u size'
copy :: UnionST s l -> ST s (UnionST s l)
copy u = grow' u (size u)
grow' :: UnionST s l -> Int -> ST s (UnionST s l)
grow' u size' = do
up' <- newListArray (0, size'1) [0..]
rank' <- newArray (0, size'1) 0
label' <- newArray (0, size'1) (def u)
forM_ [0..size u 1] $ \i -> do
readArray (up u) i >>= writeArray up' i
readArray (rank u) i >>= writeArray rank' i
readArray (label u) i >>= writeArray label' i
return u{ up = up', rank = rank', label = label', size = size' }
annotate :: UnionST s l -> Int -> l -> ST s ()
annotate u i v = writeArray (label u) i v
lookup' :: UnionST s l -> Int -> ST s Int
lookup' u i = do
i' <- readArray (up u) i
if i == i' then return i else do
i'' <- lookup' u i'
writeArray (up u) i i''
return i''
lookup :: UnionST s l -> Int -> ST s (Int, l)
lookup u i = do
i' <- lookup' u i
l' <- readArray (label u) i'
return (i', l')
equals :: UnionST s l -> Int -> Int -> ST s Bool
equals u a b = do
a' <- lookup' u a
b' <- lookup' u b
return (a' == b')
merge :: UnionST s l -> (l -> l -> (l, a)) -> Int -> Int -> ST s (Maybe a)
merge u f a b = do
(a', va) <- lookup u a
(b', vb) <- lookup u b
if a' == b' then return Nothing else do
ra <- readArray (rank u) a'
rb <- readArray (rank u) b'
let cont x vx y vy = do
writeArray (label u) y (error "invalid entry")
let (v, w) = f vx vy
writeArray (label u) x v
return (Just w)
case ra `compare` rb of
LT -> do
writeArray (up u) a' b'
cont b' vb a' va
GT -> do
writeArray (up u) b' a'
cont a' va b' vb
EQ -> do
writeArray (up u) a' b'
writeArray (rank u) b' (ra + 1)
cont b' vb a' va
flatten :: UnionST s l -> ST s ()
flatten u = forM_ [0..size u 1] $ lookup' u