```
-- | The Munkres version of the Hungarian Method for weighted minimal
-- bipartite matching.
-- The implementation is based on Robert A. Pilgrim's notes,
-- <http://216.249.163.93/bob.pilgrim/445/munkres.html>
-- (mirror: <http://www.public.iastate.edu/~ddoty/HungarianAlgorithm.html>).

{-# LANGUAGE CPP, MultiParamTypeClasses, FlexibleContextsÂ #-}
module Data.Algorithm.Munkres
(
-- hungarianMethod
hungarianMethodInt
, hungarianMethodFloat
, hungarianMethodDouble
, hungarianMethodBoxed

#ifdef MUNKRES_DEBUG
, tesztek, teszt1, teszt2, teszt3, teszt4
, bruteforce
, randomArray
, doATest, doFloatTest
, doManyTests, doManyFloatTests
, main
#endif

) where

import Prelude hiding (flip)

import Data.List hiding (insert)

import Data.STRef
import Data.Array.ST

import Data.Array.IArray ()
import Data.Array.MArray
import Data.Array.Unboxed

#ifdef MUNKRES_DEBUG
import Data.Ord (comparing)
import Debug.Trace
import System.Random
#endif

-------------------------------------------------------

swap :: (Int,Int) -> (Int,Int)
swap (x,y) = (y,x)

{-
complementSort :: Int -> [Int] -> [Int]
complementSort n xs = complement n (sort xs)
-}

-- assumes that the input is sorted
complement :: Int -> [Int] -> [Int]
complement n list = worker 1 list where
worker k xxs@(x:xs) = if k>n
then []
else case compare k x of
EQ -> worker (k+1) xs
LT -> k : worker (k+1) xxs
GT -> worker k xs
worker k [] = [k..n]

{-
merge :: [Int] -> [Int] -> [Int]
merge xxs@(x:xs) yys@(y:ys) = if x <= y
then x : merge xs yys
else y : merge xxs ys
merge xs [] = xs
merge [] ys = ys
-}

-- assumes that the inputs are sorted sets
mergeUnion :: [Int] -> [Int] -> [Int]
mergeUnion xxs@(x:xs) yys@(y:ys) = case compare x y of
LT -> x : mergeUnion xs yys
EQ -> x : mergeUnion xs  ys
GT -> y : mergeUnion xxs ys
mergeUnion xs [] = xs
mergeUnion [] ys = ys

insert :: Int -> [Int] -> [Int]
insert y xxs@(x:xs) = case compare y x of
LT -> y : xxs
EQ -> xxs
GT -> x : insert y xs
insert y [] = [y]

remove :: Int -> [Int] -> [Int]
remove y xxs@(x:xs) = case compare y x of
LT -> xxs
EQ -> xs
GT -> x : remove y xs
remove _ [] = []

{-# SPECIALIZE firstJust :: [ ST s (Maybe (Int,Int)) ] -> ST s (Maybe (Int,Int)) #-}
firstJust :: Monad m => [ m (Maybe a) ] -> m (Maybe a)
firstJust (a:as) = do
x <- a
case x of
Just _ -> return x
Nothing -> firstJust as
firstJust [] = return Nothing

{-# SPECIALISE alternate :: [Int] -> ([Int],[Int]) #-}
alternate :: [a] -> ([a],[a])
alternate list = flip list [] [] where
flip (x:xs) ys zs = flop xs (x:ys) zs
flip [] ys zs = (reverse ys,reverse zs)
flop (x:xs) ys zs = flip xs ys (x:zs)
flop [] ys zs = (reverse ys,reverse zs)

-------------------------------------------------------

-- polymorphicity problem workaround experiment...

thawST :: (IArray a e, MArray (STArray s) e (ST s)) => a (Int,Int) e -> ST s (STArray s (Int,Int) e)
thawST = thaw

thawSTU :: (IArray UArray e, MArray (STUArray s) e (ST s)) => UArray (Int,Int) e -> ST s (STUArray s (Int,Int) e)
thawSTU = thaw

newSTArray_ :: MArray (STArray s) e (ST s) => ((Int,Int),(Int,Int)) -> ST s (STArray s (Int,Int) e)
newSTArray_ = newArray_

newSTUArray_ :: MArray (STUArray s) e (ST s) => ((Int,Int),(Int,Int)) -> ST s (STUArray s (Int,Int) e)
newSTUArray_ = newArray_

-------------------------------------------------------

{- SPECIALISE hungarianMethod :: UArray (Int,Int) Int    -> ([(Int,Int)],Int   ) -}
{- SPECIALISE hungarianMethod :: UArray (Int,Int) Float  -> ([(Int,Int)],Float ) -}
{- SPECIALISE hungarianMethod :: UArray (Int,Int) Double -> ([(Int,Int)],Double) -}

-- | Needs a rectangular array of /nonnegative/ weights, which
-- encode the weights on the edges of a (complete) bipartitate graph.
-- The indexing should start from @(1,1)@.
-- Returns a minimal matching, and the cost of it.
--
-- Unfortunately, GHC is opposing hard the polymorphicity of this function. I think
-- the main reasons for that is that the there is no @Unboxed@ type class, and
-- thus the contexts @IArray UArray e@ and @MArray (STUArray s) e (ST s)@ do not
-- know about each other. (And I have problems with the @forall s@ part, too).

hungarianMethodInt :: UArray (Int,Int) Int -> ([(Int,Int)],Int)
hungarianMethodInt input = runST \$ do
let ((1,1),(n,m)) = bounds input
star <- if m >= n
then do
ar <- thawSTU input
hungarianMethodShared ar
else do
ar <- newSTUArray_ ((1,1),(m,n))
forM_ [ (i,j) | i<-[1..n] , j<-[1..m] ] \$ \(i,j) -> do
writeArray ar (j,i) \$ input ! (i,j)
star' <- hungarianMethodShared ar
return (map swap star')
let costs = [ input ! ij | ij <- star ]
return (star, sum costs)

hungarianMethodFloat :: UArray (Int,Int) Float -> ([(Int,Int)],Float)
hungarianMethodFloat input = runST \$ do
let ((1,1),(n,m)) = bounds input
star <- if m >= n
then do
ar <- thawSTU input
hungarianMethodShared ar
else do
ar <- newSTUArray_ ((1,1),(m,n))
forM_ [ (i,j) | i<-[1..n] , j<-[1..m] ] \$ \(i,j) -> do
writeArray ar (j,i) \$ input ! (i,j)
star' <- hungarianMethodShared ar
return (map swap star')
let costs = [ input ! ij | ij <- star ]
return (star, sum costs)

hungarianMethodDouble :: UArray (Int,Int) Double -> ([(Int,Int)],Double)
hungarianMethodDouble input = runST \$ do
let ((1,1),(n,m)) = bounds input
star <- if m >= n
then do
ar <- thawSTU input
hungarianMethodShared ar
else do
ar <- newSTUArray_ ((1,1),(m,n))
forM_ [ (i,j) | i<-[1..n] , j<-[1..m] ] \$ \(i,j) -> do
writeArray ar (j,i) \$ input ! (i,j)
star' <- hungarianMethodShared ar
return (map swap star')
let costs = [ input ! ij | ij <- star ]
return (star, sum costs)

-- | The same as 'hungarianMethod<Type>', but uses boxed values (thus works with
-- any data type which an instance of 'Real').
-- The usage of one the unboxed versions is recommended where possible,
-- for performance reasons.
hungarianMethodBoxed :: (Real e, IArray a e) => a (Int,Int) e -> ([(Int,Int)],e)
hungarianMethodBoxed input = runST \$ do
let ((1,1),(n,m)) = bounds input
star <- if m >= n
then do
ar <- thawST input -- :: ST s (STArray s (Int,Int) e)
hungarianMethodShared ar
else do
ar <- newSTArray_ ((1,1),(m,n)) -- :: ST s (STArray s (Int,Int) e)
forM_ [ (j,i) | j<-[1..m] , i<-[1..n] ] \$ \(j,i) ->
writeArray ar (j,i) \$ input ! (i,j)
star' <- hungarianMethodShared ar
return (map swap star')
let costs = [ input ! ij | ij <- star ]
return (star, sum costs)

{-# SPECIALISE hungarianMethodShared :: STUArray s (Int,Int) Int    -> ST s [(Int,Int)] #-}
{-# SPECIALISE hungarianMethodShared :: STUArray s (Int,Int) Float  -> ST s [(Int,Int)] #-}
{-# SPECIALISE hungarianMethodShared :: STUArray s (Int,Int) Double -> ST s [(Int,Int)] #-}

hungarianMethodShared :: (Real e, MArray a e (ST s)) => a (Int,Int) e -> ST s [(Int,Int)]
hungarianMethodShared ar = do
starred <- newSTRef []
primed  <- newSTRef []
coveredRows <- newSTRef []
coveredCols <- newSTRef []
((1,1),nm) <- getBounds ar
munkers ar nm starred primed coveredRows coveredCols

-- the meat comes here...

{-# SPECIALISE munkers ::
STUArray s (Int,Int) Int -> (Int,Int)
-> STRef s [(Int,Int)] -> STRef s [(Int,Int)]
-> STRef s [Int] -> STRef s [Int]
-> ST s [(Int,Int)] #-}

{-# SPECIALISE munkers ::
STUArray s (Int,Int) Float -> (Int,Int)
-> STRef s [(Int,Int)] -> STRef s [(Int,Int)]
-> STRef s [Int] -> STRef s [Int]
-> ST s [(Int,Int)] #-}

{-# SPECIALISE munkers ::
STUArray s (Int,Int) Double -> (Int,Int)
-> STRef s [(Int,Int)] -> STRef s [(Int,Int)]
-> STRef s [Int] -> STRef s [Int]
-> ST s [(Int,Int)] #-}

munkers :: (Real e, MArray a e (ST s))
=> a (Int,Int) e -> (Int,Int)
-> STRef s [(Int,Int)] -> STRef s [(Int,Int)]
-> STRef s [Int] -> STRef s [Int]
-> ST s [(Int,Int)]

munkers ar (n,m) starred primed coveredRows coveredCols = (step1 >> step2 >> step3) where

kk = min n m

step3 = do
--printArray "step3"
let colsC' = mergeUnion colsC (sort \$ map snd star) -- nub \$ colsC ++ (map snd star)
if length colsC' == kk
then return star
else do
writeSTRef coveredCols colsC'
step4

step4 = do
--printArray "step4"
--printPrimStar "step4"
let rowsNC = complement n rowsC
colsNC = complement m colsC
let f ij = do
if x==0 then return (Just ij) else return Nothing
mp <- firstJust [ f (i,j) | i<-rowsNC, j<-colsNC ]
--print mp
case mp of
Nothing -> do
es <- forM [ (i,j) | i<-rowsNC, j<-colsNC ] \$ \ij -> readArray ar ij
step6 (minimum es)
Just ij@(i,_) -> do
modifySTRef primed (ij:)
case find (\(p,_) -> p==i) star of
Nothing -> step5 ij
Just (_,q) -> do
modifySTRef coveredRows (insert i)
modifySTRef coveredCols (remove q)
step4
{-
case filter (\(p,_) -> p==i) star of
[] -> step5 ij
[(p,q)] -> do
modifySTRef coveredRows (insert i)
modifySTRef coveredCols (remove q)
step4
_ -> error "Munkres/step4: should not happen"
-}

step5 pq = do
--printArray "step5"
--printPrimStar "step5"
alt <- step5a star prim pq [pq]
let (ps,ss) = alternate alt
writeSTRef starred \$ (star \\ ss) ++ ps
writeSTRef primed []
writeSTRef coveredRows []
writeSTRef coveredCols []
step3

step5a :: [(Int,Int)] -> [(Int,Int)] -> (Int,Int) -> [(Int,Int)] -> ST s [(Int,Int)]
step5a star prim (_,q) xs =
case findStarred q of
Just (i,_) -> do
let (_,j) = findPrimed i
step5a star prim (i,j) ((i,j):(i,q):xs)
Nothing -> return xs
where
findStarred j =      find (\(_,c) -> (c==j)) star
findPrimed  i = case find (\(r,_) -> (r==i)) prim of
Just x  -> x
Nothing -> error \$ "Munkres/findPrimed: should not happen (" ++ show prim ++ " " ++ show i ++ ")"

step2 =
do
--printArray "step2"
s <- foldM worker [] [ (i,j) | i<-[1..n], j<-[1..m] ]
writeSTRef starred s
where
worker star ij@(i,j) = do
if x==0
then case filter (\(a,b) -> (a==i) || (b==j)) star of
[] -> return (ij : star)
_ -> return star
else return star

step6 c = do
--printArray "step6"
--printPrimStar "step6"
let rowsNC = complement n rowsC
colsNC = complement m colsC
forM rowsNC \$ \i ->
forM colsNC \$ \j -> do
writeArray ar (i,j) (x-c)
forM rowsC \$ \i ->
forM colsC \$ \j -> do
writeArray ar (i,j) (x+c)
step4

step1 = mapM_ subRow [1..n]
subRow i = do
row <- forM [1..m] \$ \j -> readArray ar (i,j)
let y = minimum row
forM [1..m] \$ \j -> do
let ij = (i,j)
writeArray ar ij (x-y)

{-
-- debugging

printArray s = do
putStrLn ""
x <- freeze ar :: IO (UArray (Int,Int) Int)
print (s,x)

printPrimStar s = do
putStrLn s
print ("starred",star)
print ("primed",prim)
print ("cov. rows",crows)
print ("cov. cols",ccols)
-}

-------------------------------------------------------

#ifdef MUNKRES_DEBUG

debug x y = trace (show x) y

-- brute-force algorithm for sanity checking

bruteforce :: UArray (Int,Int) Int -> ([(Int,Int)],Int)
bruteforce input = {- debug all \$ -} minimumBy (comparing snd) allWithCosts where
((1,1),(n,m)) = bounds input
k = min n m
g = if n<m then id else swap
lookup = (input!)
all = f [1..min n m] [1..max n m]
f [] _ = [[]]
f _ [] = [[]]
f (i:is) js = concat [ map ((i,j):) (f is (remove j js)) | j<-js ]
withCost ijs' = let ijs = map g ijs' in ( ijs , sum (map lookup ijs) )
allWithCosts = map withCost all

-- random array
-- why on earth is 'randomR' using the opposite convention of what 'mapAccumL' uses ?!?!?!?

randomR' g1 iv = let (x,g2) = randomR iv g1 in (g2,x)

randomArray :: RandomGen g => Int -> Int -> g -> (UArray (Int,Int) Int , g)
randomArray maxsize maxelem rnd0 = (ar,rnd3) where
(n,rnd1) = randomR (1,maxsize) rnd0
(m,rnd2) = randomR (1,maxsize) rnd1
(rnd3,es) = mapAccumL randomR' rnd2 \$ replicate (n*m) (0::Int,maxelem)
ar = listArray ((1,1),(n,m)) es

-- correctness testing

doATest maxsize maxelem rnd0 _ = do
let (ar,rnd1) = randomArray maxsize maxelem rnd0
(xs,c) = hungarianMethodInt ar
sol1 = (sortBy (comparing fst) xs, c)
sol2 = bruteforce ar
when (snd sol1 /= snd sol2) \$ do
print ar
putStrLn \$ show (snd \$ bounds ar) ++ ": " ++ show (snd sol1) ++ " " ++ show (snd sol2)
putStrLn \$ "hun -> " ++ show (fst sol1) ++ "\nbrt -> " ++ show (fst sol2)
return rnd1

-- int vs float testing (mainly because of the copypasted code)

doFloatTest maxsize maxelem rnd0 _ = do
let (ar,rnd1) = randomArray maxsize maxelem rnd0
sol1 = hungarianMethodInt ar
sol2 = hungarianMethodFloat  (amap fromIntegral ar)
sol3 = hungarianMethodDouble (amap fromIntegral ar)
print (snd sol1, snd sol2, snd sol3)
return rnd1

-- do lots of tests

doManyTests n maxsize maxelem = getStdGen >>= \rnd ->
foldM_ (doATest maxsize maxelem) rnd [1..n]

doManyFloatTests n maxsize maxelem = getStdGen >>= \rnd ->
foldM_ (doFloatTest maxsize maxelem) rnd [1..n]

main = do
putStrLn "a"
doManyTests 50 10 10
putStrLn "b"
doManyTests 50 10 50
putStrLn "c"
doManyTests 50 10 100
putStrLn "d"
doManyTests 100 10 10
putStrLn "e"
doManyTests 100 10 50
putStrLn "f"
doManyTests 100 10 100

#endif

-------------------------------------------------------

-- some test cases

#ifdef MUNKRES_DEBUG

tesztek = [ teszt1, teszt2, teszt3, teszt4 ]

teszt1 :: UArray (Int,Int) Int
teszt1 = listArray ((1,1),(3,3)) \$ concat \$ transpose \$
[ [ 1,2,3 ]
, [ 2,4,6 ]
, [ 3,6,9 ]
]

teszt2 :: UArray (Int,Int) Int
teszt2 = listArray ((1,1),(4,4)) \$ concat \$ transpose \$
[ [ 14,5,8,7 ]
, [ 2,12,6,5 ]
, [ 7,8,3,9  ]
, [ 2,4,6,10 ]
]

teszt3 :: UArray (Int,Int) Int
teszt3 = listArray ((1,1),(5,5)) \$ concat \$ transpose
[ [4,5,3,2,3]
, [3,2,4,3,4]
, [3,3,4,4,3]
, [2,4,3,2,4]
, [2,1,3,4,3]
]

teszt4 :: UArray (Int,Int) Int
teszt4 = listArray ((1,1),(6,6)) \$ concat \$ transpose
[ [ 3,4,5,6,2,1 ]
, [ 3,0,1,2,3,4 ]
, [ 7,6,0,2,1,1 ]
, [ 4,4,5,0,1,2 ]
, [ 0,1,0,1,0,0 ]
, [ 0,3,2,2,2,0 ]
]

#endif

-------------------------------------------------------
```