---------------------------------------------------------------------------------
-- |
-- Module      :  Math.LinearEquationSolver
-- Copyright   :  (c) Levent Erkok
-- License     :  BSD3
-- Maintainer  :  erkokl@gmail.com
-- Stability   :  stable
--
-- (The linear equation solver library is hosted at <http://github.com/LeventErkok/linearEqSolver>.
-- Comments, bug reports, and patches are always welcome.)
--
-- Solvers for linear equations over integers and rationals. Both single solution and all
-- solution variants are supported.
---------------------------------------------------------------------------------

{-# LANGUAGE FlexibleContexts #-}

module Math.LinearEquationSolver (
       -- * Available SMT solvers
       -- $solverInfo
       Solver(..)
       -- * Solutions over Integers
    ,  solveIntegerLinearEqs
    ,  solveIntegerLinearEqsAll
       -- * Solutions over Rationals
    ,  solveRationalLinearEqs
    ,  solveRationalLinearEqsAll
    ) where

import Data.SBV

-- | Solve a system of linear integer equations. The first argument is
-- the matrix of coefficients, known as @A@, of size @mxn@. The second argument
-- is the vector of results, known as @b@, of size @mx1@. The result will be
-- either `Nothing`, if there is no solution, or @Just x@ -- such that @Ax = b@ holds.
-- (Naturally, the result @x@ will be a vector of size @nx1@ in this case.)
--
-- Here's an example call, to solve the following system of equations:
--
-- @
--     2x + 3y + 4z = 20
--     6x - 3y + 9z = -6
--     2x      +  z = 8
-- @
--
-- >>> solveIntegerLinearEqs Z3 [[2, 3, 4],[6, -3, 9],[2, 0, 1]] [20, -6, 8]
-- Just [5,6,-2]
--
-- The first argument picks the SMT solver to use. Valid values are 'z3' and
-- 'cvc4'. Naturally, you should have the chosen solver installed on your system.
--
-- In case there are no solutions, we will get `Nothing`:
--
-- >>> solveIntegerLinearEqs Z3 [[1], [1]] [2, 3]
-- Nothing
--
-- Note that there are no solutions to this second system as it stipulates the unknown is
-- equal to both 2 and 3. (Overspecified.)
solveIntegerLinearEqs :: Solver                -- ^ SMT Solver to use
                      -> [[Integer]]           -- ^ Coefficient matrix (A)
                      -> [Integer]             -- ^ Result vector (b)
                      -> IO (Maybe [Integer])  -- ^ A solution to @Ax = b@, if any
solveIntegerLinearEqs :: Solver -> [[Integer]] -> [Integer] -> IO (Maybe [Integer])
solveIntegerLinearEqs Solver
cfg [[Integer]]
coeffs [Integer]
res = SatResult -> Maybe [Integer]
forall a b. (Modelable a, SatModel b) => a -> Maybe b
forall b. SatModel b => SatResult -> Maybe b
extractModel (SatResult -> Maybe [Integer])
-> IO SatResult -> IO (Maybe [Integer])
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` SMTConfig -> Symbolic SBool -> IO SatResult
forall a. Satisfiable a => SMTConfig -> a -> IO SatResult
satWith (Solver -> SMTConfig
defaultSolverConfig Solver
cfg) Symbolic SBool
cs
  where cs :: Symbolic SBool
cs = String -> [[Integer]] -> [Integer] -> Symbolic SBool
forall a.
(Ord a, Num a, Num (SBV a), SymVal a) =>
String -> [[a]] -> [a] -> Symbolic SBool
buildConstraints String
"solveIntegerLinearEqs" [[Integer]]
coeffs [Integer]
res

-- | Similar to `solveIntegerLinearEqs`, except in case the system has an infinite
-- number of solutions, then it will return the number of solutions requested. (Note
-- that if the system is underspecified, then there are an infinite number of
-- solutions.) So, the result can be empty, a singleton, or precisely the number requested, last of
-- which indicates there are an infinite number of solutions.
--
-- Here's an example call, where we underspecify the system and hence there are
-- multiple (in this case an infinite number of) solutions. Here, we ask for the first
-- 3 elements for testing purposes.
--
-- @
--     2x + 3y + 4z = 20
--     6x - 3y + 9z = -6
-- @
--
-- We have:
--
-- >>> solveIntegerLinearEqsAll Z3 3 [[2, 3, 4],[6, -3, 9]] [20, -6]
-- [[-34,0,22],[-21,2,14],[-8,4,6]]
--
-- The solutions you get might differ, depending on what the solver returns. (Though they'll be correct!)
solveIntegerLinearEqsAll :: Solver          -- ^ SMT Solver to use
                          -> Int                -- ^ Maximum number of solutions to return, in case infinite
                         -> [[Integer]]     -- ^ Coefficient matrix (A)
                         -> [Integer]       -- ^ Result vector (b)
                         -> IO [[Integer]]  -- ^ All solutions to @Ax = b@
solveIntegerLinearEqsAll :: Solver -> Int -> [[Integer]] -> [Integer] -> IO [[Integer]]
solveIntegerLinearEqsAll Solver
s Int
maxNo [[Integer]]
coeffs [Integer]
res = AllSatResult -> [[Integer]]
forall a. SatModel a => AllSatResult -> [a]
extractModels (AllSatResult -> [[Integer]]) -> IO AllSatResult -> IO [[Integer]]
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` SMTConfig -> Symbolic SBool -> IO AllSatResult
forall a. Satisfiable a => SMTConfig -> a -> IO AllSatResult
allSatWith SMTConfig
cfg Symbolic SBool
cs
  where cs :: Symbolic SBool
cs = String -> [[Integer]] -> [Integer] -> Symbolic SBool
forall a.
(Ord a, Num a, Num (SBV a), SymVal a) =>
String -> [[a]] -> [a] -> Symbolic SBool
buildConstraints String
"solveIntegerLinearEqsAll" [[Integer]]
coeffs [Integer]
res
        cfg :: SMTConfig
cfg = (Solver -> SMTConfig
defaultSolverConfig Solver
s) {allSatMaxModelCount = Just maxNo}

-- | Solve a system of linear equations over rationals. Same as the integer
-- version `solveIntegerLinearEqs`, except it takes rational coefficients
-- and returns rational results.
--
-- Here's an example call, to solve the following system of equations:
--
-- @
--     2.4x + 3.6y = 12
--     7.2x - 5y   = -8.5
-- @
--
-- >>> solveRationalLinearEqs Z3 [[2.4, 3.6],[7.2, -5]] [12, -8.5]
-- Just [245 % 316,445 % 158]
solveRationalLinearEqs :: Solver                  -- ^ SMT Solver to use
                       -> [[Rational]]            -- ^ Coefficient matrix (A)
                       -> [Rational]              -- ^ Result vector (b)
                       -> IO (Maybe [Rational])   -- ^ A solution to @Ax = b@, if any
solveRationalLinearEqs :: Solver -> [[Rational]] -> [Rational] -> IO (Maybe [Rational])
solveRationalLinearEqs Solver
cfg [[Rational]]
coeffs [Rational]
res = (([AlgReal] -> [Rational]) -> Maybe [AlgReal] -> Maybe [Rational]
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [AlgReal] -> [Rational]
from (Maybe [AlgReal] -> Maybe [Rational])
-> (SatResult -> Maybe [AlgReal]) -> SatResult -> Maybe [Rational]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SatResult -> Maybe [AlgReal]
forall a b. (Modelable a, SatModel b) => a -> Maybe b
forall b. SatModel b => SatResult -> Maybe b
extractModel) (SatResult -> Maybe [Rational])
-> IO SatResult -> IO (Maybe [Rational])
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` SMTConfig -> Symbolic SBool -> IO SatResult
forall a. Satisfiable a => SMTConfig -> a -> IO SatResult
satWith (Solver -> SMTConfig
defaultSolverConfig Solver
cfg) Symbolic SBool
cs
  where to :: [Rational] -> [AlgReal]
to   = (Rational -> AlgReal) -> [Rational] -> [AlgReal]
forall a b. (a -> b) -> [a] -> [b]
map (Rational -> AlgReal
forall a. Fractional a => Rational -> a
fromRational :: Rational -> AlgReal)
        from :: [AlgReal] -> [Rational]
from = (AlgReal -> Rational) -> [AlgReal] -> [Rational]
forall a b. (a -> b) -> [a] -> [b]
map (AlgReal -> Rational
forall a. Real a => a -> Rational
toRational   :: AlgReal -> Rational)
        cs :: Symbolic SBool
cs   = String -> [[AlgReal]] -> [AlgReal] -> Symbolic SBool
forall a.
(Ord a, Num a, Num (SBV a), SymVal a) =>
String -> [[a]] -> [a] -> Symbolic SBool
buildConstraints String
"solveRationalLinearEqs" (([Rational] -> [AlgReal]) -> [[Rational]] -> [[AlgReal]]
forall a b. (a -> b) -> [a] -> [b]
map [Rational] -> [AlgReal]
to [[Rational]]
coeffs) ([Rational] -> [AlgReal]
to [Rational]
res)

-- | Solve a system of linear equations over rationals.  Similar to `solveRationalLinearEqs`,
-- except if the system is underspecified, then returns the number of solutions requested.
--
-- Example system:
--
-- @
--     2.4x + 3.6y = 12
-- @
--
-- In this case, the system has infinitely many solutions. We can compute three of them as follows:
--
-- >>> solveRationalLinearEqsAll Z3 3 [[2.4, 3.6]] [12]
-- [[(-1) % 1,4 % 1],[0 % 1,10 % 3],[5 % 1,0 % 1]]
--
-- The solutions you get might differ, depending on what the solver returns. (Though they'll be correct!)
solveRationalLinearEqsAll :: Solver             -- ^ SMT Solver to use
                          -> Int                -- ^ Maximum number of solutions to return, in case infinite
                          -> [[Rational]]       -- ^ Coefficient matrix (A)
                          -> [Rational]         -- ^ Result vector (b)
                          -> IO [[Rational]]    -- ^ All solutions to @Ax = b@
solveRationalLinearEqsAll :: Solver -> Int -> [[Rational]] -> [Rational] -> IO [[Rational]]
solveRationalLinearEqsAll Solver
s Int
maxNo [[Rational]]
coeffs [Rational]
res = (([AlgReal] -> [Rational]) -> [[AlgReal]] -> [[Rational]]
forall a b. (a -> b) -> [a] -> [b]
map [AlgReal] -> [Rational]
from ([[AlgReal]] -> [[Rational]])
-> (AllSatResult -> [[AlgReal]]) -> AllSatResult -> [[Rational]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllSatResult -> [[AlgReal]]
forall a. SatModel a => AllSatResult -> [a]
extractModels) (AllSatResult -> [[Rational]])
-> IO AllSatResult -> IO [[Rational]]
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` SMTConfig -> Symbolic SBool -> IO AllSatResult
forall a. Satisfiable a => SMTConfig -> a -> IO AllSatResult
allSatWith SMTConfig
cfg Symbolic SBool
cs
  where to :: [Rational] -> [AlgReal]
to   = (Rational -> AlgReal) -> [Rational] -> [AlgReal]
forall a b. (a -> b) -> [a] -> [b]
map (Rational -> AlgReal
forall a. Fractional a => Rational -> a
fromRational :: Rational -> AlgReal)
        from :: [AlgReal] -> [Rational]
from = (AlgReal -> Rational) -> [AlgReal] -> [Rational]
forall a b. (a -> b) -> [a] -> [b]
map (AlgReal -> Rational
forall a. Real a => a -> Rational
toRational   :: AlgReal -> Rational)
        cs :: Symbolic SBool
cs   = String -> [[AlgReal]] -> [AlgReal] -> Symbolic SBool
forall a.
(Ord a, Num a, Num (SBV a), SymVal a) =>
String -> [[a]] -> [a] -> Symbolic SBool
buildConstraints String
"solveRationalLinearEqsAll" (([Rational] -> [AlgReal]) -> [[Rational]] -> [[AlgReal]]
forall a b. (a -> b) -> [a] -> [b]
map [Rational] -> [AlgReal]
to [[Rational]]
coeffs) ([Rational] -> [AlgReal]
to [Rational]
res)
        cfg :: SMTConfig
cfg  = (Solver -> SMTConfig
defaultSolverConfig Solver
s) {allSatMaxModelCount = Just maxNo}

-- | Build the constraints as given by the coefficient matrix and the resulting vector
buildConstraints :: (Ord a, Num a, Num (SBV a), SymVal a) => String -> [[a]] -> [a] -> Symbolic SBool
buildConstraints :: forall a.
(Ord a, Num a, Num (SBV a), SymVal a) =>
String -> [[a]] -> [a] -> Symbolic SBool
buildConstraints String
f [[a]]
coeffs [a]
res
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n) [Int]
ns Bool -> Bool -> Bool
|| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
res
  = String -> Symbolic SBool
