-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.Misc.Definitions
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Demonstrates how we can add actual SMT-definitions for functions
-- that cannot otherwise be defined in SBV. Typically, these are used
-- for recursive definitions.
-----------------------------------------------------------------------------

{-# LANGUAGE OverloadedLists  #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.Misc.Definitions where

import Data.SBV
import Data.SBV.Tuple
import qualified Data.SBV.List as L

-------------------------------------------------------------------------
-- * Simple functions
-------------------------------------------------------------------------

-- | Add one to an argument
add1 :: SInteger -> SInteger
add1 :: SInteger -> SInteger
add1 = String -> (SInteger -> SInteger) -> SInteger -> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"add1" (SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1)

-- | Reverse run the add1 function. Note that the generated SMTLib will have the function
-- add1 itself defined. You can verify this by running the below in verbose mode.
--
-- >>> add1Example
-- Satisfiable. Model:
--   x = 4 :: Integer
add1Example :: IO SatResult
add1Example :: IO SatResult
add1Example = SymbolicT IO SBool -> IO SatResult
forall a. Satisfiable a => a -> IO SatResult
sat (SymbolicT IO SBool -> IO SatResult)
-> SymbolicT IO SBool -> IO SatResult
forall a b. (a -> b) -> a -> b
$ do
        SInteger
x <- String -> Symbolic SInteger
sInteger String
"x"
        SBool -> SymbolicT IO SBool
forall a. a -> SymbolicT IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SBool -> SymbolicT IO SBool) -> SBool -> SymbolicT IO SBool
forall a b. (a -> b) -> a -> b
$ SInteger
5 SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger -> SInteger
add1 SInteger
x

-------------------------------------------------------------------------
-- * Basic recursive functions
-------------------------------------------------------------------------

-- | Sum of numbers from 0 to the given number. Since this is a recursive
-- definition, we cannot simply symbolically simulate it as it wouldn't
-- terminat. So, we use the function generation facilities to define it
-- directly in SMTLib. Note how the function itself takes a "recursive version"
-- of itself, and all recursive calls are made with this name.
sumToN :: SInteger -> SInteger
sumToN :: SInteger -> SInteger
sumToN = String -> (SInteger -> SInteger) -> SInteger -> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"sumToN" ((SInteger -> SInteger) -> SInteger -> SInteger)
-> (SInteger -> SInteger) -> SInteger -> SInteger
forall a b. (a -> b) -> a -> b
$ \SInteger
x -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
x SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.<= SInteger
0) SInteger
0 (SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger -> SInteger
sumToN (SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1))

-- | Prove that sumToN works as expected.
--
-- We have:
--
-- >>> sumToNExample
-- Satisfiable. Model:
--   s0 =  5 :: Integer
--   s1 = 15 :: Integer
sumToNExample :: IO SatResult
sumToNExample :: IO SatResult
sumToNExample = (SInteger -> SInteger -> SBool) -> IO SatResult
forall a. Satisfiable a => a -> IO SatResult
sat ((SInteger -> SInteger -> SBool) -> IO SatResult)
-> (SInteger -> SInteger -> SBool) -> IO SatResult
forall a b. (a -> b) -> a -> b
$ \SInteger
a SInteger
r -> SInteger
a SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
5 SBool -> SBool -> SBool
.&& SInteger
r SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger -> SInteger
sumToN SInteger
a

-- | Coding list-length recursively. Again, we map directly to an SMTLib function.
len :: SList Integer -> SInteger
len :: SList Integer -> SInteger
len = String -> (SList Integer -> SInteger) -> SList Integer -> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"list_length" ((SList Integer -> SInteger) -> SList Integer -> SInteger)
-> (SList Integer -> SInteger) -> SList Integer -> SInteger
forall a b. (a -> b) -> a -> b
$ \SList Integer
xs -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SList Integer -> SBool
forall a. SymVal a => SList a -> SBool
L.null SList Integer
xs) SInteger
0 (SInteger
1 SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SList Integer -> SInteger
len (SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a
L.tail SList Integer
xs))

-- | Calculate the length of a list, using recursive functions.
--
-- We have:
--
-- >>> lenExample
-- Satisfiable. Model:
--   s0 = [1,2,3] :: [Integer]
--   s1 =       3 :: Integer
lenExample :: IO SatResult
lenExample :: IO SatResult
lenExample = (SList Integer -> SInteger -> SBool) -> IO SatResult
forall a. Satisfiable a => a -> IO SatResult
sat ((SList Integer -> SInteger -> SBool) -> IO SatResult)
-> (SList Integer -> SInteger -> SBool) -> IO SatResult
forall a b. (a -> b) -> a -> b
$ \SList Integer
a SInteger
r -> SList Integer
a SList Integer -> SList Integer -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== [Integer
Item (SList Integer)
1,Integer
Item (SList Integer)
2,Integer
3::Integer] SBool -> SBool -> SBool
.&& SInteger
r SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SList Integer -> SInteger
len SList Integer
a

