-- | The Munkres version of the Hungarian Method for weighted minimal -- bipartite matching. -- The implementation is based on Robert A. Pilgrim's notes, -- -- (mirror: ). {-# LANGUAGE CPP, MultiParamTypeClasses, FlexibleContexts #-} module Data.Algorithm.Munkres ( -- hungarianMethod hungarianMethodInt , hungarianMethodFloat , hungarianMethodDouble , hungarianMethodBoxed #ifdef MUNKRES_DEBUG , tesztek, teszt1, teszt2, teszt3, teszt4 , bruteforce , randomArray , doATest, doFloatTest , doManyTests, doManyFloatTests , main #endif ) where import Prelude hiding (flip) import Control.Monad import Control.Monad.ST import Data.List hiding (insert) import Data.STRef import Data.Array.ST import Data.Array.IArray () import Data.Array.MArray import Data.Array.Unboxed #ifdef MUNKRES_DEBUG import Data.Ord (comparing) import Debug.Trace import System.Random #endif ------------------------------------------------------- swap :: (Int,Int) -> (Int,Int) swap (x,y) = (y,x) {- complementSort :: Int -> [Int] -> [Int] complementSort n xs = complement n (sort xs) -} -- assumes that the input is sorted complement :: Int -> [Int] -> [Int] complement n list = worker 1 list where worker k xxs@(x:xs) = if k>n then [] else case compare k x of EQ -> worker (k+1) xs LT -> k : worker (k+1) xxs GT -> worker k xs worker k [] = [k..n] {- merge :: [Int] -> [Int] -> [Int] merge xxs@(x:xs) yys@(y:ys) = if x <= y then x : merge xs yys else y : merge xxs ys merge xs [] = xs merge [] ys = ys -} -- assumes that the inputs are sorted sets mergeUnion :: [Int] -> [Int] -> [Int] mergeUnion xxs@(x:xs) yys@(y:ys) = case compare x y of LT -> x : mergeUnion xs yys EQ -> x : mergeUnion xs ys GT -> y : mergeUnion xxs ys mergeUnion xs [] = xs mergeUnion [] ys = ys insert :: Int -> [Int] -> [Int] insert y xxs@(x:xs) = case compare y x of LT -> y : xxs EQ -> xxs GT -> x : insert y xs insert y [] = [y] remove :: Int -> [Int] -> [Int] remove y xxs@(x:xs) = case compare y x of LT -> xxs EQ -> xs GT -> x : remove y xs remove _ [] = [] {-# SPECIALIZE firstJust :: [ ST s (Maybe (Int,Int)) ] -> ST s (Maybe (Int,Int)) #-} firstJust :: Monad m => [ m (Maybe a) ] -> m (Maybe a) firstJust (a:as) = do x <- a case x of Just _ -> return x Nothing -> firstJust as firstJust [] = return Nothing {-# SPECIALISE alternate :: [Int] -> ([Int],[Int]) #-} alternate :: [a] -> ([a],[a]) alternate list = flip list [] [] where flip (x:xs) ys zs = flop xs (x:ys) zs flip [] ys zs = (reverse ys,reverse zs) flop (x:xs) ys zs = flip xs ys (x:zs) flop [] ys zs = (reverse ys,reverse zs) ------------------------------------------------------- -- polymorphicity problem workaround experiment... thawST :: (IArray a e, MArray (STArray s) e (ST s)) => a (Int,Int) e -> ST s (STArray s (Int,Int) e) thawST = thaw thawSTU :: (IArray UArray e, MArray (STUArray s) e (ST s)) => UArray (Int,Int) e -> ST s (STUArray s (Int,Int) e) thawSTU = thaw newSTArray_ :: MArray (STArray s) e (ST s) => ((Int,Int),(Int,Int)) -> ST s (STArray s (Int,Int) e) newSTArray_ = newArray_ newSTUArray_ :: MArray (STUArray s) e (ST s) => ((Int,Int),(Int,Int)) -> ST s (STUArray s (Int,Int) e) newSTUArray_ = newArray_ ------------------------------------------------------- {- SPECIALISE hungarianMethod :: UArray (Int,Int) Int -> ([(Int,Int)],Int ) -} {- SPECIALISE hungarianMethod :: UArray (Int,Int) Float -> ([(Int,Int)],Float ) -} {- SPECIALISE hungarianMethod :: UArray (Int,Int) Double -> ([(Int,Int)],Double) -} -- | Needs a rectangular array of /nonnegative/ weights, which -- encode the weights on the edges of a (complete) bipartitate graph. -- The indexing should start from @(1,1)@. -- Returns a minimal matching, and the cost of it. -- -- Unfortunately, GHC is opposing hard the polymorphicity of this function. I think -- the main reasons for that is that the there is no @Unboxed@ type class, and -- thus the contexts @IArray UArray e@ and @MArray (STUArray s) e (ST s)@ do not -- know about each other. (And I have problems with the @forall s@ part, too). hungarianMethodInt :: UArray (Int,Int) Int -> ([(Int,Int)],Int) hungarianMethodInt input = runST $ do let ((1,1),(n,m)) = bounds input star <- if m >= n then do ar <- thawSTU input hungarianMethodShared ar else do ar <- newSTUArray_ ((1,1),(m,n)) forM_ [ (i,j) | i<-[1..n] , j<-[1..m] ] $ \(i,j) -> do writeArray ar (j,i) $ input ! (i,j) star' <- hungarianMethodShared ar return (map swap star') let costs = [ input ! ij | ij <- star ] return (star, sum costs) hungarianMethodFloat :: UArray (Int,Int) Float -> ([(Int,Int)],Float) hungarianMethodFloat input = runST $ do let ((1,1),(n,m)) = bounds input star <- if m >= n then do ar <- thawSTU input hungarianMethodShared ar else do ar <- newSTUArray_ ((1,1),(m,n)) forM_ [ (i,j) | i<-[1..n] , j<-[1..m] ] $ \(i,j) -> do writeArray ar (j,i) $ input ! (i,j) star' <- hungarianMethodShared ar return (map swap star') let costs = [ input ! ij | ij <- star ] return (star, sum costs) hungarianMethodDouble :: UArray (Int,Int) Double -> ([(Int,Int)],Double) hungarianMethodDouble input = runST $ do let ((1,1),(n,m)) = bounds input star <- if m >= n then do ar <- thawSTU input hungarianMethodShared ar else do ar <- newSTUArray_ ((1,1),(m,n)) forM_ [ (i,j) | i<-[1..n] , j<-[1..m] ] $ \(i,j) -> do writeArray ar (j,i) $ input ! (i,j) star' <- hungarianMethodShared ar return (map swap star') let costs = [ input ! ij | ij <- star ] return (star, sum costs) -- | The same as 'hungarianMethod', but uses boxed values (thus works with -- any data type which an instance of 'Real'). -- The usage of one the unboxed versions is recommended where possible, -- for performance reasons. hungarianMethodBoxed :: (Real e, IArray a e) => a (Int,Int) e -> ([(Int,Int)],e) hungarianMethodBoxed input = runST $ do let ((1,1),(n,m)) = bounds input star <- if m >= n then do ar <- thawST input -- :: ST s (STArray s (Int,Int) e) hungarianMethodShared ar else do ar <- newSTArray_ ((1,1),(m,n)) -- :: ST s (STArray s (Int,Int) e) forM_ [ (j,i) | j<-[1..m] , i<-[1..n] ] $ \(j,i) -> writeArray ar (j,i) $ input ! (i,j) star' <- hungarianMethodShared ar return (map swap star') let costs = [ input ! ij | ij <- star ] return (star, sum costs) {-# SPECIALISE hungarianMethodShared :: STUArray s (Int,Int) Int -> ST s [(Int,Int)] #-} {-# SPECIALISE hungarianMethodShared :: STUArray s (Int,Int) Float -> ST s [(Int,Int)] #-} {-# SPECIALISE hungarianMethodShared :: STUArray s (Int,Int) Double -> ST s [(Int,Int)] #-} hungarianMethodShared :: (Real e, MArray a e (ST s)) => a (Int,Int) e -> ST s [(Int,Int)] hungarianMethodShared ar = do starred <- newSTRef [] primed <- newSTRef [] coveredRows <- newSTRef [] coveredCols <- newSTRef [] ((1,1),nm) <- getBounds ar munkers ar nm starred primed coveredRows coveredCols -- the meat comes here... {-# SPECIALISE munkers :: STUArray s (Int,Int) Int -> (Int,Int) -> STRef s [(Int,Int)] -> STRef s [(Int,Int)] -> STRef s [Int] -> STRef s [Int] -> ST s [(Int,Int)] #-} {-# SPECIALISE munkers :: STUArray s (Int,Int) Float -> (Int,Int) -> STRef s [(Int,Int)] -> STRef s [(Int,Int)] -> STRef s [Int] -> STRef s [Int] -> ST s [(Int,Int)] #-} {-# SPECIALISE munkers :: STUArray s (Int,Int) Double -> (Int,Int) -> STRef s [(Int,Int)] -> STRef s [(Int,Int)] -> STRef s [Int] -> STRef s [Int] -> ST s [(Int,Int)] #-} munkers :: (Real e, MArray a e (ST s)) => a (Int,Int) e -> (Int,Int) -> STRef s [(Int,Int)] -> STRef s [(Int,Int)] -> STRef s [Int] -> STRef s [Int] -> ST s [(Int,Int)] munkers ar (n,m) starred primed coveredRows coveredCols = (step1 >> step2 >> step3) where kk = min n m step3 = do --printArray "step3" colsC <- readSTRef coveredCols star <- readSTRef starred let colsC' = mergeUnion colsC (sort $ map snd star) -- nub $ colsC ++ (map snd star) if length colsC' == kk then return star else do writeSTRef coveredCols colsC' step4 step4 = do --printArray "step4" --printPrimStar "step4" rowsC <- readSTRef coveredRows colsC <- readSTRef coveredCols let rowsNC = complement n rowsC colsNC = complement m colsC star <- readSTRef starred let f ij = do x <- readArray ar ij if x==0 then return (Just ij) else return Nothing mp <- firstJust [ f (i,j) | i<-rowsNC, j<-colsNC ] --print mp case mp of Nothing -> do es <- forM [ (i,j) | i<-rowsNC, j<-colsNC ] $ \ij -> readArray ar ij step6 (minimum es) Just ij@(i,_) -> do modifySTRef primed (ij:) case find (\(p,_) -> p==i) star of Nothing -> step5 ij Just (_,q) -> do modifySTRef coveredRows (insert i) modifySTRef coveredCols (remove q) step4 {- case filter (\(p,_) -> p==i) star of [] -> step5 ij [(p,q)] -> do modifySTRef coveredRows (insert i) modifySTRef coveredCols (remove q) step4 _ -> error "Munkres/step4: should not happen" -} step5 pq = do --printArray "step5" --printPrimStar "step5" star <- readSTRef starred prim <- readSTRef primed alt <- step5a star prim pq [pq] let (ps,ss) = alternate alt writeSTRef starred $ (star \\ ss) ++ ps writeSTRef primed [] writeSTRef coveredRows [] writeSTRef coveredCols [] step3 step5a :: [(Int,Int)] -> [(Int,Int)] -> (Int,Int) -> [(Int,Int)] -> ST s [(Int,Int)] step5a star prim (_,q) xs = case findStarred q of Just (i,_) -> do let (_,j) = findPrimed i step5a star prim (i,j) ((i,j):(i,q):xs) Nothing -> return xs where findStarred j = find (\(_,c) -> (c==j)) star findPrimed i = case find (\(r,_) -> (r==i)) prim of Just x -> x Nothing -> error $ "Munkres/findPrimed: should not happen (" ++ show prim ++ " " ++ show i ++ ")" step2 = do --printArray "step2" s <- foldM worker [] [ (i,j) | i<-[1..n], j<-[1..m] ] writeSTRef starred s where worker star ij@(i,j) = do x <- readArray ar ij if x==0 then case filter (\(a,b) -> (a==i) || (b==j)) star of [] -> return (ij : star) _ -> return star else return star step6 c = do --printArray "step6" --printPrimStar "step6" rowsC <- readSTRef coveredRows colsC <- readSTRef coveredCols let rowsNC = complement n rowsC colsNC = complement m colsC forM rowsNC $ \i -> forM colsNC $ \j -> do x <- readArray ar (i,j) writeArray ar (i,j) (x-c) forM rowsC $ \i -> forM colsC $ \j -> do x <- readArray ar (i,j) writeArray ar (i,j) (x+c) step4 step1 = mapM_ subRow [1..n] subRow i = do row <- forM [1..m] $ \j -> readArray ar (i,j) let y = minimum row forM [1..m] $ \j -> do let ij = (i,j) x <- readArray ar ij writeArray ar ij (x-y) {- -- debugging printArray s = do putStrLn "" x <- freeze ar :: IO (UArray (Int,Int) Int) print (s,x) printPrimStar s = do star <- readSTRef starred prim <- readSTRef primed crows <- readSTRef coveredRows ccols <- readSTRef coveredCols putStrLn s print ("starred",star) print ("primed",prim) print ("cov. rows",crows) print ("cov. cols",ccols) -} ------------------------------------------------------- #ifdef MUNKRES_DEBUG debug x y = trace (show x) y -- brute-force algorithm for sanity checking bruteforce :: UArray (Int,Int) Int -> ([(Int,Int)],Int) bruteforce input = {- debug all $ -} minimumBy (comparing snd) allWithCosts where ((1,1),(n,m)) = bounds input k = min n m g = if n Int -> Int -> g -> (UArray (Int,Int) Int , g) randomArray maxsize maxelem rnd0 = (ar,rnd3) where (n,rnd1) = randomR (1,maxsize) rnd0 (m,rnd2) = randomR (1,maxsize) rnd1 (rnd3,es) = mapAccumL randomR' rnd2 $ replicate (n*m) (0::Int,maxelem) ar = listArray ((1,1),(n,m)) es -- correctness testing doATest maxsize maxelem rnd0 _ = do let (ar,rnd1) = randomArray maxsize maxelem rnd0 (xs,c) = hungarianMethodInt ar sol1 = (sortBy (comparing fst) xs, c) sol2 = bruteforce ar when (snd sol1 /= snd sol2) $ do print ar putStrLn $ show (snd $ bounds ar) ++ ": " ++ show (snd sol1) ++ " " ++ show (snd sol2) putStrLn $ "hun -> " ++ show (fst sol1) ++ "\nbrt -> " ++ show (fst sol2) return rnd1 -- int vs float testing (mainly because of the copypasted code) doFloatTest maxsize maxelem rnd0 _ = do let (ar,rnd1) = randomArray maxsize maxelem rnd0 sol1 = hungarianMethodInt ar sol2 = hungarianMethodFloat (amap fromIntegral ar) sol3 = hungarianMethodDouble (amap fromIntegral ar) print (snd sol1, snd sol2, snd sol3) return rnd1 -- do lots of tests doManyTests n maxsize maxelem = getStdGen >>= \rnd -> foldM_ (doATest maxsize maxelem) rnd [1..n] doManyFloatTests n maxsize maxelem = getStdGen >>= \rnd -> foldM_ (doFloatTest maxsize maxelem) rnd [1..n] main = do putStrLn "a" doManyTests 50 10 10 putStrLn "b" doManyTests 50 10 50 putStrLn "c" doManyTests 50 10 100 putStrLn "d" doManyTests 100 10 10 putStrLn "e" doManyTests 100 10 50 putStrLn "f" doManyTests 100 10 100 #endif ------------------------------------------------------- -- some test cases #ifdef MUNKRES_DEBUG tesztek = [ teszt1, teszt2, teszt3, teszt4 ] teszt1 :: UArray (Int,Int) Int teszt1 = listArray ((1,1),(3,3)) $ concat $ transpose $ [ [ 1,2,3 ] , [ 2,4,6 ] , [ 3,6,9 ] ] teszt2 :: UArray (Int,Int) Int teszt2 = listArray ((1,1),(4,4)) $ concat $ transpose $ [ [ 14,5,8,7 ] , [ 2,12,6,5 ] , [ 7,8,3,9 ] , [ 2,4,6,10 ] ] teszt3 :: UArray (Int,Int) Int teszt3 = listArray ((1,1),(5,5)) $ concat $ transpose [ [4,5,3,2,3] , [3,2,4,3,4] , [3,3,4,4,3] , [2,4,3,2,4] , [2,1,3,4,3] ] teszt4 :: UArray (Int,Int) Int teszt4 = listArray ((1,1),(6,6)) $ concat $ transpose [ [ 3,4,5,6,2,1 ] , [ 3,0,1,2,3,4 ] , [ 7,6,0,2,1,1 ] , [ 4,4,5,0,1,2 ] , [ 0,1,0,1,0,0 ] , [ 0,3,2,2,2,0 ] ] #endif -------------------------------------------------------