module Math.Optimization.Kantorovich
  ( 
    RandomVariable
  , KantorovichValue
  , KantorovichSolution
  , KantorovichResult
  , kantorovich
  , prettyKantorovichSolution
  ) where
import           Prelude                hiding   ( EQ )
import           Control.Monad.Logger            (
                                                   runStdoutLoggingT
                                                 , filterLogger
                                                 )
import           Data.List.Extra                 (
                                                   nubSort
                                                 )
import           Data.Map.Strict                 ( 
                                                   fromList
                                                 , mapKeys
                                                 , singleton
                                                 , Map
                                                 )
import qualified Data.Map.Strict                 as DM
import           Data.Matrix                     (
                                                   fromLists
                                                 , prettyMatrix
                                                 )
import           Data.Maybe                      (
                                                   isJust
                                                 , fromJust
                                                 )
import           Data.Tuple.Extra                (
                                                   (***)
                                                 )
import           Linear.Simplex.Solver.TwoPhase  (
                                                   twoPhaseSimplex
                                                 )
import           Linear.Simplex.Types            (
                                                   Result ( .. )
                                                 , PolyConstraint ( .. )
                                                 , ObjectiveFunction ( .. )
                                                 )
import           Linear.Simplex.Util             (
                                                   simplifySystem
                                                 )
