{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -Wall #-}
module Data.DisjointMap
( DisjointMap
, empty
, singleton
, singletons
, insert
, union
, unionWeakly
, lookup
, lookup'
, representative
, representative'
, toLists
, toSets
, fromSets
, pretty
, prettyList
, foldlWithKeys'
) where
import Prelude hiding (lookup)
import Control.Monad.Trans.State.Strict
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.Class
import Control.Monad
import Data.Map (Map)
import Data.Set (Set)
import Data.Bifunctor (first)
import Data.Foldable (Foldable)
import Data.Maybe (fromMaybe)
import Data.Aeson (ToJSON(..),FromJSON(..))
import Data.Foldable (foldlM)
import qualified Data.Semigroup as SG
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified GHC.OldList as L
import qualified Data.Foldable as F
data DisjointMap k v = DisjointMap
!(Map k k)
!(Map k (Ranked k v))
deriving (Functor,Foldable,Traversable)
data Ranked k v = Ranked {-# UNPACK #-} !Int !(Set k) !v
deriving (Functor,Foldable,Traversable)
instance (Ord k, Monoid v) => Monoid (DisjointMap k v) where
mempty = empty
instance (Ord k, Semigroup v) => SG.Semigroup (DisjointMap k v) where
(<>) = append
instance (Ord k, Ord v) => Eq (DisjointMap k v) where
a == b = S.fromList (toSets a) == S.fromList (toSets b)
instance (Ord k, Ord v) => Ord (DisjointMap k v) where
compare a b = compare (S.fromList (toSets a)) (S.fromList (toSets b))
instance (Show k, Ord k, Show v) => Show (DisjointMap k v) where
show = showDisjointSet
instance (ToJSON k, ToJSON v) => ToJSON (DisjointMap k v) where
toJSON = toJSON . toSets
instance (FromJSON k, FromJSON v, Ord k) => FromJSON (DisjointMap k v) where
parseJSON x = do
theSets <- parseJSON x
case fromSets theSets of
Nothing -> fail "the sets comprising the DisjointSet were not distinct"
Just s -> return s
fromSets :: Ord k => [(Set k,v)] -> Maybe (DisjointMap k v)
fromSets xs = case unionDistinctAll (map fst xs) of
Nothing -> Nothing
Just _ -> Just (unsafeFromSets xs empty)
unsafeFromSets :: Ord k => [(Set k,v)] -> DisjointMap k v -> DisjointMap k v
unsafeFromSets ys !ds@(DisjointMap p r) = case ys of
[] -> ds
(x,v) : xs -> case setLookupMin x of
Nothing -> unsafeFromSets xs ds
Just m -> unsafeFromSets xs $ DisjointMap
(M.union (M.fromSet (\_ -> m) x) p)
(M.insert m (Ranked 0 x v) r)
unionDistinctAll :: Ord a => [Set a] -> Maybe (Set a)
unionDistinctAll = foldlM unionDistinct S.empty
unionDistinct :: Ord a => Set a -> Set a -> Maybe (Set a)
unionDistinct a b =
let s = S.union a b
in if S.size a + S.size b == S.size s
then Just s
else Nothing
showDisjointSet :: (Show k, Ord k, Show v) => DisjointMap k v -> String
showDisjointSet = show . toLists
toLists :: DisjointMap k v -> [([k],v)]
toLists = (fmap.first) S.toList . toSets
toSets :: DisjointMap k v -> [(Set k,v)]
toSets (DisjointMap _ r) = M.foldr
(\(Ranked _ s v) xs -> (s,v) : xs) [] r
pretty :: (Show k, Show v) => DisjointMap k v -> String
pretty dm = "{" ++ L.intercalate ", " (prettyList dm) ++ "}"
prettyList :: (Show k, Show v) => DisjointMap k v -> [String]
prettyList dm = L.map (\(ks,v) -> "{" ++ commafied ks ++ "} -> " ++ show v) (toSets dm)
commafied :: Show k => Set k -> String
commafied = join . L.intersperse "," . map show . S.toList
foldlWithKeys' :: (a -> Set k -> v -> a) -> a -> DisjointMap k v -> a
foldlWithKeys' f a0 (DisjointMap _ r) =
M.foldl' (\a (Ranked _ ks v) -> f a ks v) a0 r
union :: (Ord k, Monoid v) => k -> k -> DisjointMap k v -> DisjointMap k v
union !x !y set = flip execState set $ runMaybeT $ do
repx <- lift $ state $ lookupCompressAdd x
repy <- lift $ state $ lookupCompressAdd y
guard $ repx /= repy
DisjointMap p r <- lift get
let Ranked rankx keysx valx = r M.! repx
let Ranked ranky keysy valy = r M.! repy
let val = mappend valx valy
keys = mappend keysx keysy
lift $ put $! case compare rankx ranky of
LT -> let p' = M.insert repx repy p
r' = M.delete repx $! M.insert repy (Ranked ranky keys val) r
in DisjointMap p' r'
GT -> let p' = M.insert repy repx p
r' = M.delete repy $! M.insert repx (Ranked rankx keys val) r
in DisjointMap p' r'
EQ -> let p' = M.insert repx repy p
r' = M.delete repx $! M.insert repy (Ranked (ranky + 1) keys val) r
in DisjointMap p' r'
unionWeakly :: (Ord k, Semigroup v) => k -> k -> DisjointMap k v -> DisjointMap k v
unionWeakly !x !y set = flip execState set $ runMaybeT $ do
mx <- lift $ state $ lookupCompress x
my <- lift $ state $ lookupCompress y
case mx of
Nothing -> case my of
Nothing -> pure ()
Just repy -> do
DisjointMap p r <- lift get
lift $ put $
let p' = M.insert x repy p
Ranked ranky keys val = fromMaybe (error "Data.DisjointMap.unionWeakly") (M.lookup repy r)
r' = M.insert repy (Ranked ranky (S.insert x keys) val) r
in DisjointMap p' r'
Just repx -> case my of
Nothing -> do
DisjointMap p r <- lift get
lift $ put $
let p' = M.insert y repx p
Ranked rankx keys val = fromMaybe (error "Data.DisjointMap.unionWeakly") (M.lookup repx r)
r' = M.insert repx (Ranked rankx (S.insert y keys) val) r
in DisjointMap p' r'
Just repy -> do
guard $ repx /= repy
DisjointMap p r <- lift get
let Ranked rankx keysx valx = r M.! repx
let Ranked ranky keysy valy = r M.! repy
let val = valx <> valy
let keys = mappend keysx keysy
lift $ put $! case compare rankx ranky of
LT -> let p' = M.insert repx repy p
r' = M.delete repx $! M.insert repy (Ranked ranky keys val) r
in DisjointMap p' r'
GT -> let p' = M.insert repy repx p
r' = M.delete repy $! M.insert repx (Ranked rankx keys val) r
in DisjointMap p' r'
EQ -> let p' = M.insert repx repy p
r' = M.delete repx $! M.insert repy (Ranked (ranky + 1) keys val) r
in DisjointMap p' r'
representative :: Ord k => k -> DisjointMap k v -> Maybe k
representative = find
insert :: (Ord k, Semigroup v) => k -> v -> DisjointMap k v -> DisjointMap k v
insert !x = insertInternal x (S.singleton x)
insertInternal :: (Ord k, Semigroup v) => k -> Set k -> v -> DisjointMap k v -> DisjointMap k v
insertInternal !x !ks !v set@(DisjointMap p r) =
let (l, p') = M.insertLookupWithKey (\_ _ old -> old) x x p
in case l of
Just _ ->
let (m,DisjointMap p2 r') = representative' x set
in case m of
Nothing -> error "DisjointMap insert: invariant violated"
Just root -> DisjointMap p2 (M.adjust (\(Ranked rank oldKs vOld) -> Ranked rank (mappend oldKs ks) (v <> vOld)) root r')
Nothing ->
let r' = M.insert x (Ranked 0 ks v) r
in DisjointMap p' r'
singleton :: k -> v -> DisjointMap k v
singleton !x !v =
let p = M.singleton x x
r = M.singleton x (Ranked 0 (S.singleton x) v)
in DisjointMap p r
empty :: DisjointMap k v
empty = DisjointMap M.empty M.empty
append :: (Ord k, Semigroup v) => DisjointMap k v -> DisjointMap k v -> DisjointMap k v
append s1@(DisjointMap m1 r1) s2@(DisjointMap m2 r2) = if M.size m1 > M.size m2
then appendPhase2 (appendPhase1 r2 s1 m2) m2
else appendPhase2 (appendPhase1 r1 s2 m1) m1
appendPhase1 :: (Ord k, Semigroup v)
=> Map k (Ranked k v)
-> DisjointMap k v
-> Map k k
-> DisjointMap k v
appendPhase1 !ranks = M.foldlWithKey' $ \ds x y -> if x == y
then case M.lookup x ranks of
Nothing -> error "Data.DisjointMap.appendParents: invariant violated"
Just (Ranked _ ks v) -> F.foldl' (\dm k -> unionWeakly k x dm) (insert x v ds) ks
else ds
appendPhase2 :: (Ord k, Semigroup v) => DisjointMap k v -> Map k k -> DisjointMap k v
appendPhase2 = M.foldlWithKey' $ \ds x y -> if x == y
then ds
else unionWeakly x y ds
singletons :: Eq k => Set k -> v -> DisjointMap k v
singletons s v = case setLookupMin s of
Nothing -> empty
Just x ->
let p = M.fromSet (\_ -> x) s
rank = if S.size s == 1 then 0 else 1
r = M.singleton x (Ranked rank s v)
in DisjointMap p r
setLookupMin :: Set a -> Maybe a
#if MIN_VERSION_containers(0,5,9)
setLookupMin = S.lookupMin
#else
setLookupMin s = if S.size s > 0 then Just (S.findMin s) else Nothing
#endif
representative' :: Ord k => k -> DisjointMap k v -> (Maybe k, DisjointMap k v)
representative' !x set =
case find x set of
Nothing -> (Nothing, set)
Just rep -> let set' = compress rep x set
in set' `seq` (Just rep, set')
lookupCompressAdd :: (Ord k, Monoid v) => k -> DisjointMap k v -> (k, DisjointMap k v)
lookupCompressAdd !x set =
case find x set of
Nothing -> (x, insert x mempty set)
Just rep -> let !set' = compress rep x set
in (rep, set')
lookupCompress :: Ord k => k -> DisjointMap k v -> (Maybe k, DisjointMap k v)
lookupCompress !x set =
case find x set of
Nothing -> (Nothing, set)
Just rep ->
let !set' = compress rep x set
in (Just rep, set')
find :: Ord k => k -> DisjointMap k v -> Maybe k
find !x (DisjointMap p _) = do
x' <- M.lookup x p
return $! if x == x' then x' else find' x'
where
find' y =
let y' = p M.! y
in if y == y' then y' else find' y'
lookup :: (Ord k, Monoid v) => k -> DisjointMap k v -> v
lookup k = fromMaybe mempty . lookup' k
lookup' :: Ord k => k -> DisjointMap k v -> Maybe v
lookup' !x (DisjointMap p r) = do
x' <- M.lookup x p
if x == x'
then case M.lookup x r of
Nothing -> Nothing
Just (Ranked _ _ v) -> Just v
else find' x'
where
find' y =
let y' = p M.! y
in if y == y'
then case M.lookup y r of
Nothing -> Nothing
Just (Ranked _ _ v) -> Just v
else find' y'
compress :: Ord k => k -> k -> DisjointMap k v -> DisjointMap k v
compress !rep = helper
where
helper !x set@(DisjointMap p r)
| x == rep = set
| otherwise = helper x' set'
where
x' = p M.! x
set' = let !p' = M.insert x rep p
in DisjointMap p' r