-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.BitPrecise.BrokenSearch
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- The classic "binary-searches are broken" example:
--     <http://ai.googleblog.com/2006/06/extra-extra-read-all-about-it-nearly.html>
-----------------------------------------------------------------------------

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.BitPrecise.BrokenSearch where

import Data.SBV
import Data.SBV.Tools.Overflow

-- | Model the mid-point computation of the binary search, which is broken due to arithmetic overflow.
-- Note how we use the overflow checking variants of the arithmetic operators. We have:
--
-- >>> checkArithOverflow midPointBroken
-- Documentation/SBV/Examples/BitPrecise/BrokenSearch.hs:35:28:+!: SInt32 addition overflows: Violated. Model:
--   low  = 2147483647 :: Int32
--   high = 2147483647 :: Int32
--
-- Indeed:
--
-- >>> (2147483647 + 2147483647) `div` (2::Int32)
-- -1
--
-- giving us a negative mid-point value!
midPointBroken :: SInt32 -> SInt32 -> SInt32
midPointBroken :: SInt32 -> SInt32 -> SInt32
midPointBroken SInt32
low SInt32
high = (SInt32
low SInt32 -> SInt32 -> SInt32
forall a.
(CheckedArithmetic a, ?loc::CallStack) =>
SBV a -> SBV a -> SBV a
+! SInt32
high) SInt32 -> SInt32 -> SInt32
forall a.
(CheckedArithmetic a, ?loc::CallStack) =>
SBV a -> SBV a -> SBV a
/! SInt32
2

-- | The correct version of how to compute the mid-point. As expected, this version doesn't have any
-- underflow or overflow issues:
--
-- >>> checkArithOverflow midPointFixed
-- No violations detected.
--
-- As expected, the value is computed correctly too:
--
-- >>> checkCorrectMidValue midPointFixed
-- Q.E.D.
midPointFixed :: SInt32 -> SInt32 -> SInt32
midPointFixed :: SInt32 -> SInt32 -> SInt32
midPointFixed SInt32
low SInt32
high = SInt32
low SInt32 -> SInt32 -> SInt32
forall a.
(CheckedArithmetic a, ?loc::CallStack) =>
SBV a -> SBV a -> SBV a
+! ((SInt32
high SInt32 -> SInt32 -> SInt32
forall a.
(CheckedArithmetic a, ?loc::CallStack) =>
SBV a -> SBV a -> SBV a
-! SInt32
low) SInt32 -> SInt32 -> SInt32
forall a.
(CheckedArithmetic a, ?loc::CallStack) =>
SBV a -> SBV a -> SBV a
/! SInt32
2)

