---------------------------------------------------------------------------------
-- |
-- Module      :  Math.LinearEquationSolver
-- Copyright   :  (c) Levent Erkok
-- License     :  BSD3
-- Maintainer  :  erkokl@gmail.com
-- Stability   :  experimental
--
-- (The linear equation solver library is hosted at <http://github.com/LeventErkok/linearEqSolver>.
-- Comments, bug reports, and patches are always welcome.)
--
-- Solvers for linear equations over integers. Both single solution and all
-- solution variants are supported.
---------------------------------------------------------------------------------

module Math.LinearEquationSolver (
       -- * Finding a solution
       solveIntegerLinearEqs
       -- * Finding all solutions
    ,  solveIntegerLinearEqsAll
    ) where

import Data.SBV

-- | Solve a system of linear integer equations. The first argument is
-- the matrix of coefficients, known as @A@, of size @mxn@. The second argument
-- is the vector of results, known as @B@, of size @mx1@. The result will be
-- either `Nothing`, if there is no solution, or @Just x@ -- such that @Ax = B@ holds.
-- (Naturally, the result @x@ will be a vector of size @nx1@ in this case.)
--
-- Here's an example call, to solve the following system of equations:
--
-- @
--     2x + 3y + 4z = 20
--     6x - 3y + 9z = -6
--     2x      +  z = 8
-- @
--
-- >>> solveIntegerLinearEqs [[2,3,4],[6,-3,9],[2,0,1]] [20,-6,8]
-- Just [5,6,-2]
--
-- In case there are no solutions, we will get `Nothing`:
--
-- >>> solveIntegerLinearEqs [[1], [1]] [2,3]
-- Nothing
--
-- Note that there are no solutions to this second system as it stipulates the unknown is
-- equal to both 2 and 3. (Overspecified.)
solveIntegerLinearEqs :: [[Integer]] -> [Integer] -> IO (Maybe [Integer])
solveIntegerLinearEqs coeffs res
  | Just n <- check coeffs res
  = extractModel `fmap` sat (buildConstraints n coeffs res)
  | True
  = error "solveIntegerLinearEqs: Received ill-formed input"

-- | Similar to `solveIntegerLinearEqs`, except returns all possible solutions.
-- Note that there might be an infinite number of solutions if the system
-- is underspecified, in which case the result will be a lazy list of solutions
-- that the caller can consume as much as needed.
--
-- Here's an example call, where we underspecify the system and hence there are
-- multiple (in this case an infinite number of) solutions. Here, we only take the first 3 elements,
-- for testing purposes, but all such results can be computed lazily. Our system is:
--
-- @
--     2x + 3y + 4z = 20
--     6x - 3y + 9z = -6
-- @
--
-- We have:
--
-- >>> take 3 `fmap` solveIntegerLinearEqsAll [[2,3,4],[6,-3,9]] [20,-6]
-- [[5,6,-2],[-8,4,6],[18,8,-10]]
solveIntegerLinearEqsAll :: [[Integer]] -> [Integer] -> IO [[Integer]]
solveIntegerLinearEqsAll coeffs res
  | Just n <- check coeffs res
  = extractModels `fmap` allSat (buildConstraints n coeffs res)
  | True
  = error "solveIntegerLinearEqsAll: Received ill-formed input"

-- | Check that the arguments are well-formed. Returns Just the number of variables needed
-- if the arguments are well formed, otherwise Nothing.
check :: [[Integer]] -> [Integer] -> Maybe Int
check a b
  | m > 0 && not (null ns) && all (== n) ns && m == lb
  = Just n
  | True
  = Nothing
  where m  = length a
        ns = map length a
        n  = head ns
        lb = length b

-- | Build the constraints corresponding to the system given
buildConstraints :: Int -> [[Integer]] -> [Integer] -> Symbolic SBool
buildConstraints n coeffs res = do
        xs <- mkFreeVars n
        let rowEq row r = sum (zipWith (*) xs row) .== r
        solve $ zipWith rowEq (map (map literal) coeffs) (map literal res)