module Test.DejaFu.SCT.Internal where
import Control.Monad.Conc.Class (MonadConc)
import Control.Monad.Ref (MonadRef)
import Data.Coerce (Coercible, coerce)
import qualified Data.IntMap.Strict as I
import Data.List (find, mapAccumL)
import Data.Maybe (fromMaybe)
import Test.DejaFu.Conc
import Test.DejaFu.Conc.Internal.Memory (commitThreadId)
import Test.DejaFu.Internal
import Test.DejaFu.Schedule (Scheduler(..))
import Test.DejaFu.SCT.Internal.DPOR
import Test.DejaFu.Types
import Test.DejaFu.Utils
sct :: (MonadConc n, MonadRef r n)
=> Settings n a
-> ([ThreadId] -> s)
-> (s -> Maybe t)
-> ((Scheduler g -> g -> n (Either Failure a, g, Trace)) -> s -> t -> n (s, Maybe (Either Failure a, Trace)))
-> ConcT r n a
-> n [(Either Failure a, Trace)]
sct settings s0 sfun srun conc
| canDCSnapshot conc = runForDCSnapshot conc >>= \case
Just (Right snap, _) -> sct' settings (s0 (fst (threadsFromDCSnapshot snap))) sfun (srun (runSnap snap)) (runSnap snap)
Just (Left f, trace) -> pure [(Left f, trace)]
_ -> do
debugPrint "Failed to construct snapshot, continuing without."
sct' settings (s0 [initialThread]) sfun (srun runFull) runFull
| otherwise = sct' settings (s0 [initialThread]) sfun (srun runFull) runFull
where
runFull sched s = runConcurrent sched (_memtype settings) s conc
runSnap snap sched s = runWithDCSnapshot sched (_memtype settings) s snap
debugPrint = fromMaybe (const (pure ())) (_debugPrint settings)
sct' :: (MonadConc n, MonadRef r n)
=> Settings n a
-> s
-> (s -> Maybe t)
-> (s -> t -> n (s, Maybe (Either Failure a, Trace)))
-> (forall x. Scheduler x -> x -> n (Either Failure a, x, Trace))
-> n [(Either Failure a, Trace)]
sct' settings s0 sfun srun run = go Nothing [] s0 where
go (Just res) _ _ | earlyExit res = pure []
go _ seen !s = case sfun s of
Just t -> srun s t >>= \case
(s', Just (res, trace)) -> case discard res of
Just DiscardResultAndTrace -> go (Just res) seen s'
Just DiscardTrace -> result res [] seen s'
Nothing -> result res trace seen s'
(s', Nothing) -> go Nothing seen s'
Nothing -> pure []
result = case _equality settings of
Just f -> \res trace seen s ->
let eq cmp (Right a1) (Right a2) = cmp a1 a2
eq _ (Left e1) (Left e2) = e1 == e2
eq _ _ _ = False
in if any (eq f res) seen
then go (Just res) seen s
else dosimplify res trace (res:seen) s
Nothing -> dosimplify
dosimplify res [] seen s = ((res, []) :) <$> go (Just res) seen s
dosimplify res trace seen s
| not (_simplify settings) = ((res, trace) :) <$> go (Just res) seen s
| otherwise = do
shrunk <- simplifyExecution settings run res trace
(shrunk :) <$> go (Just res) seen s
earlyExit = fromMaybe (const False) (_earlyExit settings)
discard = fromMaybe (const Nothing) (_discard settings)
simplifyExecution :: (MonadConc n, MonadRef r n)
=> Settings n a
-> (forall x. Scheduler x -> x -> n (Either Failure a, x, Trace))
-> Either Failure a
-> Trace
-> n (Either Failure a, Trace)
simplifyExecution settings run res trace
| tidTrace == simplifiedTrace = do
debugPrint ("Simplifying new result '" ++ p res ++ "': no simplification possible!")
pure (res, trace)
| otherwise = do
debugPrint ("Simplifying new result '" ++ p res ++ "': OK!")
(res', _, trace') <- replay run simplifiedTrace
case (_equality settings, res, res') of
(Just f, Right a1, Right a2) | f a1 a2 -> pure (res', trace')
(_, Left e1, Left e2) | e1 == e2 -> pure (res', trace')
(Nothing, Right _, Right _) -> pure (res', trace')
_ -> do
debugPrint ("Got a different result after simplifying: '" ++ p res ++ "' /= '" ++ p res' ++ "'")
pure (res, trace)
where
tidTrace = toTIdTrace trace
simplifiedTrace = simplify (_memtype settings) tidTrace
debugPrint = fromMaybe (const (pure ())) (_debugPrint settings)
debugShow = fromMaybe (const "_") (_debugShow settings)
p = either show debugShow
replay :: (MonadConc n, MonadRef r n)
=> (forall x. Scheduler x -> x -> n (Either Failure a, x, Trace))
-> [(ThreadId, ThreadAction)]
-> n (Either Failure a, [(ThreadId, ThreadAction)], Trace)
replay run = run (Scheduler (const sched)) where
sched runnable ((t, Stop):ts) = case findThread t runnable of
Just t' -> (Just t', ts)
Nothing -> sched runnable ts
sched runnable ((t, _):ts) = (findThread t runnable, ts)
sched _ _ = (Nothing, [])
findThread tid0 =
fmap fst . find (\(tid,_) -> fromId tid == fromId tid0)
simplify :: MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
simplify memtype trc0 = loop (length trc0) (prepare trc0) where
prepare = dropCommits memtype . lexicoNormalForm memtype
step = pushForward memtype . pullBack memtype
loop 0 trc = trc
loop n trc =
let trc' = step trc
in if trc' /= trc then loop (n1) trc' else trc
lexicoNormalForm :: MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
lexicoNormalForm memtype = go where
go trc =
let trc' = permuteBy memtype (repeat (>)) trc
in if trc == trc' then trc else go trc'
permuteBy
:: MemType
-> [ThreadId -> ThreadId -> Bool]
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
permuteBy memtype = go initialDepState where
go ds (p:ps) (t1@(tid1, ta1):t2@(tid2, ta2):trc)
| independent ds tid1 ta1 tid2 ta2 && p tid1 tid2 = go' ds ps t2 (t1 : trc)
| otherwise = go' ds ps t1 (t2 : trc)
go _ _ trc = trc
go' ds ps t@(tid, ta) trc = t : go (updateDepState memtype ds tid ta) ps trc
dropCommits :: MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
dropCommits SequentialConsistency = id
dropCommits memtype = go initialDepState where
go ds (t1@(tid1, ta1@(CommitCRef _ _)):t2@(tid2, ta2):trc)
| isBarrier (simplifyAction ta2) = go ds (t2:trc)
| independent ds tid1 ta1 tid2 ta2 = t2 : go (updateDepState memtype ds tid2 ta2) (t1:trc)
go ds (t@(tid,ta):trc) = t : go (updateDepState memtype ds tid ta) trc
go _ [] = []
pullBack :: MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
pullBack memtype = go initialDepState where
go ds (t1@(tid1, ta1):trc@((tid2, _):_)) =
let ds' = updateDepState memtype ds tid1 ta1
trc' = if tid1 /= tid2
then maybe trc (uncurry (:)) (findAction tid1 ds' trc)
else trc
in t1 : go ds' trc'
go _ trc = trc
findAction tid0 = fgo where
fgo ds (t@(tid, ta):trc)
| tid == tid0 = Just (t, trc)
| otherwise = case fgo (updateDepState memtype ds tid ta) trc of
Just (ft@(ftid, fa), trc')
| independent ds tid ta ftid fa -> Just (ft, t:trc')
_ -> Nothing
fgo _ _ = Nothing
pushForward :: MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
pushForward memtype = go initialDepState where
go ds (t1@(tid1, ta1):trc@((tid2, _):_)) =
let ds' = updateDepState memtype ds tid1 ta1
in if tid1 /= tid2
then maybe (t1 : go ds' trc) (go ds) (findAction tid1 ta1 ds trc)
else t1 : go ds' trc
go _ trc = trc
findAction tid0 ta0 = fgo where
fgo ds (t@(tid, ta):trc)
| tid == tid0 = Just ((tid0, ta0) : t : trc)
| independent ds tid0 ta0 tid ta = (t:) <$> fgo (updateDepState memtype ds tid ta) trc
| otherwise = Nothing
fgo _ _ = Nothing
renumber
:: MemType
-> Int
-> Int
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
renumber memtype tid0 crid0 = snd . mapAccumL go (I.empty, tid0, I.empty, crid0) where
go s@(tidmap, _, cridmap, _) (_, CommitCRef tid crid) =
let tid' = renumbered tidmap tid
crid' = renumbered cridmap crid
act' = CommitCRef tid' crid'
in case memtype of
PartialStoreOrder -> (s, (commitThreadId tid' (Just crid'), act'))
_ -> (s, (commitThreadId tid' Nothing, act'))
go s@(tidmap, _, _, _) (tid, act) =
let (s', act') = updateAction s act
in (s', (renumbered tidmap tid, act'))
updateAction (tidmap, nexttid, cridmap, nextcrid) (Fork old) =
let tidmap' = I.insert (fromId old) nexttid tidmap
nexttid' = nexttid + 1
in ((tidmap', nexttid', cridmap, nextcrid), Fork (toId nexttid))
updateAction (tidmap, nexttid, cridmap, nextcrid) (ForkOS old) =
let tidmap' = I.insert (fromId old) nexttid tidmap
nexttid' = nexttid + 1
in ((tidmap', nexttid', cridmap, nextcrid), ForkOS (toId nexttid))
updateAction s@(tidmap, _, _, _) (PutMVar mvid olds) =
(s, PutMVar mvid (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (TryPutMVar mvid b olds) =
(s, TryPutMVar mvid b (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (TakeMVar mvid olds) =
(s, TakeMVar mvid (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (TryTakeMVar mvid b olds) =
(s, TryTakeMVar mvid b (map (renumbered tidmap) olds))
updateAction (tidmap, nexttid, cridmap, nextcrid) (NewCRef old) =
let cridmap' = I.insert (fromId old) nextcrid cridmap
nextcrid' = nextcrid + 1
in ((tidmap, nexttid, cridmap', nextcrid'), NewCRef (toId nextcrid))
updateAction s@(_, _, cridmap, _) (ReadCRef old) =
(s, ReadCRef (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (ReadCRefCas old) =
(s, ReadCRefCas (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (ModCRef old) =
(s, ModCRef (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (ModCRefCas old) =
(s, ModCRefCas (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (WriteCRef old) =
(s, WriteCRef (renumbered cridmap old))
updateAction s@(_, _, cridmap, _) (CasCRef old b) =
(s, CasCRef (renumbered cridmap old) b)
updateAction s@(tidmap, _, _, _) (STM tas olds) =
(s, STM tas (map (renumbered tidmap) olds))
updateAction s@(tidmap, _, _, _) (ThrowTo old) =
(s, ThrowTo (renumbered tidmap old))
updateAction s@(tidmap, _, _, _) (BlockedThrowTo old) =
(s, BlockedThrowTo (renumbered tidmap old))
updateAction s act = (s, act)
renumbered :: (Coercible a Id, Coercible Id a) => I.IntMap Int -> a -> a
renumbered idmap id_ = toId $ I.findWithDefault (fromId id_) (fromId id_) idmap
toId :: Coercible Id a => Int -> a
toId = coerce . Id Nothing
fromId :: Coercible a Id => a -> Int
fromId a = let (Id _ id_) = coerce a in id_