-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.Existentials.Diophantine
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Finding minimal natural number solutions to linear Diophantine equations,
-- using explicit quantification.
-----------------------------------------------------------------------------

{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.Existentials.Diophantine where

import Data.List (intercalate, transpose)

import Data.SBV
import Data.Proxy

import GHC.TypeLits

--------------------------------------------------------------------------------------------------
-- * Representing solutions
--------------------------------------------------------------------------------------------------
-- | For a homogeneous problem, the solution is any linear combination of the resulting vectors.
-- For a non-homogeneous problem, the solution is any linear combination of the vectors in the
-- second component plus one of the vectors in the first component.
data Solution = Homogeneous    [[Integer]]
              | NonHomogeneous [[Integer]] [[Integer]]

instance Show Solution where
  show :: Solution -> String
show Solution
s = case Solution
s of
             Homogeneous        [[Integer]]
xss -> [String] -> [(Bool, [Integer])] -> String
forall {a}.
(Eq a, Num a, Show a) =>
[String] -> [(Bool, [a])] -> String
comb [String]
supplyH ([Bool] -> [[Integer]] -> [(Bool, [Integer])]
forall a b. [a] -> [b] -> [(a, b)]
zip (Bool -> [Bool]
forall a. a -> [a]
repeat Bool
False) [[Integer]]
xss)
             NonHomogeneous [[Integer]]
css [[Integer]]
xss -> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"\n" [[String] -> [(Bool, [Integer])] -> String
forall {a}.
(Eq a, Num a, Show a) =>
[String] -> [(Bool, [a])] -> String
comb [String]
supplyNH ((Bool
True, [Integer]
cs) (Bool, [Integer]) -> [(Bool, [Integer])] -> [(Bool, [Integer])]
forall a. a -> [a] -> [a]
: [Bool] -> [[Integer]] -> [(Bool, [Integer])]
forall a b. [a] -> [b] -> [(a, b)]
zip (Bool -> [Bool]
forall a. a -> [a]
repeat Bool
False) [[Integer]]
xss) | [Integer]
cs <- [[Integer]]
css]
    where supplyH :: [String]
supplyH  = [Char
'k' Char -> ShowS
forall a. a -> [a] -> [a]
: Int -> Char -> String
forall a. Int -> a -> [a]
replicate Int
i Char
'\'' | Int
i <- [Int
0 ..]]
          supplyNH :: [String]
supplyNH = String
"" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
supplyH

          comb :: [String] -> [(Bool, [a])] -> String
comb [String]
supply [(Bool, [a])]
xss = [String] -> String
vec ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ ([String] -> String) -> [[String]] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map [String] -> String
add ([[String]] -> [[String]]
forall a. [[a]] -> [[a]]
transpose ((String -> (Bool, [a]) -> [String])
-> [String] -> [(Bool, [a])] -> [[String]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith String -> (Bool, [a]) -> [String]
forall {a}.
(Eq a, Num a, Show a) =>
String -> (Bool, [a]) -> [String]
muls [String]
supply [(Bool, [a])]
xss))
            where muls :: String -> (Bool, [a]) -> [String]
muls String
x (Bool
isConst, [a]
cs) = (a -> String) -> [a] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map a -> String
forall {a}. (Eq a, Num a, Show a) => a -> String
mul [a]
cs
                    where mul :: a -> String
mul a
0 = String
"0"
                          mul a
1 | Bool
isConst = String
"1"
                                | Bool
True    = String
x
                          mul a
k | Bool
isConst = a -> String
forall a. Show a => a -> String
show a
k
                                | Bool
True    = a -> String
forall a. Show a => a -> String
show a
k String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
x

                  add :: [String] -> String
add [] = String
"0"
                  add [String]
xs = (String -> ShowS) -> [String] -> String
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 String -> ShowS
plus [String]
xs

                  plus :: String -> ShowS
plus String
"0" String
y   = String
y
                  plus String
x   String
"0" = String
x
                  plus String
x   String
y   = String
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"+" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
y

          vec :: [String] -> String
vec [String]
xs = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " [String]
xs String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

--------------------------------------------------------------------------------------------------
-- * Solving diophantine equations
--------------------------------------------------------------------------------------------------
-- | ldn: Solve a (L)inear (D)iophantine equation, returning minimal solutions over (N)aturals.
-- The input is given as a rows of equations, with rhs values separated into a tuple. The first
-- argument must be a proxy of a natural, must be total number of columns in the system. (i.e.,
-- #of variables + 1). The second parameter limits the search to bound: In case there are
-- too many solutions, you might want to limit your search space.
ldn :: forall proxy n. KnownNat n => proxy n -> Maybe Int -> [([Integer], Integer)] -> IO Solution
ldn :: forall (proxy :: Nat -> *) (n :: Nat).
KnownNat n =>
proxy n -> Maybe Int -> [([Integer], Integer)] -> IO Solution
ldn proxy n
pn Maybe Int
mbLim [([Integer], Integer)]
problem = do [[Integer]]
solution <- proxy n -> Maybe Int -> [[SInteger]] -> IO [[Integer]]
forall (proxy :: Nat -> *) (n :: Nat).
KnownNat n =>
proxy n -> Maybe Int -> [[SInteger]] -> IO [[Integer]]
basis proxy n
pn Maybe Int
mbLim (([Integer] -> [SInteger]) -> [[Integer]] -> [[SInteger]]
forall a b. (a -> b) -> [a] -> [b]
map ((Integer -> SInteger) -> [Integer] -> [SInteger]
forall a b. (a -> b) -> [a] -> [b]
map Integer -> SInteger
forall a. SymVal a => a -> SBV a
literal) [[Integer]]
m)
                          if Bool
homogeneous
                              then Solution -> IO Solution
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Solution -> IO Solution) -> Solution -> IO Solution
forall a b. (a -> b) -> a -> b
$ [[Integer]] -> Solution
Homogeneous [[Integer]]
solution
                              else do let ones :: [[Integer]]
ones  = [[Integer]
xs | (Integer
1:[Integer]
xs) <- [[Integer]]
solution]
                                          zeros :: [[Integer]]
zeros = [[Integer]
xs | (Integer
0:[Integer]
xs) <- [[Integer]]
solution]
                                      Solution -> IO Solution
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Solution -> IO Solution) -> Solution -> IO Solution
forall a b. (a -> b) -> a -> b
$ [[Integer]] -> [[Integer]] -> Solution
NonHomogeneous [[Integer]]
ones [[Integer]]
zeros
  where rhs :: [Integer]
rhs = (([Integer], Integer) -> Integer)
-> [([Integer], Integer)] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map ([Integer], Integer) -> Integer
forall a b. (a, b) -> b
snd [([Integer], Integer)]
problem
        lhs :: [[Integer]]
lhs = (([Integer], Integer) -> [Integer])
-> [([Integer], Integer)] -> [[Integer]]
forall a b. (a -> b) -> [a] -> [b]
map ([Integer], Integer) -> [Integer]
forall a b. (a, b) -> a
fst [([Integer], Integer)]
problem
        homogeneous :: Bool
homogeneous = (Integer -> Bool) -> [Integer] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0) [Integer]
rhs
        m :: [[Integer]]
