{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
module Test.DejaFu.SCT.Internal where
import Data.Coerce (Coercible, coerce)
import qualified Data.IntMap.Strict as I
import Data.List (find, mapAccumL)
import Data.Maybe (fromMaybe)
import GHC.Stack (HasCallStack)
import Test.DejaFu.Conc
import Test.DejaFu.Conc.Internal (Context(..))
import Test.DejaFu.Conc.Internal.Memory (commitThreadId)
import Test.DejaFu.Conc.Internal.Program
import Test.DejaFu.Internal
import Test.DejaFu.Schedule (Scheduler(..))
import Test.DejaFu.SCT.Internal.DPOR
import Test.DejaFu.Types
import Test.DejaFu.Utils
sct :: (MonadDejaFu n, HasCallStack)
=> Settings n a
-> ([ThreadId] -> s)
-> (s -> Maybe t)
-> (ConcurrencyState -> (Scheduler g -> g -> n (Either Condition a, g, Trace)) -> s -> t -> n (s, Maybe (Either Condition a, Trace)))
-> Program pty n a
-> n [(Either Condition a, Trace)]
sct settings s0 sfun srun conc = recordSnapshot conc >>= \case
Just (Right snap, _) -> sct'Snap snap
Just (Left f, trace) -> pure [(Left f, trace)]
Nothing -> sct'Full
where
sct'Full = sct'
settings
initialCState
(s0 [initialThread])
sfun
(srun initialCState runFull)
runFull
(toId 1)
(toId 1)
sct'Snap snap =
let idsrc = cIdSource (contextFromSnapshot snap)
cstate = cCState (contextFromSnapshot snap)
in sct'
settings
cstate
(s0 (fst (threadsFromSnapshot snap)))
sfun
(srun cstate (runSnap snap))
(runSnap snap)
(toId $ 1 + fst (_tids idsrc))
(toId $ 1 + fst (_iorids idsrc))
runFull sched s = runConcurrent sched (_memtype settings) s conc
runSnap snap sched s = runSnapshot sched (_memtype settings) s snap
sct' :: (MonadDejaFu n, HasCallStack)
=> Settings n a
-> ConcurrencyState
-> s
-> (s -> Maybe t)
-> (s -> t -> n (s, Maybe (Either Condition a, Trace)))
-> (forall x. Scheduler x -> x -> n (Either Condition a, x, Trace))
-> ThreadId
-> IORefId
-> n [(Either Condition a, Trace)]
sct' settings cstate0 s0 sfun srun run nTId nCRId = go Nothing [] s0 where
go (Just res) _ _ | earlyExit res = pure []
go res0 seen !s = case sfun s of
Just t -> srun s t >>= \case
(s', Just (Left Abort, _)) | hideAborts -> go res0 seen s'
(s', Just (res, trace)) -> case discard res of
Just DiscardResultAndTrace -> go (Just res) seen s'
Just DiscardTrace -> result res [] seen s'
Nothing -> result res trace seen s'
(s', Nothing) -> go Nothing seen s'
Nothing -> pure []
result = case _equality settings of
Just f -> \res trace seen s ->
let eq cmp (Right a1) (Right a2) = cmp a1 a2
eq _ (Left e1) (Left e2) = e1 == e2
eq _ _ _ = False
in if any (eq f res) seen
then go (Just res) seen s
else dosimplify res trace (res:seen) s
Nothing -> dosimplify
dosimplify res [] seen s = ((res, []) :) <$> go (Just res) seen s
dosimplify res trace seen s
| not (_simplify settings) = ((res, trace) :) <$> go (Just res) seen s
| otherwise = do
shrunk <- simplifyExecution settings cstate0 run nTId nCRId res trace
(shrunk :) <$> go (Just res) seen s
earlyExit = fromMaybe (const False) (_earlyExit settings)
discard = fromMaybe (const Nothing) (_discard settings)
hideAborts = not (_showAborts settings)
simplifyExecution :: (MonadDejaFu n, HasCallStack)
=> Settings n a
-> ConcurrencyState
-> (forall x. Scheduler x -> x -> n (Either Condition a, x, Trace))
-> ThreadId
-> IORefId
-> Either Condition a
-> Trace
-> n (Either Condition a, Trace)
simplifyExecution settings cstate0 run nTId nCRId res trace
| tidTrace == simplifiedTrace = do
debugPrint ("Simplifying new result '" ++ p res ++ "': no simplification possible!")
pure (res, trace)
| otherwise = do
debugPrint ("Simplifying new result '" ++ p res ++ "': OK!")
(res', _, trace') <- replay run (fixup simplifiedTrace)
case (_equality settings, res, res') of
(Just f, Right a1, Right a2) | f a1 a2 -> pure (res', trace')
(_, Left e1, Left e2) | e1 == e2 -> pure (res', trace')
(Nothing, Right _, Right _) -> pure (res', trace')
_ -> do
debugFatal ("Got a different result after simplifying: '" ++ p res ++ "' /= '" ++ p res' ++ "'")
pure (res, trace)
where
tidTrace = toTIdTrace trace
simplifiedTrace = simplify (_safeIO settings) (_memtype settings) cstate0 tidTrace
fixup = renumber (_memtype settings) (fromId nTId) (fromId nCRId)
debugFatal = if _debugFatal settings then fatal else debugPrint
debugPrint = fromMaybe (const (pure ())) (_debugPrint settings)
debugShow = fromMaybe (const "_") (_debugShow settings)
p = either show debugShow
replay :: MonadDejaFu n
=> (forall x. Scheduler x -> x -> n (Either Condition a, x, Trace))
-> [(ThreadId, ThreadAction)]
-> n (Either Condition a, [(ThreadId, ThreadAction)], Trace)
replay run = run (Scheduler (const sched)) where
sched runnable cs ((t, Stop):ts) = case findThread t runnable of
Just t' -> (Just t', ts)
Nothing -> sched runnable cs ts
sched runnable _ ((t, _):ts) = (findThread t runnable, ts)
sched _ _ _ = (Nothing, [])
findThread tid0 =
fmap fst . find (\(tid,_) -> fromId tid == fromId tid0)
simplify
:: Bool
-> MemType
-> ConcurrencyState
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
simplify safeIO memtype cstate0 trc0 = loop (length trc0) (prepare trc0) where
prepare = dropCommits safeIO memtype cstate0 . lexicoNormalForm safeIO memtype cstate0
step = pushForward safeIO memtype cstate0 . pullBack safeIO memtype cstate0
loop 0 trc = trc
loop n trc =
let trc' = step trc
in if trc' /= trc then loop (n-1) trc' else trc
lexicoNormalForm
:: Bool
-> MemType
-> ConcurrencyState
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
lexicoNormalForm safeIO memtype cstate0 = go where
go trc =
let trc' = permuteBy safeIO memtype cstate0 (repeat (>)) trc
in if trc == trc' then trc else go trc'
permuteBy
:: Bool
-> MemType
-> ConcurrencyState
-> [ThreadId -> ThreadId -> Bool]
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
permuteBy safeIO memtype = go where
go ds (p:ps) (t1@(tid1, ta1):t2@(tid2, ta2):trc)
| independent safeIO ds tid1 ta1 tid2 ta2 && p tid1 tid2 = go' ds ps t2 (t1 : trc)
| otherwise = go' ds ps t1 (t2 : trc)
go _ _ trc = trc
go' ds ps t@(tid, ta) trc = t : go (updateCState memtype ds tid ta) ps trc
dropCommits
:: Bool
-> MemType
-> ConcurrencyState
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
dropCommits _ SequentialConsistency = const id
dropCommits safeIO memtype = go where
go ds (t1@(tid1, ta1@(CommitIORef _ iorefid)):t2@(tid2, ta2):trc)
| isBarrier (simplifyAction ta2) && numBuffered ds iorefid == 1 = go ds (t2:trc)
| independent safeIO ds tid1 ta1 tid2 ta2 = t2 : go (updateCState memtype ds tid2 ta2) (t1:trc)
go ds (t@(tid,ta):trc) = t : go (updateCState memtype ds tid ta) trc
go _ [] = []
pullBack
:: Bool
-> MemType
-> ConcurrencyState
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
pullBack safeIO memtype = go where
go ds (t1@(tid1, ta1):trc@((tid2, _):_)) =
let ds' = updateCState memtype ds tid1 ta1
trc' = if tid1 /= tid2
then maybe trc (uncurry (:)) (findAction tid1 ds' trc)
else trc
in t1 : go ds' trc'
go _ trc = trc
findAction tid0 = fgo where
fgo ds (t@(tid, ta):trc)
| tid == tid0 = Just (t, trc)
| otherwise = case fgo (updateCState memtype ds tid ta) trc of
Just (ft@(ftid, fa), trc')
| independent safeIO ds tid ta ftid fa -> Just (ft, t:trc')
_ -> Nothing
fgo _ _ = Nothing
pushForward
:: Bool
-> MemType
-> ConcurrencyState
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
pushForward safeIO memtype = go where
go ds (t1@(tid1, ta1):trc@((tid2, _):_)) =
let ds' = updateCState memtype ds tid1 ta1
in if tid1 /= tid2
then maybe (t1 : go ds' trc) (go ds) (findAction tid1 ta1 ds trc)
else t1 : go ds' trc
go _ trc = trc
findAction tid0 ta0 = fgo where
fgo ds (t@(tid, ta):trc)
| tid == tid0 = Just ((tid0, ta0) : t : trc)
| independent safeIO ds tid0 ta0 tid ta = (t:) <$> fgo (updateCState memtype ds tid ta) trc
| otherwise = Nothing
fgo _ _ = Nothing
renumber
:: MemType
-> Int
-> Int
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
renumber memtype tid0 crid0 = snd . mapAccumL go (I.empty, tid0, I.empty, crid0) where
go s@(tidmap, _, cridmap, _) (_, CommitIORef tid crid) =
let tid' = renumbered tidmap tid
crid' = renumbered cridmap crid
act' = CommitIORef tid' crid'
in case memtype of
PartialStoreOrder -> (s, (commitThreadId tid' (Just crid'), act'))
_ -> (s, (commitThreadId tid' Nothing, act'))
go s@(tidmap, _, _, _) (tid, act) =
let (s', act') = updateAction s act
in (s', (renumbered tidmap tid, act'))
updateAction (tidmap, nTId, cridmap, nCRId) (Fork old) =
let tidmap' = I.insert (fromId old) nTId tidmap
nTId' = nTId + 1
in ((tidmap', nTId', cridmap, nCRId), Fork (toId nTId))
updateAction (tidmap, nTId, cridmap, nCRId) (ForkOS old) =
let tidmap' = I.insert (fromId old) nTId tidmap
nTId' = nTId + 1
in ((tidmap', nTId', cridmap, nCRId), ForkOS (toId nTId))
updateAction s@(tidmap, _, _, _) (PutMVar mvid olds) =
(s, PutMVar mvid (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (TryPutMVar mvid b olds) =
(s, TryPutMVar mvid b (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (TakeMVar mvid olds) =
(s, TakeMVar mvid (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (TryTakeMVar mvid b olds) =
(s, TryTakeMVar mvid b (map (renumbered tidmap) olds))
updateAction (tidmap, nTId, cridmap, nCRId) (NewIORef old) =
let cridmap' = I.insert (fromId old) nCRId cridmap
nCRId' = nCRId + 1
in ((tidmap, nTId, cridmap', nCRId'), NewIORef (toId nCRId))
updateAction s@(_, _, cridmap, _) (ReadIORef old) =
(s, ReadIORef (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (ReadIORefCas old) =
(s, ReadIORefCas (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (ModIORef old) =
(s, ModIORef (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (ModIORefCas old) =
(s, ModIORefCas (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (WriteIORef old) =
(s, WriteIORef (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (CasIORef old b) =
(s, CasIORef (renumbered cridmap old) b)
updateAction s@(tidmap, _, _, _) (STM tas olds) =
(s, STM tas (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (ThrowTo old b) =
(s, ThrowTo (renumbered tidmap old) b)
updateAction s@(tidmap, _, _, _) (BlockedThrowTo old) =
(s, BlockedThrowTo (renumbered tidmap old))
updateAction s act = (s, act)
renumbered :: (Coercible a Id, Coercible Id a) => I.IntMap Int -> a -> a
renumbered idmap id_ = toId $ I.findWithDefault (fromId id_) (fromId id_) idmap
toId :: Coercible Id a => Int -> a
toId = coerce . Id Nothing
fromId :: Coercible a Id => a -> Int
fromId a = let (Id _ id_) = coerce a in id_