-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.KnuckleDragger.Induction
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Example use of the KnuckleDragger, for some inductive proofs
-----------------------------------------------------------------------------

{-# LANGUAGE DataKinds        #-}
{-# LANGUAGE TypeAbstractions #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.KnuckleDragger.Induction where

import Prelude hiding (sum, length)

import Data.SBV
import Data.SBV.Tools.KnuckleDragger

-- | Prove that sum of constants @c@ from @0@ to @n@ is @n*c@.
--
-- We have:
--
-- >>> sumConstProof
-- Lemma: sumConst_correct                 Q.E.D.
-- [Proven] sumConst_correct
sumConstProof :: IO Proof
sumConstProof :: IO Proof
sumConstProof = KD Proof -> IO Proof
forall a. KD a -> IO a
runKD (KD Proof -> IO Proof) -> KD Proof -> IO Proof
forall a b. (a -> b) -> a -> b
$ do
   let sum :: SInteger -> SInteger -> SInteger
       sum :: SInteger -> SInteger -> SInteger
sum = String
-> (SInteger -> SInteger -> SInteger)
-> SInteger
-> SInteger
-> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"sum" ((SInteger -> SInteger -> SInteger)
 -> SInteger -> SInteger -> SInteger)
-> (SInteger -> SInteger -> SInteger)
-> SInteger
-> SInteger
-> SInteger
forall a b. (a -> b) -> a -> b
$ \SInteger
c SInteger
n -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
n SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0) SInteger
0 (SInteger
c SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger -> SInteger -> SInteger
sum SInteger
c (SInteger
nSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
-SInteger
1))

       spec :: SInteger -> SInteger -> SInteger
       spec :: SInteger -> SInteger -> SInteger
spec SInteger
c SInteger
n = SInteger
c SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* SInteger
n

       p :: SInteger -> SInteger -> SBool
       p :: SInteger -> SInteger -> SBool
p SInteger
c SInteger
n = String -> SInteger -> SInteger
forall a. SymVal a => String -> SBV a -> SBV a
observe String
"imp" (SInteger -> SInteger -> SInteger
sum SInteger
c SInteger
n) SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== String -> SInteger -> SInteger
forall a. SymVal a => String -> SBV a -> SBV a
observe String
"spec" (SInteger -> SInteger -> SInteger
spec SInteger
c SInteger
n)

   String
-> (Forall "c" Integer -> Forall "n" Integer -> SBool)
-> [Proof]
-> KD Proof
forall a. Proposition a => String -> a -> [Proof] -> KD Proof
lemma String
"sumConst_correct" (\(Forall @"c" SInteger
c) (Forall @"n" SInteger
n) -> SInteger
n SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0 SBool -> SBool -> SBool
.=> SInteger -> SInteger -> SBool
p SInteger
c SInteger
n) [(SInteger -> SInteger -> SBool) -> Proof
forall a. Induction a => a -> Proof
induct SInteger -> SInteger -> SBool
p]

-- | Prove that sum of numbers from @0@ to @n@ is @n*(n-1)/2@.
--
-- We have:
--
-- >>> sumProof
-- Lemma: sum_correct                      Q.E.D.
-- [Proven] sum_correct
sumProof :: IO Proof
sumProof :: IO Proof
sumProof = KD Proof -> IO Proof
forall a. KD a -> IO a
runKD (KD Proof -> IO Proof) -> KD Proof -> IO Proof
forall a b. (a -> b) -> a -> b
$ do
   let sum :: SInteger -> SInteger
       sum :: SInteger -> SInteger
sum = String -> (SInteger -> SInteger) -> SInteger -> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"sum" ((SInteger -> SInteger) -> SInteger -> SInteger)
-> (SInteger -> SInteger) -> SInteger -> SInteger
forall a b. (a -> b) -> a -> b
$ \SInteger
n -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
n SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0) SInteger
0 (SInteger
n SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger -> SInteger
sum (SInteger
n SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1))

       spec :: SInteger -> SInteger
       spec :: SInteger -> SInteger
spec SInteger
n = (SInteger
n SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* (SInteger
nSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1)) SInteger -> SInteger -> SInteger
forall a. SDivisible a => a -> a -> a
`sDiv` SInteger
2

       p :: SInteger -> SBool
       p :: SInteger -> SBool
p SInteger
n = String -> SInteger -> SInteger
forall a. SymVal a => String -> SBV a -> SBV a
observe String
"imp" (SInteger -> SInteger
sum SInteger
n) SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== String -> SInteger -> SInteger
forall a. SymVal a => String -> SBV a -> SBV a
observe String
"spec" (SInteger -> SInteger
spec SInteger
n)

   String -> (Forall "n" Integer -> SBool) -> [Proof] -> KD Proof
forall a. Proposition a => String -> a -> [Proof] -> KD Proof
lemma String
"sum_correct" (\(Forall @"n" SInteger
n) -> SInteger
n SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0 SBool -> SBool -> SBool
.=> SInteger -> SBool
p SInteger
n) [(SInteger -> SBool) -> Proof
forall a. Induction a => a -> Proof
induct SInteger -> SBool
p]