type RandomVariable a = Map a Rational 
type KantorovichValue = Rational
type KantorovichSolution a b = RandomVariable (a, b)
type KantorovichResult a b = (KantorovichValue, KantorovichSolution a b) 
stack :: Int -> (Int, Int) -> Int
stack :: Int -> (Int, Int) -> Int
stack Int
ncol (Int
i, Int
j) = (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
ncol Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j
unstack :: Int -> Int -> (Int, Int)
unstack :: Int -> Int -> (Int, Int)
unstack Int
ncol Int
k = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
quotRem (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
ncol
prettyKantorovichSolution :: 
  (Ord a, Ord b)
  => Maybe (KantorovichResult a b) 
  -> String
prettyKantorovichSolution :: forall a b.
(Ord a, Ord b) =>
Maybe (KantorovichResult a b) -> String
prettyKantorovichSolution Maybe (KantorovichResult a b)
maybeKantorovichResult = 
  if Maybe (KantorovichResult a b) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (KantorovichResult a b)
maybeKantorovichResult then Matrix Rational -> String
forall a. Show a => Matrix a -> String
prettyMatrix Matrix Rational
m else String
""
  where
    kantorovichSolution :: KantorovichSolution a b
kantorovichSolution = KantorovichResult a b -> KantorovichSolution a b
forall a b. (a, b) -> b
snd (KantorovichResult a b -> KantorovichSolution a b)
-> KantorovichResult a b -> KantorovichSolution a b
forall a b. (a -> b) -> a -> b
$ Maybe (KantorovichResult a b) -> KantorovichResult a b
forall a. HasCallStack => Maybe a -> a
fromJust Maybe (KantorovichResult a b)
maybeKantorovichResult
    pairs :: [(a, b)]
pairs = KantorovichSolution a b -> [(a, b)]
forall k a. Map k a -> [k]
DM.keys KantorovichSolution a b
kantorovichSolution
    ([a]
rows, [b]
cols) = ([a] -> [a]
forall a. Ord a => [a] -> [a]
nubSort ([a] -> [a]) -> ([b] -> [b]) -> ([a], [b]) -> ([a], [b])
forall a a' b b'. (a -> a') -> (b -> b') -> (a, b) -> (a', b')
*** [b] -> [b]
forall a. Ord a => [a] -> [a]
nubSort) ([(a, b)] -> ([a], [b])
forall a b. [(a, b)] -> ([a], [b])
unzip [(a, b)]
pairs)
    m :: Matrix Rational
m = [[Rational]] -> Matrix Rational
forall a. [[a]] -> Matrix a
fromLists 
      [ 
        [ 
          KantorovichSolution a b
kantorovichSolution KantorovichSolution a b -> (a, b) -> Rational
forall k a. Ord k => Map k a -> k -> a
DM.! (a
i, b
j) | b
j <- [b]
cols
        ] 
        | a
i <- [a]
rows 
      ]
kantorovich :: 
  (Ord a, Ord b)
  => RandomVariable a     
  -> RandomVariable b     
  -> ((a, b) -> Rational) 
  -> Bool                 
  -> IO (Maybe (KantorovichResult a b))
kantorovich :: forall a b.
(Ord a, Ord b) =>
RandomVariable a
-> RandomVariable b
-> ((a, b) -> Rational)
-> Bool
-> IO (Maybe (KantorovichResult a b))
kantorovich RandomVariable a
rvA RandomVariable b
rvB (a, b) -> Rational
dist Bool
info = do 
  Maybe Result
maybeResult <- LoggingT IO (Maybe Result) -> IO (Maybe Result)
forall (m :: * -> *) a. MonadIO m => LoggingT m a -> m a
runStdoutLoggingT (LoggingT IO (Maybe Result) -> IO (Maybe Result))
-> LoggingT IO (Maybe Result) -> IO (Maybe Result)
forall a b. (a -> b) -> a -> b
$ (LogSource -> LogLevel -> Bool)
-> LoggingT IO (Maybe Result) -> LoggingT IO (Maybe Result)
forall (m :: * -> *) a.
(LogSource -> LogLevel -> Bool) -> LoggingT m a -> LoggingT m a
filterLogger (\LogSource
_ LogLevel
_ -> Bool
info) (LoggingT IO (Maybe Result) -> LoggingT IO (Maybe Result))
-> LoggingT IO (Maybe Result) -> LoggingT IO (Maybe Result)
forall a b. (a -> b) -> a -> b
$ 
                  ObjectiveFunction -> [PolyConstraint] -> LoggingT IO (Maybe Result)
forall (m :: * -> *).
(MonadIO m, MonadLogger m) =>
ObjectiveFunction -> [PolyConstraint] -> m (Maybe Result)
twoPhaseSimplex ObjectiveFunction
objFunc [PolyConstraint]
polyConstraints
  Maybe (KantorovichResult a b) -> IO (Maybe (KantorovichResult a b))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (KantorovichResult a b)
 -> IO (Maybe (KantorovichResult a b)))
-> Maybe (KantorovichResult a b)
-> IO (Maybe (KantorovichResult a b))
forall a b. (a -> b) -> a -> b
$ Maybe Result -> Maybe (KantorovichResult a b)
getObjectiveValueAndSolution Maybe Result
maybeResult
  where
    ncol :: Int
ncol = RandomVariable b -> Int
forall k a. Map k a -> Int
DM.size RandomVariable b
rvB
    as :: [a]
as = RandomVariable a -> [a]
forall k a. Map k a -> [k]
DM.keys RandomVariable a
rvA
    mu :: [Rational]
mu = RandomVariable a -> [Rational]
forall k a. Map k a -> [a]
DM.elems RandomVariable a
rvA
    bs :: [b]
bs = RandomVariable b -> [b]
forall k a. Map k a -> [k]
DM.keys RandomVariable b
rvB
    nu :: [Rational]
nu = RandomVariable b -> [Rational]
forall k a. Map k a -> [a]
DM.elems RandomVariable b
rvB
    objFunc :: ObjectiveFunction
objFunc = [a] -> [b] -> ((a, b) -> Rational) -> ObjectiveFunction
forall a b. [a] -> [b] -> ((a, b) -> Rational) -> ObjectiveFunction
kantorovichObjectiveFunction [a]
as [b]
bs (a, b) -> Rational
dist
    polyConstraints :: [PolyConstraint]
polyConstraints = [Rational] -> [Rational] -> [PolyConstraint]
kantorovichConstraints [Rational]
mu [Rational]
nu
    getObjectiveValueAndSolution :: Maybe Result -> Maybe (KantorovichResult a b)
getObjectiveValueAndSolution Maybe Result
maybeResult = 
      case Maybe Result
maybeResult of
        Just (Result Int
var VarLitMap
varLitMap) -> 
          KantorovichResult a b -> Maybe (KantorovichResult a b)
forall a. a -> Maybe a
Just (
                 VarLitMap
varLitMap VarLitMap -> Int -> Rational
forall k a. Ord k => Map k a -> k -> a
DM.! Int
var
               , (Int -> (a, b)) -> VarLitMap -> Map (a, b) Rational
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
mapKeys (\Int
k -> (([a] -> Int -> a
forall a. HasCallStack => [a] -> Int -> a
(!!) [a]
as) (Int -> a) -> (Int -> b) -> (Int, Int) -> (a, b)
forall a a' b b'. (a -> a') -> (b -> b') -> (a, b) -> (a', b')
*** ([b] -> Int -> b
forall a. HasCallStack => [a] -> Int -> a
(!!) [b]
bs)) (Int -> Int -> (Int, Int)
unstack Int
ncol Int
k)) 
                          (Int -> VarLitMap -> VarLitMap
forall k a. Ord k => k -> Map k a -> Map k a
DM.delete Int
var VarLitMap
varLitMap) 
               )
        Maybe Result
Nothing -> Maybe (KantorovichResult a b)
forall a. Maybe a
Nothing
kantorovichObjectiveFunction :: 
  [a] -> [b] -> ((a, b) -> Rational) -> ObjectiveFunction
kantorovichObjectiveFunction :: forall a b. [a] -> [b] -> ((a, b) -> Rational) -> ObjectiveFunction
kantorovichObjectiveFunction [a]
as [b]
bs (a, b) -> Rational
dist = Min 
  { 
    $sel:objective:Max :: VarLitMap
objective = [(Int, Rational)] -> VarLitMap
forall k a. Ord k => [(k, a)] -> Map k a
fromList 
      [ (Int -> (Int, Int) -> Int
stack Int
n (Int
i, Int
j), (a, b) -> Rational
dist ([a]
as[a] -> Int -> a
forall a. HasCallStack => [a] -> Int -> a
!!(Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1), [b]
bs[b] -> Int -> b
forall a. HasCallStack => [a] -> Int -> a
!!(Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))) | Int
i <- [Int]
rows, Int
j <- [Int]
cols ]
  }
  where
    n :: Int
n = [b] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [b]
bs
    rows :: [Int]
rows = [ Int
1 .. [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
as ]
    cols :: [Int]
cols = [ Int
1 .. Int
n ]
kantorovichConstraints :: [Rational] -> [Rational] -> [PolyConstraint]
kantorovichConstraints :: [Rational] -> [Rational] -> [PolyConstraint]
kantorovichConstraints [Rational]
mu [Rational]
nu = 
  [PolyConstraint] -> [PolyConstraint]
simplifySystem ([PolyConstraint] -> [PolyConstraint])
-> [PolyConstraint] -> [PolyConstraint]
forall a b. (a -> b) -> a -> b
$ 
    [PolyConstraint]
positivityConstraints [PolyConstraint] -> [PolyConstraint] -> [PolyConstraint]
forall a. [a] -> [a] -> [a]
++ [PolyConstraint]
rowMarginsConstraints [PolyConstraint] -> [PolyConstraint] -> [PolyConstraint]
forall a. [a] -> [a] -> [a]
++ [PolyConstraint]
colMarginsConstraints
  where
    m :: Int
m = [Rational] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Rational]
mu
    n :: Int
n = [Rational] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Rational]
nu
    rows :: [Int]
