{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
module Test.DejaFu.Conc.Internal where
import Control.Exception (Exception,
MaskingState(..),
toException)
import qualified Control.Monad.Conc.Class as C
import Data.Foldable (foldrM, toList)
import Data.Functor (void)
import Data.List (sortOn)
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe, isJust,
isNothing)
import Data.Monoid ((<>))
import Data.Sequence (Seq, (<|))
import qualified Data.Sequence as Seq
import GHC.Stack (HasCallStack)
import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Conc.Internal.Memory
import Test.DejaFu.Conc.Internal.STM
import Test.DejaFu.Conc.Internal.Threading
import Test.DejaFu.Internal
import Test.DejaFu.Schedule
import Test.DejaFu.Types
type SeqTrace
= Seq (Decision, [(ThreadId, Lookahead)], ThreadAction)
data CResult n g a = CResult
{ finalContext :: Context n g
, finalRef :: C.CRef n (Maybe (Either Failure a))
, finalRestore :: Maybe (Threads n -> n ())
, finalTrace :: SeqTrace
, finalDecision :: Maybe (ThreadId, ThreadAction)
}
data DCSnapshot n a = DCSnapshot
{ dcsContext :: Context n ()
, dcsRestore :: Threads n -> n ()
, dcsRef :: C.CRef n (Maybe (Either Failure a))
}
runConcurrency :: (C.MonadConc n, HasCallStack)
=> Bool
-> Scheduler g
-> MemType
-> g
-> IdSource
-> Int
-> ModelConc n a
-> n (CResult n g a)
runConcurrency forSnapshot sched memtype g idsrc caps ma = do
let ctx = Context { cSchedState = g
, cIdSource = idsrc
, cThreads = M.empty
, cWriteBuf = emptyBuffer
, cCaps = caps
}
res <- runConcurrency' forSnapshot sched memtype ctx ma
killAllThreads (finalContext res)
pure res
runConcurrency' :: (C.MonadConc n, HasCallStack)
=> Bool
-> Scheduler g
-> MemType
-> Context n g
-> ModelConc n a
-> n (CResult n g a)
runConcurrency' forSnapshot sched memtype ctx ma = do
(c, ref) <- runRefCont AStop (Just . Right) (runModelConc ma)
let threads0 = launch' Unmasked initialThread (const c) (cThreads ctx)
threads <- (if C.rtsSupportsBoundThreads then makeBound initialThread else pure) threads0
runThreads forSnapshot sched memtype ref ctx { cThreads = threads }
runConcurrencyWithSnapshot :: (C.MonadConc n, HasCallStack)
=> Scheduler g
-> MemType
-> Context n g
-> (Threads n -> n ())
-> C.CRef n (Maybe (Either Failure a))
-> n (CResult n g a)
runConcurrencyWithSnapshot sched memtype ctx restore ref = do
let boundThreads = M.filter (isJust . _bound) (cThreads ctx)
threads <- foldrM makeBound (cThreads ctx) (M.keys boundThreads)
restore threads
res <- runThreads False sched memtype ref ctx { cThreads = threads }
killAllThreads (finalContext res)
pure res
killAllThreads :: (C.MonadConc n, HasCallStack) => Context n g -> n ()
killAllThreads ctx =
let finalThreads = cThreads ctx
in mapM_ (`kill` finalThreads) (M.keys finalThreads)
data Context n g = Context
{ cSchedState :: g
, cIdSource :: IdSource
, cThreads :: Threads n
, cWriteBuf :: WriteBuffer n
, cCaps :: Int
}
runThreads :: (C.MonadConc n, HasCallStack)
=> Bool
-> Scheduler g
-> MemType
-> C.CRef n (Maybe (Either Failure a))
-> Context n g
-> n (CResult n g a)
runThreads forSnapshot sched memtype ref = schedule (const $ pure ()) Seq.empty Nothing where
die reason finalR finalT finalD finalC = do
C.writeCRef ref (Just $ Left reason)
stop finalR finalT finalD finalC
stop finalR finalT finalD finalC = pure CResult
{ finalContext = finalC
, finalRef = ref
, finalRestore = if forSnapshot then Just finalR else Nothing
, finalTrace = finalT
, finalDecision = finalD
}
schedule restore sofar prior ctx
| isTerminated = stop restore sofar prior ctx
| isDeadlocked = die Deadlock restore sofar prior ctx
| isSTMLocked = die STMDeadlock restore sofar prior ctx
| otherwise =
let ctx' = ctx { cSchedState = g' }
in case choice of
Just chosen -> case M.lookup chosen threadsc of
Just thread
| isBlocked thread -> die InternalError restore sofar prior ctx'
| otherwise ->
let decision
| Just chosen == (fst <$> prior) = Continue
| (fst <$> prior) `notElem` map (Just . fst) runnable' = Start chosen
| otherwise = SwitchTo chosen
alternatives = filter (\(t, _) -> t /= chosen) runnable'
in step decision alternatives chosen thread restore sofar prior ctx'
Nothing -> die InternalError restore sofar prior ctx'
Nothing -> die Abort restore sofar prior ctx'
where
(choice, g') = scheduleThread sched prior (efromList runnable') (cSchedState ctx)
runnable' = [(t, lookahead (_continuation a)) | (t, a) <- sortOn fst $ M.assocs runnable]
runnable = M.filter (not . isBlocked) threadsc
threadsc = addCommitThreads (cWriteBuf ctx) threads
threads = cThreads ctx
isBlocked = isJust . _blocking
isTerminated = initialThread `notElem` M.keys threads
isDeadlocked = M.null (M.filter (not . isBlocked) threads) &&
(((~= OnMVarFull undefined) <$> M.lookup initialThread threads) == Just True ||
((~= OnMVarEmpty undefined) <$> M.lookup initialThread threads) == Just True ||
((~= OnMask undefined) <$> M.lookup initialThread threads) == Just True)
isSTMLocked = M.null (M.filter (not . isBlocked) threads) &&
((~= OnTVar []) <$> M.lookup initialThread threads) == Just True
step decision alternatives chosen thread restore sofar prior ctx = do
(res, actOrTrc, actionSnap) <- stepThread
forSnapshot
(isNothing prior)
sched
memtype
chosen
(_continuation thread)
ctx
let sofar' = sofar <> getTrc actOrTrc
let prior' = getPrior actOrTrc
let restore' threads' =
if forSnapshot
then restore threads' >> actionSnap threads'
else restore threads'
let ctx' = fixContext chosen res ctx
case res of
Succeeded _ ->
schedule restore' sofar' prior' ctx'
Failed failure ->
die failure restore' sofar' prior' ctx'
Snap _ ->
stop actionSnap sofar' prior' ctx'
where
getTrc (Single a) = Seq.singleton (decision, alternatives, a)
getTrc (SubC as _) = (decision, alternatives, Subconcurrency) <| as
getPrior (Single a) = Just (chosen, a)
getPrior (SubC _ finalD) = finalD
fixContext :: ThreadId -> What n g -> Context n g -> Context n g
fixContext chosen (Succeeded ctx@Context{..}) _ =
ctx { cThreads = delCommitThreads $
if (interruptible <$> M.lookup chosen cThreads) /= Just False
then unblockWaitingOn chosen cThreads
else cThreads
}
fixContext _ (Failed _) ctx@Context{..} =
ctx { cThreads = delCommitThreads cThreads }
fixContext _ (Snap ctx@Context{..}) _ =
ctx { cThreads = delCommitThreads cThreads }
unblockWaitingOn :: ThreadId -> Threads n -> Threads n
unblockWaitingOn tid = fmap $ \thread -> case _blocking thread of
Just (OnMask t) | t == tid -> thread { _blocking = Nothing }
_ -> thread
data Act
= Single ThreadAction
| SubC SeqTrace (Maybe (ThreadId, ThreadAction))
deriving (Eq, Show)
data What n g
= Succeeded (Context n g)
| Failed Failure
| Snap (Context n g)
stepThread :: (C.MonadConc n, HasCallStack)
=> Bool
-> Bool
-> Scheduler g
-> MemType
-> ThreadId
-> Action n
-> Context n g
-> n (What n g, Act, Threads n -> n ())
stepThread _ _ _ _ tid (AFork n a b) = \ctx@Context{..} -> pure $
let (idSource', newtid) = nextTId n cIdSource
threads' = launch tid newtid a cThreads
in ( Succeeded ctx { cThreads = goto (b newtid) tid threads', cIdSource = idSource' }
, Single (Fork newtid)
, const (pure ())
)
stepThread _ _ _ _ tid (AForkOS n a b) = \ctx@Context{..} -> do
let (idSource', newtid) = nextTId n cIdSource
let threads' = launch tid newtid a cThreads
threads'' <- makeBound newtid threads'
pure ( Succeeded ctx { cThreads = goto (b newtid) tid threads'', cIdSource = idSource' }
, Single (ForkOS newtid)
, const (pure ())
)
stepThread _ _ _ _ tid (AIsBound c) = \ctx@Context{..} -> do
let isBound = isJust . _bound $ elookup tid cThreads
pure ( Succeeded ctx { cThreads = goto (c isBound) tid cThreads }
, Single (IsCurrentThreadBound isBound)
, const (pure ())
)
stepThread _ _ _ _ tid (AMyTId c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto (c tid) tid cThreads }
, Single MyThreadId
, const (pure ())
)
stepThread _ _ _ _ tid (AGetNumCapabilities c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto (c cCaps) tid cThreads }
, Single (GetNumCapabilities cCaps)
, const (pure ())
)
stepThread _ _ _ _ tid (ASetNumCapabilities i c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto c tid cThreads, cCaps = i }
, Single (SetNumCapabilities i)
, const (pure ())
)
stepThread _ _ _ _ tid (AYield c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto c tid cThreads }
, Single Yield
, const (pure ())
)
stepThread _ _ _ _ tid (ADelay n c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto c tid cThreads }
, Single (ThreadDelay n)
, const (pure ())
)
stepThread _ _ _ _ tid (ANewMVar n c) = \ctx@Context{..} -> do
let (idSource', newmvid) = nextMVId n cIdSource
ref <- C.newCRef Nothing
let mvar = ModelMVar newmvid ref
pure ( Succeeded ctx { cThreads = goto (c mvar) tid cThreads, cIdSource = idSource' }
, Single (NewMVar newmvid)
, const (C.writeCRef ref Nothing)
)
stepThread _ _ _ _ tid (APutMVar mvar@ModelMVar{..} a c) = synchronised $ \ctx@Context{..} -> do
(success, threads', woken, effect) <- putIntoMVar mvar a c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, Single (if success then PutMVar mvarId woken else BlockedPutMVar mvarId)
, const effect
)
stepThread _ _ _ _ tid (ATryPutMVar mvar@ModelMVar{..} a c) = synchronised $ \ctx@Context{..} -> do
(success, threads', woken, effect) <- tryPutIntoMVar mvar a c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, Single (TryPutMVar mvarId success woken)
, const effect
)
stepThread _ _ _ _ tid (AReadMVar mvar@ModelMVar{..} c) = synchronised $ \ctx@Context{..} -> do
(success, threads', _, _) <- readFromMVar mvar c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, Single (if success then ReadMVar mvarId else BlockedReadMVar mvarId)
, const (pure ())
)
stepThread _ _ _ _ tid (ATryReadMVar mvar@ModelMVar{..} c) = synchronised $ \ctx@Context{..} -> do
(success, threads', _, _) <- tryReadFromMVar mvar c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, Single (TryReadMVar mvarId success)
, const (pure ())
)
stepThread _ _ _ _ tid (ATakeMVar mvar@ModelMVar{..} c) = synchronised $ \ctx@Context{..} -> do
(success, threads', woken, effect) <- takeFromMVar mvar c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, Single (if success then TakeMVar mvarId woken else BlockedTakeMVar mvarId)
, const effect
)
stepThread _ _ _ _ tid (ATryTakeMVar mvar@ModelMVar{..} c) = synchronised $ \ctx@Context{..} -> do
(success, threads', woken, effect) <- tryTakeFromMVar mvar c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, Single (TryTakeMVar mvarId success woken)
, const effect
)
stepThread _ _ _ _ tid (ANewCRef n a c) = \ctx@Context{..} -> do
let (idSource', newcrid) = nextCRId n cIdSource
let val = (M.empty, 0, a)
ref <- C.newCRef val
let cref = ModelCRef newcrid ref
pure ( Succeeded ctx { cThreads = goto (c cref) tid cThreads, cIdSource = idSource' }
, Single (NewCRef newcrid)
, const (C.writeCRef ref val)
)
stepThread _ _ _ _ tid (AReadCRef cref@ModelCRef{..} c) = \ctx@Context{..} -> do
val <- readCRef cref tid
pure ( Succeeded ctx { cThreads = goto (c val) tid cThreads }
, Single (ReadCRef crefId)
, const (pure ())
)
stepThread _ _ _ _ tid (AReadCRefCas cref@ModelCRef{..} c) = \ctx@Context{..} -> do
tick <- readForTicket cref tid
pure ( Succeeded ctx { cThreads = goto (c tick) tid cThreads }
, Single (ReadCRefCas crefId)
, const (pure ())
)
stepThread _ _ _ _ tid (AModCRef cref@ModelCRef{..} f c) = synchronised $ \ctx@Context{..} -> do
(new, val) <- f <$> readCRef cref tid
effect <- writeImmediate cref new
pure ( Succeeded ctx { cThreads = goto (c val) tid cThreads }
, Single (ModCRef crefId)
, const effect
)
stepThread _ _ _ _ tid (AModCRefCas cref@ModelCRef{..} f c) = synchronised $ \ctx@Context{..} -> do
tick@(ModelTicket _ _ old) <- readForTicket cref tid
let (new, val) = f old
(_, _, effect) <- casCRef cref tid tick new
pure ( Succeeded ctx { cThreads = goto (c val) tid cThreads }
, Single (ModCRefCas crefId)
, const effect
)
stepThread _ _ _ memtype tid (AWriteCRef cref@ModelCRef{..} a c) = \ctx@Context{..} -> case memtype of
SequentialConsistency -> do
effect <- writeImmediate cref a
pure ( Succeeded ctx { cThreads = goto c tid cThreads }
, Single (WriteCRef crefId)
, const effect
)
TotalStoreOrder -> do
wb' <- bufferWrite cWriteBuf (tid, Nothing) cref a
pure ( Succeeded ctx { cThreads = goto c tid cThreads, cWriteBuf = wb' }
, Single (WriteCRef crefId)
, const (pure ())
)
PartialStoreOrder -> do
wb' <- bufferWrite cWriteBuf (tid, Just crefId) cref a
pure ( Succeeded ctx { cThreads = goto c tid cThreads, cWriteBuf = wb' }
, Single (WriteCRef crefId)
, const (pure ())
)
stepThread _ _ _ _ tid (ACasCRef cref@ModelCRef{..} tick a c) = synchronised $ \ctx@Context{..} -> do
(suc, tick', effect) <- casCRef cref tid tick a
pure ( Succeeded ctx { cThreads = goto (c (suc, tick')) tid cThreads }
, Single (CasCRef crefId suc)
, const effect
)
stepThread _ _ _ memtype _ (ACommit t c) = \ctx@Context{..} -> do
wb' <- case memtype of
SequentialConsistency ->
fatal "stepThread.ACommit" "Attempting to commit under SequentialConsistency"
TotalStoreOrder ->
commitWrite cWriteBuf (t, Nothing)
PartialStoreOrder ->
commitWrite cWriteBuf (t, Just c)
pure ( Succeeded ctx { cWriteBuf = wb' }
, Single (CommitCRef t c)
, const (pure ())
)
stepThread _ _ _ _ tid (AAtom stm c) = synchronised $ \ctx@Context{..} -> do
let transaction = runTransaction stm cIdSource
let effect = const (void transaction)
(res, idSource', trace) <- transaction
case res of
Success _ written val -> do
let (threads', woken) = wake (OnTVar written) cThreads
pure ( Succeeded ctx { cThreads = goto (c val) tid threads', cIdSource = idSource' }
, Single (STM trace woken)
, effect
)
Retry touched -> do
let threads' = block (OnTVar touched) tid cThreads
pure ( Succeeded ctx { cThreads = threads', cIdSource = idSource'}
, Single (BlockedSTM trace)
, effect
)
Exception e -> do
let act = STM trace []
res' <- stepThrow act tid e ctx
pure $ case res' of
(Succeeded ctx', _, effect') -> (Succeeded ctx' { cIdSource = idSource' }, Single act, effect')
(Failed err, _, effect') -> (Failed err, Single act, effect')
(Snap _, _, _) -> fatal "stepThread.AAtom" "Unexpected snapshot while propagating STM exception"
stepThread _ _ _ _ tid (ALift na) = \ctx@Context{..} -> do
let effect threads = runLiftedAct tid threads na
a <- effect cThreads
pure (Succeeded ctx { cThreads = goto a tid cThreads }
, Single LiftIO
, void <$> effect
)
stepThread _ _ _ _ tid (AThrow e) = stepThrow Throw tid e
stepThread _ _ _ _ tid (AThrowTo t e c) = synchronised $ \ctx@Context{..} ->
let threads' = goto c tid cThreads
blocked = block (OnMask t) tid cThreads
in case M.lookup t cThreads of
Just thread
| interruptible thread -> stepThrow (ThrowTo t) t e ctx { cThreads = threads' }
| otherwise -> pure
( Succeeded ctx { cThreads = blocked }
, Single (BlockedThrowTo t)
, const (pure ())
)
Nothing -> pure
(Succeeded ctx { cThreads = threads' }
, Single (ThrowTo t)
, const (pure ())
)
stepThread _ _ _ _ tid (ACatching h ma c) = \ctx@Context{..} -> pure $
let a = runModelConc ma (APopCatching . c)
e exc = runModelConc (h exc) c
in ( Succeeded ctx { cThreads = goto a tid (catching e tid cThreads) }
, Single Catching
, const (pure ())
)
stepThread _ _ _ _ tid (APopCatching a) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto a tid (uncatching tid cThreads) }
, Single PopCatching
, const (pure ())
)
stepThread _ _ _ _ tid (AMasking m ma c) = \ctx@Context{..} -> pure $
let resetMask typ ms = ModelConc $ \k -> AResetMask typ True ms $ k ()
umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> pure b
m' = _masking $ elookup tid cThreads
a = runModelConc (ma umask) (AResetMask False False m' . c)
in ( Succeeded ctx { cThreads = goto a tid (mask m tid cThreads) }
, Single (SetMasking False m)
, const (pure ())
)
stepThread _ _ _ _ tid (AResetMask b1 b2 m c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto c tid (mask m tid cThreads) }
, Single ((if b1 then SetMasking else ResetMasking) b2 m)
, const (pure ())
)
stepThread _ _ _ _ tid (AReturn c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto c tid cThreads }
, Single Return
, const (pure ())
)
stepThread _ _ _ _ tid (AStop na) = \ctx@Context{..} -> do
na
threads' <- kill tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, Single Stop
, const (pure ())
)
stepThread forSnapshot _ sched memtype tid (ASub ma c) = \ctx ->
if | forSnapshot -> pure (Failed IllegalSubconcurrency, Single Subconcurrency, const (pure ()))
| M.size (cThreads ctx) > 1 -> pure (Failed IllegalSubconcurrency, Single Subconcurrency, const (pure ()))
| otherwise -> do
res <- runConcurrency False sched memtype (cSchedState ctx) (cIdSource ctx) (cCaps ctx) ma
out <- efromJust <$> C.readCRef (finalRef res)
pure ( Succeeded ctx
{ cThreads = goto (AStopSub (c out)) tid (cThreads ctx)
, cIdSource = cIdSource (finalContext res)
, cSchedState = cSchedState (finalContext res)
}
, SubC (finalTrace res) (finalDecision res)
, const (pure ())
)
stepThread _ _ _ _ tid (AStopSub c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto c tid cThreads }
, Single StopSubconcurrency
, const (pure ())
)
stepThread forSnapshot isFirst _ _ tid (ADontCheck lb ma c) = \ctx ->
if | isFirst -> do
threads' <- kill tid (cThreads ctx)
let dcCtx = ctx { cThreads = threads', cSchedState = lb }
res <- runConcurrency' forSnapshot dcSched SequentialConsistency dcCtx ma
out <- efromJust <$> C.readCRef (finalRef res)
case out of
Right a -> do
let threads'' = launch' Unmasked tid (const (c a)) (cThreads (finalContext res))
threads''' <- (if C.rtsSupportsBoundThreads then makeBound tid else pure) threads''
pure ( (if forSnapshot then Snap else Succeeded) (finalContext res)
{ cThreads = threads''', cSchedState = cSchedState ctx }
, Single (DontCheck (toList (finalTrace res)))
, fromMaybe (const (pure ())) (finalRestore res)
)
Left f -> pure
( Failed f
, Single (DontCheck (toList (finalTrace res)))
, const (pure ())
)
| otherwise -> pure
( Failed IllegalDontCheck
, Single (DontCheck [])
, const (pure ())
)
stepThrow :: (C.MonadConc n, Exception e)
=> ThreadAction
-> ThreadId
-> e
-> Context n g
-> n (What n g, Act, Threads n -> n ())
stepThrow act tid e ctx@Context{..} = case propagate some tid cThreads of
Just ts' -> pure
( Succeeded ctx { cThreads = ts' }
, Single act
, const (pure ())
)
Nothing
| tid == initialThread -> pure
( Failed (UncaughtException some)
, Single act
, const (pure ())
)
| otherwise -> do
ts' <- kill tid cThreads
pure ( Succeeded ctx { cThreads = ts' }
, Single act
, const (pure ())
)
where
some = toException e
synchronised :: C.MonadConc n
=> (Context n g -> n (What n g, Act, Threads n -> n ()))
-> Context n g
-> n (What n g, Act, Threads n -> n ())
synchronised ma ctx@Context{..} = do
writeBarrier cWriteBuf
ma ctx { cWriteBuf = emptyBuffer }
dcSched :: Scheduler (Maybe Int)
dcSched = Scheduler go where
go _ _ (Just 0) = (Nothing, Just 0)
go prior threads s =
let (t, _) = scheduleThread roundRobinSchedNP prior threads ()
in (t, fmap (\lb -> lb - 1) s)