-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.KnuckleDragger.ListLen
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Example use of the KnuckleDragger, about lenghts of lists
-----------------------------------------------------------------------------

{-# LANGUAGE CPP                 #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE DeriveAnyClass      #-}
{-# LANGUAGE DeriveDataTypeable  #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeAbstractions    #-}
{-# LANGUAGE TypeApplications    #-}

{-# OPTIONS_GHC -Wall -Werror -Wno-unused-do-bind #-}

module Documentation.SBV.Examples.KnuckleDragger.ListLen where

import Prelude hiding (sum, length, reverse, (++))

import Data.SBV
import Data.SBV.Tools.KnuckleDragger

import qualified Data.SBV.List as SL

#ifndef HADDOCK
-- $setup
-- >>> -- For doctest purposes only:
-- >>> :set -XScopedTypeVariables
-- >>> import Control.Exception
#endif

-- | Use an uninterpreted type for the elements
data Elt
mkUninterpretedSort ''Elt

-- | Prove that the length of a list is one more than the length of its tail.
--
-- We have:
--
-- >>> listLengthProof
-- Lemma: length_correct                   Q.E.D.
-- [Proven] length_correct
listLengthProof :: IO Proof
listLengthProof :: IO Proof
listLengthProof = KD Proof -> IO Proof
forall a. KD a -> IO a
runKD (KD Proof -> IO Proof) -> KD Proof -> IO Proof
forall a b. (a -> b) -> a -> b
$ do
   let length :: SList Elt -> SInteger
       length :: SList Elt -> SInteger
length = String -> (SList Elt -> SInteger) -> SList Elt -> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"length" ((SList Elt -> SInteger) -> SList Elt -> SInteger)
-> (SList Elt -> SInteger) -> SList Elt -> SInteger
forall a b. (a -> b) -> a -> b
$ \SList Elt
xs -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SList Elt -> SBool
forall a. SymVal a => SList a -> SBool
SL.null SList Elt
xs) SInteger
0 (SInteger
1 SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SList Elt -> SInteger
length (SList Elt -> SList Elt
forall a. SymVal a => SList a -> SList a
SL.tail SList Elt
xs))

       spec :: SList Elt -> SInteger
       spec :: SList Elt -> SInteger
spec = SList Elt -> SInteger
forall a. SymVal a => SList a -> SInteger
SL.length

       p :: SList Elt -> SBool
       p :: SList Elt -> SBool
p SList Elt
xs = String -> SInteger -> SInteger
forall a. SymVal a => String -> SBV a -> SBV a
observe String
"imp" (SList Elt -> SInteger
length SList Elt
xs) SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== String -> SInteger -> SInteger
forall a. SymVal a => String -> SBV a -> SBV a
observe String
"spec" (SList Elt -> SInteger
spec SList Elt
xs)

   String -> (Forall "xs" [Elt] -> SBool) -> [Proof] -> KD Proof
forall a. Proposition a => String -> a -> [Proof] -> KD Proof
lemma String
"length_correct" (\(Forall @"xs" SList Elt
xs) -> SList Elt -> SBool
p SList Elt
xs) [(SList Elt -> SBool) -> Proof
forall a. Induction a => a -> Proof
induct SList Elt -> SBool
p]

-- | It is instructive to see what kind of counter-example we get if a lemma fails to prove.
-- Below, we do a variant of the 'listLengthProof', but with a bad implementation over integers,
-- and see the counter-example. Our implementation returns an incorrect answer if the given list is longer
-- than 5 elements and have 42 in it. We have:
--
-- >>> badProof `catch` (\(_ :: SomeException) -> pure ())
-- Lemma: bad
-- *** Failed to prove bad.
-- Falsifiable. Counter-example:
--   xs   = [8,25,26,27,28,42] :: [Integer]
--   imp  =                 42 :: Integer
--   spec =                  6 :: Integer
badProof :: IO ()
badProof :: IO ()
badProof = KD () -> IO ()
forall a. KD a -> IO a
runKD (KD () -> IO ()) -> KD () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   let length :: SList Integer -> SInteger
       length :: SList Integer -> SInteger
length = String -> (SList Integer -> SInteger) -> SList Integer -> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"length" ((SList Integer -> SInteger) -> SList Integer -> SInteger)
-> (SList Integer -> SInteger) -> SList Integer -> SInteger
forall a b. (a -> b) -> a -> b
$ \SList Integer
xs -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SList Integer -> SBool
forall a. SymVal a => SList a -> SBool
SL.null SList Integer
xs) SInteger
0 (SInteger
1 SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SList Integer -> SInteger
length (SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a
SL.tail SList Integer
xs))

       badLength :: SList Integer -> SInteger
       badLength :: SList Integer -> SInteger
badLength SList Integer
xs = SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SList Integer -> SInteger
forall a. SymVal a => SList a -> SInteger
SL.length SList Integer
xs SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.> SInteger
5 SBool -> SBool -> SBool
.&& SInteger
42 SInteger -> SList Integer -> SBool
forall a. (Eq a, SymVal a) => SBV a -> SList a -> SBool
`SL.elem` SList Integer
xs) SInteger
42 (SList Integer -> SInteger
length SList Integer
xs)

       spec :: SList Integer -> SInteger
       spec :: SList Integer -> SInteger
spec = SList Integer -> SInteger
forall a. SymVal a => SList a -> SInteger
SL.length

       p :: SList Integer -> SBool
       p :: SList Integer -> SBool
p SList Integer
xs = String -> SInteger -> SInteger
forall a. SymVal a => String -> SBV a -> SBV a
observe String
"imp" (SList Integer -> SInteger
badLength SList Integer
xs) SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== String -> SInteger -> SInteger
forall a. SymVal a => String -> SBV a -> SBV a
observe String
"spec" (SList Integer -> SInteger
spec SList Integer
xs)

   String -> (Forall "xs" [Integer] -> SBool) -> [Proof] -> KD Proof
forall a. Proposition a => String -> a -> [Proof] -> KD Proof
lemma String
"bad" (\(Forall @"xs" SList Integer
xs) -> SList Integer -> SBool
p SList Integer
xs) [(SList Integer -> SBool) -> Proof
forall a. Induction a => a -> Proof
induct SList Integer -> SBool
p]

   () -> KD ()
forall a. a -> KD a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | @length (xs ++ ys) == length xs + length ys@
--
-- We have:
--
-- >>> lenAppend
-- Lemma: lenAppend                        Q.E.D.
-- [Proven] lenAppend
lenAppend :: IO Proof
lenAppend :: IO Proof
lenAppend = KD Proof -> IO Proof
forall a. KD a -> IO a
runKD (KD Proof -> IO Proof) -> KD Proof -> IO Proof
forall a b. (a -> b) -> a -> b
$ String
-> (Forall "xs" [Elt] -> Forall "ys" [Elt] -> SBool)
-> [Proof]
-> KD Proof
forall a. Proposition a => String -> a -> [Proof] -> KD Proof
lemma String
"lenAppend"
                           (\(Forall @"xs" (SList Elt
xs :: SList Elt)) (Forall @"ys" SList Elt
ys) ->
                                 SList Elt -> SInteger
forall a. SymVal a => SList a -> SInteger
SL.length (SList Elt
xs SList Elt -> SList Elt -> SList Elt
forall a. SymVal a => SList a -> SList a -> SList a
SL.++ SList Elt
ys) SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SList Elt -> SInteger
forall a. SymVal a => SList a -> SInteger
SL.length SList Elt
xs SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SList Elt -> SInteger
forall a. SymVal a => SList a -> SInteger
SL.length SList Elt
ys)
                           []

-- | @length xs == length ys -> length (xs ++ ys) == 2 * length xs@
--
-- We have:
--
-- >>> lenAppend2
-- Lemma: lenAppend2                       Q.E.D.
-- [Proven] lenAppend2
lenAppend2 :: IO Proof
lenAppend2 :: IO Proof
lenAppend2 = KD Proof -> IO Proof
forall a. KD a -> IO a
runKD (KD Proof -> IO Proof) -> KD Proof -> IO Proof
forall a b. (a -> b) -> a -> b
$ String
-> (Forall "xs" [Elt] -> Forall "ys" [Elt] -> SBool)
-> [Proof]
-> KD Proof
forall a. Proposition a => String -> a -> [Proof] -> KD Proof
lemma String
"lenAppend2"
                           (\(Forall @"xs" (SList Elt
xs :: SList Elt)) (Forall @"ys" SList Elt
ys) ->
                                     SList Elt -> SInteger
forall a. SymVal a => SList a -> SInteger
SL.length SList Elt
xs SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SList Elt -> SInteger
forall a. SymVal a => SList a -> SInteger
SL.length SList Elt
ys
                                 SBool -> SBool -> SBool
.=> SList Elt -> SInteger
forall a. SymVal a => SList a -> SInteger
SL.length (SList Elt
xs SList Elt -> SList Elt -> SList Elt
forall a. SymVal a => SList a -> SList a -> SList a
SL.++ SList Elt
ys) SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
2 SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* SList Elt -> SInteger
forall a. SymVal a => SList a -> SInteger
SL.length SList Elt
xs)
                           []