-- | Prove that sum of square of numbers from @0@ to @n@ is @n*(n+1)*(2n+1)/6@.
--
-- We have:
--
-- >>> sumSquareProof
-- Lemma: sumSquare_correct                Q.E.D.
-- [Proven] sumSquare_correct
sumSquareProof :: IO Proof
sumSquareProof :: IO Proof
sumSquareProof = KD Proof -> IO Proof
forall a. KD a -> IO a
runKD (KD Proof -> IO Proof) -> KD Proof -> IO Proof
forall a b. (a -> b) -> a -> b
$ do
   let sumSquare :: SInteger -> SInteger
       sumSquare :: SInteger -> SInteger
sumSquare = String -> (SInteger -> SInteger) -> SInteger -> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"sumSquare" ((SInteger -> SInteger) -> SInteger -> SInteger)
-> (SInteger -> SInteger) -> SInteger -> SInteger
forall a b. (a -> b) -> a -> b
$ \SInteger
n -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
n SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0) SInteger
0 (SInteger
n SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* SInteger
n SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger -> SInteger
sumSquare (SInteger
n SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1))

       spec :: SInteger -> SInteger
       spec :: SInteger -> SInteger
spec SInteger
n = (SInteger
n SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* (SInteger
nSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1) SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* (SInteger
2SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
*SInteger
nSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1)) SInteger -> SInteger -> SInteger
forall a. SDivisible a => a -> a -> a
`sDiv` SInteger
6

       p :: SInteger -> SBool
       p :: SInteger -> SBool
p SInteger
n = String -> SInteger -> SInteger
forall a. SymVal a => String -> SBV a -> SBV a
observe String
"imp" (SInteger -> SInteger
sumSquare SInteger
n) SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== String -> SInteger -> SInteger
forall a. SymVal a => String -> SBV a -> SBV a
observe String
"spec" (SInteger -> SInteger
spec SInteger
n)

   String -> (Forall "n" Integer -> SBool) -> [Proof] -> KD Proof
forall a. Proposition a => String -> a -> [Proof] -> KD Proof
lemma String
"sumSquare_correct" (\(Forall @"n" SInteger
n) -> SInteger
n SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0 SBool -> SBool -> SBool
.=> SInteger -> SBool
p SInteger
n) [(SInteger -> SBool) -> Proof
forall a. Induction a => a -> Proof
induct SInteger -> SBool
p]

-- | Prove that @11^n - 4^n@ is always divisible by 7. Note that power operator is hard for
-- SMT solvers to deal with due to non-linearity. For this example, we use cvc5 to discharge
-- the final goal, where z3 can't converge on it.
--
-- We have:
--
-- >>> elevenMinusFour
-- Lemma: pow0                             Q.E.D.
-- Lemma: powN                             Q.E.D.
-- Lemma: elevenMinusFour                  Q.E.D.
-- [Proven] elevenMinusFour
elevenMinusFour :: IO Proof
elevenMinusFour :: IO Proof
elevenMinusFour = KD Proof -> IO Proof
forall a. KD a -> IO a
runKD (KD Proof -> IO Proof) -> KD Proof -> IO Proof
forall a b. (a -> b) -> a -> b
$ do
   let pow :: SInteger -> SInteger -> SInteger
       pow :: SInteger -> SInteger -> SInteger
pow = String
-> (SInteger -> SInteger -> SInteger)
-> SInteger
-> SInteger
-> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"pow" ((SInteger -> SInteger -> SInteger)
 -> SInteger -> SInteger -> SInteger)
-> (SInteger -> SInteger -> SInteger)
-> SInteger
-> SInteger
-> SInteger
forall a b. (a -> b) -> a -> b
$ \SInteger
x SInteger
y -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
y SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0) SInteger
1 (SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* SInteger -> SInteger -> SInteger
pow SInteger
x (SInteger
y SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1))

       emf :: SInteger -> SBool
       emf :: SInteger -> SBool
emf SInteger
n = Integer
7 Integer -> SInteger -> SBool
`sDivides` (SInteger
11 SInteger -> SInteger -> SInteger
`pow` SInteger
n SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
4 SInteger -> SInteger -> SInteger
`pow` SInteger
n)

   pow0 <- String -> (Forall "x" Integer -> SBool) -> [Proof] -> KD Proof
forall a. Proposition a => String -> a -> [Proof] -> KD Proof
lemma String
"pow0" (\(Forall @"x" SInteger
x)                 ->             SInteger
x SInteger -> SInteger -> SInteger
`pow` SInteger
0     SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
1)             []
   powN <- lemma "powN" (\(Forall @"x" SInteger
x) (Forall @"n" SInteger
n) -> SInteger
n SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0 SBool -> SBool -> SBool
.=> SInteger
x SInteger -> SInteger -> SInteger
`pow` (SInteger
nSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1) SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* SInteger
x SInteger -> SInteger -> SInteger
`pow` SInteger
n) []

   lemmaWith cvc5 "elevenMinusFour" (\(Forall @"n" SInteger
n) -> SInteger
n SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0 SBool -> SBool -> SBool
.=> SInteger -> SBool
emf SInteger
n) [pow0, powN, induct emf]