module Test.DejaFu.SCT.Internal where
import Control.Applicative ((<|>))
import Control.DeepSeq (NFData(..))
import Control.Exception (MaskingState(..))
import qualified Data.Foldable as F
import Data.Function (on)
import Data.List (nubBy, partition, sortOn)
import Data.List.NonEmpty (toList)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (fromJust, fromMaybe, isJust, isNothing,
listToMaybe)
import Data.Sequence (Seq, (|>))
import qualified Data.Sequence as Sq
import Data.Set (Set)
import qualified Data.Set as S
import System.Random (RandomGen, randomR)
import Test.DejaFu.Common
import Test.DejaFu.Schedule (Scheduler(..), decisionOf, tidOf)
data DPOR = DPOR
{ dporRunnable :: Set ThreadId
, dporTodo :: Map ThreadId Bool
, dporNext :: Maybe (ThreadId, DPOR)
, dporDone :: Set ThreadId
, dporSleep :: Map ThreadId ThreadAction
, dporTaken :: Map ThreadId ThreadAction
} deriving (Eq, Show)
instance NFData DPOR where
rnf dpor = rnf ( dporRunnable dpor
, dporTodo dpor
, dporNext dpor
, dporDone dpor
, dporSleep dpor
, dporTaken dpor
)
data BacktrackStep = BacktrackStep
{ bcktThreadid :: ThreadId
, bcktDecision :: Decision
, bcktAction :: ThreadAction
, bcktRunnable :: Map ThreadId Lookahead
, bcktBacktracks :: Map ThreadId Bool
, bcktState :: DepState
} deriving (Eq, Show)
instance NFData BacktrackStep where
rnf bs = rnf ( bcktThreadid bs
, bcktDecision bs
, bcktAction bs
, bcktRunnable bs
, bcktBacktracks bs
, bcktState bs
)
initialState :: DPOR
initialState = DPOR
{ dporRunnable = S.singleton initialThread
, dporTodo = M.singleton initialThread False
, dporNext = Nothing
, dporDone = S.empty
, dporSleep = M.empty
, dporTaken = M.empty
}
findSchedulePrefix
:: DPOR
-> Maybe ([ThreadId], Bool, Map ThreadId ThreadAction)
findSchedulePrefix dpor = case dporNext dpor of
Just (tid, child) -> go tid child <|> here
Nothing -> here
where
go tid child = (\(ts,c,slp) -> (tid:ts,c,slp)) <$> findSchedulePrefix child
here =
let todos = [([t], c, sleeps) | (t, c) <- M.toList $ dporTodo dpor]
(best, worst) = partition (\([t],_,_) -> t >= initialThread) todos
in listToMaybe best <|> listToMaybe worst
sleeps = dporSleep dpor `M.union` dporTaken dpor
incorporateTrace
:: MemType
-> Bool
-> Trace
-> DPOR
-> DPOR
incorporateTrace memtype conservative trace dpor0 = grow initialDepState (initialDPORThread dpor0) trace dpor0 where
grow state tid trc@((d, _, a):rest) dpor =
let tid' = tidOf tid d
state' = updateDepState state tid' a
in case dporNext dpor of
Just (t, child)
| t == tid' -> dpor { dporNext = Just (tid', grow state' tid' rest child) }
| hasTodos child -> err "incorporateTrace" "replacing child with todos!"
_ ->
let taken = M.insert tid' a (dporTaken dpor)
sleep = dporSleep dpor `M.union` dporTaken dpor
in dpor { dporTaken = if conservative then dporTaken dpor else taken
, dporTodo = M.delete tid' (dporTodo dpor)
, dporNext = Just (tid', subtree state' tid' sleep trc)
, dporDone = S.insert tid' (dporDone dpor)
}
grow _ _ [] _ = err "incorporateTrace" "trace exhausted without reading a to-do point!"
hasTodos dpor = not (M.null (dporTodo dpor)) || (case dporNext dpor of Just (_, dpor') -> hasTodos dpor'; _ -> False)
subtree state tid sleep ((_, _, a):rest) =
let state' = updateDepState state tid a
sleep' = M.filterWithKey (\t a' -> not $ dependent memtype state' tid a t a') sleep
in DPOR
{ dporRunnable = S.fromList $ case rest of
((_, runnable, _):_) -> map fst runnable
[] -> []
, dporTodo = M.empty
, dporNext = case rest of
((d', _, _):_) ->
let tid' = tidOf tid d'
in Just (tid', subtree state' tid' sleep' rest)
[] -> Nothing
, dporDone = case rest of
((d', _, _):_) -> S.singleton (tidOf tid d')
[] -> S.empty
, dporSleep = sleep'
, dporTaken = case rest of
((d', _, a'):_) -> M.singleton (tidOf tid d') a'
[] -> M.empty
}
subtree _ _ _ [] = err "incorporateTrace" "subtree suffix empty!"
findBacktrackSteps
:: MemType
-> BacktrackFunc
-> Bool
-> Seq ([(ThreadId, Lookahead)], [ThreadId])
-> Trace
-> [BacktrackStep]
findBacktrackSteps memtype backtrack boundKill = go initialDepState S.empty initialThread [] . F.toList where
go state allThreads tid bs ((e,i):is) ((d,_,a):ts) =
let tid' = tidOf tid d
state' = updateDepState state tid' a
this = BacktrackStep
{ bcktThreadid = tid'
, bcktDecision = d
, bcktAction = a
, bcktRunnable = M.fromList e
, bcktBacktracks = M.fromList $ map (\i' -> (i', False)) i
, bcktState = state'
}
bs' = doBacktrack killsEarly allThreads' e (bs++[this])
runnable = S.fromList (M.keys $ bcktRunnable this)
allThreads' = allThreads `S.union` runnable
killsEarly = null ts && boundKill
in go state' allThreads' tid' bs' is ts
go _ _ _ bs _ _ = bs
doBacktrack killsEarly allThreads enabledThreads bs =
let tagged = reverse $ zip [0..] bs
idxs = [ (head is, False, u)
| (u, n) <- enabledThreads
, v <- S.toList allThreads
, u /= v
, let is = idxs' u n v tagged
, not $ null is]
idxs' u n v = go' True where
go' final ((i,b):rest)
| isSubC final b = []
| bcktThreadid b == v && (killsEarly || isDependent b) = i : go' False rest
| otherwise = go' False rest
go' _ [] = []
isSubC final b = case bcktAction b of
Stop -> not final && bcktThreadid b == initialThread
Subconcurrency -> bcktThreadid b == initialThread
_ -> False
isDependent b = dependent' memtype (bcktState b) (bcktThreadid b) (bcktAction b) u n
in backtrack bs idxs
incorporateBacktrackSteps :: [BacktrackStep] -> DPOR -> DPOR
incorporateBacktrackSteps (b:bs) dpor = dpor' where
tid = bcktThreadid b
dpor' = dpor
{ dporTodo = dporTodo dpor `M.union` M.fromList todo
, dporNext = Just (tid, child)
}
todo =
[ x
| x@(t,c) <- M.toList $ bcktBacktracks b
, Just t /= (fst <$> dporNext dpor)
, S.notMember t (dporDone dpor)
, c || M.notMember t (dporSleep dpor)
]
child = case dporNext dpor of
Just (t, d)
| t /= tid -> err "incorporateBacktrackSteps" "incorporating wrong trace!"
| otherwise -> incorporateBacktrackSteps bs d
Nothing -> err "incorporateBacktrackSteps" "child is missing!"
incorporateBacktrackSteps [] dpor = dpor
data DPORSchedState k = DPORSchedState
{ schedSleep :: Map ThreadId ThreadAction
, schedPrefix :: [ThreadId]
, schedBPoints :: Seq ([(ThreadId, Lookahead)], [ThreadId])
, schedIgnore :: Bool
, schedBoundKill :: Bool
, schedDepState :: DepState
, schedBState :: Maybe k
} deriving (Eq, Show)
instance NFData k => NFData (DPORSchedState k) where
rnf s = rnf ( schedSleep s
, schedPrefix s
, schedBPoints s
, schedIgnore s
, schedBoundKill s
, schedDepState s
, schedBState s
)
initialDPORSchedState :: Map ThreadId ThreadAction
-> [ThreadId]
-> DPORSchedState k
initialDPORSchedState sleep prefix = DPORSchedState
{ schedSleep = sleep
, schedPrefix = prefix
, schedBPoints = Sq.empty
, schedIgnore = False
, schedBoundKill = False
, schedDepState = initialDepState
, schedBState = Nothing
}
type IncrementalBoundFunc k
= Maybe k -> Maybe (ThreadId, ThreadAction) -> (Decision, Lookahead) -> Maybe k
type BacktrackFunc
= [BacktrackStep] -> [(Int, Bool, ThreadId)] -> [BacktrackStep]
backtrackAt
:: (ThreadId -> BacktrackStep -> Bool)
-> BacktrackFunc
backtrackAt toAll bs0 = backtrackAt' . nubBy ((==) `on` fst') . sortOn fst' where
fst' (x,_,_) = x
backtrackAt' ((i,c,t):is) = go i bs0 i c t is
backtrackAt' [] = bs0
go i0 (b:bs) 0 c tid is
| not (toAll tid b) && tid `M.member` bcktRunnable b =
let val = M.lookup tid $ bcktBacktracks b
b' = if isNothing val || (val == Just False && c)
then b { bcktBacktracks = backtrackTo tid c b }
else b
in b' : case is of
((i',c',t'):is') -> go i' bs (i'i01) c' t' is'
[] -> bs
| otherwise =
let b' = b { bcktBacktracks = backtrackAll c b }
in b' : case is of
((i',c',t'):is') -> go i' bs (i'i01) c' t' is'
[] -> bs
go i0 (b:bs) i c tid is = b : go i0 bs (i1) c tid is
go _ [] _ _ _ _ = err "backtrackAt" "ran out of schedule whilst backtracking!"
backtrackTo tid c = M.insert tid c . bcktBacktracks
backtrackAll c = M.map (const c) . bcktRunnable
dporSched
:: MemType
-> IncrementalBoundFunc k
-> Scheduler (DPORSchedState k)
dporSched memtype boundf = Scheduler $ \prior threads s ->
let
nextState rest = s
{ schedBPoints = schedBPoints s |> (restrictToBound fst threads', rest)
, schedDepState = nextDepState
}
nextDepState = let ds = schedDepState s in maybe ds (uncurry $ updateDepState ds) prior
initialise = tryDaemons . yieldsToEnd $ case prior of
Just (tid, act)
| not (didYield act) && tid `elem` tids -> [tid]
_ -> tids
tryDaemons ts
| any doesKill ts = case partition doesKill tids of
(kills, nokills) -> nokills ++ kills
| otherwise = ts
doesKill t = killsDaemons t (action t)
restrictToBound f =
filter (\x -> let t = f x in isJust $ boundf (schedBState s) prior (decision t, action t))
yieldsToEnd ts = case partition (willYield . action) ts of
(yields, noyields) -> noyields ++ yields
decision = decisionOf (fst <$> prior) (S.fromList tids)
action t = fromJust $ lookup t threads'
tids = map fst threads'
threads' = toList threads
in case schedPrefix s of
(t:ts) ->
let bstate' = boundf (schedBState s) prior (decision t, action t)
in (Just t, (nextState []) { schedPrefix = ts, schedBState = bstate' })
[] ->
let choices = restrictToBound id initialise
checkDep t a = case prior of
Just (tid, act) -> dependent memtype (schedDepState s) tid act t a
Nothing -> False
ssleep' = M.filterWithKey (\t a -> not $ checkDep t a) $ schedSleep s
choices' = filter (`notElem` M.keys ssleep') choices
signore' = not (null choices) && all (`elem` M.keys ssleep') choices
sbkill' = not (null initialise) && null choices
in case choices' of
(nextTid:rest) ->
let bstate' = boundf (schedBState s) prior (decision nextTid, action nextTid)
in (Just nextTid, (nextState rest) { schedSleep = ssleep', schedBState = bstate' })
[] ->
(Nothing, (nextState []) { schedIgnore = signore', schedBoundKill = sbkill', schedBState = Nothing })
data RandSchedState g = RandSchedState
{ schedWeights :: Map ThreadId Int
, schedGen :: g
} deriving (Eq, Show)
instance NFData g => NFData (RandSchedState g) where
rnf s = rnf ( schedWeights s
, schedGen s
)
initialRandSchedState :: Maybe (Map ThreadId Int) -> g -> RandSchedState g
initialRandSchedState = RandSchedState . fromMaybe M.empty
randSched :: RandomGen g => (g -> (Int, g)) -> Scheduler (RandSchedState g)
randSched weightf = Scheduler $ \_ threads s ->
let
pick idx ((x, f):xs)
| idx < f = Just x
| otherwise = pick (idx f) xs
pick _ [] = Nothing
(choice, g'') = randomR (0, sum (map snd enabled) 1) g'
enabled = M.toList $ M.filterWithKey (\tid _ -> tid `elem` tids) weights'
(weights', g') = foldr assignWeight (M.empty, schedGen s) tids
assignWeight tid ~(ws, g0) =
let (w, g) = maybe (weightf g0) (,g0) (M.lookup tid (schedWeights s))
in (M.insert tid w ws, g)
tids = map fst (toList threads)
in (pick choice enabled, RandSchedState weights' g'')
dependent :: MemType -> DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool
dependent _ _ _ (SetNumCapabilities a) _ (GetNumCapabilities b) = a /= b
dependent _ ds _ (ThrowTo t) t2 a = t == t2 && canInterrupt ds t2 a
dependent memtype ds t1 a1 t2 a2 = case rewind a2 of
Just l2
| isSTM a1 && isSTM a2
-> not . S.null $ tvarsOf a1 `S.intersection` tvarsOf a2
| not (isBlock a1 && isBarrier (simplifyLookahead l2)) ->
dependent' memtype ds t1 a1 t2 l2
_ -> dependentActions memtype ds (simplifyAction a1) (simplifyAction a2)
where
isSTM (STM _ _) = True
isSTM (BlockedSTM _) = True
isSTM _ = False
dependent' :: MemType -> DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool
dependent' memtype ds t1 a1 t2 l2 = case (a1, l2) of
(LiftIO, WillLiftIO) -> True
(ThrowTo t, WillStop) | t == t2 -> False
(Stop, WillThrowTo t) | t == t1 -> False
(ThrowTo t, _) -> t == t2 && canInterruptL ds t2 l2
(_, WillThrowTo t) -> t == t1 && canInterrupt ds t1 a1
(STM _ _, WillSTM) -> True
(GetNumCapabilities a, WillSetNumCapabilities b) -> a /= b
(SetNumCapabilities _, WillGetNumCapabilities) -> True
(SetNumCapabilities a, WillSetNumCapabilities b) -> a /= b
_ | isBlock a1 && isBarrier (simplifyLookahead l2) -> False
| otherwise -> dependentActions memtype ds (simplifyAction a1) (simplifyLookahead l2)
dependentActions :: MemType -> DepState -> ActionType -> ActionType -> Bool
dependentActions memtype ds a1 a2 = case (a1, a2) of
(UnsynchronisedRead r1, _) | same crefOf && a2 /= PartiallySynchronisedCommit r1 -> a2 /= UnsynchronisedRead r1
(UnsynchronisedWrite r1, _) | same crefOf && a2 /= PartiallySynchronisedCommit r1 -> True
(PartiallySynchronisedWrite r1, _) | same crefOf && a2 /= PartiallySynchronisedCommit r1 -> True
(PartiallySynchronisedModify r1, _) | same crefOf && a2 /= PartiallySynchronisedCommit r1 -> True
(SynchronisedModify r1, _) | same crefOf && a2 /= PartiallySynchronisedCommit r1 -> True
(UnsynchronisedWrite r1, _) | same crefOf && isCommit a2 r1 && isBuffered ds r1 -> False
(_, UnsynchronisedWrite r2) | same crefOf && isCommit a1 r2 && isBuffered ds r2 -> False
(UnsynchronisedRead r1, _) | isBarrier a2 -> isBuffered ds r1 && memtype /= SequentialConsistency
(_, UnsynchronisedRead r2) | isBarrier a1 -> isBuffered ds r2 && memtype /= SequentialConsistency
(_, _) -> case getSame crefOf of
Just r -> synchronises a1 r || synchronises a2 r
_ -> same mvarOf
where
same :: Eq a => (ActionType -> Maybe a) -> Bool
same = isJust . getSame
getSame :: Eq a => (ActionType -> Maybe a) -> Maybe a
getSame f =
let f1 = f a1
f2 = f a2
in if f1 == f2 then f1 else Nothing
data DepState = DepState
{ depCRState :: Map CRefId Bool
, depMaskState :: Map ThreadId MaskingState
} deriving (Eq, Show)
instance NFData DepState where
rnf depstate = rnf ( depCRState depstate
, [(t, m `seq` ()) | (t, m) <- M.toList (depMaskState depstate)]
)
initialDepState :: DepState
initialDepState = DepState M.empty M.empty
updateDepState :: DepState -> ThreadId -> ThreadAction -> DepState
updateDepState depstate tid act = DepState
{ depCRState = updateCRState act $ depCRState depstate
, depMaskState = updateMaskState tid act $ depMaskState depstate
}
updateCRState :: ThreadAction -> Map CRefId Bool -> Map CRefId Bool
updateCRState (CommitCRef _ r) = M.delete r
updateCRState (WriteCRef r) = M.insert r True
updateCRState ta
| isBarrier $ simplifyAction ta = const M.empty
| otherwise = id
updateMaskState :: ThreadId -> ThreadAction -> Map ThreadId MaskingState -> Map ThreadId MaskingState
updateMaskState tid (Fork tid2) = \masks -> case M.lookup tid masks of
Just ms -> M.insert tid2 ms masks
Nothing -> masks
updateMaskState tid (SetMasking _ ms) = M.insert tid ms
updateMaskState tid (ResetMasking _ ms) = M.insert tid ms
updateMaskState _ _ = id
isBuffered :: DepState -> CRefId -> Bool
isBuffered depstate r = M.findWithDefault False r (depCRState depstate)
canInterrupt :: DepState -> ThreadId -> ThreadAction -> Bool
canInterrupt depstate tid act
| isMaskedInterruptible depstate tid = case act of
BlockedPutMVar _ -> True
BlockedReadMVar _ -> True
BlockedTakeMVar _ -> True
BlockedSTM _ -> True
BlockedThrowTo _ -> True
_ -> False
| isMaskedUninterruptible depstate tid = False
| otherwise = True
canInterruptL :: DepState -> ThreadId -> Lookahead -> Bool
canInterruptL depstate tid lh
| isMaskedInterruptible depstate tid = case lh of
WillPutMVar _ -> True
WillReadMVar _ -> True
WillTakeMVar _ -> True
WillSTM -> True
WillThrowTo _ -> True
_ -> False
| isMaskedUninterruptible depstate tid = False
| otherwise = True
isMaskedInterruptible :: DepState -> ThreadId -> Bool
isMaskedInterruptible depstate tid =
M.lookup tid (depMaskState depstate) == Just MaskedInterruptible
isMaskedUninterruptible :: DepState -> ThreadId -> Bool
isMaskedUninterruptible depstate tid =
M.lookup tid (depMaskState depstate) == Just MaskedUninterruptible
initialDPORThread :: DPOR -> ThreadId
initialDPORThread = S.elemAt 0 . dporRunnable
didYield :: ThreadAction -> Bool
didYield Yield = True
didYield (ThreadDelay _) = True
didYield _ = False
willYield :: Lookahead -> Bool
willYield WillYield = True
willYield (WillThreadDelay _) = True
willYield _ = False
killsDaemons :: ThreadId -> Lookahead -> Bool
killsDaemons t WillStop = t == initialThread
killsDaemons _ _ = False
err :: String -> String -> a
err func msg = error (func ++ ": (internal error) " ++ msg)