module Test.DejaFu.SCT.Internal where
import Control.DeepSeq (NFData(..))
import Control.Exception (MaskingState(..))
import Data.Char (ord)
import Data.Function (on)
import qualified Data.Foldable as F
import Data.List (intercalate, nubBy, partition, sortOn)
import Data.List.NonEmpty (NonEmpty(..), toList)
import Data.Map.Strict (Map)
import Data.Maybe (catMaybes, fromJust, isNothing, listToMaybe)
import qualified Data.Map.Strict as M
import Data.Set (Set)
import qualified Data.Set as S
import Data.Sequence (Seq, (|>))
import qualified Data.Sequence as Sq
import System.Random (RandomGen, randomR)
import Test.DejaFu.Common
import Test.DejaFu.Schedule (Decision(..), Scheduler, decisionOf, tidOf)
data DPOR = DPOR
{ dporRunnable :: Set ThreadId
, dporTodo :: Map ThreadId Bool
, dporDone :: Map ThreadId DPOR
, dporSleep :: Map ThreadId ThreadAction
, dporTaken :: Map ThreadId ThreadAction
, dporAction :: Maybe ThreadAction
} deriving (Eq, Show)
instance NFData DPOR where
rnf dpor = rnf ( dporRunnable dpor
, dporTodo dpor
, dporDone dpor
, dporSleep dpor
, dporTaken dpor
, dporAction dpor
)
data BacktrackStep = BacktrackStep
{ bcktThreadid :: ThreadId
, bcktDecision :: Decision
, bcktAction :: ThreadAction
, bcktRunnable :: Map ThreadId Lookahead
, bcktBacktracks :: Map ThreadId Bool
, bcktState :: DepState
} deriving (Eq, Show)
instance NFData BacktrackStep where
rnf bs = rnf ( bcktThreadid bs
, bcktDecision bs
, bcktAction bs
, bcktRunnable bs
, bcktBacktracks bs
, bcktState bs
)
initialState :: DPOR
initialState = DPOR
{ dporRunnable = S.singleton initialThread
, dporTodo = M.singleton initialThread False
, dporDone = M.empty
, dporSleep = M.empty
, dporTaken = M.empty
, dporAction = Nothing
}
findSchedulePrefix
:: (ThreadId -> Bool)
-> DPOR
-> Maybe ([ThreadId], Bool, Map ThreadId ThreadAction)
findSchedulePrefix predicate = listToMaybe . go where
go dpor =
let prefixes = here dpor : map go' (M.toList $ dporDone dpor)
in case concatPartition (\(t:_,_,_) -> predicate t) prefixes of
([], choices) -> choices
(choices, _) -> choices
go' (tid, dpor) = (\(ts,c,slp) -> (tid:ts,c,slp)) <$> go dpor
here dpor = [([t], c, sleeps dpor) | (t, c) <- M.toList $ dporTodo dpor]
sleeps dpor = dporSleep dpor `M.union` dporTaken dpor
incorporateTrace
:: (DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool)
-> Bool
-> Trace
-> DPOR
-> DPOR
incorporateTrace dependency conservative trace dpor0 = grow initialDepState (initialDPORThread dpor0) trace dpor0 where
grow state tid trc@((d, _, a):rest) dpor =
let tid' = tidOf tid d
state' = updateDepState state tid' a
in case M.lookup tid' (dporDone dpor) of
Just dpor' ->
let done = M.insert tid' (grow state' tid' rest dpor') (dporDone dpor)
in dpor { dporDone = done }
Nothing ->
let taken = M.insert tid' a (dporTaken dpor)
sleep = dporSleep dpor `M.union` dporTaken dpor
done = M.insert tid' (subtree state' tid' sleep trc) (dporDone dpor)
in dpor { dporTaken = if conservative then dporTaken dpor else taken
, dporTodo = M.delete tid' (dporTodo dpor)
, dporDone = done
}
grow _ _ [] _ = err "incorporateTrace" "trace exhausted without reading a to-do point!"
subtree state tid sleep ((_, _, a):rest) =
let state' = updateDepState state tid a
sleep' = M.filterWithKey (\t a' -> not $ dependency state' tid a t a') sleep
in DPOR
{ dporRunnable = S.fromList $ case rest of
((_, runnable, _):_) -> map fst runnable
[] -> []
, dporTodo = M.empty
, dporDone = M.fromList $ case rest of
((d', _, _):_) ->
let tid' = tidOf tid d'
in [(tid', subtree state' tid' sleep' rest)]
[] -> []
, dporSleep = sleep'
, dporTaken = case rest of
((d', _, a'):_) -> M.singleton (tidOf tid d') a'
[] -> M.empty
, dporAction = Just a
}
subtree _ _ _ [] = err "incorporateTrace" "subtree suffix empty!"
findBacktrackSteps
:: (DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool)
-> BacktrackFunc
-> Bool
-> Seq (NonEmpty (ThreadId, Lookahead), [ThreadId])
-> Trace
-> [BacktrackStep]
findBacktrackSteps dependency backtrack boundKill = go initialDepState S.empty initialThread [] . F.toList where
go state allThreads tid bs ((e,i):is) ((d,_,a):ts) =
let tid' = tidOf tid d
state' = updateDepState state tid' a
this = BacktrackStep
{ bcktThreadid = tid'
, bcktDecision = d
, bcktAction = a
, bcktRunnable = M.fromList . toList $ e
, bcktBacktracks = M.fromList $ map (\i' -> (i', False)) i
, bcktState = state'
}
bs' = doBacktrack killsEarly allThreads' (toList e) (bs++[this])
runnable = S.fromList (M.keys $ bcktRunnable this)
allThreads' = allThreads `S.union` runnable
killsEarly = null ts && boundKill
in go state' allThreads' tid' bs' is ts
go _ _ _ bs _ _ = bs
doBacktrack killsEarly allThreads enabledThreads bs =
let tagged = reverse $ zip [0..] bs
idxs = [ (head is, False, u)
| (u, n) <- enabledThreads
, v <- S.toList allThreads
, u /= v
, let is = idxs' u n v tagged
, not $ null is]
idxs' u n v = catMaybes . go' True where
go' final ((i,b):rest)
| isSubC final b = []
| bcktThreadid b == v && (killsEarly || isDependent b) = Just i : go' False rest
| otherwise = go' False rest
go' _ [] = []
isSubC final b = case bcktAction b of
Stop -> not final && bcktThreadid b == initialThread
Subconcurrency -> bcktThreadid b == initialThread
_ -> False
isDependent b = dependency (bcktState b) (bcktThreadid b) (bcktAction b) u n
in backtrack bs idxs
incorporateBacktrackSteps
:: ([(Decision, ThreadAction)] -> (Decision, Lookahead) -> Bool)
-> [BacktrackStep]
-> DPOR
-> DPOR
incorporateBacktrackSteps bv = go Nothing [] where
go priorTid pref (b:bs) bpor =
let bpor' = doBacktrack priorTid pref b bpor
tid = bcktThreadid b
pref' = pref ++ [(bcktDecision b, bcktAction b)]
child = go (Just tid) pref' bs . fromJust $ M.lookup tid (dporDone bpor)
in bpor' { dporDone = M.insert tid child $ dporDone bpor' }
go _ _ [] bpor = bpor
doBacktrack priorTid pref b bpor =
let todo' = [ x
| x@(t,c) <- M.toList $ bcktBacktracks b
, let decision = decisionOf priorTid (dporRunnable bpor) t
, let lahead = fromJust . M.lookup t $ bcktRunnable b
, bv pref (decision, lahead)
, t `notElem` M.keys (dporDone bpor)
, c || M.notMember t (dporSleep bpor)
]
in bpor { dporTodo = dporTodo bpor `M.union` M.fromList todo' }
data DPORSchedState = DPORSchedState
{ schedSleep :: Map ThreadId ThreadAction
, schedPrefix :: [ThreadId]
, schedBPoints :: Seq (NonEmpty (ThreadId, Lookahead), [ThreadId])
, schedIgnore :: Bool
, schedBoundKill :: Bool
, schedDepState :: DepState
} deriving (Eq, Show)
instance NFData DPORSchedState where
rnf s = rnf ( schedSleep s
, schedPrefix s
, schedBPoints s
, schedIgnore s
, schedBoundKill s
, schedDepState s
)
initialDPORSchedState :: Map ThreadId ThreadAction
-> [ThreadId]
-> DPORSchedState
initialDPORSchedState sleep prefix = DPORSchedState
{ schedSleep = sleep
, schedPrefix = prefix
, schedBPoints = Sq.empty
, schedIgnore = False
, schedBoundKill = False
, schedDepState = initialDepState
}
type BoundFunc
= [(Decision, ThreadAction)] -> (Decision, Lookahead) -> Bool
type BacktrackFunc
= [BacktrackStep] -> [(Int, Bool, ThreadId)] -> [BacktrackStep]
backtrackAt
:: (ThreadId -> BacktrackStep -> Bool)
-> BacktrackFunc
backtrackAt toAll bs0 = backtrackAt' . nubBy ((==) `on` fst') . sortOn fst' where
fst' (x,_,_) = x
backtrackAt' ((i,c,t):is) = go i bs0 i c t is
backtrackAt' [] = bs0
go i0 (b:bs) 0 c tid is
| not (toAll tid b) && tid `M.member` bcktRunnable b =
let val = M.lookup tid $ bcktBacktracks b
b' = if isNothing val || (val == Just False && c)
then b { bcktBacktracks = backtrackTo tid c b }
else b
in b' : case is of
((i',c',t'):is') -> go i' bs (i'i01) c' t' is'
[] -> bs
| otherwise =
let b' = b { bcktBacktracks = backtrackAll c b }
in b' : case is of
((i',c',t'):is') -> go i' bs (i'i01) c' t' is'
[] -> bs
go i0 (b:bs) i c tid is = b : go i0 bs (i1) c tid is
go _ [] _ _ _ _ = err "backtrackAt" "ran out of schedule whilst backtracking!"
backtrackTo tid c = M.insert tid c . bcktBacktracks
backtrackAll c = M.map (const c) . bcktRunnable
dporSched
:: (DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool)
-> BoundFunc
-> Scheduler DPORSchedState
dporSched dependency inBound trc prior threads s = schedule where
schedule = case schedPrefix s of
(d:ds) -> (Just d, (nextState []) { schedPrefix = ds })
[] ->
let choices = restrictToBound initialise
checkDep t a = case prior of
Just (tid, act) -> dependency (schedDepState s) tid act t a
Nothing -> False
ssleep' = M.filterWithKey (\t a -> not $ checkDep t a) $ schedSleep s
choices' = filter (`notElem` M.keys ssleep') choices
signore' = not (null choices) && all (`elem` M.keys ssleep') choices
sbkill' = not (null initialise) && null choices
in case choices' of
(nextTid:rest) -> (Just nextTid, (nextState rest) { schedSleep = ssleep' })
[] -> (Nothing, (nextState []) { schedIgnore = signore', schedBoundKill = sbkill' })
nextState rest = s
{ schedBPoints = schedBPoints s |> (threads, rest)
, schedDepState = nextDepState
}
nextDepState = let ds = schedDepState s in maybe ds (uncurry $ updateDepState ds) prior
initialise = tryDaemons . yieldsToEnd $ case prior of
Just (tid, act)
| not (didYield act) && tid `elem` tids -> [tid]
_ -> tids'
tryDaemons ts
| any doesKill ts = case partition doesKill tids' of
(kills, nokills) -> nokills ++ kills
| otherwise = ts
doesKill t = killsDaemons t (action t)
restrictToBound = filter (\t -> inBound trc (decision t, action t))
yieldsToEnd ts = case partition (willYield . action) ts of
(yields, noyields) -> noyields ++ yields
decision = decisionOf (fst <$> prior) (S.fromList tids')
action t = fromJust $ lookup t threads'
tids = fst <$> threads
threads' = toList threads
tids' = toList tids
data RandSchedState g = RandSchedState
{ schedWeights :: Map ThreadId Int
, schedGen :: g
} deriving (Eq, Show)
instance NFData g => NFData (RandSchedState g) where
rnf s = rnf ( schedWeights s
, schedGen s
)
initialRandSchedState :: g -> RandSchedState g
initialRandSchedState = RandSchedState M.empty
randSched :: RandomGen g => Scheduler (RandSchedState g)
randSched _ _ threads s = (pick choice enabled, RandSchedState weights' g'') where
pick idx ((x, f):xs)
| idx < f = Just x
| otherwise = pick (idx f) xs
pick _ [] = Nothing
(choice, g'') = randomR (0, sum (map snd enabled) 1) g'
enabled = M.toList $ M.filterWithKey (\tid _ -> tid `elem` tids) weights'
weights' = schedWeights s `M.union` M.fromList newWeights
(newWeights, g') = foldr assignWeight ([], schedGen s) $ filter (`M.notMember` schedWeights s) tids
assignWeight tid ~(ws, g0) =
let (w, g) = randomR (1, 50) g0
in ((tid, w):ws, g)
tids = map fst (toList threads)
data DepState = DepState
{ depCRState :: Map CRefId Bool
, depMaskState :: Map ThreadId MaskingState
} deriving (Eq, Show)
instance NFData DepState where
rnf depstate = rnf ( depCRState depstate
, [(t, m `seq` ()) | (t, m) <- M.toList (depMaskState depstate)]
)
initialDepState :: DepState
initialDepState = DepState M.empty M.empty
updateDepState :: DepState -> ThreadId -> ThreadAction -> DepState
updateDepState depstate tid act = DepState
{ depCRState = updateCRState act $ depCRState depstate
, depMaskState = updateMaskState tid act $ depMaskState depstate
}
updateCRState :: ThreadAction -> Map CRefId Bool -> Map CRefId Bool
updateCRState (CommitCRef _ r) = M.delete r
updateCRState (WriteCRef r) = M.insert r True
updateCRState ta
| isBarrier $ simplifyAction ta = const M.empty
| otherwise = id
updateMaskState :: ThreadId -> ThreadAction -> Map ThreadId MaskingState -> Map ThreadId MaskingState
updateMaskState tid (Fork tid2) = \masks -> case M.lookup tid masks of
Just ms -> M.insert tid2 ms masks
Nothing -> masks
updateMaskState tid (SetMasking _ ms) = M.insert tid ms
updateMaskState tid (ResetMasking _ ms) = M.insert tid ms
updateMaskState _ _ = id
isBuffered :: DepState -> CRefId -> Bool
isBuffered depstate r = M.findWithDefault False r (depCRState depstate)
canInterrupt :: DepState -> ThreadId -> ThreadAction -> Bool
canInterrupt depstate tid act
| isMaskedInterruptible depstate tid = case act of
BlockedPutMVar _ -> True
BlockedReadMVar _ -> True
BlockedTakeMVar _ -> True
BlockedSTM _ -> True
BlockedThrowTo _ -> True
_ -> False
| isMaskedUninterruptible depstate tid = False
| otherwise = True
canInterruptL :: DepState -> ThreadId -> Lookahead -> Bool
canInterruptL depstate tid lh
| isMaskedInterruptible depstate tid = case lh of
WillPutMVar _ -> True
WillReadMVar _ -> True
WillTakeMVar _ -> True
WillSTM -> True
WillThrowTo _ -> True
_ -> False
| isMaskedUninterruptible depstate tid = False
| otherwise = True
isMaskedInterruptible :: DepState -> ThreadId -> Bool
isMaskedInterruptible depstate tid =
M.lookup tid (depMaskState depstate) == Just MaskedInterruptible
isMaskedUninterruptible :: DepState -> ThreadId -> Bool
isMaskedUninterruptible depstate tid =
M.lookup tid (depMaskState depstate) == Just MaskedUninterruptible
initialDPORThread :: DPOR -> ThreadId
initialDPORThread = S.elemAt 0 . dporRunnable
didYield :: ThreadAction -> Bool
didYield Yield = True
didYield _ = False
willYield :: Lookahead -> Bool
willYield WillYield = True
willYield _ = False
killsDaemons :: ThreadId -> Lookahead -> Bool
killsDaemons t WillStop = t == initialThread
killsDaemons _ _ = False
toDot :: (ThreadId -> String)
-> (ThreadAction -> String)
-> DPOR
-> String
toDot = toDotFiltered (\_ _ -> True)
toDotFiltered :: (ThreadId -> DPOR -> Bool)
-> (ThreadId -> String)
-> (ThreadAction -> String)
-> DPOR
-> String
toDotFiltered check showTid showAct = digraph . go "L" where
digraph str = "digraph {\n" ++ str ++ "\n}"
go l b = unlines $ node l b : edges l b
node n b = n ++ " [label=\"" ++ label b ++ "\"]"
edges l b = [ edge l l' i ++ go l' b'
| (i, b') <- M.toList (dporDone b)
, check i b'
, let l' = l ++ tidId i
]
label b = showLst id
[ maybe "Nothing" (("Just " ++) . showAct) $ dporAction b
, "Run:" ++ showLst showTid (S.toList $ dporRunnable b)
, "Tod:" ++ showLst showTid (M.keys $ dporTodo b)
, "Slp:" ++ showLst (\(t,a) -> "(" ++ showTid t ++ ", " ++ showAct a ++ ")")
(M.toList $ dporSleep b)
]
edge n1 n2 l = n1 ++ " -> " ++ n2 ++ " [label=\"" ++ showTid l ++ "\"]\n"
showLst showf xs = "[" ++ intercalate ", " (map showf xs) ++ "]"
tidId = concatMap (show . ord) . showTid
err :: String -> String -> a
err func msg = error (func ++ ": (internal error) " ++ msg)
concatPartition :: (a -> Bool) -> [[a]] -> ([a], [a])
concatPartition p = foldl (foldr select) ([], []) where
select a ~(ts, fs)
| p a = (a:ts, fs)
| otherwise = (ts, a:fs)