-- Linear Diaphantine Equation solver
--
-- Copyright (c) 2009 The MITRE Corporation
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- (at your option) any later version.

-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.

-- You should have received a copy of the GNU General Public License
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.

-- |
-- Module      : Algebra.CommutativeMonoid.LinDiaphEq
-- Copyright   : (C) 2009 John D. Ramsdell
-- License     : GPL
--
-- Linear Diaphantine Equation solver.
--
-- The solver uses the algorithm of Contejean and Devie as specified
-- in \"An Efficient Incremental Algorithm for Solving Systems of
-- Linear Diophantine Equations\", Information and Computation
-- Vol. 113, pp. 143-174, 1994 after a modification explained below.
--
-- The algorithm for systems of homogeneous linear Diophantine
-- equations follows.  Let e[k] be the kth basis vector for 1 <= k <=
-- n.  To find the minimal, non-negative solutions M to the system of
-- equations sum(i=1,n,a[i]*v[i]) = 0, the modified algorithm of
-- Contejean and Devie is:
--
--  1. [init] A := {e[k] | 1 <= k <= n}; M := {}
--
--  2. [new minimal results] M := M + {a in A | a is a solution}
--
--  3. [breadth-first search] A := {a + e[k] | a in A, 1 <= k <= n,
-- \<sum(i=1,n,a[i]*v[i]),v[k]> \< 0}
--
--  4. [unnecessary branches] A := {a in A | all m in M : some
--     1 <= k <= n : m[k] < a[k]}
--
--  5. [test] If A = {}, stop, else go to 2.
--
-- The original algorithm reversed steps 3 and 4.
--
-- This module provides a solver for a single linear Diophantine
-- equation a*v = b, where a and v are vectors, not matrices.
-- Conceptually, it uses the homogeneous solver after appending -b as
-- the last element of v and by appending 1 to a at each step in the
-- computation.  The extra 1 is omitted when an answer is produced.
--
-- Steps 3 and 4 were switched because the use of the original
-- algorithm for the problem 2x + y - z = 2 produces a non-minimal
-- solution.  linDiaphEq [2,1,-1] 2 = [[1,0,0],[0,2,0]], but the
-- original algorithm produces [[1,0,0],[0,2,0],[1,1,1]].
--
-- The algorithm is likely to be Fortenbacher's algorithm, the one
-- generalized to systems of equations by Contejean and Devie, but I
-- have not been able to verified this fact.  I learned how to extend
-- Contejean and Devie's results to an inhomogeneous equation by
-- reading \"Effective Solutions of Linear Diophantine Equation
-- Systems with an Application to Chemistry\" by David Papp and Bela
-- Vizari, Rutcor Research Report RRR 28-2004, September, 2004,
-- <http://rutcor.rutgers.edu/pub/rrr/reports2004/28_2004.ps>.
--
-- The example that shows a problem with the original algorithm
-- follows.  For the problem linDiaphEq [2,1,-1] 2, the value of a and
-- m at the beginning of the loop is:
--
-- @
--                    a                                 m
--    [[0, 0, 1], [0, 1, 0], [1, 0, 0]]       []
--    [[0, 1, 1], [0, 2, 0]]                  [[1, 0, 0]]
--    []                                      [[1, 0, 0], [0, 2, 0]]
-- @
--
-- Consider [0, 1, 1] in a.  If you remove unnecessary branches first,
-- the element will stay in a.  After performing breadth-first search,
-- a will contain [1, 1, 1], which is the unwanted, non-minimal
-- solution.

module Algebra.CommutativeMonoid.LinDiaphEq (linDiaphEq) where

import Data.Array
import Data.Set (Set)
import qualified Data.Set as S

{-- Debugging hack
import System.IO.Unsafe

z :: Show a => a -> b -> b
z x y = seq (unsafePerformIO (print x)) y

zz :: Show a => a -> a
zz x = z x x

pr :: Set (Vector Int) -> [[Int]]
pr s = map elems $ S.toList s

zzz :: Set (Vector Int) -> Set (Vector Int)
zzz s = z (pr s) s
--}

type Vector a = Array Int a

vector :: Int -> [a] -> Vector a
vector n elems =
    listArray (0, n - 1) elems

-- | The 'linDiaphEq' function takes a list of integers that specifies
-- the coefficients of linear Diophantine equation and a constant,
-- and returns the equation's minimal, non-negative solutions.  When
-- solving an inhomogeneous equation, solve the related homogeneous
-- equation and add in those solutions.
linDiaphEq :: [Int] -> Int -> [[Int]]
linDiaphEq [] _ = []
linDiaphEq v c =
    newMinimalResults (vector n v) c (basis n) S.empty
    where n = length v

-- Construct the basis vectors for an n-dimensional space
basis :: Int -> Set (Vector Int)
basis n =
    S.fromList [ z // [(k, 1)] | k <- indices z ]
    where z = vector n $ replicate n 0

-- This is the main loop.

-- Add elements of a that solve the equation to m and the output
newMinimalResults :: Vector Int -> Int -> Set (Vector Int) ->
                     Set (Vector Int) -> [[Int]]
newMinimalResults _ _ a _ | S.null a = []
newMinimalResults v c a m =
    loop m (S.toList a)         -- Test each element in a
    where
      loop m [] =               -- When done, prepare for next iteration
          let a' = breadthFirstSearch v c a     -- Step 3
              a'' = unnecessaryBranches a' m in -- Step 4
-- The original algorithm reverses these two steps.
--          let a' = unnecessaryBranches a m
--              a'' = breadthFirstSearch v c a' in
          newMinimalResults v c a'' m
      loop m (x:xs)
           | prod v x == c && S.notMember x m =
               elems x:loop (S.insert x m) xs -- Answer found
           | otherwise =
               loop m xs

-- Breadth-first search using the algorithm of Contejean and Devie
breadthFirstSearch :: Vector Int -> Int -> Set (Vector Int) -> Set (Vector Int)
breadthFirstSearch v c a =
    S.fold f S.empty a
    where
      f x acc =
          foldl (flip S.insert) acc
            [ x // [(k, x!k + 1)] |
              k <- indices x,
              (prod v x - c) * v!k < 0 ] -- Fortenbacher contribution

-- Inner product
prod :: Vector Int -> Vector Int -> Int
prod x y =
    sum [ x!i * y!i | i <- indices x ]

-- Remove unnecessary branches.  A test vector is not necessary if all
-- of its elements are greater than or equal to the elements of some
-- minimal solution.
unnecessaryBranches :: Set (Vector Int) -> Set (Vector Int) -> Set (Vector Int)
unnecessaryBranches a m =
    S.filter f a
    where
      f x = all (g x) (S.toList m)
      g x y = not (lessEq y x)

-- Compare vectors element-wise.
lessEq :: Vector Int -> Vector Int -> Bool
lessEq x y =
    all (\i-> x!i <= y!i) (indices x)