{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
module StableMarriage.GaleShapley
       ( Men(..)
       , Women(..)
       , World
       , meets

       -- re-export
       , PO.Ordering(..)
       ) where

import Prelude hiding (Ordering(..), compare)
import Control.Arrow ((&&&))
import Data.List (sortOn, groupBy, splitAt)
import Data.Poset as PO (Ordering(..), sortBy')
import Data.Function (on)

class Men m where
  type W m :: *
  loves :: m -> [W m]
  forget :: m -> m

class (Ord w, Men m, w ~ W m) => Women m w where
  acceptable :: w -> m -> Bool
  compare :: w -> m -> m -> PO.Ordering
  limit :: w -> [m] -> Int
  limit w
_ [m]
_ = Int
1

type World w m = (Men m, Women m w, w ~ W m) => ([(w, [m])], [m])

marriage :: World w m -> World w m
marriage :: World w m -> World w m
marriage World w m
x = let x' :: ([(w, [m])], [m])
x' = World w m -> World w m
forall w m. World w m -> World w m
counter (World w m -> World w m) -> World w m -> World w m
forall a b. (a -> b) -> a -> b
$ World w m -> World w m
forall w m. World w m -> World w m
attack World w m
x
             in if World (W m) m -> Bool
forall m w. (Men m, Women m w, w ~ W m) => World w m -> Bool
stable ([(w, [m])], [m])
World (W m) m
x'
                then ([(w, [m])], [m])
x'
                else World w m -> World w m
forall w m. World w m -> World w m
marriage ([(w, [m])], [m])
World w m
x'

stable :: (Men m, Women m w, w ~ W m) => World w m -> Bool
stable :: World w m -> Bool
stable (cs, ms) = (m -> Bool) -> [m] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all m -> Bool
forall m. Men m => m -> Bool
resigned [m]
ms
    where
      resigned :: Men m => m -> Bool
      resigned :: m -> Bool
resigned = [W m] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([W m] -> Bool) -> (m -> [W m]) -> m -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m -> [W m]
forall m. Men m => m -> [W m]
loves
      satisfy :: (Men m, Women m w, w ~ W m) => (w, [m]) -> Bool
      satisfy :: (w, [m]) -> Bool
satisfy (w
w, [m]
ms) = w -> [m] -> Int
forall m w. Women m w => w -> [m] -> Int
limit w
w [m]
ms Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [m] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [m]
ms

attack :: World w m -> World w m
attack :: World w m -> World w m
attack (cs, ms) = ([(w, [m])]
cs', [m]
ms')
    where
      cs' :: [(w, [m])]
cs' = [(w, [m])] -> [(w, [m])] -> [(w, [m])]
forall m w.
(Men m, Women m w, w ~ W m) =>
[(w, [m])] -> [(w, [m])] -> [(w, [m])]
join [(w, [m])]
cs ([m] -> [(w, [m])]
forall w m. (Ord w, Men m, Women m w, w ~ W m) => [m] -> [(w, [m])]
propose [m]
ms)
      ms' :: [m]
ms' = [m] -> [m]
forall m. Men m => [m] -> [m]
despair [m]
ms

propose :: (Ord w, Men m, Women m w, w ~ W m) => [m] -> [(w, [m])]
propose :: [m] -> [(w, [m])]
propose = [[(w, m)]] -> [(w, [m])]
forall m w. (Men m, Women m w, w ~ W m) => [[(w, m)]] -> [(w, [m])]
gather ([[(w, m)]] -> [(w, [m])])
-> ([m] -> [[(w, m)]]) -> [m] -> [(w, [m])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [m] -> [[(w, m)]]
forall m w. (Men m, Women m w, w ~ W m, Ord w) => [m] -> [[(w, m)]]
competes
    where
      competes :: (Men m, Women m w, w ~ W m, Ord w) => [m] -> [[(w, m)]]
      competes :: [m] -> [[(w, m)]]
competes = ((w, m) -> (w, m) -> Bool) -> [(w, m)] -> [[(w, m)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (w -> w -> Bool
forall a. Eq a => a -> a -> Bool
(==) (w -> w -> Bool) -> ((w, m) -> w) -> (w, m) -> (w, m) -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (w, m) -> w
forall a b. (a, b) -> a
fst) ([(w, m)] -> [[(w, m)]]) -> ([m] -> [(w, m)]) -> [m] -> [[(w, m)]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((w, m) -> w) -> [(w, m)] -> [(w, m)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (w, m) -> w
forall a b. (a, b) -> a
fst ([(w, m)] -> [(w, m)]) -> ([m] -> [(w, m)]) -> [m] -> [(w, m)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (m -> [(w, m)]) -> [m] -> [(w, m)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap m -> [(w, m)]
forall m w. (Men m, Women m w, w ~ W m) => m -> [(w, m)]
next
          where
            next :: (Men m, Women m w, w ~ W m) => m -> [(w, m)]
            next :: m -> [(w, m)]
next m
m = let xs :: [W m]
xs = m -> [W m]
forall m. Men m => m -> [W m]
loves m
m
                     in if [w] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [w]
[W m]
xs
                        then []
                        else [([w] -> w
forall a. [a] -> a
head [w]
[W m]
xs, m
m)]
gather :: (Men m, Women m w, w ~ W m) => [[(w, m)]] -> [(w, [m])]
gather :: [[(w, m)]] -> [(w, [m])]
gather = ([(w, m)] -> (w, [m])) -> [[(w, m)]] -> [(w, [m])]
forall a b. (a -> b) -> [a] -> [b]
map [(w, m)] -> (w, [m])
forall m w. (Men m, Women m w, w ~ W m) => [(w, m)] -> (w, [m])
sub
          where
            sub :: (Men m, Women m w, w ~ W m) => [(w, m)] -> (w, [m])
            sub :: [(w, m)] -> (w, [m])
sub cs :: [(w, m)]
cs@((w
w, m
m):[(w, m)]
_) = (w
w, ((w, m) -> m) -> [(w, m)] -> [m]
forall a b. (a -> b) -> [a] -> [b]
map (w, m) -> m
forall a b. (a, b) -> b
snd [(w, m)]
cs)

join :: (Men m, Women m w, w ~ W m) => [(w, [m])] -> [(w, [m])] -> [(w, [m])]
join :: [(w, [m])] -> [(w, [m])] -> [(w, [m])]
join [(w, [m])]
cs [(w, [m])]
xs = [[(w, [m])]] -> [(w, [m])]
forall m w.
(Men m, Women m w, w ~ W m) =>
[[(w, [m])]] -> [(w, [m])]
gather ([[(w, [m])]] -> [(w, [m])]) -> [[(w, [m])]] -> [(w, [m])]
forall a b. (a -> b) -> a -> b
$ ((w, [m]) -> (w, [m]) -> Bool) -> [(w, [m])] -> [[(w, [m])]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (w -> w -> Bool
forall a. Eq a => a -> a -> Bool
(==) (w -> w -> Bool) -> ((w, [m]) -> w) -> (w, [m]) -> (w, [m]) -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (w, [m]) -> w
forall a b. (a, b) -> a
fst) ([(w, [m])] -> [[(w, [m])]]) -> [(w, [m])] -> [[(w, [m])]]
forall a b. (a -> b) -> a -> b
$ ((w, [m]) -> w) -> [(w, [m])] -> [(w, [m])]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (w, [m]) -> w
forall a b. (a, b) -> a
fst ([(w, [m])] -> [(w, [m])]) -> [(w, [m])] -> [(w, [m])]
forall a b. (a -> b) -> a -> b
$ [(w, [m])]
cs [(w, [m])] -> [(w, [m])] -> [(w, [m])]
forall a. [a] -> [a] -> [a]
++ [(w, [m])]
xs
    where
      gather :: (Men m, Women m w, w ~ W m) => [[(w, [m])]] -> [(w, [m])]
      gather :: [[(w, [m])]] -> [(w, [m])]
gather = ([(w, [m])] -> (w, [m])) -> [[(w, [m])]] -> [(w, [m])]
forall a b. (a -> b) -> [a] -> [b]
map [(w, [m])] -> (w, [m])
forall m w. (Men m, Women m w, w ~ W m) => [(w, [m])] -> (w, [m])
sub
          where
            sub :: (Men m, Women m w, w ~ W m) => [(w, [m])] -> (w, [m])
            sub :: [(w, [m])] -> (w, [m])
sub cs :: [(w, [m])]
cs@((w
w, [m]
m):[(w, [m])]
_) = (w
w, ((w, [m]) -> [m]) -> [(w, [m])] -> [m]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (w, [m]) -> [m]
forall a b. (a, b) -> b
snd [(w, [m])]
cs)

despair :: Men m => [m] -> [m]
despair :: [m] -> [m]
despair = (m -> Bool) -> [m] -> [m]
forall a. (a -> Bool) -> [a] -> [a]
filter ([W m] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([W m] -> Bool) -> (m -> [W m]) -> m -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m -> [W m]
forall m. Men m => m -> [W m]
loves)

counter :: World w m -> World w m
counter :: World w m -> World w m
counter (cs, ms) = ([(w, [m])]
cs', [m]
ms'')
    where
      ([(w, [m])]
cs', [m]
ms') = [(w, [m])] -> World w m
forall m w. (Men m, Women m w, w ~ W m) => [(w, [m])] -> World w m
choice [(w, [m])]
cs
      ms'' :: [m]
ms'' = [m]
ms [m] -> [m] -> [m]
forall a. [a] -> [a] -> [a]
++ [m] -> [m]
forall m. Men m => [m] -> [m]
heartbreak [m]
ms'

      heartbreak :: Men m => [m] -> [m]
      heartbreak :: [m] -> [m]
heartbreak = (m -> m) -> [m] -> [m]
forall a b. (a -> b) -> [a] -> [b]
map m -> m
forall m. Men m => m -> m
forget


choice :: (Men m, Women m w, w ~ W m) => [(w, [m])] -> World w m
choice :: [(w, [m])] -> World w m
choice = [((w, [m]), [m])] -> ([(w, [m])], [m])
forall m w.
(Men m, Women m w, w ~ W m) =>
[((w, [m]), [m])] -> World w m
gather ([((w, [m]), [m])] -> ([(w, [m])], [m]))
-> ([(w, [m])] -> [((w, [m]), [m])])
-> [(w, [m])]
-> ([(w, [m])], [m])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((w, [m]) -> ((w, [m]), [m])) -> [(w, [m])] -> [((w, [m]), [m])]
forall a b. (a -> b) -> [a] -> [b]
map (w, [m]) -> ((w, [m]), [m])
forall m w.
(Men m, Women m w, w ~ W m) =>
(w, [m]) -> ((w, [m]), [m])
judge
    where
      judge :: (Men m, Women m w, w ~ W m) => (w, [m]) -> ((w, [m]), [m])
      judge :: (w, [m]) -> ((w, [m]), [m])
judge (w
w, [m]
ms) = let (Int
n, m -> Bool
p, m -> m -> Ordering
cmp) = (w -> [m] -> Int
forall m w. Women m w => w -> [m] -> Int
limit w
w [m]
ms, w -> m -> Bool
forall m w. Women m w => w -> m -> Bool
acceptable w
w, w -> m -> m -> Ordering
forall m w. Women m w => w -> m -> m -> Ordering
compare w
w)
                          ([m]
cs, [m]
rs) = Int -> [m] -> ([m], [m])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n ([m] -> ([m], [m])) -> [m] -> ([m], [m])
forall a b. (a -> b) -> a -> b
$ (m -> Bool, m -> m -> Ordering) -> [m] -> [m]
forall a. (a -> Bool, a -> a -> Ordering) -> [a] -> [a]
sortBy' (m -> Bool
p, m -> m -> Ordering
cmp) [m]
ms
                          out :: [m]
out = (m -> Bool) -> [m] -> [m]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (m -> Bool) -> m -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m -> Bool
p) [m]
ms
                      in ((w
w, [m]
cs), [m]
rs [m] -> [m] -> [m]
forall a. [a] -> [a] -> [a]
++ [m]
out)
      gather :: (Men m, Women m w, w ~ W m) => [((w, [m]), [m])] -> World w m
      gather :: [((w, [m]), [m])] -> World w m
gather = (((w, [m]), [m]) -> (w, [m])) -> [((w, [m]), [m])] -> [(w, [m])]
forall a b. (a -> b) -> [a] -> [b]
map ((w, [m]), [m]) -> (w, [m])
forall a b. (a, b) -> a
fst ([((w, [m]), [m])] -> [(w, [m])])
-> ([((w, [m]), [m])] -> [m])
-> [((w, [m]), [m])]
-> ([(w, [m])], [m])
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& (((w, [m]), [m]) -> [m]) -> [((w, [m]), [m])] -> [m]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((w, [m]), [m]) -> [m]
forall a b. (a, b) -> b
snd

meets :: (Men m, Women m w, w ~ W m) => [m] -> [w] -> World w m
meets :: [m] -> [w] -> World w m
meets [m]
ms [w]
ws = World w m -> World w m
forall w m. World w m -> World w m
marriage ([w] -> [[m]] -> [(w, [m])]
forall a b. [a] -> [b] -> [(a, b)]
zip [w]
ws ([m] -> [[m]]
forall a. a -> [a]
repeat []), [m]
ms)