{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ViewPatterns #-}
module Test.DejaFu.SCT.Internal.DPOR where
import Control.Applicative ((<|>))
import Control.DeepSeq (NFData(..))
import Control.Exception (MaskingState(..))
import qualified Data.Foldable as F
import Data.Function (on)
import Data.List (nubBy, partition, sortOn)
import Data.List.NonEmpty (toList)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (isJust, isNothing, listToMaybe,
maybeToList)
import Data.Sequence (Seq, (|>))
import qualified Data.Sequence as Sq
import Data.Set (Set)
import qualified Data.Set as S
import GHC.Generics (Generic)
import GHC.Stack (HasCallStack)
import Test.DejaFu.Internal
import Test.DejaFu.Schedule (Scheduler(..))
import Test.DejaFu.Types
import Test.DejaFu.Utils (decisionOf, tidOf)
data DPOR = DPOR
{ dporRunnable :: Set ThreadId
, dporTodo :: Map ThreadId Bool
, dporNext :: Maybe (ThreadId, DPOR)
, dporDone :: Set ThreadId
, dporSleep :: Map ThreadId ThreadAction
, dporTaken :: Map ThreadId ThreadAction
} deriving (Eq, Show, Generic, NFData)
validateDPOR :: HasCallStack => DPOR -> DPOR
validateDPOR dpor
| not (todo `S.isSubsetOf` runnable) = fatal "thread exists in todo set but not runnable set"
| not (done `S.isSubsetOf` runnable) = fatal "thread exists in done set but not runnable set"
| not (taken `S.isSubsetOf` done) = fatal "thread exists in taken set but not done set"
| not (todo `disjoint` done) = fatal "thread exists in both taken set and done set"
| not (maybe True (`S.member` done) next) = fatal "taken thread does not exist in done set"
| otherwise = dpor
where
done = dporDone dpor
next = fst <$> dporNext dpor
runnable = dporRunnable dpor
taken = S.fromList (M.keys (dporTaken dpor))
todo = S.fromList (M.keys (dporTodo dpor))
disjoint s1 s2 = S.null (S.intersection s1 s2)
data BacktrackStep = BacktrackStep
{ bcktThreadid :: ThreadId
, bcktDecision :: Decision
, bcktAction :: ThreadAction
, bcktRunnable :: Map ThreadId Lookahead
, bcktBacktracks :: Map ThreadId Bool
, bcktState :: DepState
} deriving (Eq, Show, Generic, NFData)
initialState :: [ThreadId] -> DPOR
initialState threads
| initialThread `elem` threads = DPOR
{ dporRunnable = S.fromList threads
, dporTodo = M.singleton initialThread False
, dporNext = Nothing
, dporDone = S.empty
, dporSleep = M.empty
, dporTaken = M.empty
}
| otherwise = fatal "initialState" "Initial thread is not in initially runnable set"
findSchedulePrefix
:: DPOR
-> Maybe ([ThreadId], Bool, Map ThreadId ThreadAction)
findSchedulePrefix dpor = case dporNext dpor of
Just (tid, child) -> go tid child <|> here
Nothing -> here
where
go tid child = (\(ts,c,slp) -> (tid:ts,c,slp)) <$> findSchedulePrefix child
here =
let todos = [([t], c, sleeps) | (t, c) <- M.toList $ dporTodo dpor]
(best, worst) = partition (\([t],_,_) -> t >= initialThread) todos
in listToMaybe best <|> listToMaybe worst
sleeps = dporSleep dpor `M.union` dporTaken dpor
incorporateTrace :: HasCallStack
=> MemType
-> Bool
-> Trace
-> DPOR
-> DPOR
incorporateTrace memtype conservative trace dpor0 = grow initialDepState (initialDPORThread dpor0) trace dpor0 where
grow state tid trc@((d, _, a):rest) dpor =
let tid' = tidOf tid d
state' = updateDepState memtype state tid' a
in case dporNext dpor of
Just (t, child)
| t == tid' ->
validateDPOR $ dpor { dporNext = Just (tid', grow state' tid' rest child) }
| hasTodos child -> fatal "replacing child with todos!"
_ -> validateDPOR $
let taken = M.insert tid' a (dporTaken dpor)
sleep = dporSleep dpor `M.union` dporTaken dpor
in dpor { dporTaken = if conservative then dporTaken dpor else taken
, dporTodo = M.delete tid' (dporTodo dpor)
, dporNext = Just (tid', subtree state' tid' sleep trc)
, dporDone = S.insert tid' (dporDone dpor)
}
grow _ _ [] _ = fatal "trace exhausted without reading a to-do point!"
hasTodos dpor = not (M.null (dporTodo dpor)) || (case dporNext dpor of Just (_, dpor') -> hasTodos dpor'; _ -> False)
subtree state tid sleep ((_, _, a):rest) = validateDPOR $
let state' = updateDepState memtype state tid a
sleep' = M.filterWithKey (\t a' -> not $ dependent state' tid a t a') sleep
in DPOR
{ dporRunnable = S.fromList $ case rest of
((d', runnable, _):_) -> tidOf tid d' : map fst runnable
[] -> []
, dporTodo = M.empty
, dporNext = case rest of
((d', _, _):_) ->
let tid' = tidOf tid d'
in Just (tid', subtree state' tid' sleep' rest)
[] -> Nothing
, dporDone = case rest of
((d', _, _):_) -> S.singleton (tidOf tid d')
[] -> S.empty
, dporSleep = sleep'
, dporTaken = case rest of
((d', _, a'):_) -> M.singleton (tidOf tid d') a'
[] -> M.empty
}
subtree _ _ _ [] = fatal "subtree suffix empty!"
findBacktrackSteps
:: MemType
-> BacktrackFunc
-> Bool
-> Seq ([(ThreadId, Lookahead)], [ThreadId])
-> Trace
-> [BacktrackStep]
findBacktrackSteps memtype backtrack boundKill = go initialDepState S.empty initialThread [] . F.toList where
go state allThreads tid bs ((e,i):is) ((d,_,a):ts) =
let tid' = tidOf tid d
state' = updateDepState memtype state tid' a
this = BacktrackStep
{ bcktThreadid = tid'
, bcktDecision = d
, bcktAction = a
, bcktRunnable = M.fromList e
, bcktBacktracks = M.fromList $ map (\i' -> (i', False)) i
, bcktState = state
}
bs' = doBacktrack killsEarly allThreads' e (bs++[this])
runnable = S.fromList (M.keys $ bcktRunnable this)
allThreads' = allThreads `S.union` runnable
killsEarly = null ts && boundKill
in go state' allThreads' tid' bs' is ts
go _ _ _ bs _ _ = bs
doBacktrack killsEarly allThreads enabledThreads bs =
let tagged = reverse $ zip [0..] bs
idxs = [ (i, False, u)
| (u, n) <- enabledThreads
, v <- S.toList allThreads
, u /= v
, i <- maybeToList (findIndex u n v tagged)]
findIndex u n v = go' True where
{-# INLINE go' #-}
go' final ((i,b):rest)
| isSubC final b = Nothing
| bcktThreadid b == v && (killsEarly || isDependent b) = Just i
| otherwise = go' False rest
go' _ [] = Nothing
{-# INLINE isSubC #-}
isSubC final b = case bcktAction b of
Stop -> not final && bcktThreadid b == initialThread
Subconcurrency -> bcktThreadid b == initialThread
_ -> False
{-# INLINE isDependent #-}
isDependent b
| isBlock (bcktAction b) && isBarrier (simplifyLookahead n) = False
| otherwise = dependent' (bcktState b) (bcktThreadid b) (bcktAction b) u n
in backtrack bs idxs
incorporateBacktrackSteps :: HasCallStack
=> [BacktrackStep] -> DPOR -> DPOR
incorporateBacktrackSteps (b:bs) dpor = validateDPOR dpor' where
tid = bcktThreadid b
dpor' = dpor
{ dporTodo = dporTodo dpor `M.union` M.fromList todo
, dporNext = Just (tid, child)
}
todo =
[ x
| x@(t,c) <- M.toList $ bcktBacktracks b
, Just t /= (fst <$> dporNext dpor)
, S.notMember t (dporDone dpor)
, c || M.notMember t (dporSleep dpor)
]
child = case dporNext dpor of
Just (t, d)
| t /= tid -> fatal "incorporating wrong trace!"
| otherwise -> incorporateBacktrackSteps bs d
Nothing -> fatal "child is missing!"
incorporateBacktrackSteps [] dpor = dpor
data DPORSchedState k = DPORSchedState
{ schedSleep :: Map ThreadId ThreadAction
, schedPrefix :: [ThreadId]
, schedBPoints :: Seq ([(ThreadId, Lookahead)], [ThreadId])
, schedIgnore :: Bool
, schedBoundKill :: Bool
, schedDepState :: DepState
, schedBState :: Maybe k
} deriving (Eq, Show, Generic, NFData)
initialDPORSchedState :: Map ThreadId ThreadAction
-> [ThreadId]
-> DPORSchedState k
initialDPORSchedState sleep prefix = DPORSchedState
{ schedSleep = sleep
, schedPrefix = prefix
, schedBPoints = Sq.empty
, schedIgnore = False
, schedBoundKill = False
, schedDepState = initialDepState
, schedBState = Nothing
}
type IncrementalBoundFunc k
= Maybe k -> Maybe (ThreadId, ThreadAction) -> (Decision, Lookahead) -> Maybe k
type BacktrackFunc
= [BacktrackStep] -> [(Int, Bool, ThreadId)] -> [BacktrackStep]
backtrackAt :: HasCallStack
=> (ThreadId -> BacktrackStep -> Bool)
-> BacktrackFunc
backtrackAt toAll bs0 = backtrackAt' . nubBy ((==) `on` fst') . sortOn fst' where
fst' (x,_,_) = x
backtrackAt' ((i,c,t):is) = go i bs0 i c t is
backtrackAt' [] = bs0
go i0 (b:bs) 0 c tid is
| not (toAll tid b) && tid `M.member` bcktRunnable b =
let val = M.lookup tid $ bcktBacktracks b
b' = if isNothing val || (val == Just False && c)
then b { bcktBacktracks = backtrackTo tid c b }
else b
in b' : case is of
((i',c',t'):is') -> go i' bs (i'-i0-1) c' t' is'
[] -> bs
| otherwise =
let b' = b { bcktBacktracks = backtrackAll c b }
in b' : case is of
((i',c',t'):is') -> go i' bs (i'-i0-1) c' t' is'
[] -> bs
go i0 (b:bs) i c tid is = b : go i0 bs (i-1) c tid is
go _ [] _ _ _ _ = fatal "ran out of schedule whilst backtracking!"
backtrackTo tid c = M.insert tid c . bcktBacktracks
backtrackAll c = M.map (const c) . bcktRunnable
dporSched :: HasCallStack
=> MemType
-> IncrementalBoundFunc k
-> Scheduler (DPORSchedState k)
dporSched memtype boundf = Scheduler $ \prior threads s ->
let
nextState rest = s
{ schedBPoints = schedBPoints s |> (restrictToBound fst threads', rest)
, schedDepState = nextDepState
}
nextDepState = let ds = schedDepState s in maybe ds (uncurry $ updateDepState memtype ds) prior
initialise = tryDaemons . yieldsToEnd $ case prior of
Just (tid, act)
| not (didYield act) && tid `elem` tids && isInBound tid -> [tid]
_ -> tids
tryDaemons ts
| any doesKill ts = case partition doesKill tids of
(kills, nokills) -> nokills ++ kills
| otherwise = ts
doesKill t = killsDaemons t (action t)
restrictToBound f = filter (isInBound . f)
isInBound t = isJust $ boundf (schedBState s) prior (decision t, action t)
yieldsToEnd ts = case partition (willYield . action) ts of
(yields, noyields) -> noyields ++ yields
decision = decisionOf (fst <$> prior) (S.fromList tids)
action t = efromJust (lookup t threads')
tids = map fst threads'
threads' = toList threads
in case schedPrefix s of
(t:ts) ->
let bstate' = boundf (schedBState s) prior (decision t, action t)
in (Just t, (nextState []) { schedPrefix = ts, schedBState = bstate' })
[] ->
let choices = restrictToBound id initialise
checkDep t a = case prior of
Just (tid, act) -> dependent (schedDepState s) tid act t a
Nothing -> False
ssleep' = M.filterWithKey (\t a -> not $ checkDep t a) $ schedSleep s
choices' = filter (`notElem` M.keys ssleep') choices
signore' = not (null choices) && all (`elem` M.keys ssleep') choices
sbkill' = not (null initialise) && null choices
in case choices' of
(nextTid:rest) ->
let bstate' = boundf (schedBState s) prior (decision nextTid, action nextTid)
in (Just nextTid, (nextState rest) { schedSleep = ssleep', schedBState = bstate' })
[] ->
(Nothing, (nextState []) { schedIgnore = signore', schedBoundKill = sbkill', schedBState = Nothing })
independent :: DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool
independent ds t1 a1 t2 a2
| t1 == t2 = False
| check t1 a1 t2 a2 = False
| check t2 a2 t1 a1 = False
| otherwise = not (dependent ds t1 a1 t2 a2)
where
check _ (DontCheck _) _ _ = True
check _ (Fork t) tid _ | t == tid = True
check _ (ForkOS t) tid _ | t == tid = True
check _ (ThrowTo t _) tid _ | t == tid = True
check _ (BlockedThrowTo t) tid _ | t == tid = True
check _ (simplifyAction -> UnsynchronisedWrite r) _ (simplifyAction -> a) | synchronises a r = True
check _ _ _ _ = False
dependent :: DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool
dependent ds t1 a1 t2 a2 = case (a1, a2) of
(ThrowTo t _, ThrowTo u _)
| t == t2 && u == t1 -> canInterrupt ds t1 a1 || canInterrupt ds t2 a2
(ThrowTo t _, _) | t == t2 -> canInterrupt ds t2 a2 && a2 /= Stop
(_, ThrowTo t _) | t == t1 -> canInterrupt ds t1 a1 && a1 /= Stop
(STM _ _, STM _ _) -> checkSTM
(STM _ _, BlockedSTM _) -> checkSTM
(BlockedSTM _, STM _ _) -> checkSTM
(BlockedSTM _, BlockedSTM _) -> checkSTM
_ -> dependent' ds t1 a1 t2 (rewind a2)
&& dependent' ds t2 a2 t1 (rewind a1)
where
checkSTM = checkSTM' a1 a2 || checkSTM' a2 a1
checkSTM' a b = not . S.null $ tvarsWritten a `S.intersection` tvarsOf b
dependent' :: DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool
dependent' ds t1 a1 t2 l2 = case (a1, l2) of
(LiftIO, WillLiftIO) -> True
(ThrowTo t _, WillThrowTo u)
| t == t2 && u == t1 -> canInterrupt ds t1 a1 || canInterruptL ds t2 l2
(ThrowTo t _, _) | t == t2 -> canInterruptL ds t2 l2 && l2 /= WillStop
(_, WillThrowTo t) | t == t1 -> canInterrupt ds t1 a1 && a1 /= Stop
(STM _ _, WillSTM) -> True
(BlockedSTM _, WillSTM) -> True
(GetNumCapabilities _, WillSetNumCapabilities _) -> True
(SetNumCapabilities _, WillGetNumCapabilities) -> True
(SetNumCapabilities _, WillSetNumCapabilities _) -> True
_ -> dependentActions ds (simplifyAction a1) (simplifyLookahead l2)
dependentActions :: DepState -> ActionType -> ActionType -> Bool
dependentActions ds a1 a2 = case (a1, a2) of
(UnsynchronisedRead _, UnsynchronisedRead _) -> False
(UnsynchronisedWrite r1, PartiallySynchronisedCommit r2) | r1 == r2 && isBuffered ds r1 -> False
(PartiallySynchronisedCommit r1, UnsynchronisedWrite r2) | r1 == r2 && isBuffered ds r1 -> False
(UnsynchronisedRead r1, _) | isBarrier a2 && isBuffered ds r1 -> True
(_, UnsynchronisedRead r2) | isBarrier a1 && isBuffered ds r2 -> True
(PartiallySynchronisedCommit r1, _) | synchronises a2 r1 -> True
(_, PartiallySynchronisedCommit r2) | synchronises a1 r2 -> True
(SynchronisedWrite v1, SynchronisedWrite v2) | v1 == v2 -> not (isFull ds v1)
(SynchronisedRead v1, SynchronisedRead v2) | v1 == v2 -> isFull ds v1
(SynchronisedWrite v1, SynchronisedRead v2) | v1 == v2 -> True
(SynchronisedRead v1, SynchronisedWrite v2) | v1 == v2 -> True
(_, _) -> maybe False (\r -> Just r == crefOf a2) (crefOf a1)
data DepState = DepState
{ depCRState :: Map CRefId Bool
, depMVState :: Set MVarId
, depMaskState :: Map ThreadId MaskingState
} deriving (Eq, Show)
instance NFData DepState where
rnf depstate = rnf ( depCRState depstate
, depMVState depstate
, [(t, m `seq` ()) | (t, m) <- M.toList (depMaskState depstate)]
)
initialDepState :: DepState
initialDepState = DepState M.empty S.empty M.empty
updateDepState :: MemType -> DepState -> ThreadId -> ThreadAction -> DepState
updateDepState memtype depstate tid act = DepState
{ depCRState = updateCRState memtype act $ depCRState depstate
, depMVState = updateMVState act $ depMVState depstate
, depMaskState = updateMaskState tid act $ depMaskState depstate
}
updateCRState :: MemType -> ThreadAction -> Map CRefId Bool -> Map CRefId Bool
updateCRState SequentialConsistency _ = const M.empty
updateCRState _ (CommitCRef _ r) = M.delete r
updateCRState _ (WriteCRef r) = M.insert r True
updateCRState _ ta
| isBarrier $ simplifyAction ta = const M.empty
| otherwise = id
updateMVState :: ThreadAction -> Set MVarId -> Set MVarId
updateMVState (PutMVar mvid _) = S.insert mvid
updateMVState (TryPutMVar mvid True _) = S.insert mvid
updateMVState (TakeMVar mvid _) = S.delete mvid
updateMVState (TryTakeMVar mvid True _) = S.delete mvid
updateMVState _ = id
updateMaskState :: ThreadId -> ThreadAction -> Map ThreadId MaskingState -> Map ThreadId MaskingState
updateMaskState tid (Fork tid2) = \masks -> case M.lookup tid masks of
Just ms -> M.insert tid2 ms masks
Nothing -> masks
updateMaskState tid (SetMasking _ ms) = M.insert tid ms
updateMaskState tid (ResetMasking _ ms) = M.insert tid ms
updateMaskState tid (Throw True) = M.delete tid
updateMaskState _ (ThrowTo tid True) = M.delete tid
updateMaskState tid Stop = M.delete tid
updateMaskState _ _ = id
isBuffered :: DepState -> CRefId -> Bool
isBuffered depstate r = M.findWithDefault False r (depCRState depstate)
isFull :: DepState -> MVarId -> Bool
isFull depstate v = S.member v (depMVState depstate)
canInterrupt :: DepState -> ThreadId -> ThreadAction -> Bool
canInterrupt depstate tid act
| isMaskedInterruptible depstate tid = case act of
BlockedPutMVar _ -> True
BlockedReadMVar _ -> True
BlockedTakeMVar _ -> True
BlockedSTM _ -> True
BlockedThrowTo _ -> True
_ -> False
| isMaskedUninterruptible depstate tid = False
| otherwise = True
canInterruptL :: DepState -> ThreadId -> Lookahead -> Bool
canInterruptL depstate tid lh
| isMaskedInterruptible depstate tid = case lh of
WillPutMVar _ -> True
WillReadMVar _ -> True
WillTakeMVar _ -> True
WillSTM -> True
WillThrowTo _ -> True
_ -> False
| isMaskedUninterruptible depstate tid = False
| otherwise = True
isMaskedInterruptible :: DepState -> ThreadId -> Bool
isMaskedInterruptible depstate tid =
M.lookup tid (depMaskState depstate) == Just MaskedInterruptible
isMaskedUninterruptible :: DepState -> ThreadId -> Bool
isMaskedUninterruptible depstate tid =
M.lookup tid (depMaskState depstate) == Just MaskedUninterruptible
initialDPORThread :: DPOR -> ThreadId
initialDPORThread = S.elemAt 0 . dporRunnable
didYield :: ThreadAction -> Bool
didYield Yield = True
didYield (ThreadDelay _) = True
didYield _ = False
willYield :: Lookahead -> Bool
willYield WillYield = True
willYield (WillThreadDelay _) = True
willYield _ = False
killsDaemons :: ThreadId -> Lookahead -> Bool
killsDaemons t WillStop = t == initialThread
killsDaemons _ _ = False