rows = [ Int
1 .. Int
m ]
    cols :: [Int]
cols = [ Int
1 .. Int
n ]
    positivityConstraints :: [PolyConstraint]
positivityConstraints = 
      [ 
        GEQ { 
              $sel:lhs:LEQ :: VarLitMap
lhs = Int -> Rational -> VarLitMap
forall k a. k -> a -> Map k a
singleton (Int -> (Int, Int) -> Int
stack Int
n (Int
i, Int
j)) Rational
1, $sel:rhs:LEQ :: Rational
rhs = Rational
0 
            } 
        | Int
i <- [Int]
rows, Int
j <- [Int]
cols 
      ]
    rowMarginsConstraints :: [PolyConstraint]
rowMarginsConstraints = 
      [ 
        EQ { 
             $sel:lhs:LEQ :: VarLitMap
lhs = [(Int, Rational)] -> VarLitMap
forall k a. Ord k => [(k, a)] -> Map k a
fromList [ (Int -> (Int, Int) -> Int
stack Int
n (Int
i, Int
j), Rational
1) | Int
j <- [Int]
cols ]
           , $sel:rhs:LEQ :: Rational
rhs = [Rational]
mu [Rational] -> Int -> Rational
forall a. HasCallStack => [a] -> Int -> a
!! (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) 
           } 
        | Int
i <- [Int]
rows 
      ]
    colMarginsConstraints :: [PolyConstraint]
colMarginsConstraints = 
      [ 
        EQ { 
             $sel:lhs:LEQ :: VarLitMap
lhs = [(Int, Rational)] -> VarLitMap
forall k a. Ord k => [(k, a)] -> Map k a
fromList [ (Int -> (Int, Int) -> Int
stack Int
n (Int
i, Int
j), Rational
1) | Int
i <- [Int]
rows ]
           , $sel:rhs:LEQ :: Rational
rhs = [Rational]
nu [Rational] -> Int -> Rational
forall a. HasCallStack => [a] -> Int -> a
!! (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) 
           } 
        | Int
j <- [Int]
cols 
      ]