-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.Lists.BoundedMutex
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Demonstrates use of bounded list utilities, proving a simple
-- mutex algorithm correct up to given bounds.
-----------------------------------------------------------------------------

{-# LANGUAGE DeriveAnyClass      #-}
{-# LANGUAGE DeriveDataTypeable  #-}
{-# LANGUAGE OverloadedLists     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}
{-# LANGUAGE TemplateHaskell     #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.Lists.BoundedMutex where

import Data.SBV
import Data.SBV.Control

import Prelude hiding ((!!))
import Data.SBV.List ((!!))
import qualified Data.SBV.List              as L
import qualified Data.SBV.Tools.BoundedList as L

-- | Each agent can be in one of the three states
data State = Idle     -- ^ Regular work
           | Ready    -- ^ Intention to enter critical state
           | Critical -- ^ In the critical state

-- | Make 'State' a symbolic enumeration
mkSymbolicEnumeration ''State

-- | A bounded mutex property holds for two sequences of state transitions, if they are not in
-- their critical section at the same time up to that given bound.
mutex :: Int -> SList State -> SList State -> SBool
mutex :: Int -> SList State -> SList State -> SBool
mutex Int
i SList State
p1s SList State
p2s = Int -> SList Bool -> SBool
L.band Int
i forall a b. (a -> b) -> a -> b
$ forall a b c.
(SymVal a, SymVal b, SymVal c) =>
Int -> (SBV a -> SBV b -> SBV c) -> SList a -> SList b -> SList c
L.bzipWith Int
i (\SBV State
p1 SBV State
p2 -> SBV State
p1 forall a. EqSymbolic a => a -> a -> SBool
./= SBV State
sCritical SBool -> SBool -> SBool
.|| SBV State
p2 forall a. EqSymbolic a => a -> a -> SBool
./= SBV State
sCritical) SList State
p1s SList State
p2s

-- | A sequence is valid upto a bound if it starts at 'Idle', and follows the mutex rules. That is:
--
--    * From 'Idle' it can switch to 'Ready' or stay 'Idle'
--    * From 'Ready' it can switch to 'Critical' if it's its turn
--    * From 'Critical' it can either stay in 'Critical' or go back to 'Idle'
--
-- The variable @me@ identifies the agent id.
validSequence :: Int -> Integer -> SList Integer -> SList State -> SBool
validSequence :: Int -> Integer -> SList Integer -> SList State -> SBool
validSequence Int
b Integer
me SList Integer
pturns SList State
proc = [SBool] -> SBool
sAnd [ forall a. SymVal a => SList a -> SInteger
L.length SList State
proc forall a. EqSymbolic a => a -> a -> SBool
.== forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
b
                                      , SBV State
sIdle forall a. EqSymbolic a => a -> a -> SBool
.== forall a. SymVal a => SList a -> SBV a
L.head SList State
proc
                                      , forall {t}.
(Eq t, Num t) =>
t -> SList Integer -> SList State -> SBV State -> SBool
check Int
b SList Integer
pturns SList State
proc SBV State
sIdle
                                      ]
   where check :: t -> SList Integer -> SList State -> SBV State -> SBool
check t
0 SList Integer
_  SList State
_  SBV State
_    = SBool
sTrue
         check t
i SList Integer
ts SList State
ps SBV State
prev = let (SBV State
cur,  SList State
rest)  = forall a. SymVal a => SList a -> (SBV a, SList a)
L.uncons SList State
ps
                                  (SInteger
turn, SList Integer
turns) = forall a. SymVal a => SList a -> (SBV a, SList a)
L.uncons SList Integer
ts
                                  ok :: SBool
ok   = forall a. Mergeable a => SBool -> a -> a -> a
ite (SBV State
prev forall a. EqSymbolic a => a -> a -> SBool
.== SBV State
sIdle)                          (SBV State
cur forall a. EqSymbolic a => a -> [a] -> SBool
`sElem` [SBV State
sIdle, SBV State
sReady])
                                       forall a b. (a -> b) -> a -> b
$ forall a. Mergeable a => SBool -> a -> a -> a
ite (SBV State
prev forall a. EqSymbolic a => a -> a -> SBool
.== SBV State
sReady SBool -> SBool -> SBool
.&& SInteger
turn forall a. EqSymbolic a => a -> a -> SBool
.== forall a. SymVal a => a -> SBV a
literal Integer
me) (SBV State
cur forall a. EqSymbolic a => a -> [a] -> SBool
`sElem` [SBV State
sCritical])
                                       forall a b. (a -> b) -> a -> b
$ forall a. Mergeable a => SBool -> a -> a -> a
ite (SBV State
prev forall a. EqSymbolic a => a -> a -> SBool
.== SBV State
sCritical)                      (SBV State
cur forall a. EqSymbolic a => a -> [a] -> SBool
`sElem` [SBV State
sCritical, SBV State
sIdle])
                                                                                       (SBV State
cur forall a. EqSymbolic a => a -> [a] -> SBool
`sElem` [SBV State
prev])
                              in SBool
ok SBool -> SBool -> SBool
.&& t -> SList Integer -> SList State -> SBV State -> SBool
check (t
iforall a. Num a => a -> a -> a
-t
1) SList Integer
turns SList State
rest SBV State
cur

-- | The mutex algorithm, coded implicitly as an assignment to turns. Turns start at @1@, and at each stage is either
-- @1@ or @2@; giving preference to that process. The only condition is that if either process is in its critical
-- section, then the turn value stays the same. Note that this is sufficient to satisfy safety (i.e., mutual
-- exclusion), though it does not guarantee liveness.
validTurns :: Int -> SList Integer -> SList State -> SList State -> SBool
validTurns :: Int -> SList Integer -> SList State -> SList State -> SBool
validTurns Int
b SList Integer
turns SList State
process1 SList State
process2 = [SBool] -> SBool
sAnd [ forall a. SymVal a => SList a -> SInteger
L.length SList Integer
turns forall a. EqSymbolic a => a -> a -> SBool
.== forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
b
                                            , SInteger
1 forall a. EqSymbolic a => a -> a -> SBool
.== forall a. SymVal a => SList a -> SBV a
L.head SList Integer
turns
                                            , forall {a} {t}.
(Ord a, Num t, Num a, Eq t, SymVal a) =>
t -> SList a -> SList State -> SList State -> SBV a -> SBool
check Int
b SList Integer
turns SList State
process1 SList State
process2 SInteger
1
                                            ]
   where check :: t -> SList a -> SList State -> SList State -> SBV a -> SBool
check t
0 SList a
_  SList State
_     SList State
_     SBV a
_    = SBool
sTrue
         check t
i SList a
ts SList State
proc1 SList State
proc2 SBV a
prev =   SBV a
cur forall a. EqSymbolic a => a -> [a] -> SBool
`sElem` [SBV a
1, SBV a
2]
                                     SBool -> SBool -> SBool
.&& (SBV State
p1 forall a. EqSymbolic a => a -> a -> SBool
.== SBV State
sCritical SBool -> SBool -> SBool
.|| SBV State
p2 forall a. EqSymbolic a => a -> a -> SBool
.== SBV State
sCritical SBool -> SBool -> SBool
.=> SBV a
cur forall a. EqSymbolic a => a -> a -> SBool
.== SBV a
prev)
                                     SBool -> SBool -> SBool
.&& t -> SList a -> SList State -> SList State -> SBV a -> SBool
check (t
iforall a. Num a => a -> a -> a
-t
1) SList a
rest SList State
p1s SList State
p2s SBV a
cur
            where (SBV a
cur, SList a
rest) = forall a. SymVal a => SList a -> (SBV a, SList a)
L.uncons SList a
ts
                  (SBV State
p1,  SList State
p1s)  = forall a. SymVal a => SList a -> (SBV a, SList a)
L.uncons SList State
proc1
                  (SBV State
p2,  SList State
p2s)  = forall a. SymVal a => SList a -> (SBV a, SList a)
L.uncons SList State
proc2

-- | Check that we have the mutex property so long as 'validSequence' and 'validTurns' holds; i.e.,
-- so long as both the agents and the arbiter act according to the rules. The check is bounded up-to-the
-- given concrete bound; so this is an example of a bounded-model-checking style proof. We have:
--
-- >>> checkMutex 20
-- All is good!
checkMutex :: Int -> IO ()
checkMutex :: Int -> IO ()
checkMutex Int
b = forall a. Symbolic a -> IO a
runSMT forall a b. (a -> b) -> a -> b
$ do
                  SList State
p1    :: SList State   <- forall a. SymVal a => String -> Symbolic (SList a)
sList String
"p1"
                  SList State
p2    :: SList State   <- forall a. SymVal a => String -> Symbolic (SList a)
sList String
"p2"
                  SList Integer
turns :: SList Integer <- forall a. SymVal a => String -> Symbolic (SList a)
sList String
"turns"

                  -- Ensure that both sequences and the turns are valid
                  forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain forall a b. (a -> b) -> a -> b
$ Int -> Integer -> SList Integer -> SList State -> SBool
validSequence Int
b Integer
1 SList Integer
turns SList State
p1
                  forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain forall a b. (a -> b) -> a -> b
$ Int -> Integer -> SList Integer -> SList State -> SBool
validSequence Int
b Integer
2 SList Integer
turns SList State
p2
                  forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain forall a b. (a -> b) -> a -> b
$ Int -> SList Integer -> SList State -> SList State -> SBool
validTurns    Int
b SList Integer
turns SList State
p1 SList State
p2

                  -- Try to assert that mutex does not hold. If we get a
                  -- counter example, we would've found a violation!
                  forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain forall a b. (a -> b) -> a -> b
$ SBool -> SBool
sNot forall a b. (a -> b) -> a -> b
$ Int -> SList State -> SList State -> SBool
mutex Int
b SList State
p1 SList State
p2

                  forall a. Query a -> Symbolic a
query forall a b. (a -> b) -> a -> b
$ do CheckSatResult
cs <- Query CheckSatResult
checkSat
                             case CheckSatResult
cs of
                               CheckSatResult
Unk    -> forall a. HasCallStack => String -> a
error String
"Solver said Unknown!"
                               DSat{} -> forall a. HasCallStack => String -> a
error String
"Solver said delta-satisfiable!"
                               CheckSatResult
Unsat  -> forall a. IO a -> Query a
io forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"All is good!"
                               CheckSatResult
Sat    -> do forall a. IO a -> Query a
io forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"Violation detected!"
                                            do [State]
p1V <- forall a. SymVal a => SBV a -> Query a
getValue SList State
p1
                                               [State]
p2V <- forall a. SymVal a => SBV a -> Query a
getValue SList State
p2
                                               [Integer]
ts  <- forall a. SymVal a => SBV a -> Query a
getValue SList Integer
turns

                                               forall a. IO a -> Query a
io forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"P1: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [State]
p1V
                                               forall a. IO a -> Query a
io forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"P2: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [State]
p2V
                                               forall a. IO a -> Query a
io forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"Ts: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [Integer]
ts

-- | Our algorithm is correct, but it is not fair. It does not guarantee that a process that
-- wants to enter its critical-section will always do so eventually. Demonstrate this by
-- trying to show a bounded trace of length 10, such that the second process is ready but
-- never transitions to critical. We have:
--
-- > ghci> notFair 10
-- > Fairness is violated at bound: 10
-- > P1: [Idle,Idle,Ready,Critical,Idle,Idle,Ready,Critical,Idle,Idle]
-- > P2: [Idle,Ready,Ready,Ready,Ready,Ready,Ready,Ready,Ready,Ready]
-- > Ts: [1,2,1,1,1,1,1,1,1,1]
--
-- As expected, P2 gets ready but never goes critical since the arbiter keeps picking
-- P1 unfairly. (You might get a different trace depending on what z3 happens to produce!)
--
-- Exercise for the reader: Change the 'validTurns' function so that it alternates the turns
-- from the previous value if neither process is in critical. Show that this makes the 'notFair'
-- function below no longer exhibits the issue. Is this sufficient? Concurrent programming is tricky!
notFair :: Int -> IO ()
notFair :: Int -> IO ()
notFair Int
b = forall a. Symbolic a -> IO a
runSMT forall a b. (a -> b) -> a -> b
$ do SList State
p1    :: SList State   <- forall a. SymVal a => String -> Symbolic (SList a)
sList String
"p1"
                        SList State
p2    :: SList State   <- forall a. SymVal a => String -> Symbolic (SList a)
sList String
"p2"
                        SList Integer
turns :: SList Integer <- forall a. SymVal a => String -> Symbolic (SList a)
sList String
"turns"

                        -- Ensure that both sequences and the turns are valid
                        forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain forall a b. (a -> b) -> a -> b
$ Int -> Integer -> SList Integer -> SList State -> SBool
validSequence Int
b Integer
1 SList Integer
turns SList State
p1
                        forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain forall a b. (a -> b) -> a -> b
$ Int -> Integer -> SList Integer -> SList State -> SBool
validSequence Int
b Integer
2 SList Integer
turns SList State
p2
                        forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain forall a b. (a -> b) -> a -> b
$ Int -> SList Integer -> SList State -> SList State -> SBool
validTurns    Int
b SList Integer
turns SList State
p1 SList State
p2

                        -- Ensure that the second process becomes ready in the second cycle:
                        forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain forall a b. (a -> b) -> a -> b
$ SList State
p2 forall a. SymVal a => SList a -> SInteger -> SBV a
!! SInteger
1 forall a. EqSymbolic a => a -> a -> SBool
.== SBV State
sReady

                        -- Find a trace where p2 never goes critical
                        -- counter example, we would've found a violation!
                        forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain forall a b. (a -> b) -> a -> b
$ SBool -> SBool
sNot forall a b. (a -> b) -> a -> b
$ forall a. (Eq a, SymVal a) => Int -> SBV a -> SList a -> SBool
L.belem Int
b SBV State
sCritical SList State
p2

                        forall a. Query a -> Symbolic a
query forall a b. (a -> b) -> a -> b
$ do CheckSatResult
cs <- Query CheckSatResult
checkSat
                                   case CheckSatResult
cs of
                                     CheckSatResult
Unk    -> forall a. HasCallStack => String -> a
error String
"Solver said Unknown!"
                                     DSat{} -> forall a. HasCallStack => String -> a
error String
"Solver said delta-satisfiable!"
                                     CheckSatResult
Unsat  -> forall a. HasCallStack => String -> a
error String
"Solver couldn't find a violating trace!"
                                     CheckSatResult
Sat    -> do forall a. IO a -> Query a
io forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"Fairness is violated at bound: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
b
                                                  do [State]
p1V <- forall a. SymVal a => SBV a -> Query a
getValue SList State
p1
                                                     [State]
p2V <- forall a. SymVal a => SBV a -> Query a
getValue SList State
p2
                                                     [Integer]
ts  <- forall a. SymVal a => SBV a -> Query a
getValue SList Integer
turns

                                                     forall a. IO a -> Query a
io forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"P1: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [State]
p1V
                                                     forall a. IO a -> Query a
io forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"P2: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [State]
p2V
                                                     forall a. IO a -> Query a
io forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"Ts: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [Integer]
ts

{-# ANN module ("HLint: ignore Reduce duplication" :: String) #-}