{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
module Test.DejaFu.SCT.Internal where
import Control.Monad.Conc.Class (MonadConc)
import Data.Coerce (Coercible, coerce)
import qualified Data.IntMap.Strict as I
import Data.List (find, mapAccumL)
import Data.Maybe (fromMaybe)
import GHC.Stack (HasCallStack)
import Test.DejaFu.Conc
import Test.DejaFu.Conc.Internal (Context(..), DCSnapshot(..))
import Test.DejaFu.Conc.Internal.Memory (commitThreadId)
import Test.DejaFu.Internal
import Test.DejaFu.Schedule (Scheduler(..))
import Test.DejaFu.SCT.Internal.DPOR
import Test.DejaFu.Types
import Test.DejaFu.Utils
sct :: (MonadConc n, HasCallStack)
=> Settings n a
-> ([ThreadId] -> s)
-> (s -> Maybe t)
-> ((Scheduler g -> g -> n (Either Failure a, g, Trace)) -> s -> t -> n (s, Maybe (Either Failure a, Trace)))
-> ConcT n a
-> n [(Either Failure a, Trace)]
sct settings s0 sfun srun conc
| canDCSnapshot conc = runForDCSnapshot conc >>= \case
Just (Right snap, _) -> sct'Snap snap
Just (Left f, trace) -> pure [(Left f, trace)]
_ -> do
debugFatal "Failed to construct snapshot, continuing without."
sct'Full
| otherwise = sct'Full
where
sct'Full = sct'
settings
(s0 [initialThread])
sfun
(srun runFull)
runFull
(toId 1)
(toId 1)
sct'Snap snap = let idsrc = cIdSource (dcsContext snap) in sct'
settings
(s0 (fst (threadsFromDCSnapshot snap)))
sfun
(srun (runSnap snap))
(runSnap snap)
(toId $ 1 + fst (_tids idsrc))
(toId $ 1 + fst (_crids idsrc))
runFull sched s = runConcurrent sched (_memtype settings) s conc
runSnap snap sched s = runWithDCSnapshot sched (_memtype settings) s snap
debugFatal = if _debugFatal settings then fatal else debugPrint
debugPrint = fromMaybe (const (pure ())) (_debugPrint settings)
sct' :: (MonadConc n, HasCallStack)
=> Settings n a
-> s
-> (s -> Maybe t)
-> (s -> t -> n (s, Maybe (Either Failure a, Trace)))
-> (forall x. Scheduler x -> x -> n (Either Failure a, x, Trace))
-> ThreadId
-> CRefId
-> n [(Either Failure a, Trace)]
sct' settings s0 sfun srun run nTId nCRId = go Nothing [] s0 where
go (Just res) _ _ | earlyExit res = pure []
go _ seen !s = case sfun s of
Just t -> srun s t >>= \case
(s', Just (res, trace)) -> case discard res of
Just DiscardResultAndTrace -> go (Just res) seen s'
Just DiscardTrace -> result res [] seen s'
Nothing -> result res trace seen s'
(s', Nothing) -> go Nothing seen s'
Nothing -> pure []
result = case _equality settings of
Just f -> \res trace seen s ->
let eq cmp (Right a1) (Right a2) = cmp a1 a2
eq _ (Left e1) (Left e2) = e1 == e2
eq _ _ _ = False
in if any (eq f res) seen
then go (Just res) seen s
else dosimplify res trace (res:seen) s
Nothing -> dosimplify
dosimplify res [] seen s = ((res, []) :) <$> go (Just res) seen s
dosimplify res trace seen s
| not (_simplify settings) = ((res, trace) :) <$> go (Just res) seen s
| otherwise = do
shrunk <- simplifyExecution settings run nTId nCRId res trace
(shrunk :) <$> go (Just res) seen s
earlyExit = fromMaybe (const False) (_earlyExit settings)
discard = fromMaybe (const Nothing) (_discard settings)
simplifyExecution :: (MonadConc n, HasCallStack)
=> Settings n a
-> (forall x. Scheduler x -> x -> n (Either Failure a, x, Trace))
-> ThreadId
-> CRefId
-> Either Failure a
-> Trace
-> n (Either Failure a, Trace)
simplifyExecution settings run nTId nCRId res trace
| tidTrace == simplifiedTrace = do
debugPrint ("Simplifying new result '" ++ p res ++ "': no simplification possible!")
pure (res, trace)
| otherwise = do
debugPrint ("Simplifying new result '" ++ p res ++ "': OK!")
(res', _, trace') <- replay run (fixup simplifiedTrace)
case (_equality settings, res, res') of
(Just f, Right a1, Right a2) | f a1 a2 -> pure (res', trace')
(_, Left e1, Left e2) | e1 == e2 -> pure (res', trace')
(Nothing, Right _, Right _) -> pure (res', trace')
_ -> do
debugFatal ("Got a different result after simplifying: '" ++ p res ++ "' /= '" ++ p res' ++ "'")
pure (res, trace)
where
tidTrace = toTIdTrace trace
simplifiedTrace = simplify (_memtype settings) tidTrace
fixup = renumber (_memtype settings) (fromId nTId) (fromId nCRId)
debugFatal = if _debugFatal settings then fatal else debugPrint
debugPrint = fromMaybe (const (pure ())) (_debugPrint settings)
debugShow = fromMaybe (const "_") (_debugShow settings)
p = either show debugShow
replay :: MonadConc n
=> (forall x. Scheduler x -> x -> n (Either Failure a, x, Trace))
-> [(ThreadId, ThreadAction)]
-> n (Either Failure a, [(ThreadId, ThreadAction)], Trace)
replay run = run (Scheduler (const sched)) where
sched runnable ((t, Stop):ts) = case findThread t runnable of
Just t' -> (Just t', ts)
Nothing -> sched runnable ts
sched runnable ((t, _):ts) = (findThread t runnable, ts)
sched _ _ = (Nothing, [])
findThread tid0 =
fmap fst . find (\(tid,_) -> fromId tid == fromId tid0)
simplify :: MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
simplify memtype trc0 = loop (length trc0) (prepare trc0) where
prepare = dropCommits memtype . lexicoNormalForm memtype
step = pushForward memtype . pullBack memtype
loop 0 trc = trc
loop n trc =
let trc' = step trc
in if trc' /= trc then loop (n-1) trc' else trc
lexicoNormalForm :: MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
lexicoNormalForm memtype = go where
go trc =
let trc' = permuteBy memtype (repeat (>)) trc
in if trc == trc' then trc else go trc'
permuteBy
:: MemType
-> [ThreadId -> ThreadId -> Bool]
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
permuteBy memtype = go initialDepState where
go ds (p:ps) (t1@(tid1, ta1):t2@(tid2, ta2):trc)
| independent ds tid1 ta1 tid2 ta2 && p tid1 tid2 = go' ds ps t2 (t1 : trc)
| otherwise = go' ds ps t1 (t2 : trc)
go _ _ trc = trc
go' ds ps t@(tid, ta) trc = t : go (updateDepState memtype ds tid ta) ps trc
dropCommits :: MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
dropCommits SequentialConsistency = id
dropCommits memtype = go initialDepState where
go ds (t1@(tid1, ta1@(CommitCRef _ _)):t2@(tid2, ta2):trc)
| isBarrier (simplifyAction ta2) = go ds (t2:trc)
| independent ds tid1 ta1 tid2 ta2 = t2 : go (updateDepState memtype ds tid2 ta2) (t1:trc)
go ds (t@(tid,ta):trc) = t : go (updateDepState memtype ds tid ta) trc
go _ [] = []
pullBack :: MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
pullBack memtype = go initialDepState where
go ds (t1@(tid1, ta1):trc@((tid2, _):_)) =
let ds' = updateDepState memtype ds tid1 ta1
trc' = if tid1 /= tid2
then maybe trc (uncurry (:)) (findAction tid1 ds' trc)
else trc
in t1 : go ds' trc'
go _ trc = trc
findAction tid0 = fgo where
fgo ds (t@(tid, ta):trc)
| tid == tid0 = Just (t, trc)
| otherwise = case fgo (updateDepState memtype ds tid ta) trc of
Just (ft@(ftid, fa), trc')
| independent ds tid ta ftid fa -> Just (ft, t:trc')
_ -> Nothing
fgo _ _ = Nothing
pushForward :: MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
pushForward memtype = go initialDepState where
go ds (t1@(tid1, ta1):trc@((tid2, _):_)) =
let ds' = updateDepState memtype ds tid1 ta1
in if tid1 /= tid2
then maybe (t1 : go ds' trc) (go ds) (findAction tid1 ta1 ds trc)
else t1 : go ds' trc
go _ trc = trc
findAction tid0 ta0 = fgo where
fgo ds (t@(tid, ta):trc)
| tid == tid0 = Just ((tid0, ta0) : t : trc)
| independent ds tid0 ta0 tid ta = (t:) <$> fgo (updateDepState memtype ds tid ta) trc
| otherwise = Nothing
fgo _ _ = Nothing
renumber
:: MemType
-> Int
-> Int
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
renumber memtype tid0 crid0 = snd . mapAccumL go (I.empty, tid0, I.empty, crid0) where
go s@(tidmap, _, cridmap, _) (_, CommitCRef tid crid) =
let tid' = renumbered tidmap tid
crid' = renumbered cridmap crid
act' = CommitCRef tid' crid'
in case memtype of
PartialStoreOrder -> (s, (commitThreadId tid' (Just crid'), act'))
_ -> (s, (commitThreadId tid' Nothing, act'))
go s@(tidmap, _, _, _) (tid, act) =
let (s', act') = updateAction s act
in (s', (renumbered tidmap tid, act'))
updateAction (tidmap, nTId, cridmap, nCRId) (Fork old) =
let tidmap' = I.insert (fromId old) nTId tidmap
nTId' = nTId + 1
in ((tidmap', nTId', cridmap, nCRId), Fork (toId nTId))
updateAction (tidmap, nTId, cridmap, nCRId) (ForkOS old) =
let tidmap' = I.insert (fromId old) nTId tidmap
nTId' = nTId + 1
in ((tidmap', nTId', cridmap, nCRId), ForkOS (toId nTId))
updateAction s@(tidmap, _, _, _) (PutMVar mvid olds) =
(s, PutMVar mvid (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (TryPutMVar mvid b olds) =
(s, TryPutMVar mvid b (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (TakeMVar mvid olds) =
(s, TakeMVar mvid (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (TryTakeMVar mvid b olds) =
(s, TryTakeMVar mvid b (map (renumbered tidmap) olds))
updateAction (tidmap, nTId, cridmap, nCRId) (NewCRef old) =
let cridmap' = I.insert (fromId old) nCRId cridmap
nCRId' = nCRId + 1
in ((tidmap, nTId, cridmap', nCRId'), NewCRef (toId nCRId))
updateAction s@(_, _, cridmap, _) (ReadCRef old) =
(s, ReadCRef (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (ReadCRefCas old) =
(s, ReadCRefCas (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (ModCRef old) =
(s, ModCRef (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (ModCRefCas old) =
(s, ModCRefCas (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (WriteCRef old) =
(s, WriteCRef (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (CasCRef old b) =
(s, CasCRef (renumbered cridmap old) b)
updateAction s@(tidmap, _, _, _) (STM tas olds) =
(s, STM tas (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (ThrowTo old) =
(s, ThrowTo (renumbered tidmap old))
updateAction s@(tidmap, _, _, _) (BlockedThrowTo old) =
(s, BlockedThrowTo (renumbered tidmap old))
updateAction s act = (s, act)
renumbered :: (Coercible a Id, Coercible Id a) => I.IntMap Int -> a -> a
renumbered idmap id_ = toId $ I.findWithDefault (fromId id_) (fromId id_) idmap
toId :: Coercible Id a => Int -> a
toId = coerce . Id Nothing
fromId :: Coercible a Id => a -> Int
fromId a = let (Id _ id_) = coerce a in id_