module Referees.Solver.Internal where
import Referees.Solver.Types.Internal
    ( Bounds(Bounds),
      ProfitFunction,
      ProfitMatrix,
      Col,
      Row,
      Index(Index, _idx),
      Copies,
      Capacity )
import Control.Exception ( SomeException, handle )
import Control.Monad ( guard )
import Control.Monad.LPMonad
    ( constrain,
      equalTo,
      execLPM,
      leqTo,
      setDirection,
      setObjective,
      setVarKind )
import Data.LinearProgram
    ( mipDefaults,
      ReturnCode,
      GLPOpts(msgLev),
      LP,
      VarKind(BinVar),
      Direction(Max),
      MsgLev(MsgAll, MsgOff),
      LinFunc,
      glpSolveVars,
      var,
      (*&),
      gsum )
import qualified Data.LinearProgram as LP ( Bounds(Bound) )
import Data.Map as Map ( Map, toList )
import Data.Matrix ( getElem, matrix, ncols, nrows )
import System.Exit ( exitFailure )
lpGAP :: ProfitMatrix -> [Capacity] -> Bounds Copies -> LP String Double
lpGAP profitM cap (Bounds minCops maxCops) = execLPM $ do
  setDirection Max
  setObjective $ objFun profitM
  let (m, n) =
        (Index $ nrows profitM, Index $ ncols profitM) :: (Row, Col)
      setCap j =
        subFunCap profitM j `leqTo` fromIntegral (cap !! (_idx j  1))
      setMult i =
        if minCops == maxCops
           then
             subFunMult profitM i `equalTo` fromIntegral maxCops
           else
             let bounds = LP.Bound (fromIntegral minCops) (fromIntegral maxCops)
             in subFunMult profitM i `constrain` bounds
      setVK (i, j) =
        setVarKind (x i j) BinVar
      setOff (i, j) =
        var (x i j) `equalTo` 0
  mapM_ setCap [1 .. n]
  mapM_ setMult [1 .. m]
  mapM_ setVK
    $ [(i, j) | i <- [1 .. m]
              , j <- [1 .. n]]
  mapM_ setOff
    $ [(i, j) | i <- [1 .. m]
              , j <- [1 .. n]
              , profitM `safeGetElem` (i, j) == 0]
x :: Row -> Col -> String
x i j = show (i, j)
objFun :: ProfitMatrix -> LinFunc String Double
objFun pM = gsum $ do
  let (m, n) = (Index $ nrows pM, Index $ ncols pM) :: (Row, Col)
  i <- [1 .. m]
  j <- [1 .. n]
  let p = pM `safeGetElem` (i, j)
  return $ p *& x i j
subFunCap :: ProfitMatrix -> Col -> LinFunc String Double
subFunCap pM j = gsum $ do
  let m = Index $ nrows pM :: Row
  i <- [1 .. m]
  return $ 1.0 *& x i j
subFunMult :: ProfitMatrix -> Row -> LinFunc String Double
subFunMult pM i = gsum $ do
  let n = Index $ ncols pM :: Col
  j <- [1 .. n]
  return $ var $ x i j
mkProfitMatrix :: ProfitFunction a b c
               -> [a] 
               -> [b] 
               -> Maybe c 
               -> ProfitMatrix
mkProfitMatrix f items bins quality =
  safeMatrix itemsNumber binsNumber profitFnMatrixWrapper
    where
      itemsNumber = Index $ length items :: Row
      binsNumber = Index $ length bins :: Col
      profitFnMatrixWrapper (i, b) = f i' b' quality
        where 
          i' = items !! _idx (i  1)
          b' = bins !! _idx (b  1)
safeMatrix :: Row
           -> Col
           -> ((Row, Col) -> Double) 
                                     
           -> ProfitMatrix
safeMatrix r c f = matrix (_idx r) (_idx c) f'
  where
    f' :: (Int, Int) -> Double
    f' (i, j) = f ((Index i, Index j) :: (Row, Col))
safeGetElem :: ProfitMatrix -> (Row, Col) -> Double
safeGetElem pM (r, c) = getElem (_idx r) (_idx c) pM
run_lpGAP :: ProfitMatrix -> [Capacity] -> Bounds Copies
          -> IO (ReturnCode, Maybe (Double, Map String Double))
run_lpGAP profitM cap bnds =
  handle fuse $
#ifdef DEBUG
    glpSolveVars mipDefaults { msgLev = MsgAll }
#else
    glpSolveVars mipDefaults { msgLev = MsgOff }
#endif
      $ lpGAP profitM cap bnds
  where
    fuse e = flip const (e :: SomeException) $ do
      putStrLn $
        "Fatal error: unknown exception in solver. If this was"
        ++ "unexpected, please report the issue.\n"
        ++ "Error message: " ++ show e
      exitFailure
fromGLPKtoList :: (ReturnCode, Maybe (Double, Map String Double))
               -> Maybe [(Col, Row)]
fromGLPKtoList (_, output) =
  case output of
    Just (_, output') ->
      Just $ do
        (s, v) <- Map.toList output'
        guard $ v == 1.0
        return $ fixIndices . read $ s :: [(Col, Row)]
    Nothing ->
      Just []
  where 
    fixIndices (i, j) = (j  1, i  1)