-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.KnuckleDragger.CaseSplit
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Use KnuckleDragger to prove @2n^2 + n + 1@ is never divisible by @3@.
-----------------------------------------------------------------------------

{-# LANGUAGE DataKinds        #-}
{-# LANGUAGE TypeAbstractions #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.KnuckleDragger.CaseSplit where

import Prelude hiding (sum, length)

import Data.SBV
import Data.SBV.Tools.KnuckleDragger

-- | The default settings for z3 have trouble running this proof out-of-the-box.
-- We have to pass auto_config=false to z3!
z3NoAutoConfig :: SMTConfig
z3NoAutoConfig :: SMTConfig
z3NoAutoConfig = SMTConfig
z3{extraArgs = ["auto_config=false"]}

-- | Prove that @2n^2 + n + 1@ is not divisible by @3@.
--
-- We have:
--
-- >>> notDiv3
-- Chain: case_n_mod_3_eq_0
--   Lemma: case_n_mod_3_eq_0.1            Q.E.D.
--   Lemma: case_n_mod_3_eq_0.2            Q.E.D.
-- Lemma: case_n_mod_3_eq_0                Q.E.D.
-- Chain: case_n_mod_3_eq_1
--   Lemma: case_n_mod_3_eq_1.1            Q.E.D.
--   Lemma: case_n_mod_3_eq_1.2            Q.E.D.
-- Lemma: case_n_mod_3_eq_1                Q.E.D.
-- Chain: case_n_mod_3_eq_2
--   Lemma: case_n_mod_3_eq_2.1            Q.E.D.
--   Lemma: case_n_mod_3_eq_2.2            Q.E.D.
-- Lemma: case_n_mod_3_eq_2                Q.E.D.
-- Lemma: notDiv3                          Q.E.D.
-- [Proven] notDiv3
notDiv3 :: IO Proof
notDiv3 :: IO Proof
notDiv3 = SMTConfig -> KD Proof -> IO Proof
forall a. SMTConfig -> KD a -> IO a
runKDWith SMTConfig
z3NoAutoConfig (KD Proof -> IO Proof) -> KD Proof -> IO Proof
forall a b. (a -> b) -> a -> b
$ do

   let s :: a -> a
s a
n = a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a
n a -> a -> a
forall a. Num a => a -> a -> a
* a
n a -> a -> a
forall a. Num a => a -> a -> a
+ a
n a -> a -> a
forall a. Num a => a -> a -> a
+ a
1
       p :: SInteger -> SBool
p SInteger
n = SInteger -> SInteger
forall {a}. Num a => a -> a
s SInteger
n SInteger -> SInteger -> SInteger
`sEMod` SInteger
3 SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
./= SInteger
0

   -- Do the proof in 3 phases; one each for the possible value of n `mod` 3 being 0, 1, and 2
   -- Note that we use the euclidian definition of division/modulus.

   -- Case 0: n = 0 (mod 3)
   case0 <- String
-> (Forall "n" Integer -> SBool)
-> (SInteger -> [SBool])
-> [Proof]
-> KD Proof
forall a.
Proposition a =>
String -> a -> (SInteger -> [SBool]) -> [Proof] -> KD Proof
forall steps step a.
(ChainLemma steps step, Proposition a) =>
String -> a -> steps -> [Proof] -> KD Proof
chainLemma String
"case_n_mod_3_eq_0"
                       (\(Forall @"n" SInteger
n) -> (SInteger
n SInteger -> SInteger -> SInteger
`sEMod` SInteger
3 SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0) SBool -> SBool -> SBool
.=> SInteger -> SBool
p SInteger
n)
                       (\SInteger
n -> let k :: SInteger
k = SInteger
n SInteger -> SInteger -> SInteger
`sEDiv` SInteger
3
                              in [ SInteger
n SInteger -> SInteger -> SInteger
`sEMod` SInteger
3 SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0
                                 , SInteger
n SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
3 SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* SInteger
k
                                 , SInteger -> SInteger
forall {a}. Num a => a -> a
s SInteger
n SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger -> SInteger
forall {a}. Num a => a -> a
s (SInteger
3 SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* SInteger
k)
                                 ])
                       []

   -- Case 1: n = 1 (mod 3)
   case1 <- chainLemma "case_n_mod_3_eq_1"
                       (\(Forall @"n" SInteger
n) -> (SInteger
n SInteger -> SInteger -> SInteger
`sEMod` SInteger
3 SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
1) SBool -> SBool -> SBool
.=> SInteger -> SBool
p SInteger
n)
                       (\SInteger
n -> let k :: SInteger
k = SInteger
n SInteger -> SInteger -> SInteger
`sEDiv` SInteger
3
                              in [ SInteger
n SInteger -> SInteger -> SInteger
`sEMod` SInteger
3 SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
1
                                 , SInteger
n SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
3 SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* SInteger
k SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger
1
                                 , SInteger -> SInteger
forall {a}. Num a => a -> a
s SInteger
n SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger -> SInteger
forall {a}. Num a => a -> a
s (SInteger
3 SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* SInteger
k SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger
1)
                                 ])
                       []

   -- Case 2: n = 2 (mod 3)
   case2 <- chainLemma "case_n_mod_3_eq_2"
                       (\(Forall @"n" SInteger
n) -> (SInteger
n SInteger -> SInteger -> SInteger
`sEMod` SInteger
3 SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
2) SBool -> SBool -> SBool
.=> SInteger -> SBool
p SInteger
n)
                       (\SInteger
n -> let k :: SInteger
k = SInteger
n SInteger -> SInteger -> SInteger
`sEDiv` SInteger
3
                              in [ SInteger
n SInteger -> SInteger -> SInteger
`sEMod` SInteger
3 SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
2
                                 , SInteger
n SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
3 SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* SInteger
k SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger
2
                                 , SInteger -> SInteger
forall {a}. Num a => a -> a
s SInteger
n SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger -> SInteger
forall {a}. Num a => a -> a
s (SInteger
3 SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* SInteger
k SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger
2)
                                 ])
                       []

   -- Note that z3 is smart enough to figure out the above cases are complete, so
   -- no extra completeness helper is needed.
   lemma "notDiv3"
         (\(Forall @"n" SInteger
n) -> SInteger -> SBool
p SInteger
n)
         [case0, case1, case2]