{-# LANGUAGE ForeignFunctionInterface #-} module Algorithms.Hungarian ( hungarian , hungarianScore , unsafeHungarian , unsafeHungarianScore ) where import Data.List import Foreign import Foreign.C import System.IO.Unsafe foreign import ccall "hungarian" c_hungarian :: Ptr CDouble -> CInt -> CInt -> Ptr CSize -> Ptr CSize -> IO Double -- | solve the LSAP by hungarian algorithm, return assignment and score. hungarian :: [Double] -- ^ row majored flat matrix -> Int -- ^ number of rows -> Int -- ^ number of columns -> ([(Int, Int)], Double) hungarian costMatrix rows cols | length costMatrix /= rows * cols = error "Algorithms.Hungarian.hungarian: incorrect size" | otherwise = unsafeHungarian costMatrix rows cols {-# INLINE hungarian #-} -- | solve the LSAP by hungarian algorithm, return score only hungarianScore :: [Double] -> Int -> Int -> Double hungarianScore costMatrix rows cols | length costMatrix /= rows * cols = error "Algorithms.Hungarian.hungarian: incorrect size" | otherwise = unsafePerformIO $do withArray (map realToFrac costMatrix)$ \input -> do fmap realToFrac $c_hungarian input (fromIntegral rows) (fromIntegral cols) nullPtr nullPtr {-# INLINE hungarianScore #-} -- | doesn't check if the input is a valid matrix unsafeHungarian :: [Double] -- ^ row majored flat matrix -> Int -- ^ number of rows -> Int -- ^ number of columns -> ([(Int, Int)], Double) unsafeHungarian costMatrix rows cols = unsafePerformIO$ do withArray (map realToFrac costMatrix) $\input -> allocaArray n$ \from -> allocaArray n $\to -> do cost <- c_hungarian input (fromIntegral rows) (fromIntegral cols) from to froms <- peekArray n from tos <- peekArray n to return (zipWith f froms tos, realToFrac cost) where f x y = (fromIntegral x, fromIntegral y) n = min rows cols {-# INLINE unsafeHungarian #-} -- | solve the LSAP by hungarian algorithm, return score only unsafeHungarianScore :: [Double] -> Int -> Int -> Double unsafeHungarianScore costMatrix rows cols = unsafePerformIO$ do withArray (map realToFrac costMatrix) $\input -> do fmap realToFrac$ c_hungarian input (fromIntegral rows) (fromIntegral cols) nullPtr nullPtr {-# INLINE unsafeHungarianScore #-}