-------------------------------------------------------------------------
-- * Mutual recursion
-------------------------------------------------------------------------

-- | A simple mutual-recursion example, from the z3 documentation. We have:
--
-- >>> pingPong
-- Satisfiable. Model:
--   s0 = 1 :: Integer
pingPong :: IO SatResult
pingPong :: IO SatResult
pingPong = (SInteger -> SBool) -> IO SatResult
forall a. Satisfiable a => a -> IO SatResult
sat ((SInteger -> SBool) -> IO SatResult)
-> (SInteger -> SBool) -> IO SatResult
forall a b. (a -> b) -> a -> b
$ \SInteger
x -> SInteger
x SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.> SInteger
0 SBool -> SBool -> SBool
.&& SInteger -> SBool -> SInteger
ping SInteger
x SBool
sTrue SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.> SInteger
x
  where ping :: SInteger -> SBool -> SInteger
        ping :: SInteger -> SBool -> SInteger
ping = String
-> (SInteger -> SBool -> SInteger) -> SInteger -> SBool -> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"ping" ((SInteger -> SBool -> SInteger) -> SInteger -> SBool -> SInteger)
-> (SInteger -> SBool -> SInteger) -> SInteger -> SBool -> SInteger
forall a b. (a -> b) -> a -> b
$ \SInteger
x SBool
y -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite SBool
y (SInteger -> SBool -> SInteger
pong (SInteger
xSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1) (SBool -> SBool
sNot SBool
y)) (SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1)

        pong :: SInteger -> SBool -> SInteger
        pong :: SInteger -> SBool -> SInteger
pong = String
-> (SInteger -> SBool -> SInteger) -> SInteger -> SBool -> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"pong" ((SInteger -> SBool -> SInteger) -> SInteger -> SBool -> SInteger)
-> (SInteger -> SBool -> SInteger) -> SInteger -> SBool -> SInteger
forall a b. (a -> b) -> a -> b
$ \SInteger
a SBool
b -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite SBool
b (SInteger -> SBool -> SInteger
ping (SInteger
aSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
-SInteger
1) (SBool -> SBool
sNot SBool
b)) SInteger
a

-- | Usual way to define even-odd mutually recursively. Unfortunately, while this goes through,
-- the backend solver does not terminate on this example. See 'evenOdd2' for an alternative
-- technique to handle such definitions, which seems to be more solver friendly.
evenOdd :: IO SatResult
evenOdd :: IO SatResult
evenOdd = SMTConfig -> (SInteger -> SBool -> SBool) -> IO SatResult
forall a. Satisfiable a => SMTConfig -> a -> IO SatResult
satWith SMTConfig
z3{verbose=True} ((SInteger -> SBool -> SBool) -> IO SatResult)
-> (SInteger -> SBool -> SBool) -> IO SatResult
forall a b. (a -> b) -> a -> b
$ \SInteger
a SBool
r -> SInteger
a SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
20 SBool -> SBool -> SBool
.&& SBool
r SBool -> SBool -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger -> SBool
isE SInteger
a
  where isE, isO :: SInteger -> SBool
        isE :: SInteger -> SBool
isE = String -> (SInteger -> SBool) -> SInteger -> SBool
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"isE" ((SInteger -> SBool) -> SInteger -> SBool)
-> (SInteger -> SBool) -> SInteger -> SBool
forall a b. (a -> b) -> a -> b
$ \SInteger
x -> SBool -> SBool -> SBool -> SBool
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
x SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.< SInteger
0) (SInteger -> SBool
isE (-SInteger
x)) (SInteger
x SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0 SBool -> SBool -> SBool
.|| SInteger -> SBool
isO  (SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1))
        isO :: SInteger -> SBool
isO = String -> (SInteger -> SBool) -> SInteger -> SBool
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"isO"  ((SInteger -> SBool) -> SInteger -> SBool)
-> (SInteger -> SBool) -> SInteger -> SBool
forall a b. (a -> b) -> a -> b
$ \SInteger
x -> SBool -> SBool -> SBool -> SBool
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
x SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.< SInteger
0) (SInteger -> SBool
isO  (-SInteger
x)) (SInteger
x SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0 SBool -> SBool -> SBool
.|| SInteger -> SBool
isE (SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1))

-- | Another technique to handle mutually definitions is to define the functions together, and pull the results out individually.
-- This usually works better than defining the functions separately, from a solver perspective.
isEvenOdd :: SInteger -> STuple Bool Bool
isEvenOdd :: SInteger -> STuple Bool Bool
isEvenOdd = String
-> (SInteger -> STuple Bool Bool) -> SInteger -> STuple Bool Bool
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"isEvenOdd" ((SInteger -> STuple Bool Bool) -> SInteger -> STuple Bool Bool)
-> (SInteger -> STuple Bool Bool) -> SInteger -> STuple Bool Bool
forall a b. (a -> b) -> a -> b
$ \SInteger
x -> SBool -> STuple Bool Bool -> STuple Bool Bool -> STuple Bool Bool
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
x SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.<  SInteger
0) (SInteger -> STuple Bool Bool
isEvenOdd (-SInteger
x))
                                          (STuple Bool Bool -> STuple Bool Bool)
