{-# OPTIONS_GHC -Wall #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.FourierMotzkin.Core
-- Copyright   :  (c) Masahiro Sakai 2011-2013
-- License     :  BSD-style
-- 
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  provisional
-- Portability :  portable
--
-- Naïve implementation of Fourier-Motzkin Variable Elimination
-- 
-- Reference:
--
-- * <http://users.cecs.anu.edu.au/~michaeln/pubs/arithmetic-dps.pdf>
--
-----------------------------------------------------------------------------
module Algorithm.FourierMotzkin.Core
    ( ExprZ
    , Rat
    , toRat
    , fromRat

    , Lit (..)
    , fromLAAtom
    , toLAAtom

    , project
    , project'
    , projectN
    , projectN'
    , solve
    , solve'
    ) where

import Control.Monad
import Data.List
import Data.Maybe
import Data.Ratio
import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import Data.VectorSpace hiding (project)

import Algebra.Lattice.Boolean

import Data.ArithRel
import Data.DNF
import qualified Data.LA as LA
import qualified Data.Interval as Interval
import Data.Interval (Interval, EndPoint (..), (<=..<), (<..<=), (<..<))
import Data.Var

-- ---------------------------------------------------------------------------

type ExprZ = LA.Expr Integer

normalizeExprR :: ExprZ -> ExprZ
normalizeExprR e = LA.mapCoeff (`div` d) e
  where d = abs $ gcd' $ map fst $ LA.terms e

-- ---------------------------------------------------------------------------

-- | (t,c) represents t/c, and c must be >0.
type Rat = (ExprZ, Integer)

toRat :: LA.Expr Rational -> Rat
toRat e = seq m $ (LA.mapCoeff f e, m)
  where
    f x = numerator (fromInteger m * x)
    m = foldl' lcm 1 [denominator c | (c,_) <- LA.terms e]

fromRat :: Rat -> LA.Expr Rational
fromRat (e,c) = LA.mapCoeff (% c) e

evalRat :: Model Rational -> Rat -> Rational
evalRat model (e, d) = LA.lift1 1 (model IM.!) (LA.mapCoeff fromIntegral e) / (fromIntegral d)

-- ---------------------------------------------------------------------------

-- | Literal
data Lit = Nonneg ExprZ | Pos ExprZ deriving (Show, Eq, Ord)

instance Variables Lit where
  vars (Pos t) = vars t
  vars (Nonneg t) = vars t

instance Complement Lit where
  notB (Pos t) = Nonneg (negateV t)
  notB (Nonneg t) = Pos (negateV t)

-- 制約集合の単純化
-- It returns Nothing when a inconsistency is detected.
simplify :: [Lit] -> Maybe [Lit]
simplify = fmap concat . mapM f
  where
    f :: Lit -> Maybe [Lit]
    f lit@(Pos e) =
      case LA.asConst e of
        Just x -> guard (x > 0) >> return []
        Nothing -> return [lit]
    f lit@(Nonneg e) =
      case LA.asConst e of
        Just x -> guard (x >= 0) >> return []
        Nothing -> return [lit]

-- ---------------------------------------------------------------------------

fromLAAtom :: LA.Atom Rational -> DNF Lit
fromLAAtom (Rel a op b) = atomR' op (toRat a) (toRat b)

toLAAtom :: Lit -> LA.Atom Rational
toLAAtom (Nonneg e) = LA.mapCoeff fromInteger e .>=. LA.constant 0
toLAAtom (Pos e)    = LA.mapCoeff fromInteger e .>. LA.constant 0

constraintsToDNF :: [LA.Atom Rational] -> DNF Lit
constraintsToDNF = andB . map fromLAAtom

atomR' :: RelOp -> Rat -> Rat -> DNF Lit
atomR' op a b = 
  case op of
    Le -> DNF [[a `leR` b]]
    Lt -> DNF [[a `ltR` b]]
    Ge -> DNF [[a `geR` b]]
    Gt -> DNF [[a `gtR` b]]
    Eql -> DNF [[a `leR` b, a `geR` b]]
    NEq -> DNF [[a `ltR` b], [a `gtR` b]]

leR, ltR, geR, gtR :: Rat -> Rat -> Lit
leR (e1,c) (e2,d) = Nonneg $ normalizeExprR $ c *^ e2 ^-^ d *^ e1
ltR (e1,c) (e2,d) = Pos $ normalizeExprR $ c *^ e2 ^-^ d *^ e1
geR = flip leR
gtR = flip gtR

-- ---------------------------------------------------------------------------

{-
(ls1,ls2,us1,us2) represents
{ x | ∀(M,c)∈ls1. M/c≤x, ∀(M,c)∈ls2. M/c<x, ∀(M,c)∈us1. x≤M/c, ∀(M,c)∈us2. x<M/c }
-}
type BoundsR = ([Rat], [Rat], [Rat], [Rat])

project :: Var -> [LA.Atom Rational] -> [([LA.Atom Rational], Model Rational -> Model Rational)]
project v xs = do
  ys <- unDNF $ constraintsToDNF xs
  (zs, mt) <- project' v ys
  return (map toLAAtom zs, mt)

project' :: Var -> [Lit] -> [([Lit], Model Rational -> Model Rational)]
project' v xs = do
  case collectBounds v xs of
    (bnd, rest) -> do
      cond <- unDNF $ boundConditions bnd
      let mt m =
           case Interval.pickup (evalBounds m bnd) of
             Nothing  -> error "FourierMotzkin.project: should not happen"
             Just val -> IM.insert v val m
      return (rest ++ cond, mt)

projectN :: VarSet -> [LA.Atom Rational] -> [([LA.Atom Rational], Model Rational -> Model Rational)]
projectN vs xs = do
  ys <- unDNF $ constraintsToDNF xs
  (zs, mt) <- projectN' vs ys
  return (map toLAAtom zs, mt)

projectN' :: VarSet -> [Lit] -> [([Lit], Model Rational -> Model Rational)]
projectN' vs2 xs2 = do
  (zs, mt) <- f (IS.toList vs2) xs2
  return (zs, mt)
  where
    f [] xs     = return (xs, id)
    f (v:vs) xs = do
      (ys, mt1) <- project' v xs
      (zs, mt2) <- f vs ys
      return (zs, mt1 . mt2)

collectBounds :: Var -> [Lit] -> (BoundsR, [Lit])
collectBounds v = foldr phi (([],[],[],[]),[])
  where
    phi :: Lit -> (BoundsR, [Lit]) -> (BoundsR, [Lit])
    phi lit@(Nonneg t) x = f False lit t x
    phi lit@(Pos t) x = f True lit t x

    f :: Bool -> Lit -> ExprZ -> (BoundsR, [Lit]) -> (BoundsR, [Lit])
    f strict lit t (bnd@(ls1,ls2,us1,us2), xs) =
      case LA.extract v t of
        (c,t') ->
          case c `compare` 0 of
            EQ -> (bnd, lit : xs)
            GT ->
              if strict
              then ((ls1, (negateV t', c) : ls2, us1, us2), xs) -- 0 < cx + M ⇔ -M/c <  x
              else (((negateV t', c) : ls1, ls2, us1, us2), xs) -- 0 ≤ cx + M ⇔ -M/c ≤ x
            LT ->
              if strict
              then ((ls1, ls2, us1, (t', negate c) : us2), xs) -- 0 < cx + M ⇔ x < M/-c
              else ((ls1, ls2, (t', negate c) : us1, us2), xs) -- 0 ≤ cx + M ⇔ x ≤ M/-c

boundConditions :: BoundsR -> DNF Lit
boundConditions  (ls1, ls2, us1, us2) = DNF $ maybeToList $ simplify $ 
  [ x `leR` y | x <- ls1, y <- us1 ] ++
  [ x `ltR` y | x <- ls1, y <- us2 ] ++ 
  [ x `ltR` y | x <- ls2, y <- us1 ] ++
  [ x `ltR` y | x <- ls2, y <- us2 ]

solve :: VarSet -> [LA.Atom Rational] -> Maybe (Model Rational)
solve vs cs = msum [solve' vs cs2 | cs2 <- unDNF (constraintsToDNF cs)]

solve' :: VarSet -> [Lit] -> Maybe (Model Rational)
solve' vs cs = listToMaybe $ do
  (ys,mt) <- projectN' vs =<< maybeToList (simplify cs)
  guard $ Just [] == simplify ys
  return $ mt IM.empty

evalBounds :: Model Rational -> BoundsR -> Interval Rational
evalBounds model (ls1,ls2,us1,us2) =
  foldl' Interval.intersection Interval.whole $ 
    [ Finite (evalRat model x) <=..< PosInf | x <- ls1 ] ++
    [ Finite (evalRat model x) <..<  PosInf | x <- ls2 ] ++
    [ NegInf <..<= Finite (evalRat model x) | x <- us1 ] ++
    [ NegInf <..<  Finite (evalRat model x) | x <- us2 ]

-- ---------------------------------------------------------------------------

gcd' :: [Integer] -> Integer
gcd' [] = 1
gcd' xs = foldl1' gcd xs

-- ---------------------------------------------------------------------------