{-# LANGUAGE ScopedTypeVariables, BangPatterns #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.BoundsInference
-- Copyright   :  (c) Masahiro Sakai 2011
-- License     :  BSD-style
-- 
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  non-portable (ScopedTypeVariables, BangPatterns)
--
-- Tightening variable bounds by constraint propagation.
-- 
-----------------------------------------------------------------------------
module Algorithm.BoundsInference
  ( BoundsEnv
  , inferBounds
  , LA.computeInterval
  ) where

import Control.Monad
import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import Data.VectorSpace

import Data.ArithRel
import Data.Interval
import Data.LA (BoundsEnv)
import qualified Data.LA as LA
import Data.Var
import Util (isInteger)

type C r = (RelOp, LA.Expr r)

-- | tightening variable bounds by constraint propagation.
inferBounds :: forall r. (RealFrac r)
  => LA.BoundsEnv r -- ^ initial bounds
  -> [LA.Atom r]    -- ^ constraints
  -> VarSet         -- ^ integral variables
  -> Int            -- ^ limit of iterations
  -> LA.BoundsEnv r
inferBounds bounds constraints ivs limit = loop 0 bounds
  where
    cs :: VarMap [C r]
    cs = IM.fromListWith (++) $ do
      Rel lhs op rhs <- constraints
      let m = LA.coeffMap (lhs ^-^ rhs)
      (v,c) <- IM.toList m
      guard $ v /= LA.unitVar
      let op' = if c < 0 then flipOp op else op
          rhs' = (-1/c) *^ LA.fromCoeffMap (IM.delete v m)
      return (v, [(op', rhs')])

    loop  :: Int -> LA.BoundsEnv r -> LA.BoundsEnv r
    loop !i b = if (limit>=0 && i>=limit) || b==b' then b else loop (i+1) b'
      where
        b' = refine b

    refine :: LA.BoundsEnv r -> LA.BoundsEnv r
    refine b = IM.mapWithKey (\v i -> tighten v $ f b (IM.findWithDefault [] v cs) i) b

    -- tighten bounds of integer variables
    tighten :: Var -> Interval r -> Interval r
    tighten v x =
      if v `IS.notMember` ivs
        then x
        else tightenToInteger x

f :: (Real r, Fractional r) => LA.BoundsEnv r -> [C r] -> Interval r -> Interval r
f b cs i = foldr intersection i $ do
  (op, rhs) <- cs
  let i' = LA.computeInterval b rhs
      lb = lowerBound' i'
      ub = upperBound' i'
  case op of
    Eql -> return i'
    Le -> return $ interval (NegInf, False) ub
    Ge -> return $ interval lb (PosInf, False)
    Lt -> return $ interval (NegInf, False) (strict ub)
    Gt -> return $ interval (strict ub) (PosInf, False)
    NEq -> []

strict :: (EndPoint r, Bool) -> (EndPoint r, Bool)
strict (x, _) = (x, False)

-- | tightening intervals by ceiling lower bounds and flooring upper bounds.
tightenToInteger :: forall r. (RealFrac r) => Interval r -> Interval r
tightenToInteger ival = interval lb2 ub2
  where
    lb@(x1, in1) = lowerBound' ival
    ub@(x2, in2) = upperBound' ival
    lb2 =
      case x1 of
        Finite x ->
          ( if isInteger x && not in1
            then Finite (x + 1)
            else Finite (fromInteger (ceiling x))
          , True
          )
        _ -> lb
    ub2 =
      case x2 of
        Finite x ->
          ( if isInteger x && not in2
            then Finite (x - 1)
            else Finite (fromInteger (floor x))
          , True
          )
        _ -> ub