-> STuple Bool Bool -> STuple Bool Bool
forall a b. (a -> b) -> a -> b
$ SBool -> STuple Bool Bool -> STuple Bool Bool -> STuple Bool Bool
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
x SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0) ((SBool, SBool) -> STuple Bool Bool
forall tup a. Tuple tup a => a -> SBV tup
tuple (SBool
sTrue, SBool
sFalse))
                                                          (STuple Bool Bool -> STuple Bool Bool
forall a b. (SymVal a, SymVal b) => STuple a b -> STuple b a
swap (SInteger -> STuple Bool Bool
isEvenOdd (SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1)))

-- | Extract the isEven function for easier use.
isEven :: SInteger -> SBool
isEven :: SInteger -> SBool
isEven SInteger
x = SInteger -> STuple Bool Bool
isEvenOdd SInteger
x STuple Bool Bool -> (STuple Bool Bool -> SBool) -> SBool
forall a b. a -> (a -> b) -> b
^.STuple Bool Bool -> SBool
forall b a. HasField "_1" b a => SBV a -> SBV b
_1

-- | Extract the isOdd function for easier use.
isOdd :: SInteger -> SBool
isOdd :: SInteger -> SBool
isOdd SInteger
x = SInteger -> STuple Bool Bool
isEvenOdd SInteger
x STuple Bool Bool -> (STuple Bool Bool -> SBool) -> SBool
forall a b. a -> (a -> b) -> b
^.STuple Bool Bool -> SBool
forall b a. HasField "_2" b a => SBV a -> SBV b
_2

-- | We can prove 20 is even and definitely not odd, thusly:
--
-- >>> evenOdd2
-- Satisfiable. Model:
--   s0 =    20 :: Integer
--   s1 =  True :: Bool
--   s2 = False :: Bool
evenOdd2 :: IO SatResult
evenOdd2 :: IO SatResult
evenOdd2 = (SInteger -> SBool -> SBool -> SBool) -> IO SatResult
forall a. Satisfiable a => a -> IO SatResult
sat ((SInteger -> SBool -> SBool -> SBool) -> IO SatResult)
-> (SInteger -> SBool -> SBool -> SBool) -> IO SatResult
forall a b. (a -> b) -> a -> b
$ \SInteger
a SBool
r1 SBool
r2 -> SInteger
a SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
20 SBool -> SBool -> SBool
.&& SBool
r1 SBool -> SBool -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger -> SBool
isEven SInteger
a SBool -> SBool -> SBool
.&& SBool
r2 SBool -> SBool -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger -> SBool
isOdd SInteger
a

-------------------------------------------------------------------------
-- * Nested recursion
-------------------------------------------------------------------------

-- | Ackermann function, demonstrating nested recursion.
ack :: SInteger -> SInteger -> SInteger
ack :: SInteger -> SInteger -> SInteger
ack = String
-> (SInteger -> SInteger -> SInteger)
-> SInteger
-> SInteger
-> SInteger
forall a. (SMTDefinable a, Lambda Symbolic a) => String -> a -> a
smtFunction String
"ack" ((SInteger -> SInteger -> SInteger)
 -> SInteger -> SInteger -> SInteger)
-> (SInteger -> SInteger -> SInteger)
-> SInteger
-> SInteger
-> SInteger
forall a b. (a -> b) -> a -> b
$ \SInteger
x SInteger
y -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
x SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0) (SInteger
y SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger
1)
                                (SInteger -> SInteger) -> SInteger -> SInteger
forall a b. (a -> b) -> a -> b
$ SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
y SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0) (SInteger -> SInteger -> SInteger
ack (SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1) SInteger
1)
                                                (SInteger -> SInteger -> SInteger
ack (SInteger
x SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1) (SInteger -> SInteger -> SInteger
ack SInteger
x (SInteger
y SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1)))

-- | We can prove constant-folding instances of the equality @ack 1 y == y + 2@:
--
-- >>> ack1y
-- Satisfiable. Model:
--   s0 = 5 :: Integer
--   s1 = 7 :: Integer
--
-- Expecting the prover to handle the general case for arbitrary @y@ is beyond the current
-- scope of what SMT solvers do out-of-the-box for the time being.
ack1y :: IO SatResult
ack1y :: IO SatResult
ack1y = (SInteger -> SInteger -> SBool) -> IO SatResult
forall a. Satisfiable a => a -> IO SatResult
sat ((SInteger -> SInteger -> SBool) -> IO SatResult)
-> (SInteger -> SInteger -> SBool) -> IO SatResult
forall a b. (a -> b) -> a -> b
$ \SInteger
y SInteger
r -> SInteger
y SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
5 SBool -> SBool -> SBool
.&& SInteger
r SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger -> SInteger -> SInteger
ack SInteger
1 SInteger
y