-- | Show that the variant suggested by the blog post is good as well:
--
--       @mid = ((unsigned int)low + (unsigned int)high) >> 1;@
--
-- In this case the overflow is eliminated by doing the computation at a wider
-- range:
--
-- >>> checkArithOverflow midPointAlternative
-- No violations detected.
--
-- And the value computed is indeed correct:
--
-- >>> checkCorrectMidValue midPointAlternative
-- Q.E.D.
midPointAlternative :: SInt32 -> SInt32 -> SInt32
midPointAlternative :: SInt32 -> SInt32 -> SInt32
midPointAlternative SInt32
low SInt32
high = SBV Word32 -> SInt32
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
sFromIntegral ((SBV Word32
low' SBV Word32 -> SBV Word32 -> SBV Word32
forall a.
(CheckedArithmetic a, ?loc::CallStack) =>
SBV a -> SBV a -> SBV a
+! SBV Word32
high') SBV Word32 -> Int -> SBV Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
  where low', high' :: SWord32
        low' :: SBV Word32
low'  = SInt32 -> SBV Word32
forall a b.
(?loc::CallStack, Integral a, HasKind a, HasKind b, Num a,
 SymVal a, HasKind b, Num b, SymVal b) =>
SBV a -> SBV b
sFromIntegralChecked SInt32
low
        high' :: SBV Word32
high' = SInt32 -> SBV Word32
forall a b.
(?loc::CallStack, Integral a, HasKind a, HasKind b, Num a,
 SymVal a, HasKind b, Num b, SymVal b) =>
SBV a -> SBV b
sFromIntegralChecked SInt32
high

-------------------------------------------------------------------------------------
-- * Helpers
-------------------------------------------------------------------------------------

-- | A helper predicate to check safety under the conditions that @low@ is at least 0
-- and @high@ is at least @low@.
checkArithOverflow :: (SInt32 -> SInt32 -> SInt32) -> IO ()
checkArithOverflow :: (SInt32 -> SInt32 -> SInt32) -> IO ()
checkArithOverflow SInt32 -> SInt32 -> SInt32
f = do [SafeResult]
sr <- SymbolicT IO SInt32 -> IO [SafeResult]
forall a. SExecutable IO a => a -> IO [SafeResult]
safe (SymbolicT IO SInt32 -> IO [SafeResult])
-> SymbolicT IO SInt32 -> IO [SafeResult]
forall a b. (a -> b) -> a -> b
$ do SInt32
low   <- String -> SymbolicT IO SInt32
sInt32 String
"low"
                                          SInt32
high <- String -> SymbolicT IO SInt32
sInt32 String
"high"

                                          SBool -> SymbolicT IO ()
forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain (SBool -> SymbolicT IO ()) -> SBool -> SymbolicT IO ()
forall a b. (a -> b) -> a -> b
$ SInt32
low SInt32 -> SInt32 -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInt32
0
                                          SBool -> SymbolicT IO ()
forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain (SBool -> SymbolicT IO ()) -> SBool -> SymbolicT IO ()
forall a b. (a -> b) -> a -> b
$ SInt32
low SInt32 -> SInt32 -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.<= SInt32
high

                                          SInt32 -> SymbolicT IO SInt32
forall a. Outputtable a => a -> Symbolic a
output (SInt32 -> SymbolicT IO SInt32) -> SInt32 -> SymbolicT IO SInt32
forall a b. (a -> b) -> a -> b
$ SInt32 -> SInt32 -> SInt32
f SInt32
low SInt32
high

                          case (SafeResult -> Bool) -> [SafeResult] -> [SafeResult]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (SafeResult -> Bool) -> SafeResult -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SafeResult -> Bool
isSafe) [SafeResult]
sr of
                                 [] -> String -> IO ()
putStrLn String
"No violations detected."
                                 [SafeResult]
xs -> (SafeResult -> IO ()) -> [SafeResult] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SafeResult -> IO ()
forall a. Show a => a -> IO ()
print [SafeResult]
xs

-- | Another helper to show that the result is actually the correct value, if it was done over
-- 64-bit integers, which is sufficiently large enough.
checkCorrectMidValue :: (SInt32 -> SInt32 -> SInt32) -> IO ThmResult
checkCorrectMidValue :: (SInt32 -> SInt32 -> SInt32) -> IO ThmResult
checkCorrectMidValue SInt32 -> SInt32 -> SInt32
f = SymbolicT IO SBool -> IO ThmResult
forall a. Provable a => a -> IO ThmResult
prove (SymbolicT IO SBool -> IO ThmResult)
-> SymbolicT IO SBool -> IO ThmResult
forall a b. (a -> b) -> a -> b
$ do SInt32
low  <- String -> SymbolicT IO SInt32
sInt32 String
"low"
                                    SInt32
high <- String -> SymbolicT IO SInt32
sInt32 String
"high"

                                    SBool -> SymbolicT IO ()
forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain (SBool -> SymbolicT IO ()) -> SBool -> SymbolicT IO ()
forall a b. (a -> b) -> a -> b
$ SInt32
low SInt32 -> SInt32 -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInt32
0
                                    SBool -> SymbolicT IO ()
forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain (SBool -> SymbolicT IO ()) -> SBool -> SymbolicT IO ()
forall a b. (a -> b) -> a -> b
$ SInt32
low SInt32 -> SInt32 -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.<= SInt32
high

                                    let low', high' :: SInt64
                                        low' :: SInt64
low'  = SInt32 -> SInt64
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
sFromIntegral SInt32
low
                                        high' :: SInt64
high' = SInt32 -> SInt64
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
sFromIntegral SInt32
high
                                        mid' :: SInt64
mid'  = (SInt64
low' SInt64 -> SInt64 -> SInt64
forall a. Num a => a -> a -> a
+ SInt64
high') SInt64 -> SInt64 -> SInt64
forall a. SDivisible a => a -> a -> a
`sDiv` SInt64
2

                                        mid :: SInt32
mid   = SInt32 -> SInt32 -> SInt32
f SInt32
low SInt32
high

                                    SBool -> SymbolicT IO SBool
forall (m :: * -> *) a. Monad m => a -> m a
return (SBool -> SymbolicT IO SBool) -> SBool -> SymbolicT IO SBool
forall a b. (a -> b) -> a -> b
$ SInt32 -> SInt64
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
sFromIntegral SInt32
mid SInt64 -> SInt64 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInt64
mid'

{-# ANN module ("HLint: ignore Reduce duplication" :: String) #-}