forall a. HasCallStack => String -> a
error (String -> Symbolic SBool) -> String -> Symbolic SBool
forall a b. (a -> b) -> a -> b
$ String
f String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": received ill-formed input."
  | Bool
True
  = do xs <- Int -> Symbolic [SBV a]
forall a. SymVal a => Int -> Symbolic [SBV a]
mkFreeVars Int
n
       let rowEq [SBV a]
row SBV a
r = [SBV a] -> SBV a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SBV a -> SBV a -> SBV a) -> [SBV a] -> [SBV a] -> [SBV a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SBV a -> SBV a -> SBV a
forall a. Num a => a -> a -> a
(*) [SBV a]
xs [SBV a]
row) SBV a -> SBV a -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SBV a
r
       solve $ zipWith rowEq (map (map literal) coeffs) (map literal res)
 where m :: Int
m    = [[a]] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[a]]
coeffs
       Int
n:[Int]
ns = ([a] -> Int) -> [[a]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[a]]
coeffs

{- $solverInfo
Note that while we allow all SMT-solvers supported by SBV to be used, not all will work. In particular,
the backend solver will need to understand unbounded integers and rationals. Currently, the following
solvers provide the required capability: 'Z3', 'CVC4', and 'MathSAT'. Passing other instances will result
in an "unsupported" error, though this can of course change as the SBV package itself evolves.
-}