```{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wall #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.LPSolver
-- Copyright   :  (c) Masahiro Sakai 2011
--
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  non-portable (ScopedTypeVariables)
--
-- Naïve implementation of Simplex method
--
-- Reference:
--
-- * <http://www.math.cuhk.edu.hk/~wei/lpch3.pdf>
--
-----------------------------------------------------------------------------

module Algorithm.LPSolver where

import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import Data.OptDir
import Data.VectorSpace

import Data.ArithRel
import qualified Data.LA as LA
import qualified Data.Interval as Interval
import Data.Var
import qualified Algorithm.Simplex as Simplex
import qualified Algorithm.BoundsInference as BI

-- ---------------------------------------------------------------------------

type Solver r = (Var, Simplex.Tableau r, VarSet, VarMap (LA.Expr r))
type LP r = State (Solver r)

emptySolver :: VarSet -> Solver r
emptySolver vs = (1 + maximum ((-1) : IS.toList vs), IM.empty, IS.empty, IM.empty)

gensym :: LP r Var
gensym = do
(x,tbl,avs,defs) <- get
put (x+1,tbl,avs,defs)
return x

getTableau :: LP r (Simplex.Tableau r)
getTableau = do
(_,tbl,_,_) <- get
return tbl

putTableau :: Simplex.Tableau r -> LP r ()
putTableau tbl = do
(x,_,avs,defs) <- get
put (x,tbl,avs,defs)

addArtificialVariable :: Var -> LP r ()
(x,tbl,avs,defs) <- get
put (x, tbl, IS.insert v avs, defs)

getArtificialVariables :: LP r VarSet
getArtificialVariables = do
(_,_,avs,_) <- get
return avs

clearArtificialVariables :: LP r ()
clearArtificialVariables = do
(x,tbl,_,defs) <- get
put (x, tbl, IS.empty, defs)

define :: Var -> LA.Expr r -> LP r ()
define v e = do
(x,tbl,avs,defs) <- get
put (x,tbl,avs, IM.insert v e defs)

getDefs :: LP r (VarMap (LA.Expr r))
getDefs = do
(_,_,_,defs) <- get
return defs

-- ---------------------------------------------------------------------------

addConstraint :: Real r => LA.Atom r -> LP r ()
c2 <- expandDefs' c
let (e, rop, b) = normalizeConstraint c2
tbl <- getTableau
case rop of
-- x≥b で b≤0 なら追加しない。ad hoc なので一般化したい。
Ge | isSingleVar e && b<=0 -> return ()

Le -> do
v <- gensym -- slack variable
putTableau \$ Simplex.setRow v tbl (LA.coeffMap e, b)
Ge -> do
v1 <- gensym -- surplus variable
v2 <- gensym -- artificial variable
putTableau \$ Simplex.setRow v2 tbl (LA.coeffMap (e ^-^ LA.var v1), b)
Eql -> do
v <- gensym -- artificial variable
putTableau \$ Simplex.setRow v tbl (LA.coeffMap e, b)
_ -> error \$ "addConstraint does not support " ++ show rop
where
isSingleVar e =
case LA.terms e of
[(1,_)] -> True
_ -> False

addConstraint2 :: Real r => LA.Atom r -> LP r ()
Rel lhs rop rhs <- expandDefs' c
let
(b', e) = LA.extract LA.unitVar (lhs ^-^ rhs)
b = - b'
case rop of
Le -> f e b
Ge -> f (negateV e) (negate b)
Eql -> do
f e b
f (negateV e) (negate b)
_ -> error \$ "addConstraint does not support " ++ show rop
where
-- -x≤b で -b≤0 なら追加しない。ad hoc なので一般化したい。
f e b  | isSingleNegatedVar e && -b <= 0 = return ()
f e b = do
tbl <- getTableau
v <- gensym -- slack variable
putTableau \$ Simplex.setRow v tbl (LA.coeffMap e, b)

isSingleNegatedVar e =
case LA.terms e of
[(-1,_)] -> True
_ -> False

expandDefs :: (Num r, Eq r) => LA.Expr r -> LP r (LA.Expr r)
expandDefs e = do
defs <- getDefs
return \$ LA.applySubst defs e

expandDefs' :: (Num r, Eq r) => LA.Atom r -> LP r (LA.Atom r)
expandDefs' (Rel lhs op rhs) = do
lhs' <- expandDefs lhs
rhs' <- expandDefs rhs
return \$ Rel lhs' op rhs'

tableau :: (RealFrac r) => [LA.Atom r] -> LP r ()
tableau cs = do
let (nonnegVars, cs') = collectNonnegVars cs IS.empty
fvs = vars cs `IS.difference` nonnegVars
forM_ (IS.toList fvs) \$ \v -> do
v1 <- gensym
v2 <- gensym
define v (LA.var v1 ^-^ LA.var v2)

getModel :: Fractional r => VarSet -> LP r (Model r)
getModel vs = do
tbl <- getTableau
defs <- getDefs
let bs = IM.map snd (IM.delete Simplex.objRow tbl)
vs' = vs `IS.union` IS.fromList [v | e <- IM.elems defs, v <- IS.toList (vars e)]
-- Note that IM.union is left-biased
m0 = bs `IM.union` IM.fromList [(v,0) | v <- IS.toList vs']
return \$ IM.filterWithKey (\k _ -> k `IS.member` vs) \$ IM.map (LA.evalExpr m0) defs `IM.union` m0

phaseI :: (Fractional r, Real r) => LP r Bool
phaseI = do
tbl <- getTableau
avs <- getArtificialVariables
let (ret, tbl') = Simplex.phaseI tbl avs
putTableau tbl'
when ret clearArtificialVariables
return ret

simplex :: (Fractional r, Real r) => OptDir -> LA.Expr r -> LP r Bool
simplex optdir obj = do
tbl <- getTableau
defs <- getDefs
let (ret, tbl') = Simplex.simplex optdir (Simplex.setObjFun tbl (LA.applySubst defs obj))
putTableau tbl'
return ret

dualSimplex :: (Fractional r, Real r) => OptDir -> LA.Expr r -> LP r Bool
dualSimplex optdir obj = do
tbl <- getTableau
defs <- getDefs
let (ret, tbl') = Simplex.dualSimplex optdir (Simplex.setObjFun tbl (LA.applySubst defs obj))
putTableau tbl'
return ret

-- ---------------------------------------------------------------------------

normalizeConstraint :: forall r. Real r => LA.Atom r -> (LA.Expr r, RelOp, r)
normalizeConstraint (Rel a op b)
| rhs < 0   = (negateV lhs, flipOp op, negate rhs)
| otherwise = (lhs, op, rhs)
where
(c, lhs) = LA.extract LA.unitVar (a ^-^ b)
rhs = - c

collectNonnegVars :: forall r. (RealFrac r) => [LA.Atom r] -> VarSet -> (VarSet, [LA.Atom r])
collectNonnegVars cs ivs = (nonnegVars, cs)
where
vs = vars cs
bounds = BI.inferBounds initialBounds cs ivs 1000
where
initialBounds = IM.fromList [(v, Interval.whole) | v <- IS.toList vs]
nonnegVars = IS.filter f vs
where
f v = case Interval.lowerBound (bounds IM.! v) of
Interval.Finite lb | 0 <= lb -> True
_ -> False

isTriviallyTrue :: LA.Atom r -> Bool
isTriviallyTrue (Rel a op b) =
case op of
Le ->
case ub of
Interval.PosInf -> False
Interval.Finite val -> val <= 0
Interval.NegInf -> True -- should not happen
Ge ->
case lb of
Interval.NegInf -> False
Interval.Finite val -> val >= 0
Interval.PosInf -> True -- should not happen
Lt ->
case ub of
Interval.PosInf -> False
Interval.Finite val -> val < 0 || (not inUB && val <= 0)
Interval.NegInf -> True -- should not happen
Gt ->
case lb of
Interval.NegInf -> False
Interval.Finite val -> val > 0 || (not inLB && val >= 0)
Interval.PosInf -> True -- should not happen
Eql -> isTriviallyTrue (c .<=. zeroV) && isTriviallyTrue (c .>=. zeroV)
NEq -> isTriviallyTrue (c .<. zeroV) || isTriviallyTrue (c .>. zeroV)
where
c = a ^-^ b
i = LA.computeInterval bounds c
(lb, inLB) = Interval.lowerBound' i
(ub, inUB) = Interval.upperBound' i

-- ---------------------------------------------------------------------------
```