{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | An implementation of bipartite matching using the Ford-Fulkerson algorithm.
module Test.Predicates.Internal.FlowMatcher where

import Control.Monad (forM_, when)
import Control.Monad.ST (ST)
import Data.Array.IArray (Array, assocs, elems)
import Data.Array.ST
  ( MArray (newArray),
    STArray,
    newListArray,
    readArray,
    runSTArray,
    writeArray,
  )
import Data.List ((\\))
import Data.Maybe (catMaybes)

-- $setup
-- >>> :set -Wno-type-defaults

-- | Computes the best bipartite matching of the elements in the two lists,
-- given the compatibility function.
--
-- Returns matched pairs, then unmatched lhs elements, then unmatched rhs
-- elements.
--
-- >>> bipartiteMatching (==) [1 .. 5] [6, 5 .. 2]
-- ([(2,2),(3,3),(4,4),(5,5)],[1],[6])
bipartiteMatching ::
  forall a b. (a -> b -> Bool) -> [a] -> [b] -> ([(a, b)], [a], [b])
bipartiteMatching :: forall a b. (a -> b -> Bool) -> [a] -> [b] -> ([(a, b)], [a], [b])
bipartiteMatching a -> b -> Bool
compatible [a]
xs [b]
ys = ([(a, b)]
matchedPairs, [a]
unmatchedX, [b]
unmatchedY)
  where
    matchedPairs :: [(a, b)]
    matchedPairs :: [(a, b)]
matchedPairs = [([a]
xs forall a. [a] -> Int -> a
!! Int
i, [b]
ys forall a. [a] -> Int -> a
!! Int
j) | (Int
i, Just Int
j) <- forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> [(i, e)]
assocs Array Int (Maybe Int)
matches]

    unmatchedX :: [a]
    unmatchedX :: [a]
unmatchedX = [[a]
xs forall a. [a] -> Int -> a
!! Int
i | (Int
i, Maybe Int
Nothing) <- forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> [(i, e)]
assocs Array Int (Maybe Int)
matches]

    unmatchedY :: [b]
    unmatchedY :: [b]
unmatchedY = [[b]
ys forall a. [a] -> Int -> a
!! Int
j | Int
j <- [Int
0 .. Int
numYs forall a. Num a => a -> a -> a
- Int
1] forall a. Eq a => [a] -> [a] -> [a]
\\ forall a. [Maybe a] -> [a]
catMaybes (forall (a :: * -> * -> *) e i. (IArray a e, Ix i) => a i e -> [e]
elems Array Int (Maybe Int)
matches)]

    matches :: Array Int (Maybe Int)
    matches :: Array Int (Maybe Int)
matches = forall i e. (forall s. ST s (STArray s i e)) -> Array i e
runSTArray forall s. ST s (STArray s Int (Maybe Int))
st

    st :: forall s. ST s (STArray s Int (Maybe Int))
    st :: forall s. ST s (STArray s Int (Maybe Int))
st = do
      STArray s (Int, Int) Bool
compatArray <-
        forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> [e] -> m (a i e)
newListArray
          ((Int
0, Int
0), (Int
numXs forall a. Num a => a -> a -> a
- Int
1, Int
numYs forall a. Num a => a -> a -> a
- Int
1))
          [a -> b -> Bool
compatible a
x b
y | a
x <- [a]
xs, b
y <- [b]
ys] ::
          ST s (STArray s (Int, Int) Bool)
      STArray s Int (Maybe Int)
matchArray <-
        forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
numXs forall a. Num a => a -> a -> a
- Int
1) forall a. Maybe a
Nothing ::
          ST s (STArray s Int (Maybe Int))
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
numYs forall a. Num a => a -> a -> a
- Int
1] forall a b. (a -> b) -> a -> b
$ \Int
j -> do
        STArray s Int Bool
seen <-
          forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
numXs forall a. Num a => a -> a -> a
- Int
1) Bool
False :: ST s (STArray s Int Bool)
        Bool
_ <- forall s.
STArray s (Int, Int) Bool
-> Int
-> STArray s Int (Maybe Int)
-> STArray s Int Bool
-> ST s Bool
go STArray s (Int, Int) Bool
compatArray Int
j STArray s Int (Maybe Int)
matchArray STArray s Int Bool
seen
        forall (m :: * -> *) a. Monad m => a -> m a
return ()

      forall (m :: * -> *) a. Monad m => a -> m a
return STArray s Int (Maybe Int)
matchArray

    numXs, numYs :: Int
    numXs :: Int
numXs = forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs
    numYs :: Int
numYs = forall (t :: * -> *) a. Foldable t => t a -> Int
length [b]
ys

    go ::
      forall s.
      STArray s (Int, Int) Bool ->
      Int ->
      STArray s Int (Maybe Int) ->
      STArray s Int Bool ->
      ST s Bool
    go :: forall s.
STArray s (Int, Int) Bool
-> Int
-> STArray s Int (Maybe Int)
-> STArray s Int Bool
-> ST s Bool
go STArray s (Int, Int) Bool
compatArray Int
j STArray s Int (Maybe Int)
matchArray STArray s Int Bool
seen = Bool -> Int -> ST s Bool
loop Bool
False Int
0
      where
        loop :: Bool -> Int -> ST s Bool
loop Bool
True Int
_ = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        loop Bool
_ Int
i
          | Int
i forall a. Eq a => a -> a -> Bool
== Int
numXs = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
          | Bool
otherwise = do
            Bool
compat <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s (Int, Int) Bool
compatArray (Int
i, Int
j)
            Bool
isSeen <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int Bool
seen Int
i
            Bool
replace <-
              if Bool
isSeen Bool -> Bool -> Bool
|| Bool -> Bool
not Bool
compat
                then forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
                else do
                  forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int Bool
seen Int
i Bool
True
                  Maybe Int
matchNum <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STArray s Int (Maybe Int)
matchArray Int
i
                  case Maybe Int
matchNum of
                    Maybe Int
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
                    Just Int
n -> forall s.
STArray s (Int, Int) Bool
-> Int
-> STArray s Int (Maybe Int)
-> STArray s Int Bool
-> ST s Bool
go STArray s (Int, Int) Bool
compatArray Int
n STArray s Int (Maybe Int)
matchArray STArray s Int Bool
seen
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
replace forall a b. (a -> b) -> a -> b
$ forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STArray s Int (Maybe Int)
matchArray Int
i (forall a. a -> Maybe a
Just Int
j)
            Bool -> Int -> ST s Bool
loop Bool
replace (Int
i forall a. Num a => a -> a -> a
+ Int
1)