m | Bool
homogeneous = [[Integer]]
lhs
          | Bool
True        = (Integer -> [Integer] -> [Integer])
-> [Integer] -> [[Integer]] -> [[Integer]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Integer
x [Integer]
y -> -Integer
x Integer -> [Integer] -> [Integer]
forall a. a -> [a] -> [a]
: [Integer]
y) [Integer]
rhs [[Integer]]
lhs

-- | Find the basis solution. By definition, the basis has all non-trivial (i.e., non-0) solutions
-- that cannot be written as the sum of two other solutions. We use the mathematically equivalent
-- statement that a solution is in the basis if it's least according to the natural partial
-- order using the ordinary less-than relation.
basis :: forall proxy n. KnownNat n => proxy n -> Maybe Int -> [[SInteger]] -> IO [[Integer]]
basis :: forall (proxy :: Nat -> *) (n :: Nat).
KnownNat n =>
proxy n -> Maybe Int -> [[SInteger]] -> IO [[Integer]]
basis proxy n
_ Maybe Int
mbLim [[SInteger]]
m = AllSatResult -> [[Integer]]
forall a. SatModel a => AllSatResult -> [a]
extractModels (AllSatResult -> [[Integer]]) -> IO AllSatResult -> IO [[Integer]]
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` SMTConfig -> SymbolicT IO () -> IO AllSatResult
forall a. Satisfiable a => SMTConfig -> a -> IO AllSatResult
allSatWith SMTConfig
z3{allSatMaxModelCount = mbLim} SymbolicT IO ()
cond
 where cond :: SymbolicT IO ()
cond = do [SInteger]
as <- Int -> Symbolic [SInteger]
forall a. SymVal a => Int -> Symbolic [SBV a]
mkFreeVars  Int
n

                 (ForallN n Any Integer -> SBool) -> SymbolicT IO ()
forall a. QuantifiedBool a => a -> SymbolicT IO ()
forall (m :: * -> *) a.
(SolverContext m, QuantifiedBool a) =>
a -> m ()
constrain ((ForallN n Any Integer -> SBool) -> SymbolicT IO ())
-> (ForallN n Any Integer -> SBool) -> SymbolicT IO ()
forall a b. (a -> b) -> a -> b
$ \(ForallN [SInteger]
bs :: ForallN n nm Integer) ->
                        [SInteger] -> SBool
ok [SInteger]
as SBool -> SBool -> SBool
.&& ([SInteger] -> SBool
ok [SInteger]
bs SBool -> SBool -> SBool
.=> [SInteger]
as [SInteger] -> [SInteger] -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== [SInteger]
bs SBool -> SBool -> SBool
.|| SBool -> SBool
sNot ([SInteger]
bs [SInteger] -> [SInteger] -> SBool
forall {b}. OrdSymbolic b => [b] -> [b] -> SBool
`less` [SInteger]
as))

       n :: Int
n = if [[SInteger]] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [[SInteger]]
m then Int
0 else [SInteger] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([[SInteger]] -> [SInteger]
forall a. HasCallStack => [a] -> a
head [[SInteger]]
m)

       ok :: [SInteger] -> SBool
ok [SInteger]
xs = (SInteger -> SBool) -> [SInteger] -> SBool
forall a. (a -> SBool) -> [a] -> SBool
sAny (SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.> SInteger
0) [SInteger]
xs SBool -> SBool -> SBool
.&& (SInteger -> SBool) -> [SInteger] -> SBool
forall a. (a -> SBool) -> [a] -> SBool
sAll (SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0) [SInteger]
xs SBool -> SBool -> SBool
.&& [SBool] -> SBool
sAnd [[SInteger] -> SInteger
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((SInteger -> SInteger -> SInteger)
-> [SInteger] -> [SInteger] -> [SInteger]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
(*) [SInteger]
r [SInteger]
xs) SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0 | [SInteger]
r <- [[SInteger]]
m]

       [b]
as less :: [b] -> [b] -> SBool
`less` [b]
bs = [SBool] -> SBool
sAnd ((b -> b -> SBool) -> [b] -> [b] -> [SBool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith b -> b -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
(.<=) [b]
as [b]
bs) SBool -> SBool -> SBool
.&& [SBool] -> SBool
sOr ((b -> b -> SBool) -> [b] -> [b] -> [SBool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith b -> b -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
(.<) [b]
as [b]
bs)

--------------------------------------------------------------------------------------------------
-- * Examples
--------------------------------------------------------------------------------------------------

-- | Solve the equation:
--
--    @2x + y - z = 2@
--
-- We have:
--
-- >>> test
-- (k, 2+k', 2k+k')
-- (1+k, k', 2k+k')
--
-- That is, for arbitrary @k@ and @k'@, we have two different solutions. (An infinite family.)
-- You can verify these solutuions by substituting the values for @x@, @y@ and @z@ in the above, for each choice.
-- It's harder to see that they cover all possibilities, but a moments thought reveals that is indeed the case.
test :: IO Solution
test :: IO Solution
test = Proxy 4 -> Maybe Int -> [([Integer], Integer)] -> IO Solution
forall (proxy :: Nat -> *) (n :: Nat).
KnownNat n =>
proxy n -> Maybe Int -> [([Integer], Integer)] -> IO Solution
ldn (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @4) Maybe Int
forall a. Maybe a
Nothing [([Integer
2,Integer
1,-Integer
1], Integer
2)]

-- | A puzzle: Five sailors and a monkey escape from a naufrage and reach an island with
-- coconuts. Before dawn, they gather a few of them and decide to sleep first and share
-- the next day. At night, however, one of them awakes, counts the nuts, makes five parts,
-- gives the remaining nut to the monkey, saves his share away, and sleeps. All other
-- sailors do the same, one by one. When they all wake up in the morning, they again make 5 shares,
-- and give the last remaining nut to the monkey. How many nuts were there at the beginning?
--
-- We can model this as a series of diophantine equations:
--
-- @
--       x_0 = 5 x_1 + 1
--     4 x_1 = 5 x_2 + 1
--     4 x_2 = 5 x_3 + 1
--     4 x_3 = 5 x_4 + 1
--     4 x_4 = 5 x_5 + 1
--     4 x_5 = 5 x_6 + 1
-- @
--
-- We need to solve for x_0, over the naturals. If you run this program, z3 takes its time (quite long!)
-- but, it eventually computes: [15621,3124,2499,1999,1599,1279,1023] as the answer.
--
-- That is:
--
-- @
--   * There was a total of 15621 coconuts
--   * 1st sailor: 15621 = 3124*5+1, leaving 15621-3124-1 = 12496
--   * 2nd sailor: 12496 = 2499*5+1, leaving 12496-2499-1 =  9996
--   * 3rd sailor:  9996 = 1999*5+1, leaving  9996-1999-1 =  7996
--   * 4th sailor:  7996 = 1599*5+1, leaving  7996-1599-1 =  6396
--   * 5th sailor:  6396 = 1279*5+1, leaving  6396-1279-1 =  5116
--   * In the morning, they had: 5116 = 1023*5+1.
-- @
--
-- Note that this is the minimum solution, that is, we are guaranteed that there's
-- no solution with less number of coconuts. In fact, any member of @[15625*k-4 | k <- [1..]]@
-- is a solution, i.e., so are @31246@, @46871@, @62496@, @78121@, etc.
--
-- Note that we iteratively deepen our search by requesting increasing number of
-- solutions to avoid the all-sat pitfall.
sailors :: IO [Integer]
sailors :: IO [Integer]
sailors = Int -> IO [Integer]
search Int
1
  where search :: Int -> IO [Integer]
search Int
i = do Solution
soln <- Proxy 8 -> Maybe Int -> [([Integer], Integer)] -> IO Solution
forall (proxy :: Nat -> *) (n :: Nat).
KnownNat n =>
proxy n -> Maybe Int -> [([Integer], Integer)] -> IO Solution
ldn (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @8)
                                  (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
i)
                                  [ ([Integer
1, -Integer
5,  Integer
0,  Integer
0,  Integer
0,  Integer
0,  Integer
0], Integer
1)
                                  , ([Integer
0,  Integer
4, -Integer
5 , Integer
0,  Integer
0,  Integer
0,  Integer
0], Integer
1)
                                  , ([Integer
0,  Integer
0,  Integer
4, -Integer
5 , Integer
0,  Integer
0,  Integer
0], Integer
1)
                                  , ([Integer
0,  Integer
0,  Integer
0,  Integer
4, -Integer
5,  Integer
0,  Integer
0], Integer
1)
                                  , ([Integer
0,  Integer
0,  Integer
0,  Integer
0,  Integer
4, -Integer
5,  Integer
0], Integer
1)
                                  , ([Integer
0,  Integer
0,  Integer
0,  Integer
0,  Integer
0,  Integer
4, -Integer
5], Integer
1)
                                  ]
                      case Solution
soln of
                        NonHomogeneous ([Integer]
xs:[[Integer]]
_) [[Integer]]
_ -> [Integer] -> IO [Integer]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [Integer]
xs
                        Solution
_                       -> Int -> IO [Integer]
search (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)