{-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} -- | -- Module : Test.DejaFu.SCT.Internal -- Copyright : (c) 2018--2019 Michael Walker -- License : MIT -- Maintainer : Michael Walker -- Stability : experimental -- Portability : BangPatterns, FlexibleContexts, LambdaCase, RankNTypes -- -- Internal types and functions for SCT. This module is NOT -- considered to form part of the public interface of this library. module Test.DejaFu.SCT.Internal where import Data.Coerce (Coercible, coerce) import qualified Data.IntMap.Strict as I import Data.List (find, mapAccumL) import Data.Maybe (fromMaybe) import GHC.Stack (HasCallStack) import Test.DejaFu.Conc import Test.DejaFu.Conc.Internal (Context(..)) import Test.DejaFu.Conc.Internal.Memory (commitThreadId) import Test.DejaFu.Conc.Internal.Program import Test.DejaFu.Internal import Test.DejaFu.Schedule (Scheduler(..)) import Test.DejaFu.SCT.Internal.DPOR import Test.DejaFu.Types import Test.DejaFu.Utils ------------------------------------------------------------------------------- -- * Exploration -- | General-purpose SCT function. sct :: (MonadDejaFu n, HasCallStack) => Settings n a -- ^ The SCT settings ('Way' is ignored) -> ([ThreadId] -> s) -- ^ Initial state -> (s -> Maybe t) -- ^ State predicate -> (ConcurrencyState -> (Scheduler g -> g -> n (Either Condition a, g, Trace)) -> s -> t -> n (s, Maybe (Either Condition a, Trace))) -- ^ Run the computation and update the state -> Program pty n a -> n [(Either Condition a, Trace)] sct settings s0 sfun srun conc = recordSnapshot conc >>= \case Just (Right snap, _) -> sct'Snap snap Just (Left f, trace) -> pure [(Left f, trace)] Nothing -> sct'Full where sct'Full = sct' settings initialCState (s0 [initialThread]) sfun (srun initialCState runFull) runFull (toId 1) (toId 1) sct'Snap snap = let idsrc = cIdSource (contextFromSnapshot snap) cstate = cCState (contextFromSnapshot snap) in sct' settings cstate (s0 (fst (threadsFromSnapshot snap))) sfun (srun cstate (runSnap snap)) (runSnap snap) (toId $ 1 + fst (_tids idsrc)) (toId $ 1 + fst (_iorids idsrc)) runFull sched s = runConcurrent sched (_memtype settings) s conc runSnap snap sched s = runSnapshot sched (_memtype settings) s snap -- | Like 'sct' but given a function to run the computation. sct' :: (MonadDejaFu n, HasCallStack) => Settings n a -- ^ The SCT settings ('Way' is ignored) -> ConcurrencyState -- ^ The initial concurrency state -> s -- ^ Initial state -> (s -> Maybe t) -- ^ State predicate -> (s -> t -> n (s, Maybe (Either Condition a, Trace))) -- ^ Run the computation and update the state -> (forall x. Scheduler x -> x -> n (Either Condition a, x, Trace)) -- ^ Just run the computation -> ThreadId -- ^ The first available @ThreadId@ -> IORefId -- ^ The first available @IORefId@ -> n [(Either Condition a, Trace)] sct' settings cstate0 s0 sfun srun run nTId nCRId = go Nothing [] s0 where go (Just res) _ _ | earlyExit res = pure [] go res0 seen !s = case sfun s of Just t -> srun s t >>= \case (s', Just (Left Abort, _)) | hideAborts -> go res0 seen s' (s', Just (res, trace)) -> case discard res of Just DiscardResultAndTrace -> go (Just res) seen s' Just DiscardTrace -> result res [] seen s' Nothing -> result res trace seen s' (s', Nothing) -> go Nothing seen s' Nothing -> pure [] -- Sadly, we have to use a list to store the set of unique results, -- as we don't have an @Ord a@ dict hanging around. I suspect that -- most test cases will have a relatively small number of unique -- results, compared to the number of executions, however. -- Pathological cases (like IORef ones in dejafu-tests which produce -- a different result on every execution) are probably uncommon. result = case _equality settings of Just f -> \res trace seen s -> let eq cmp (Right a1) (Right a2) = cmp a1 a2 eq _ (Left e1) (Left e2) = e1 == e2 eq _ _ _ = False in if any (eq f res) seen then go (Just res) seen s else dosimplify res trace (res:seen) s Nothing -> dosimplify dosimplify res [] seen s = ((res, []) :) <$> go (Just res) seen s dosimplify res trace seen s | not (_simplify settings) = ((res, trace) :) <$> go (Just res) seen s | otherwise = do shrunk <- simplifyExecution settings cstate0 run nTId nCRId res trace (shrunk :) <$> go (Just res) seen s earlyExit = fromMaybe (const False) (_earlyExit settings) discard = fromMaybe (const Nothing) (_discard settings) hideAborts = not (_showAborts settings) -- | Given a result and a trace, produce a more minimal trace. -- -- In principle, simplification is semantics preserving and can be -- done without needing to execute the computation again. However, -- there are two good reasons to do so: -- -- * It's a sanity check that there are no bugs. -- * It's easier to generate a reduced sequence of scheduling -- decisions and let dejafu generate the full trace, than to -- generate a reduced trace directly -- -- Unlike shrinking in randomised property-testing tools like -- QuickCheck or Hedgehog, we only run the test case /once/, at the -- end, rather than after every simplification step. simplifyExecution :: (MonadDejaFu n, HasCallStack) => Settings n a -- ^ The SCT settings ('Way' is ignored) -> ConcurrencyState -- ^ The initial concurrency state -> (forall x. Scheduler x -> x -> n (Either Condition a, x, Trace)) -- ^ Just run the computation -> ThreadId -- ^ The first available @ThreadId@ -> IORefId -- ^ The first available @IORefId@ -> Either Condition a -- ^ The expected result -> Trace -> n (Either Condition a, Trace) simplifyExecution settings cstate0 run nTId nCRId res trace | tidTrace == simplifiedTrace = do debugPrint ("Simplifying new result '" ++ p res ++ "': no simplification possible!") pure (res, trace) | otherwise = do debugPrint ("Simplifying new result '" ++ p res ++ "': OK!") (res', _, trace') <- replay run (fixup simplifiedTrace) case (_equality settings, res, res') of (Just f, Right a1, Right a2) | f a1 a2 -> pure (res', trace') (_, Left e1, Left e2) | e1 == e2 -> pure (res', trace') (Nothing, Right _, Right _) -> pure (res', trace') -- this is a risky case! _ -> do debugFatal ("Got a different result after simplifying: '" ++ p res ++ "' /= '" ++ p res' ++ "'") pure (res, trace) where tidTrace = toTIdTrace trace simplifiedTrace = simplify (_safeIO settings) (_memtype settings) cstate0 tidTrace fixup = renumber (_memtype settings) (fromId nTId) (fromId nCRId) debugFatal = if _debugFatal settings then fatal else debugPrint debugPrint = fromMaybe (const (pure ())) (_debugPrint settings) debugShow = fromMaybe (const "_") (_debugShow settings) p = either show debugShow -- | Replay an execution. replay :: MonadDejaFu n => (forall x. Scheduler x -> x -> n (Either Condition a, x, Trace)) -- ^ Run the computation -> [(ThreadId, ThreadAction)] -- ^ The reduced sequence of scheduling decisions -> n (Either Condition a, [(ThreadId, ThreadAction)], Trace) replay run = run (Scheduler (const sched)) where sched runnable cs ((t, Stop):ts) = case findThread t runnable of Just t' -> (Just t', ts) Nothing -> sched runnable cs ts sched runnable _ ((t, _):ts) = (findThread t runnable, ts) sched _ _ _ = (Nothing, []) -- find a thread ignoring names findThread tid0 = fmap fst . find (\(tid,_) -> fromId tid == fromId tid0) ------------------------------------------------------------------------------- -- * Schedule simplification -- | Simplify a trace by permuting adjacent independent actions to -- reduce context switching. simplify :: Bool -> MemType -> ConcurrencyState -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)] simplify safeIO memtype cstate0 trc0 = loop (length trc0) (prepare trc0) where prepare = dropCommits safeIO memtype cstate0 . lexicoNormalForm safeIO memtype cstate0 step = pushForward safeIO memtype cstate0 . pullBack safeIO memtype cstate0 loop 0 trc = trc loop n trc = let trc' = step trc in if trc' /= trc then loop (n-1) trc' else trc -- | Put a trace into lexicographic (by thread ID) normal form. lexicoNormalForm :: Bool -> MemType -> ConcurrencyState -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)] lexicoNormalForm safeIO memtype cstate0 = go where go trc = let trc' = permuteBy safeIO memtype cstate0 (repeat (>)) trc in if trc == trc' then trc else go trc' -- | Swap adjacent independent actions in the trace if a predicate -- holds. permuteBy :: Bool -> MemType -> ConcurrencyState -> [ThreadId -> ThreadId -> Bool] -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)] permuteBy safeIO memtype = go where go ds (p:ps) (t1@(tid1, ta1):t2@(tid2, ta2):trc) | independent safeIO ds tid1 ta1 tid2 ta2 && p tid1 tid2 = go' ds ps t2 (t1 : trc) | otherwise = go' ds ps t1 (t2 : trc) go _ _ trc = trc go' ds ps t@(tid, ta) trc = t : go (updateCState memtype ds tid ta) ps trc -- | Throw away commit actions which are followed by a memory barrier. dropCommits :: Bool -> MemType -> ConcurrencyState -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)] dropCommits _ SequentialConsistency = const id dropCommits safeIO memtype = go where go ds (t1@(tid1, ta1@(CommitIORef _ iorefid)):t2@(tid2, ta2):trc) | isBarrier (simplifyAction ta2) && numBuffered ds iorefid == 1 = go ds (t2:trc) | independent safeIO ds tid1 ta1 tid2 ta2 = t2 : go (updateCState memtype ds tid2 ta2) (t1:trc) go ds (t@(tid,ta):trc) = t : go (updateCState memtype ds tid ta) trc go _ [] = [] -- | Attempt to reduce context switches by \"pulling\" thread actions -- back to a prior execution of that thread. -- -- Simple example, say we have @[(tidA, act1), (tidB, act2), (tidA, -- act3)]@, where @act2@ and @act3@ are independent. In this case -- 'pullBack' will swap them, giving the sequence @[(tidA, act1), -- (tidA, act3), (tidB, act2)]@. It works for arbitrary separations. pullBack :: Bool -> MemType -> ConcurrencyState -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)] pullBack safeIO memtype = go where go ds (t1@(tid1, ta1):trc@((tid2, _):_)) = let ds' = updateCState memtype ds tid1 ta1 trc' = if tid1 /= tid2 then maybe trc (uncurry (:)) (findAction tid1 ds' trc) else trc in t1 : go ds' trc' go _ trc = trc findAction tid0 = fgo where fgo ds (t@(tid, ta):trc) | tid == tid0 = Just (t, trc) | otherwise = case fgo (updateCState memtype ds tid ta) trc of Just (ft@(ftid, fa), trc') | independent safeIO ds tid ta ftid fa -> Just (ft, t:trc') _ -> Nothing fgo _ _ = Nothing -- | Attempt to reduce context switches by \"pushing\" thread actions -- forward to a future execution of that thread. -- -- This is kind of the opposite of 'pullBack', but there are cases -- where one applies but not the other. -- -- Simple example, say we have @[(tidA, act1), (tidB, act2), (tidA, -- act3)]@, where @act1@ and @act2@ are independent. In this case -- 'pushForward' will swap them, giving the sequence @[(tidB, act2), -- (tidA, act1), (tidA, act3)]@. It works for arbitrary separations. pushForward :: Bool -> MemType -> ConcurrencyState -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)] pushForward safeIO memtype = go where go ds (t1@(tid1, ta1):trc@((tid2, _):_)) = let ds' = updateCState memtype ds tid1 ta1 in if tid1 /= tid2 then maybe (t1 : go ds' trc) (go ds) (findAction tid1 ta1 ds trc) else t1 : go ds' trc go _ trc = trc findAction tid0 ta0 = fgo where fgo ds (t@(tid, ta):trc) | tid == tid0 = Just ((tid0, ta0) : t : trc) | independent safeIO ds tid0 ta0 tid ta = (t:) <$> fgo (updateCState memtype ds tid ta) trc | otherwise = Nothing fgo _ _ = Nothing -- | Re-number threads and IORefs. -- -- Permuting forks or newIORefs makes the existing numbering invalid, -- which then causes problems for scheduling. Just re-numbering -- threads isn't enough, as IORef IDs are used to determine commit -- thread IDs. -- -- Renumbered things will not fix their names, so don't rely on those -- at all. renumber :: MemType -- ^ The memory model determines how commit threads are numbered. -> Int -- ^ First free thread ID. -> Int -- ^ First free @IORef@ ID. -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)] renumber memtype tid0 crid0 = snd . mapAccumL go (I.empty, tid0, I.empty, crid0) where go s@(tidmap, _, cridmap, _) (_, CommitIORef tid crid) = let tid' = renumbered tidmap tid crid' = renumbered cridmap crid act' = CommitIORef tid' crid' in case memtype of PartialStoreOrder -> (s, (commitThreadId tid' (Just crid'), act')) _ -> (s, (commitThreadId tid' Nothing, act')) go s@(tidmap, _, _, _) (tid, act) = let (s', act') = updateAction s act in (s', (renumbered tidmap tid, act')) -- I can't help but feel there should be some generic programming -- solution to this sort of thing (and to the many other functions -- operating over @ThreadAction@s / @Lookahead@s) updateAction (tidmap, nTId, cridmap, nCRId) (Fork old) = let tidmap' = I.insert (fromId old) nTId tidmap nTId' = nTId + 1 in ((tidmap', nTId', cridmap, nCRId), Fork (toId nTId)) updateAction (tidmap, nTId, cridmap, nCRId) (ForkOS old) = let tidmap' = I.insert (fromId old) nTId tidmap nTId' = nTId + 1 in ((tidmap', nTId', cridmap, nCRId), ForkOS (toId nTId)) updateAction s@(tidmap, _, _, _) (PutMVar mvid olds) = (s, PutMVar mvid (map (renumbered tidmap) olds)) updateAction s@(tidmap, _, _, _) (TryPutMVar mvid b olds) = (s, TryPutMVar mvid b (map (renumbered tidmap) olds)) updateAction s@(tidmap, _, _, _) (TakeMVar mvid olds) = (s, TakeMVar mvid (map (renumbered tidmap) olds)) updateAction s@(tidmap, _, _, _) (TryTakeMVar mvid b olds) = (s, TryTakeMVar mvid b (map (renumbered tidmap) olds)) updateAction (tidmap, nTId, cridmap, nCRId) (NewIORef old) = let cridmap' = I.insert (fromId old) nCRId cridmap nCRId' = nCRId + 1 in ((tidmap, nTId, cridmap', nCRId'), NewIORef (toId nCRId)) updateAction s@(_, _, cridmap, _) (ReadIORef old) = (s, ReadIORef (renumbered cridmap old)) updateAction s@(_, _, cridmap, _) (ReadIORefCas old) = (s, ReadIORefCas (renumbered cridmap old)) updateAction s@(_, _, cridmap, _) (ModIORef old) = (s, ModIORef (renumbered cridmap old)) updateAction s@(_, _, cridmap, _) (ModIORefCas old) = (s, ModIORefCas (renumbered cridmap old)) updateAction s@(_, _, cridmap, _) (WriteIORef old) = (s, WriteIORef (renumbered cridmap old)) updateAction s@(_, _, cridmap, _) (CasIORef old b) = (s, CasIORef (renumbered cridmap old) b) updateAction s@(tidmap, _, _, _) (STM tas olds) = (s, STM tas (map (renumbered tidmap) olds)) updateAction s@(tidmap, _, _, _) (ThrowTo old b) = (s, ThrowTo (renumbered tidmap old) b) updateAction s@(tidmap, _, _, _) (BlockedThrowTo old) = (s, BlockedThrowTo (renumbered tidmap old)) updateAction s act = (s, act) renumbered :: (Coercible a Id, Coercible Id a) => I.IntMap Int -> a -> a renumbered idmap id_ = toId $ I.findWithDefault (fromId id_) (fromId id_) idmap ------------------------------------------------------------------------------- -- * Utilities -- | Helper function for constructing IDs of any sort. toId :: Coercible Id a => Int -> a toId = coerce . Id Nothing -- | Helper function for deconstructing IDs of any sort. fromId :: Coercible a Id => a -> Int fromId a = let (Id _ id_) = coerce a in id_